]>
Commit | Line | Data |
---|---|---|
e0a84778 VB |
1 | import contextlib |
2 | import ctypes | |
3 | import errno | |
4 | import os | |
5 | import pyroute2 | |
6 | import pytest | |
7 | import signal | |
8 | ||
9 | # All allowed namespace types | |
10 | NAMESPACE_FLAGS = dict(mnt=0x00020000, | |
11 | uts=0x04000000, | |
12 | ipc=0x08000000, | |
13 | user=0x10000000, | |
14 | pid=0x20000000, | |
15 | net=0x40000000) | |
16 | STACKSIZE = 1024*1024 | |
17 | ||
18 | libc = ctypes.CDLL('libc.so.6', use_errno=True) | |
19 | ||
20 | ||
21 | @contextlib.contextmanager | |
22 | def keep_directory(): | |
23 | """Restore the current directory on exit.""" | |
24 | pwd = os.getcwd() | |
25 | try: | |
26 | yield | |
27 | finally: | |
28 | os.chdir(pwd) | |
29 | ||
30 | ||
08e05799 VB |
31 | def mount_sys(target="/sys"): |
32 | flags = [2 | 4 | 8] # MS_NOSUID | MS_NODEV | MS_NOEXEC | |
33 | flags.append(1 << 18) # MS_PRIVATE | |
34 | flags.append(1 << 19) # MS_SLAVE | |
35 | for fl in flags: | |
36 | ret = libc.mount(b"none", | |
37 | target.encode('ascii'), | |
38 | b"sysfs", | |
39 | fl, | |
40 | None) | |
41 | if ret == -1: | |
42 | e = ctypes.get_errno() | |
43 | raise OSError(e, os.strerror(e)) | |
44 | ||
45 | ||
e0a84778 VB |
46 | class Namespace(object): |
47 | """Combine several namespaces into one. | |
48 | ||
49 | This gets a list of namespace types to create and combine into one. The | |
50 | combined namespace can be used as a context manager to enter all the | |
51 | created namespaces and exit them at the end. | |
52 | """ | |
53 | ||
54 | def __init__(self, *namespaces): | |
55 | self.namespaces = namespaces | |
56 | for ns in namespaces: | |
57 | assert ns in NAMESPACE_FLAGS | |
58 | ||
59 | # Get a pipe to signal the future child to exit | |
60 | self.pipe = os.pipe() | |
61 | ||
62 | # First, create a child in the given namespaces | |
63 | child = ctypes.CFUNCTYPE(ctypes.c_int)(self.child) | |
64 | child_stack = ctypes.create_string_buffer(STACKSIZE) | |
65 | child_stack_pointer = ctypes.c_void_p( | |
66 | ctypes.cast(child_stack, | |
67 | ctypes.c_void_p).value + STACKSIZE) | |
68 | flags = signal.SIGCHLD | |
69 | for ns in namespaces: | |
70 | flags |= NAMESPACE_FLAGS[ns] | |
71 | pid = libc.clone(child, child_stack_pointer, flags) | |
72 | if pid == -1: | |
73 | e = ctypes.get_errno() | |
74 | raise OSError(e, os.strerror(e)) | |
75 | ||
76 | # If a user namespace, map UID 0 to the current one | |
77 | if 'user' in namespaces: | |
78 | uid_map = '0 {} 1'.format(os.getuid()) | |
79 | gid_map = '0 {} 1'.format(os.getgid()) | |
80 | with open('/proc/{}/uid_map'.format(pid), 'w') as f: | |
81 | f.write(uid_map) | |
82 | with open('/proc/{}/setgroups'.format(pid), 'w') as f: | |
83 | f.write('deny') | |
84 | with open('/proc/{}/gid_map'.format(pid), 'w') as f: | |
85 | f.write(gid_map) | |
86 | ||
87 | # Retrieve a file descriptor to this new namespace | |
88 | self.next = [os.open('/proc/{}/ns/{}'.format(pid, x), | |
89 | os.O_RDONLY) for x in namespaces] | |
90 | ||
91 | # Keep a file descriptor to our old namespaces | |
92 | self.previous = [os.open('/proc/self/ns/{}'.format(x), | |
93 | os.O_RDONLY) for x in namespaces] | |
94 | ||
95 | # Tell the child all is done and let it die | |
96 | os.close(self.pipe[0]) | |
97 | if 'pid' not in namespaces: | |
98 | os.close(self.pipe[1]) | |
ad8971ec | 99 | self.pipe = None |
e0a84778 VB |
100 | os.waitpid(pid, 0) |
101 | ||
ad8971ec VB |
102 | def __del__(self): |
103 | for fd in self.next: | |
104 | os.close(fd) | |
105 | for fd in self.previous: | |
106 | os.close(fd) | |
107 | if self.pipe is not None: | |
108 | os.close(self.pipe[1]) | |
109 | ||
e0a84778 VB |
110 | def child(self): |
111 | """Cloned child. | |
112 | ||
113 | Just be here until our parent extract the file descriptor from | |
114 | us. | |
115 | ||
116 | """ | |
117 | os.close(self.pipe[1]) | |
118 | ||
119 | # For a network namespace, enable lo | |
120 | if 'net' in self.namespaces: | |
121 | ipr = pyroute2.IPRoute() | |
122 | lo = ipr.link_lookup(ifname='lo')[0] | |
123 | ipr.link('set', index=lo, state='up') | |
124 | # For a mount namespace, make it private | |
125 | if 'mnt' in self.namespaces: | |
126 | libc.mount(b"none", b"/", None, | |
127 | # MS_REC | MS_PRIVATE | |
128 | 16384 | (1 << 18), | |
129 | None) | |
08e05799 | 130 | mount_sys() |
e0a84778 VB |
131 | |
132 | while True: | |
133 | try: | |
134 | os.read(self.pipe[0], 1) | |
135 | except OSError as e: | |
136 | if e.errno in [errno.EAGAIN, errno.EINTR]: | |
137 | continue | |
138 | break | |
139 | ||
140 | os._exit(0) | |
141 | ||
142 | def fd(self, namespace): | |
143 | """Return the file descriptor associated to a namespace""" | |
144 | assert namespace in self.namespaces | |
145 | return self.next[self.namespaces.index(namespace)] | |
146 | ||
147 | def __enter__(self): | |
148 | with keep_directory(): | |
149 | for n in self.next: | |
150 | if libc.setns(n, 0) == -1: | |
151 | ns = self.namespaces[self.next.index(n)] # NOQA | |
152 | e = ctypes.get_errno() | |
153 | raise OSError(e, os.strerror(e)) | |
154 | ||
155 | def __exit__(self, *exc): | |
156 | with keep_directory(): | |
157 | err = None | |
158 | for p in reversed(self.previous): | |
159 | if libc.setns(p, 0) == -1 and err is None: | |
160 | ns = self.namespaces[self.previous.index(p)] # NOQA | |
161 | e = ctypes.get_errno() | |
162 | err = OSError(e, os.strerror(e)) | |
163 | if err: | |
164 | raise err | |
165 | ||
166 | def __repr__(self): | |
167 | return 'Namespace({})'.format(", ".join(self.namespaces)) | |
168 | ||
169 | ||
170 | class NamespaceFactory(object): | |
171 | """Dynamically create namespaces as they are created. | |
172 | ||
173 | Those namespaces are namespaces for IPC, net, mount and UTS. PID | |
174 | is a bit special as we have to keep a process for that. We don't | |
175 | do that to ensure that everything is cleaned | |
176 | automatically. Therefore, the child process is killed as soon as | |
177 | we got a file descriptor to the namespace. We don't use a user | |
178 | namespace either because we are unlikely to be able to exit it. | |
179 | ||
180 | """ | |
181 | ||
182 | def __init__(self): | |
183 | self.namespaces = {} | |
184 | ||
185 | def __call__(self, ns): | |
186 | """Return a namespace. Create it if it doesn't exist.""" | |
187 | if ns in self.namespaces: | |
188 | return self.namespaces[ns] | |
189 | self.namespaces[ns] = Namespace('ipc', 'net', 'mnt', 'uts') | |
190 | return self.namespaces[ns] | |
191 | ||
192 | ||
193 | @pytest.fixture | |
194 | def namespaces(): | |
195 | return NamespaceFactory() |