#!/usr/bin/python3

# This file is part of chiark-utils, a collection of useful programs
# used on chiark.greenend.org.uk.
#
# This file is:
#  Copyright 2018 Citrix Systems Ltd
#
# This is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation; either version 3, or (at your option) any later version.
#
# This is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, consult the Free Software Foundation's
# website at www.fsf.org, or the GNU Project website at www.gnu.org.

import sys
import fishdescriptor.fish
import optparse
import re
import subprocess
import socket
import os

donor = None

usage = '''fishdescriptor [-p|--pid] <pid> <action>... [-p|--pid <pid> <action>...]

<action>s
  [<here-fd>=]<there-fd>
          fish the openfile referenced by descriptor <there-fd> in
          (the most recent) <pid> and keep a descriptor onto it;
          and, optionally, give it the number <here-fd> for exec
  exec <program> [<arg>...]
          execute a process with each specified <here>
          as an actual fd
  sockinfo
          calls getsockname/getpeername on the most recent
          <there-fd>

  -p|-pid <pid>
          now attach to <pid>, detaching from previous pid
'''

pending = []
# list of (nominal, there) where nominal might be None

fdmap = { }
# fdmap[nominal] = (actual, Donor, there)

def implement_pending():
    try: actuals = donor.fish([pend[1] for pend in pending])
    except fishdescriptor.fish.Error as e:
        print('fishdescriptor error: %s' % e, file=sys.stderr)
        sys.exit(127)
    assert(len(actuals) == len(pending))
    for (nominal, there), actual in zip(pending, actuals):
        overwriting_info = fdmap.get(nominal)
        if overwriting_info is not None: os.close(overwriting_info[0])
        fdmap[nominal] = [actual, donor, there]

def implement_sockinfo(nominal):
    (actual, tdonor, there) = fdmap[nominal]
    # socket.fromfd requires the AF.  But of course we don't know the AF.
    # There isn't a sane way to get it in Python:
    #  https://utcc.utoronto.ca/~cks/space/blog/python/SocketFromFdMistake
    # Rejected options:
    #  https://github.com/tiran/socketfromfd
    #   adds a dependency, not portable due to reliance on SO_DOMAIN
    #  call getsockname using ctypes
    #   no sane way to discover how to unpack sa_family_t
    perl_script = '''
        use strict;
        use Socket;
        use POSIX;
        my $sa = getsockname STDIN;
        exit 0 if !defined $sa and $!==ENOTSOCK;
        my $family = sockaddr_family $sa;
        print $family, "\n" or die $!;
    '''
    famp = subprocess.Popen(
        stdin = actual,
        stdout = subprocess.PIPE,
        args = ['perl','-we',perl_script]
    )
    (output, dummy) = famp.communicate()
    family = int(output)

    sock = socket.fromfd(actual, family, 0)

    print("[%s] %d sockinfo" % (tdonor.pid, there), end='')
    for f in (lambda: socket.AddressFamily(family).name,
              lambda: repr(sock.getsockname()),
              lambda: repr(sock.getpeername())):
        try: info = f()
        except Exception as e: info = repr(e)
        print("\t", info, sep='', end='')
    print("")

    sock.close()

def permute_fds_for_exec():
    actual2intended = { info[0]: nominal for nominal, info in fdmap.items() }
    # invariant at the start of each loop iteration:
    #     for each intended (aka `nominal') we have processed:
    #         relevant open-file is only held in fd intended
    #         (unless `nominal' is None in which case it is closed)
    #     for each intended (aka `nominal') we have NOT processed:
    #         relevant open-file is only held in actual
    #         where  actual = fdmap[nominal][0]
    #         and where  actual2intended[actual] = nominal
    # we can rely on processing each intended only once,
    #  since they're hash keys
    # the post-condition is not really a valid state (fdmap
    #  is nonsense) but we call this function just before exec
    for intended, (actual, tdonor, there) in fdmap.items():
        if intended == actual:
            continue
        if intended is not None:
            inway_intended = actual2intended.get(intended)
            if inway_intended is not None:
                inway_moved = os.dup(intended)
                actual2intended[inway_moved] = inway_intended
                fdmap[inway_intended][0] = inway_moved
            os.dup2(actual, intended)
        os.close(actual)
        del actual2intended[actual]

def implement_exec(argl):
    if donor is not None: donor.detach()
    sys.stdout.flush()
    permute_fds_for_exec()
    os.execvp(argl[0], argl)

def set_donor(pid):
    global donor
    if donor is not None: donor.detach()
    donor = fishdescriptor.fish.Donor(pid, debug=ov.debug)

def ocb_set_donor(option, opt, value, parser):
    set_donor(value)

ov = optparse.Values()

def process_args():
    global ov

    m = None
    
    def arg_matches(regexp):
        nonlocal m
        m = re.search(regexp, arg)
        return m

    op = optparse.OptionParser(usage=usage)

    op.disable_interspersed_args()
    op.add_option('-p','--pid', type='int', action='callback',
                  callback=ocb_set_donor)
    op.add_option('-D','--debug', action='store_const',
                  dest='debug', const=sys.stderr)
    ov.debug = None

    args = sys.argv[1:]
    last_nominal = None # None or (nominal,) ie None or (None,) or (int,)

    while True:
        (ov, args) = op.parse_args(args=args, values=ov)
        if not len(args): break

        arg = args.pop(0)

        if donor is None:
            set_donor(int(arg))
        elif arg_matches(r'^(?:(\d+)=)?(\d+)?$'):
            (nominal, there) = m.groups()
            nominal = None if nominal is None else int(nominal)
            there = int(there)
            pending.append((nominal,there))
            last_nominal = (nominal,)
        elif arg == 'exec':
            if not len(args):
                op.error("exec needs command to run")
            implement_pending()
            implement_exec(args)
        elif arg == 'sockinfo':
            if last_nominal is None:
                op.error('sockinfo needs a prior fd spec')
            implement_pending()
            implement_sockinfo(last_nominal[0])
        else:
            op.error("unknown argument/option `%s'" % arg)

process_args()
