]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
WIP namespaces.py
authorNicki Křížek <nicki@isc.org>
Mon, 24 Jun 2024 14:30:09 +0000 (16:30 +0200)
committerNicki Křížek <nicki@isc.org>
Mon, 24 Jun 2024 14:30:09 +0000 (16:30 +0200)
bin/tests/system/namespaces.py [new file with mode: 0644]

diff --git a/bin/tests/system/namespaces.py b/bin/tests/system/namespaces.py
new file mode 100644 (file)
index 0000000..f71be58
--- /dev/null
@@ -0,0 +1,248 @@
+import contextlib
+import ctypes
+import errno
+import os
+import pyroute2
+import pytest
+import signal
+import multiprocessing
+
+# TODO move to contrib
+
+# All allowed namespace types
+NAMESPACE_FLAGS = dict(
+    mnt=0x00020000,
+    uts=0x04000000,
+    ipc=0x08000000,
+    user=0x10000000,
+    pid=0x20000000,
+    net=0x40000000,
+)
+STACKSIZE = 1024 * 1024
+
+libc = ctypes.CDLL("libc.so.6", use_errno=True)
+
+
+@contextlib.contextmanager
+def keep_directory():
+    """Restore the current directory on exit."""
+    pwd = os.getcwd()
+    try:
+        yield
+    finally:
+        os.chdir(pwd)
+
+
+def mount_sys(target="/sys"):
+    flags = [2 | 4 | 8]  # MS_NOSUID | MS_NODEV | MS_NOEXEC
+    flags.append(1 << 18)  # MS_PRIVATE
+    flags.append(1 << 19)  # MS_SLAVE
+    for fl in flags:
+        ret = libc.mount(b"none", target.encode("ascii"), b"sysfs", fl, None)
+        if ret == -1:
+            e = ctypes.get_errno()
+            raise OSError(e, os.strerror(e))
+
+
+def mount_tmpfs(target, private=False):
+    flags = [0]
+    if private:
+        flags.append(1 << 18)  # MS_PRIVATE
+        flags.append(1 << 19)  # MS_SLAVE
+    for fl in flags:
+        ret = libc.mount(b"none", target.encode("ascii"), b"tmpfs", fl, None)
+        if ret == -1:
+            e = ctypes.get_errno()
+            raise OSError(e, os.strerror(e))
+
+
+def _mount_proc(target):
+    flags = [2 | 4 | 8]  # MS_NOSUID | MS_NODEV | MS_NOEXEC
+    flags.append(1 << 18)  # MS_PRIVATE
+    flags.append(1 << 19)  # MS_SLAVE
+    for fl in flags:
+        ret = libc.mount(b"proc", target.encode("ascii"), b"proc", fl, None)
+        if ret == -1:
+            e = ctypes.get_errno()
+            raise OSError(e, os.strerror(e))
+
+
+def mount_proc(target="/proc"):
+    # We need to be sure /proc is correct. We do that in another
+    # process as this doesn't play well with setns().
+    if not os.path.isdir(target):
+        os.mkdir(target)
+    p = multiprocessing.Process(target=_mount_proc, args=(target,))
+    p.start()
+    p.join()
+
+
+class Namespace(object):
+    """Combine several namespaces into one.
+
+    This gets a list of namespace types to create and combine into one. The
+    combined namespace can be used as a context manager to enter all the
+    created namespaces and exit them at the end.
+    """
+
+    def __init__(self, *namespaces):
+        self.next = []
+        self.namespaces = namespaces
+        for ns in namespaces:
+            assert ns in NAMESPACE_FLAGS
+
+        # Get a pipe to signal the future child to exit
+        self.pipe = os.pipe()
+
+        # First, create a child in the given namespaces
+        child = ctypes.CFUNCTYPE(ctypes.c_int)(self.child)
+        child_stack = ctypes.create_string_buffer(STACKSIZE)
+        child_stack_pointer = ctypes.c_void_p(
+            ctypes.cast(child_stack, ctypes.c_void_p).value + STACKSIZE
+        )
+        flags = signal.SIGCHLD
+        for ns in namespaces:
+            flags |= NAMESPACE_FLAGS[ns]
+        self.pid = libc.clone(child, child_stack_pointer, flags)
+        if self.pid == -1:
+            e = ctypes.get_errno()
+            raise OSError(e, os.strerror(e))
+
+        # If a user namespace, map UID 0 to the current one
+        if "user" in namespaces:
+            uid_map = "0 {} 1".format(os.getuid())
+            gid_map = "0 {} 1".format(os.getgid())
+            print(uid_map)
+            with open("/proc/{}/uid_map".format(self.pid), "w") as f:
+                f.write(uid_map)
+            with open("/proc/{}/setgroups".format(self.pid), "w") as f:
+                f.write("deny")
+            with open("/proc/{}/gid_map".format(self.pid), "w") as f:
+                f.write(gid_map)
+
+        # Retrieve a file descriptor to this new namespace
+        self.next = [
+            os.open("/proc/{}/ns/{}".format(self.pid, x), os.O_RDONLY)
+            for x in namespaces
+        ]
+
+        # Keep a file descriptor to our old namespaces
+        self.previous = [
+            os.open("/proc/self/ns/{}".format(x), os.O_RDONLY) for x in namespaces
+        ]
+
+        # Tell the child all is done and let it die
+        os.close(self.pipe[0])
+        if "pid" not in self.namespaces:
+            os.close(self.pipe[1])
+            self.pipe = None
+            os.waitpid(self.pid, 0)
+
+    def __del__(self):
+        for fd in self.next:
+            os.close(fd)
+        for fd in self.previous:
+            os.close(fd)
+        if self.pipe is not None:
+            os.close(self.pipe[1])
+
+    def child(self):
+        """Cloned child.
+
+        Just be here until our parent extract the file descriptor from
+        us.
+
+        """
+        os.close(self.pipe[1])
+
+        # For a network namespace, enable lo
+        if "net" in self.namespaces:
+            with pyroute2.IPRoute() as ipr:
+                lo = ipr.link_lookup(ifname="lo")[0]
+                ipr.link("set", index=lo, state="up")
+        # For a mount namespace, make it private
+        if "mnt" in self.namespaces:
+            libc.mount(
+                b"none",
+                b"/",
+                None,
+                # MS_REC | MS_PRIVATE
+                16384 | (1 << 18),
+                None,
+            )
+
+        while True:
+            try:
+                os.read(self.pipe[0], 1)
+            except OSError as e:
+                if e.errno in [errno.EAGAIN, errno.EINTR]:
+                    continue
+            break
+
+        os._exit(0)
+
+    def fd(self, namespace):
+        """Return the file descriptor associated to a namespace"""
+        assert namespace in self.namespaces
+        return self.next[self.namespaces.index(namespace)]
+
+    def __enter__(self):
+        with keep_directory():
+            for n in self.next:
+                if libc.setns(n, 0) == -1:
+                    ns = self.namespaces[self.next.index(n)]  # NOQA
+                    e = ctypes.get_errno()
+                    raise OSError(e, os.strerror(e))
+
+    def __exit__(self, *exc):
+        # TODO remove -- we can't exit the namespaces anyway (yay user namespace!)
+        # with keep_directory():
+        #     err = None
+        #     for p in reversed(self.previous):
+        #         if libc.setns(p, 0) == -1 and err is None:
+        #             ns = self.namespaces[self.previous.index(p)]  # NOQA
+        #             e = ctypes.get_errno()
+        #             err = OSError(e, os.strerror(e))
+        #     if err:
+        #         raise err
+        pass
+
+    def __repr__(self):
+        return "Namespace({})".format(", ".join(self.namespaces))
+
+
+class NamespaceFactory(object):
+    """Dynamically create namespaces as they are created.
+
+    Those namespaces are namespaces for IPC, net, mount and UTS. PID
+    is a bit special as we have to keep a process for that. We don't
+    do that to ensure that everything is cleaned
+    automatically. Therefore, the child process is killed as soon as
+    we got a file descriptor to the namespace. We don't use a user
+    namespace either because we are unlikely to be able to exit it.
+
+    """
+
+    def __init__(self, tmpdir):
+        self.namespaces = {}
+        self.tmpdir = tmpdir
+
+    def __call__(self, ns):
+        """Return a namespace. Create it if it doesn't exist."""
+        if ns in self.namespaces:
+            return self.namespaces[ns]
+
+        self.namespaces[ns] = Namespace("ipc", "net", "mnt", "uts")
+        with self.namespaces[ns]:
+            mount_proc()
+            mount_sys()
+            # Also setup the "namespace-dependant" directory
+            self.tmpdir.join("ns").ensure(dir=True)
+            mount_tmpfs(str(self.tmpdir.join("ns")), private=True)
+
+        return self.namespaces[ns]
+
+
+@pytest.fixture
+def namespaces(tmpdir):
+    return NamespaceFactory(tmpdir)