import threading
import warnings
+from . import AuthenticationError
from . import connection
from . import process
from .context import reduction
MAXFDS_TO_SEND = 256
SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
+_AUTHKEY_LEN = 32 # <= PIPEBUF so it fits a single write to an empty pipe.
#
# Forkserver class
class ForkServer(object):
def __init__(self):
+ self._forkserver_authkey = None
self._forkserver_address = None
self._forkserver_alive_fd = None
self._forkserver_pid = None
if not util.is_abstract_socket_namespace(self._forkserver_address):
os.unlink(self._forkserver_address)
self._forkserver_address = None
+ self._forkserver_authkey = None
def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.'''
process data.
'''
self.ensure_running()
+ assert self._forkserver_authkey
if len(fds) + 4 >= MAXFDS_TO_SEND:
raise ValueError('too many fds')
with socket.socket(socket.AF_UNIX) as client:
resource_tracker.getfd()]
allfds += fds
try:
+ client.setblocking(True)
+ wrapped_client = connection.Connection(client.fileno())
+ # The other side of this exchange happens in the child as
+ # implemented in main().
+ try:
+ connection.answer_challenge(
+ wrapped_client, self._forkserver_authkey)
+ connection.deliver_challenge(
+ wrapped_client, self._forkserver_authkey)
+ finally:
+ wrapped_client._detach()
+ del wrapped_client
reduction.sendfds(client, allfds)
return parent_r, parent_w
except:
return
# dead, launch it again
os.close(self._forkserver_alive_fd)
+ self._forkserver_authkey = None
self._forkserver_address = None
self._forkserver_alive_fd = None
self._forkserver_pid = None
if self._preload_modules:
desired_keys = {'main_path', 'sys_path'}
data = spawn.get_preparation_data('ignore')
- data = {x: y for x, y in data.items() if x in desired_keys}
+ main_kws = {x: y for x, y in data.items() if x in desired_keys}
else:
- data = {}
+ main_kws = {}
with socket.socket(socket.AF_UNIX) as listener:
address = connection.arbitrary_address('AF_UNIX')
# all client processes own the write end of the "alive" pipe;
# when they all terminate the read end becomes ready.
alive_r, alive_w = os.pipe()
+ # A short lived pipe to initialize the forkserver authkey.
+ authkey_r, authkey_w = os.pipe()
try:
- fds_to_pass = [listener.fileno(), alive_r]
+ fds_to_pass = [listener.fileno(), alive_r, authkey_r]
+ main_kws['authkey_r'] = authkey_r
cmd %= (listener.fileno(), alive_r, self._preload_modules,
- data)
+ main_kws)
exe = spawn.get_executable()
args = [exe] + util._args_from_interpreter_flags()
args += ['-c', cmd]
pid = util.spawnv_passfds(exe, args, fds_to_pass)
except:
os.close(alive_w)
+ os.close(authkey_w)
raise
finally:
os.close(alive_r)
+ os.close(authkey_r)
+ # Authenticate our control socket to prevent access from
+ # processes we have not shared this key with.
+ try:
+ self._forkserver_authkey = os.urandom(_AUTHKEY_LEN)
+ os.write(authkey_w, self._forkserver_authkey)
+ finally:
+ os.close(authkey_w)
self._forkserver_address = address
self._forkserver_alive_fd = alive_w
self._forkserver_pid = pid
#
#
-def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
- '''Run forkserver.'''
+def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
+ *, authkey_r=None):
+ """Run forkserver."""
+ if authkey_r is not None:
+ try:
+ authkey = os.read(authkey_r, _AUTHKEY_LEN)
+ assert len(authkey) == _AUTHKEY_LEN, f'{len(authkey)} < {_AUTHKEY_LEN}'
+ finally:
+ os.close(authkey_r)
+ else:
+ authkey = b''
+
if preload:
if sys_path is not None:
sys.path[:] = sys_path
if listener in rfds:
# Incoming fork request
with listener.accept()[0] as s:
- # Receive fds from client
- fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
+ try:
+ if authkey:
+ wrapped_s = connection.Connection(s.fileno())
+ # The other side of this exchange happens in
+ # in connect_to_new_process().
+ try:
+ connection.deliver_challenge(
+ wrapped_s, authkey)
+ connection.answer_challenge(
+ wrapped_s, authkey)
+ finally:
+ wrapped_s._detach()
+ del wrapped_s
+ # Receive fds from client
+ fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
+ except (EOFError, BrokenPipeError, AuthenticationError):
+ s.close()
+ continue
if len(fds) > MAXFDS_TO_SEND:
raise RuntimeError(
"Too many ({0:n}) fds to send".format(
__all__ += ['DupFd', 'sendfds', 'recvfds']
import array
- # On MacOSX we should acknowledge receipt of fds -- see Issue14669
- ACKNOWLEDGE = sys.platform == 'darwin'
-
def sendfds(sock, fds):
'''Send an array of fds over an AF_UNIX socket.'''
fds = array.array('i', fds)
msg = bytes([len(fds) % 256])
sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
- if ACKNOWLEDGE and sock.recv(1) != b'A':
+ if sock.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd')
def recvfds(sock, size):
if not msg and not ancdata:
raise EOFError
try:
- if ACKNOWLEDGE:
- sock.send(b'A')
+ # We send/recv an Ack byte after the fds to work around an old
+ # macOS bug; it isn't clear if this is still required but it
+ # makes unit testing fd sending easier.
+ # See: https://github.com/python/cpython/issues/58874
+ sock.send(b'A') # Acknowledge
if len(ancdata) != 1:
raise RuntimeError('received %d items of ancdata' %
len(ancdata))
finally:
setattr(sys, stream_name, old_stream)
- @classmethod
- def _sleep_and_set_event(self, evt, delay=0.0):
+ @staticmethod
+ def _sleep_and_set_event(evt, delay=0.0):
time.sleep(delay)
evt.set()
if os.name != 'nt':
self.check_forkserver_death(signal.SIGKILL)
+ def test_forkserver_auth_is_enabled(self):
+ if self.TYPE == "threads":
+ self.skipTest(f"test not appropriate for {self.TYPE}")
+ if multiprocessing.get_start_method() != "forkserver":
+ self.skipTest("forkserver start method specific")
+
+ forkserver = multiprocessing.forkserver._forkserver
+ forkserver.ensure_running()
+ self.assertTrue(forkserver._forkserver_pid)
+ authkey = forkserver._forkserver_authkey
+ self.assertTrue(authkey)
+ self.assertGreater(len(authkey), 15)
+ addr = forkserver._forkserver_address
+ self.assertTrue(addr)
+
+ # Demonstrate that a raw auth handshake, as Client performs, does not
+ # raise an error.
+ client = multiprocessing.connection.Client(addr, authkey=authkey)
+ client.close()
+
+ # That worked, now launch a quick process.
+ proc = self.Process(target=sys.exit)
+ proc.start()
+ proc.join()
+ self.assertEqual(proc.exitcode, 0)
+
+ def test_forkserver_without_auth_fails(self):
+ if self.TYPE == "threads":
+ self.skipTest(f"test not appropriate for {self.TYPE}")
+ if multiprocessing.get_start_method() != "forkserver":
+ self.skipTest("forkserver start method specific")
+
+ forkserver = multiprocessing.forkserver._forkserver
+ forkserver.ensure_running()
+ self.assertTrue(forkserver._forkserver_pid)
+ authkey_len = len(forkserver._forkserver_authkey)
+ with unittest.mock.patch.object(
+ forkserver, '_forkserver_authkey', None):
+ # With an incorrect authkey we should get an auth rejection
+ # rather than the above protocol error.
+ forkserver._forkserver_authkey = b'T' * authkey_len
+ proc = self.Process(target=sys.exit)
+ with self.assertRaises(multiprocessing.AuthenticationError):
+ proc.start()
+ del proc
+
+ # authkey restored, launching processes should work again.
+ proc = self.Process(target=sys.exit)
+ proc.start()
+ proc.join()
#
#