#!/usr/bin/env python
import ctypes
import ctypes.util
import errno
import sys
import os

# <sched.h> constants for unshare
CLONE_NEWNS = 0x00020000
CLONE_NEWPID = 0x20000000

# <sys/mount.h> - constants for mount
MS_REC = 16384
MS_PRIVATE = 1 << 18
MS_SLAVE = 1 << 19

# Load libc bindings
_libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)

try:
    _unshare = _libc.unshare
except AttributeError:
    raise OSError(errno.EINVAL, "unshare is not supported on this platform")
else:
    _unshare.argtypes = [ctypes.c_int]
    _unshare.restype = ctypes.c_int

try:
    _setns = _libc.setns
except AttributeError:
    raise OSError(errno.EINVAL, "setns is not supported on this platform")
else:
    _setns.argtypes = [ctypes.c_int, ctypes.c_int]
    _setns.restype = ctypes.c_int

try:
    _mount = _libc.mount
except AttributeError:
    raise OSError(errno.EINVAL, "mount is not supported on this platform")
else:
    _mount.argtypes = [
        ctypes.c_char_p,
        ctypes.c_char_p,
        ctypes.c_char_p,
        ctypes.c_ulong,
        ctypes.c_void_p
    ]
    _mount.restype = ctypes.c_int


def _mount_new_proc():
    """
    Mount new /proc filesystem.
    """
    if _mount(None, b"/", None, MS_SLAVE | MS_REC, None):
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])
    if _mount(b'proc', b'/proc', b'proc', 0, None):
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])


def _wait_for_process_status(criu_pid):
    """
    Wait for CRIU to exit and report the status back.
    """
    while True:
        try:
            (pid, status) = os.wait()
            if pid == criu_pid:
                return os.WEXITSTATUS(status)
        except OSError:
            return -251


def run_criu(args):
    """
    Spawn CRIU binary
    """
    print(sys.argv)
    os.execlp('criu', *['criu'] + args)
    raise OSError(errno.ENOENT, "No such command")


def wrap_restore():
    restore_args = sys.argv[1:]
    if '--restore-sibling' in restore_args:
        raise OSError(errno.EINVAL, "--restore-sibling is not supported")

    # Unshare pid and mount namespaces
    if _unshare(CLONE_NEWNS | CLONE_NEWPID) != 0:
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])

    restore_detached = False
    if '-d' in restore_args:
        restore_detached = True
        restore_args.remove('-d')
    if '--restore-detached' in restore_args:
        restore_detached = True
        restore_args.remove('--restore-detached')

    criu_pid = os.fork()
    if criu_pid == 0:
        _mount_new_proc()
        run_criu(restore_args)

    if restore_detached:
        return 0

    return _wait_for_process_status(criu_pid)


def get_varg(args):
    for i in range(1, len(sys.argv)):
        if not sys.argv[i] in args:
            continue

        if i + 1 >= len(sys.argv):
            break

        return (sys.argv[i + 1], i + 1)

    return (None, None)


def _set_namespace(fd):
    """Join namespace referred to by fd"""
    if _setns(fd, 0) != 0:
        _errno = ctypes.get_errno()
        raise OSError(_errno, errno.errorcode[_errno])


def is_my_namespace(fd):
    """Returns True if fd refers to current namespace"""
    return os.stat('/proc/self/ns/pid').st_ino != os.fstat(fd).st_ino


def set_pidns(tpid, pid_idx):
    """
    Join pid namespace. Note, that the given pid should
    be changed in -t option, as task lives in different
    pid namespace.
    """
    ns_fd = os.open('/proc/%s/ns/pid' % tpid, os.O_RDONLY)
    if is_my_namespace(ns_fd):
        for line in open('/proc/%s/status' % tpid):
            if not line.startswith('NSpid:'):
                continue
            ls = line.split()
            if ls[1] != tpid:
                os.close(ns_fd)
                raise OSError(errno.ESRCH, 'No such pid')

            print('Replace pid {} with {}'.format(tpid, ls[2]))
            sys.argv[pid_idx] = ls[2]
            break
        else:
            os.close(ns_fd)
            raise OSError(errno.ENOENT, 'Cannot find NSpid field in proc')
        _set_namespace(ns_fd)
    os.close(ns_fd)


def set_mntns(tpid):
    """
    Join mount namespace. Trick here too -- check / and .
    will be the same in target mntns.
    """
    ns_fd = os.open('/proc/%s/ns/mnt' % tpid, os.O_RDONLY)
    if is_my_namespace(ns_fd):
        root_st = os.stat('/')
        cwd_st = os.stat('.')
        cwd_path = os.path.realpath('.')

        _set_namespace(ns_fd)

        os.chdir(cwd_path)
        root_nst = os.stat('/')
        cwd_nst = os.stat('.')

        def steq(st, nst):
            return (st.st_dev, st.st_ino) == (nst.st_dev, nst.st_ino)

        if not steq(root_st, root_nst):
            os.close(ns_fd)
            raise OSError(errno.EXDEV, 'Target ns / is not as current')
        if not steq(cwd_st, cwd_nst):
            os.close(ns_fd)
            raise OSError(errno.EXDEV, 'Target ns . is not as current')
    os.close(ns_fd)


def wrap_dump():
    (pid, pid_idx) = get_varg(('-t', '--tree'))
    if pid is None:
        raise OSError(errno.EINVAL, 'No --tree option given')

    set_pidns(pid, pid_idx)
    set_mntns(pid)

    criu_pid = os.fork()
    if criu_pid == 0:
        run_criu(sys.argv[1:])
    return _wait_for_process_status(pid)


def show_usage():
    print("""
Usage:
  {0} dump|pre-dump -t PID [<options>]
  {0} restore [<options>]
\nCommands:
  dump           checkpoint a process/tree identified by pid
  pre-dump       pre-dump task(s) minimizing their frozen time
  restore        restore a process/tree
  check          checks whether the kernel support is up-to-date
""".format(sys.argv[0]))


if __name__ == "__main__":
    if len(sys.argv) == 1:
        show_usage()
        exit(1)

    action = sys.argv[1]
    if action == 'restore':
        res = wrap_restore()
    elif action in ['dump', 'pre-dump']:
        res = wrap_dump()
    elif action == 'check':
        run_criu(sys.argv[1:])
    else:
        print('Unsupported action {} for criu-ns'.format(action))
        res = -1

    sys.exit(res)
