"""
import collections
+import contextvars
import collections.abc
import concurrent.futures
import errno
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None
+ self._context = contextvars.copy_context()
def __repr__(self):
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout,
- self._ssl_shutdown_timeout)
+ self._ssl_shutdown_timeout, context=self._context)
def get_loop(self):
return self._loop
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
- call_connection_made=True):
+ call_connection_made=True,
+ context=None):
"""Create SSL transport."""
raise NotImplementedError
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_shutdown_timeout=None, context=None):
sock.setblocking(False)
+ context = context if context is not None else contextvars.copy_context()
protocol = protocol_factory()
waiter = self.create_future()
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_shutdown_timeout=ssl_shutdown_timeout,
+ context=context)
else:
- transport = self._make_socket_transport(sock, protocol, waiter)
+ transport = self._make_socket_transport(sock, protocol, waiter, context=context)
try:
await waiter
signal.set_wakeup_fd(self._csock.fileno())
def _make_socket_transport(self, sock, protocol, waiter=None,
- extra=None, server=None):
+ extra=None, server=None, context=None):
return _ProactorSocketTransport(self, sock, protocol, waiter,
extra, server)
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_shutdown_timeout=None, context=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=None,
- ssl_shutdown_timeout=None):
+ ssl_shutdown_timeout=None, context=None):
def loop(f=None):
try:
self._transports = weakref.WeakValueDictionary()
def _make_socket_transport(self, sock, protocol, waiter=None, *,
- extra=None, server=None):
+ extra=None, server=None, context=None):
self._ensure_fd_no_transport(sock)
return _SelectorSocketTransport(self, sock, protocol, waiter,
- extra, server)
+ extra, server, context=context)
def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
+ context=None,
):
self._ensure_fd_no_transport(rawsock)
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout
+ ssl_shutdown_timeout=ssl_shutdown_timeout,
)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
- extra=extra, server=server)
+ extra=extra, server=server, context=context)
return ssl_protocol._app_transport
def _make_datagram_transport(self, sock, protocol,
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
- ssl_handshake_timeout, ssl_shutdown_timeout)
+ ssl_handshake_timeout, ssl_shutdown_timeout, context)
def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout,
- ssl_shutdown_timeout)
+ ssl_shutdown_timeout, context)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
+ conn_context = context.copy() if context is not None else None
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
- ssl_handshake_timeout, ssl_shutdown_timeout)
- self.create_task(accept)
+ ssl_handshake_timeout, ssl_shutdown_timeout, context=conn_context)
+ self.create_task(accept, context=conn_context)
async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
- ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+ ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
protocol = None
transport = None
try:
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout,
- ssl_shutdown_timeout=ssl_shutdown_timeout)
+ ssl_shutdown_timeout=ssl_shutdown_timeout,
+ context=context)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
- server=server)
+ server=server, context=context)
try:
await waiter
f'File descriptor {fd!r} is used by transport '
f'{transport!r}')
- def _add_reader(self, fd, callback, *args):
+ def _add_reader(self, fd, callback, *args, context=None):
self._check_closed()
- handle = events.Handle(callback, args, self, None)
+ handle = events.Handle(callback, args, self, context=context)
key = self._selector.get_map().get(fd)
if key is None:
self._selector.register(fd, selectors.EVENT_READ,
else:
return False
- def _add_writer(self, fd, callback, *args):
+ def _add_writer(self, fd, callback, *args, context=None):
self._check_closed()
- handle = events.Handle(callback, args, self, None)
+ handle = events.Handle(callback, args, self, context=context)
key = self._selector.get_map().get(fd)
if key is None:
self._selector.register(fd, selectors.EVENT_WRITE,
# exception)
_sock = None
- def __init__(self, loop, sock, protocol, extra=None, server=None):
+ def __init__(self, loop, sock, protocol, extra=None, server=None, context=None):
super().__init__(extra, loop)
self._extra['socket'] = trsock.TransportSocket(sock)
try:
self._extra['peername'] = None
self._sock = sock
self._sock_fd = sock.fileno()
-
+ self._context = context
self._protocol_connected = False
self.set_protocol(protocol)
if not self._buffer:
self._conn_lost += 1
self._loop._remove_writer(self._sock_fd)
- self._loop.call_soon(self._call_connection_lost, None)
+ self._call_soon(self._call_connection_lost, None)
def __del__(self, _warn=warnings.warn):
if self._sock is not None:
self._closing = True
self._loop._remove_reader(self._sock_fd)
self._conn_lost += 1
- self._loop.call_soon(self._call_connection_lost, exc)
+ self._call_soon(self._call_connection_lost, exc)
def _call_connection_lost(self, exc):
try:
def _add_reader(self, fd, callback, *args):
if not self.is_reading():
return
- self._loop._add_reader(fd, callback, *args)
+ self._loop._add_reader(fd, callback, *args, context=self._context)
+ def _add_writer(self, fd, callback, *args):
+ self._loop._add_writer(fd, callback, *args, context=self._context)
+
+ def _call_soon(self, callback, *args):
+ self._loop.call_soon(callback, *args, context=self._context)
class _SelectorSocketTransport(_SelectorTransport):
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE
def __init__(self, loop, sock, protocol, waiter=None,
- extra=None, server=None):
-
+ extra=None, server=None, context=None):
self._read_ready_cb = None
- super().__init__(loop, sock, protocol, extra, server)
+ super().__init__(loop, sock, protocol, extra, server, context)
self._eof = False
self._empty_waiter = None
if _HAS_SENDMSG:
# decreases the latency (in some cases significantly.)
base_events._set_nodelay(self._sock)
- self._loop.call_soon(self._protocol.connection_made, self)
+ self._call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
- self._loop.call_soon(self._add_reader,
- self._sock_fd, self._read_ready)
+ self._call_soon(self._add_reader, self._sock_fd, self._read_ready)
if waiter is not None:
# only wake up the waiter when connection_made() has been called
- self._loop.call_soon(futures._set_result_unless_cancelled,
- waiter, None)
+ self._call_soon(futures._set_result_unless_cancelled, waiter, None)
def set_protocol(self, protocol):
if isinstance(protocol, protocols.BufferedProtocol):
if not data:
return
# Not all was written; register write handler.
- self._loop._add_writer(self._sock_fd, self._write_ready)
+ self._add_writer(self._sock_fd, self._write_ready)
# Add it to the buffer.
self._buffer.append(data)
self._write_ready()
# If the entire buffer couldn't be written, register a write handler
if self._buffer:
- self._loop._add_writer(self._sock_fd, self._write_ready)
+ self._add_writer(self._sock_fd, self._write_ready)
self._maybe_pause_protocol()
def can_write_eof(self):
super().__init__(loop, sock, protocol, extra)
self._address = address
self._buffer_size = 0
- self._loop.call_soon(self._protocol.connection_made, self)
+ self._call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
- self._loop.call_soon(self._add_reader,
- self._sock_fd, self._read_ready)
+ self._call_soon(self._add_reader, self._sock_fd, self._read_ready)
if waiter is not None:
# only wake up the waiter when connection_made() has been called
- self._loop.call_soon(futures._set_result_unless_cancelled,
- waiter, None)
+ self._call_soon(futures._set_result_unless_cancelled, waiter, None)
def get_write_buffer_size(self):
return self._buffer_size
self._sock.sendto(data, addr)
return
except (BlockingIOError, InterruptedError):
- self._loop._add_writer(self._sock_fd, self._sendto_ready)
+ self._add_writer(self._sock_fd, self._sendto_ready)
except OSError as exc:
self._protocol.error_received(exc)
return
server_side=False,
server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_shutdown_timeout=shutdown_timeout,
+ context=ANY)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
server_side=False,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_shutdown_timeout=shutdown_timeout,
+ context=ANY)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
server_side=False,
server_hostname='',
ssl_handshake_timeout=handshake_timeout,
- ssl_shutdown_timeout=shutdown_timeout)
+ ssl_shutdown_timeout=shutdown_timeout,
+ context=ANY)
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
- MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
+ MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY, mock.ANY)
def test_call_coroutine(self):
async def simple_coroutine():
--- /dev/null
+import asyncio
+import contextvars
+import unittest
+import sys
+
+from unittest import TestCase
+
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from test.test_asyncio import utils as test_utils
+
+def tearDownModule():
+ asyncio.events._set_event_loop_policy(None)
+
+class ServerContextvarsTestCase:
+ loop_factory = None # To be defined in subclasses
+ server_ssl_context = None # To be defined in subclasses for SSL tests
+ client_ssl_context = None # To be defined in subclasses for SSL tests
+
+ def run_coro(self, coro):
+ return asyncio.run(coro, loop_factory=self.loop_factory)
+
+ def test_start_server1(self):
+ # Test that asyncio.start_server captures the context at the time of server creation
+ async def test():
+ var = contextvars.ContextVar("var", default="default")
+
+ async def handle_client(reader, writer):
+ value = var.get()
+ writer.write(value.encode())
+ await writer.drain()
+ writer.close()
+
+ server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
+ ssl=self.server_ssl_context)
+ # change the value
+ var.set("after_server")
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr,
+ ssl=self.client_ssl_context)
+ data = await reader.read(100)
+ writer.close()
+ await writer.wait_closed()
+ return data.decode()
+
+ async with server:
+ addr = server.sockets[0].getsockname()
+ self.assertEqual(await client(addr), "default")
+
+ self.assertEqual(var.get(), "after_server")
+
+ self.run_coro(test())
+
+ def test_start_server2(self):
+ # Test that mutations to the context in one handler don't affect other handlers or the server's context
+ async def test():
+ var = contextvars.ContextVar("var", default="default")
+
+ async def handle_client(reader, writer):
+ value = var.get()
+ writer.write(value.encode())
+ var.set("in_handler")
+ await writer.drain()
+ writer.close()
+
+ server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
+ ssl=self.server_ssl_context)
+ var.set("after_server")
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr,
+ ssl=self.client_ssl_context)
+ data = await reader.read(100)
+ writer.close()
+ await writer.wait_closed()
+ return data.decode()
+
+ async with server:
+ addr = server.sockets[0].getsockname()
+ self.assertEqual(await client(addr), "default")
+ self.assertEqual(await client(addr), "default")
+ self.assertEqual(await client(addr), "default")
+
+ self.assertEqual(var.get(), "after_server")
+
+ self.run_coro(test())
+
+ def test_start_server3(self):
+ # Test that mutations to context in concurrent handlers don't affect each other or the server's context
+ async def test():
+ var = contextvars.ContextVar("var", default="default")
+ var.set("before_server")
+
+ async def handle_client(reader, writer):
+ writer.write(var.get().encode())
+ await writer.drain()
+ writer.close()
+
+ server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
+ ssl=self.server_ssl_context)
+ var.set("after_server")
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr,
+ ssl=self.client_ssl_context)
+ data = await reader.read(100)
+ self.assertEqual(data.decode(), "before_server")
+ writer.close()
+ await writer.wait_closed()
+
+ async with server:
+ addr = server.sockets[0].getsockname()
+ async with asyncio.TaskGroup() as tg:
+ for _ in range(100):
+ tg.create_task(client(addr))
+
+ self.assertEqual(var.get(), "after_server")
+
+ self.run_coro(test())
+
+ def test_create_server1(self):
+ # Test that loop.create_server captures the context at the time of server creation
+ # and that mutations to the context in protocol callbacks don't affect the server's context
+ async def test():
+ var = contextvars.ContextVar("var", default="default")
+
+ class EchoProtocol(asyncio.Protocol):
+ def connection_made(self, transport):
+ self.transport = transport
+ value = var.get()
+ var.set("in_handler")
+ self.transport.write(value.encode())
+ self.transport.close()
+
+ server = await asyncio.get_running_loop().create_server(
+ lambda: EchoProtocol(), '127.0.0.1', 0,
+ ssl=self.server_ssl_context)
+ var.set("after_server")
+
+ async def client(addr):
+ reader, writer = await asyncio.open_connection(*addr,
+ ssl=self.client_ssl_context)
+ data = await reader.read(100)
+ self.assertEqual(data.decode(), "default")
+ writer.close()
+ await writer.wait_closed()
+
+ async with server:
+ addr = server.sockets[0].getsockname()
+ await client(addr)
+
+ self.assertEqual(var.get(), "after_server")
+
+ self.run_coro(test())
+
+ def test_create_server2(self):
+ # Test that mutations to context in one protocol instance don't affect other instances or the server's context
+ async def test():
+ var = contextvars.ContextVar("var", default="default")
+
+ class EchoProtocol(asyncio.Protocol):
+ def __init__(self):
+ super().__init__()
+ assert var.get() == "default", var.get()
+ def connection_made(self, transport):
+ self.transport = transport
+ value = var.get()
+ var.set("in_handler")
+ self.transport.write(value.encode())
+ self.transport.close()
+
+ server = await asyncio.get_running_loop().create_server(
+ lambda: EchoProtocol(), '127.0.0.1', 0,
+ ssl=self.server_ssl_context)
+
+ var.set("after_server")
+
+ async def client(addr, expected):
+ reader, writer = await asyncio.open_connection(*addr,
+ ssl=self.client_ssl_context)
+ data = await reader.read(100)
+ self.assertEqual(data.decode(), expected)
+ writer.close()
+ await writer.wait_closed()
+
+ async with server:
+ addr = server.sockets[0].getsockname()
+ await client(addr, "default")
+ await client(addr, "default")
+
+ self.assertEqual(var.get(), "after_server")
+
+ self.run_coro(test())
+
+ def test_gh140947(self):
+ # See https://github.com/python/cpython/issues/140947
+
+ cvar1 = contextvars.ContextVar("cvar1")
+ cvar2 = contextvars.ContextVar("cvar2")
+ cvar3 = contextvars.ContextVar("cvar3")
+ results = {}
+ is_ssl = self.server_ssl_context is not None
+
+ def capture_context(meth):
+ result = []
+ for k,v in contextvars.copy_context().items():
+ if k.name.startswith("cvar"):
+ result.append((k.name, v))
+ results[meth] = sorted(result)
+
+ class DemoProtocol(asyncio.Protocol):
+ def __init__(self, on_conn_lost):
+ self.transport = None
+ self.on_conn_lost = on_conn_lost
+ self.tasks = set()
+
+ def connection_made(self, transport):
+ capture_context("connection_made")
+ self.transport = transport
+
+ def data_received(self, data):
+ capture_context("data_received")
+
+ task = asyncio.create_task(self.asgi())
+ self.tasks.add(task)
+ task.add_done_callback(self.tasks.discard)
+
+ self.transport.pause_reading()
+
+ def connection_lost(self, exc):
+ capture_context("connection_lost")
+ if not self.on_conn_lost.done():
+ self.on_conn_lost.set_result(True)
+
+ async def asgi(self):
+ capture_context("asgi start")
+ cvar1.set(True)
+ # make sure that we only resume after the pause
+ # otherwise the resume does nothing
+ if is_ssl:
+ while not self.transport._ssl_protocol._app_reading_paused:
+ await asyncio.sleep(0.01)
+ else:
+ while not self.transport._paused:
+ await asyncio.sleep(0.01)
+ cvar2.set(True)
+ self.transport.resume_reading()
+ cvar3.set(True)
+ capture_context("asgi end")
+
+ async def main():
+ loop = asyncio.get_running_loop()
+ on_conn_lost = loop.create_future()
+
+ server = await loop.create_server(
+ lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0,
+ ssl=self.server_ssl_context)
+ async with server:
+ addr = server.sockets[0].getsockname()
+ reader, writer = await asyncio.open_connection(*addr,
+ ssl=self.client_ssl_context)
+ writer.write(b"anything")
+ await writer.drain()
+ writer.close()
+ await writer.wait_closed()
+ await on_conn_lost
+
+ self.run_coro(main())
+ self.assertDictEqual(results, {
+ "connection_made": [],
+ "data_received": [],
+ "asgi start": [],
+ "asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)],
+ "connection_lost": [],
+ })
+
+
+class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
+ loop_factory = staticmethod(asyncio.new_event_loop)
+
+@unittest.skipUnless(ssl, "SSL not available")
+class AsyncioEventLoopSSLTests(AsyncioEventLoopTests):
+ def setUp(self):
+ super().setUp()
+ self.server_ssl_context = test_utils.simple_server_sslcontext()
+ self.client_ssl_context = test_utils.simple_client_sslcontext()
+
+if sys.platform == "win32":
+ class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase):
+ loop_factory = asyncio.ProactorEventLoop
+
+ class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase):
+ loop_factory = asyncio.SelectorEventLoop
+
+ @unittest.skipUnless(ssl, "SSL not available")
+ class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests):
+ def setUp(self):
+ super().setUp()
+ self.server_ssl_context = test_utils.simple_server_sslcontext()
+ self.client_ssl_context = test_utils.simple_client_sslcontext()
+
+ @unittest.skipUnless(ssl, "SSL not available")
+ class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests):
+ def setUp(self):
+ super().setUp()
+ self.server_ssl_context = test_utils.simple_server_sslcontext()
+ self.client_ssl_context = test_utils.simple_client_sslcontext()
+
+if __name__ == "__main__":
+ unittest.main()
else: # pragma: no cover
raise AssertionError("Time generator is not finished")
- def _add_reader(self, fd, callback, *args):
- self.readers[fd] = events.Handle(callback, args, self, None)
+ def _add_reader(self, fd, callback, *args, context=None):
+ self.readers[fd] = events.Handle(callback, args, self, context)
def _remove_reader(self, fd):
self.remove_reader_count[fd] += 1
if fd in self.readers:
raise AssertionError(f'fd {fd} is registered')
- def _add_writer(self, fd, callback, *args):
- self.writers[fd] = events.Handle(callback, args, self, None)
+ def _add_writer(self, fd, callback, *args, context=None):
+ self.writers[fd] = events.Handle(callback, args, self, context)
def _remove_writer(self, fd):
self.remove_writer_count[fd] += 1
--- /dev/null
+Fix incorrect contextvars handling in server tasks created by :mod:`asyncio`. Patch by Kumar Aditya.