]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-140947: fix contextvars handling for server tasks in asyncio (#141158)
authorKumar Aditya <kumaraditya@python.org>
Sat, 21 Mar 2026 12:14:08 +0000 (17:44 +0530)
committerGitHub <noreply@github.com>
Sat, 21 Mar 2026 12:14:08 +0000 (17:44 +0530)
Lib/asyncio/base_events.py
Lib/asyncio/proactor_events.py
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_server_context.py [new file with mode: 0644]
Lib/test/test_asyncio/utils.py
Misc/NEWS.d/next/Library/2026-03-21-08-23-26.gh-issue-140947.owZ4r_.rst [new file with mode: 0644]

index 77c70aaa7b986e15402342a571fb14c9664d13f1..7a6837546d930f35498026511da8701819ffb5c0 100644 (file)
@@ -14,6 +14,7 @@ to modify the meaning of the API call itself.
 """
 
 import collections
+import contextvars
 import collections.abc
 import concurrent.futures
 import errno
@@ -290,6 +291,7 @@ class Server(events.AbstractServer):
         self._ssl_shutdown_timeout = ssl_shutdown_timeout
         self._serving = False
         self._serving_forever_fut = None
+        self._context = contextvars.copy_context()
 
     def __repr__(self):
         return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
@@ -319,7 +321,7 @@ class Server(events.AbstractServer):
             self._loop._start_serving(
                 self._protocol_factory, sock, self._ssl_context,
                 self, self._backlog, self._ssl_handshake_timeout,
-                self._ssl_shutdown_timeout)
+                self._ssl_shutdown_timeout, context=self._context)
 
     def get_loop(self):
         return self._loop
@@ -509,7 +511,8 @@ class BaseEventLoop(events.AbstractEventLoop):
             extra=None, server=None,
             ssl_handshake_timeout=None,
             ssl_shutdown_timeout=None,
-            call_connection_made=True):
+            call_connection_made=True,
+            context=None):
         """Create SSL transport."""
         raise NotImplementedError
 
@@ -1213,9 +1216,10 @@ class BaseEventLoop(events.AbstractEventLoop):
             self, sock, protocol_factory, ssl,
             server_hostname, server_side=False,
             ssl_handshake_timeout=None,
-            ssl_shutdown_timeout=None):
+            ssl_shutdown_timeout=None, context=None):
 
         sock.setblocking(False)
+        context = context if context is not None else contextvars.copy_context()
 
         protocol = protocol_factory()
         waiter = self.create_future()
@@ -1225,9 +1229,10 @@ class BaseEventLoop(events.AbstractEventLoop):
                 sock, protocol, sslcontext, waiter,
                 server_side=server_side, server_hostname=server_hostname,
                 ssl_handshake_timeout=ssl_handshake_timeout,
-                ssl_shutdown_timeout=ssl_shutdown_timeout)
+                ssl_shutdown_timeout=ssl_shutdown_timeout,
+                context=context)
         else:
-            transport = self._make_socket_transport(sock, protocol, waiter)
+            transport = self._make_socket_transport(sock, protocol, waiter, context=context)
 
         try:
             await waiter
index 3fa93b14a6787ffe8372cf56f79d89845c39e0f5..2dc1569d7807911736bd0eeb76a5918e3a07ab5d 100644 (file)
@@ -642,7 +642,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
             signal.set_wakeup_fd(self._csock.fileno())
 
     def _make_socket_transport(self, sock, protocol, waiter=None,
-                               extra=None, server=None):
+                               extra=None, server=None, context=None):
         return _ProactorSocketTransport(self, sock, protocol, waiter,
                                         extra, server)
 
@@ -651,7 +651,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
             *, server_side=False, server_hostname=None,
             extra=None, server=None,
             ssl_handshake_timeout=None,
-            ssl_shutdown_timeout=None):
+            ssl_shutdown_timeout=None, context=None):
         ssl_protocol = sslproto.SSLProtocol(
                 self, protocol, sslcontext, waiter,
                 server_side, server_hostname,
@@ -837,7 +837,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
     def _start_serving(self, protocol_factory, sock,
                        sslcontext=None, server=None, backlog=100,
                        ssl_handshake_timeout=None,
-                       ssl_shutdown_timeout=None):
+                       ssl_shutdown_timeout=None, context=None):
 
         def loop(f=None):
             try:
index ff7e16df3c62737a01cba22017ed3eefcc93fac8..9685e7fc05d241f077c1a979a7b35e9b92dae35a 100644 (file)
@@ -67,10 +67,10 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
         self._transports = weakref.WeakValueDictionary()
 
     def _make_socket_transport(self, sock, protocol, waiter=None, *,
-                               extra=None, server=None):
+                               extra=None, server=None, context=None):
         self._ensure_fd_no_transport(sock)
         return _SelectorSocketTransport(self, sock, protocol, waiter,
-                                        extra, server)
+                                        extra, server, context=context)
 
     def _make_ssl_transport(
             self, rawsock, protocol, sslcontext, waiter=None,
@@ -78,16 +78,17 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
             extra=None, server=None,
             ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
             ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
+            context=None,
     ):
         self._ensure_fd_no_transport(rawsock)
         ssl_protocol = sslproto.SSLProtocol(
             self, protocol, sslcontext, waiter,
             server_side, server_hostname,
             ssl_handshake_timeout=ssl_handshake_timeout,
-            ssl_shutdown_timeout=ssl_shutdown_timeout
+            ssl_shutdown_timeout=ssl_shutdown_timeout,
         )
         _SelectorSocketTransport(self, rawsock, ssl_protocol,
-                                 extra=extra, server=server)
+                                 extra=extra, server=server, context=context)
         return ssl_protocol._app_transport
 
     def _make_datagram_transport(self, sock, protocol,
@@ -159,16 +160,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
     def _start_serving(self, protocol_factory, sock,
                        sslcontext=None, server=None, backlog=100,
                        ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
-                       ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+                       ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
         self._add_reader(sock.fileno(), self._accept_connection,
                          protocol_factory, sock, sslcontext, server, backlog,
-                         ssl_handshake_timeout, ssl_shutdown_timeout)
+                         ssl_handshake_timeout, ssl_shutdown_timeout, context)
 
     def _accept_connection(
             self, protocol_factory, sock,
             sslcontext=None, server=None, backlog=100,
             ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
-            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
         # This method is only called once for each event loop tick where the
         # listening socket has triggered an EVENT_READ. There may be multiple
         # connections waiting for an .accept() so it is called in a loop.
@@ -204,21 +205,22 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                                     self._start_serving,
                                     protocol_factory, sock, sslcontext, server,
                                     backlog, ssl_handshake_timeout,
-                                    ssl_shutdown_timeout)
+                                    ssl_shutdown_timeout, context)
                 else:
                     raise  # The event loop will catch, log and ignore it.
             else:
                 extra = {'peername': addr}
+                conn_context = context.copy() if context is not None else None
                 accept = self._accept_connection2(
                     protocol_factory, conn, extra, sslcontext, server,
-                    ssl_handshake_timeout, ssl_shutdown_timeout)
-                self.create_task(accept)
+                    ssl_handshake_timeout, ssl_shutdown_timeout, context=conn_context)
+                self.create_task(accept, context=conn_context)
 
     async def _accept_connection2(
             self, protocol_factory, conn, extra,
             sslcontext=None, server=None,
             ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
-            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
+            ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None):
         protocol = None
         transport = None
         try:
@@ -229,11 +231,12 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                     conn, protocol, sslcontext, waiter=waiter,
                     server_side=True, extra=extra, server=server,
                     ssl_handshake_timeout=ssl_handshake_timeout,
-                    ssl_shutdown_timeout=ssl_shutdown_timeout)
+                    ssl_shutdown_timeout=ssl_shutdown_timeout,
+                    context=context)
             else:
                 transport = self._make_socket_transport(
                     conn, protocol, waiter=waiter, extra=extra,
-                    server=server)
+                    server=server, context=context)
 
             try:
                 await waiter
@@ -275,9 +278,9 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                 f'File descriptor {fd!r} is used by transport '
                 f'{transport!r}')
 
-    def _add_reader(self, fd, callback, *args):
+    def _add_reader(self, fd, callback, *args, context=None):
         self._check_closed()
-        handle = events.Handle(callback, args, self, None)
+        handle = events.Handle(callback, args, self, context=context)
         key = self._selector.get_map().get(fd)
         if key is None:
             self._selector.register(fd, selectors.EVENT_READ,
@@ -309,9 +312,9 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
         else:
             return False
 
-    def _add_writer(self, fd, callback, *args):
+    def _add_writer(self, fd, callback, *args, context=None):
         self._check_closed()
-        handle = events.Handle(callback, args, self, None)
+        handle = events.Handle(callback, args, self, context=context)
         key = self._selector.get_map().get(fd)
         if key is None:
             self._selector.register(fd, selectors.EVENT_WRITE,
@@ -770,7 +773,7 @@ class _SelectorTransport(transports._FlowControlMixin,
     # exception)
     _sock = None
 
-    def __init__(self, loop, sock, protocol, extra=None, server=None):
+    def __init__(self, loop, sock, protocol, extra=None, server=None, context=None):
         super().__init__(extra, loop)
         self._extra['socket'] = trsock.TransportSocket(sock)
         try:
@@ -784,7 +787,7 @@ class _SelectorTransport(transports._FlowControlMixin,
                 self._extra['peername'] = None
         self._sock = sock
         self._sock_fd = sock.fileno()
-
+        self._context = context
         self._protocol_connected = False
         self.set_protocol(protocol)
 
@@ -866,7 +869,7 @@ class _SelectorTransport(transports._FlowControlMixin,
         if not self._buffer:
             self._conn_lost += 1
             self._loop._remove_writer(self._sock_fd)
-            self._loop.call_soon(self._call_connection_lost, None)
+            self._call_soon(self._call_connection_lost, None)
 
     def __del__(self, _warn=warnings.warn):
         if self._sock is not None:
@@ -899,7 +902,7 @@ class _SelectorTransport(transports._FlowControlMixin,
             self._closing = True
             self._loop._remove_reader(self._sock_fd)
         self._conn_lost += 1
-        self._loop.call_soon(self._call_connection_lost, exc)
+        self._call_soon(self._call_connection_lost, exc)
 
     def _call_connection_lost(self, exc):
         try:
@@ -921,8 +924,13 @@ class _SelectorTransport(transports._FlowControlMixin,
     def _add_reader(self, fd, callback, *args):
         if not self.is_reading():
             return
-        self._loop._add_reader(fd, callback, *args)
+        self._loop._add_reader(fd, callback, *args, context=self._context)
 
+    def _add_writer(self, fd, callback, *args):
+        self._loop._add_writer(fd, callback, *args, context=self._context)
+
+    def _call_soon(self, callback, *args):
+        self._loop.call_soon(callback, *args, context=self._context)
 
 class _SelectorSocketTransport(_SelectorTransport):
 
@@ -930,10 +938,9 @@ class _SelectorSocketTransport(_SelectorTransport):
     _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
 
     def __init__(self, loop, sock, protocol, waiter=None,
-                 extra=None, server=None):
-
+                 extra=None, server=None, context=None):
         self._read_ready_cb = None
-        super().__init__(loop, sock, protocol, extra, server)
+        super().__init__(loop, sock, protocol, extra, server, context)
         self._eof = False
         self._empty_waiter = None
         if _HAS_SENDMSG:
@@ -945,14 +952,12 @@ class _SelectorSocketTransport(_SelectorTransport):
         # decreases the latency (in some cases significantly.)
         base_events._set_nodelay(self._sock)
 
-        self._loop.call_soon(self._protocol.connection_made, self)
+        self._call_soon(self._protocol.connection_made, self)
         # only start reading when connection_made() has been called
-        self._loop.call_soon(self._add_reader,
-                             self._sock_fd, self._read_ready)
+        self._call_soon(self._add_reader, self._sock_fd, self._read_ready)
         if waiter is not None:
             # only wake up the waiter when connection_made() has been called
-            self._loop.call_soon(futures._set_result_unless_cancelled,
-                                 waiter, None)
+            self._call_soon(futures._set_result_unless_cancelled, waiter, None)
 
     def set_protocol(self, protocol):
         if isinstance(protocol, protocols.BufferedProtocol):
@@ -1081,7 +1086,7 @@ class _SelectorSocketTransport(_SelectorTransport):
                 if not data:
                     return
             # Not all was written; register write handler.
-            self._loop._add_writer(self._sock_fd, self._write_ready)
+            self._add_writer(self._sock_fd, self._write_ready)
 
         # Add it to the buffer.
         self._buffer.append(data)
@@ -1185,7 +1190,7 @@ class _SelectorSocketTransport(_SelectorTransport):
         self._write_ready()
         # If the entire buffer couldn't be written, register a write handler
         if self._buffer:
-            self._loop._add_writer(self._sock_fd, self._write_ready)
+            self._add_writer(self._sock_fd, self._write_ready)
             self._maybe_pause_protocol()
 
     def can_write_eof(self):
@@ -1226,14 +1231,12 @@ class _SelectorDatagramTransport(_SelectorTransport, transports.DatagramTranspor
         super().__init__(loop, sock, protocol, extra)
         self._address = address
         self._buffer_size = 0
-        self._loop.call_soon(self._protocol.connection_made, self)
+        self._call_soon(self._protocol.connection_made, self)
         # only start reading when connection_made() has been called
-        self._loop.call_soon(self._add_reader,
-                             self._sock_fd, self._read_ready)
+        self._call_soon(self._add_reader, self._sock_fd, self._read_ready)
         if waiter is not None:
             # only wake up the waiter when connection_made() has been called
-            self._loop.call_soon(futures._set_result_unless_cancelled,
-                                 waiter, None)
+            self._call_soon(futures._set_result_unless_cancelled, waiter, None)
 
     def get_write_buffer_size(self):
         return self._buffer_size
@@ -1280,7 +1283,7 @@ class _SelectorDatagramTransport(_SelectorTransport, transports.DatagramTranspor
                     self._sock.sendto(data, addr)
                 return
             except (BlockingIOError, InterruptedError):
-                self._loop._add_writer(self._sock_fd, self._sendto_ready)
+                self._add_writer(self._sock_fd, self._sendto_ready)
             except OSError as exc:
                 self._protocol.error_received(exc)
                 return
index 8c02de77c2474044b2f5c5f7d42287af35304f7e..e59bc25668b4cba09d10edeae1d68c7347a36f7d 100644 (file)
@@ -1696,7 +1696,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
             server_side=False,
             server_hostname='python.org',
             ssl_handshake_timeout=handshake_timeout,
-            ssl_shutdown_timeout=shutdown_timeout)
+            ssl_shutdown_timeout=shutdown_timeout,
+            context=ANY)
         # Next try an explicit server_hostname.
         self.loop._make_ssl_transport.reset_mock()
         coro = self.loop.create_connection(
@@ -1711,7 +1712,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
             server_side=False,
             server_hostname='perl.com',
             ssl_handshake_timeout=handshake_timeout,
-            ssl_shutdown_timeout=shutdown_timeout)
+            ssl_shutdown_timeout=shutdown_timeout,
+            context=ANY)
         # Finally try an explicit empty server_hostname.
         self.loop._make_ssl_transport.reset_mock()
         coro = self.loop.create_connection(
@@ -1726,7 +1728,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
                 server_side=False,
                 server_hostname='',
                 ssl_handshake_timeout=handshake_timeout,
-                ssl_shutdown_timeout=shutdown_timeout)
+                ssl_shutdown_timeout=shutdown_timeout,
+                context=ANY)
 
     def test_create_connection_no_ssl_server_hostname_errors(self):
         # When not using ssl, server_hostname must be None.
@@ -2104,7 +2107,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
             constants.ACCEPT_RETRY_DELAY,
             # self.loop._start_serving
             mock.ANY,
-            MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
+            MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY, mock.ANY)
 
     def test_call_coroutine(self):
         async def simple_coroutine():
diff --git a/Lib/test/test_asyncio/test_server_context.py b/Lib/test/test_asyncio/test_server_context.py
new file mode 100644 (file)
index 0000000..3f15654
--- /dev/null
@@ -0,0 +1,314 @@
+import asyncio
+import contextvars
+import unittest
+import sys
+
+from unittest import TestCase
+
+try:
+    import ssl
+except ImportError:
+    ssl = None
+
+from test.test_asyncio import utils as test_utils
+
+def tearDownModule():
+    asyncio.events._set_event_loop_policy(None)
+
+class ServerContextvarsTestCase:
+    loop_factory = None  # To be defined in subclasses
+    server_ssl_context = None  # To be defined in subclasses for SSL tests
+    client_ssl_context = None  # To be defined in subclasses for SSL tests
+
+    def run_coro(self, coro):
+        return asyncio.run(coro, loop_factory=self.loop_factory)
+
+    def test_start_server1(self):
+        # Test that asyncio.start_server captures the context at the time of server creation
+        async def test():
+            var = contextvars.ContextVar("var", default="default")
+
+            async def handle_client(reader, writer):
+                value = var.get()
+                writer.write(value.encode())
+                await writer.drain()
+                writer.close()
+
+            server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
+                                               ssl=self.server_ssl_context)
+            # change the value
+            var.set("after_server")
+
+            async def client(addr):
+                reader, writer = await asyncio.open_connection(*addr,
+                                                               ssl=self.client_ssl_context)
+                data = await reader.read(100)
+                writer.close()
+                await writer.wait_closed()
+                return data.decode()
+
+            async with server:
+                addr = server.sockets[0].getsockname()
+                self.assertEqual(await client(addr), "default")
+
+            self.assertEqual(var.get(), "after_server")
+
+        self.run_coro(test())
+
+    def test_start_server2(self):
+        # Test that mutations to the context in one handler don't affect other handlers or the server's context
+        async def test():
+            var = contextvars.ContextVar("var", default="default")
+
+            async def handle_client(reader, writer):
+                value = var.get()
+                writer.write(value.encode())
+                var.set("in_handler")
+                await writer.drain()
+                writer.close()
+
+            server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
+                                               ssl=self.server_ssl_context)
+            var.set("after_server")
+
+            async def client(addr):
+                reader, writer = await asyncio.open_connection(*addr,
+                                                               ssl=self.client_ssl_context)
+                data = await reader.read(100)
+                writer.close()
+                await writer.wait_closed()
+                return data.decode()
+
+            async with server:
+                addr = server.sockets[0].getsockname()
+                self.assertEqual(await client(addr), "default")
+                self.assertEqual(await client(addr), "default")
+                self.assertEqual(await client(addr), "default")
+
+            self.assertEqual(var.get(), "after_server")
+
+        self.run_coro(test())
+
+    def test_start_server3(self):
+        # Test that mutations to context in concurrent handlers don't affect each other or the server's context
+        async def test():
+            var = contextvars.ContextVar("var", default="default")
+            var.set("before_server")
+
+            async def handle_client(reader, writer):
+                writer.write(var.get().encode())
+                await writer.drain()
+                writer.close()
+
+            server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
+                                               ssl=self.server_ssl_context)
+            var.set("after_server")
+
+            async def client(addr):
+                reader, writer = await asyncio.open_connection(*addr,
+                                                               ssl=self.client_ssl_context)
+                data = await reader.read(100)
+                self.assertEqual(data.decode(), "before_server")
+                writer.close()
+                await writer.wait_closed()
+
+            async with server:
+                addr = server.sockets[0].getsockname()
+                async with asyncio.TaskGroup() as tg:
+                    for _ in range(100):
+                        tg.create_task(client(addr))
+
+            self.assertEqual(var.get(), "after_server")
+
+        self.run_coro(test())
+
+    def test_create_server1(self):
+        # Test that loop.create_server captures the context at the time of server creation
+        # and that mutations to the context in protocol callbacks don't affect the server's context
+        async def test():
+            var = contextvars.ContextVar("var", default="default")
+
+            class EchoProtocol(asyncio.Protocol):
+                def connection_made(self, transport):
+                    self.transport = transport
+                    value = var.get()
+                    var.set("in_handler")
+                    self.transport.write(value.encode())
+                    self.transport.close()
+
+            server = await asyncio.get_running_loop().create_server(
+                lambda: EchoProtocol(), '127.0.0.1', 0,
+                ssl=self.server_ssl_context)
+            var.set("after_server")
+
+            async def client(addr):
+                reader, writer = await asyncio.open_connection(*addr,
+                                                               ssl=self.client_ssl_context)
+                data = await reader.read(100)
+                self.assertEqual(data.decode(), "default")
+                writer.close()
+                await writer.wait_closed()
+
+            async with server:
+                addr = server.sockets[0].getsockname()
+                await client(addr)
+
+            self.assertEqual(var.get(), "after_server")
+
+        self.run_coro(test())
+
+    def test_create_server2(self):
+        # Test that mutations to context in one protocol instance don't affect other instances or the server's context
+        async def test():
+            var = contextvars.ContextVar("var", default="default")
+
+            class EchoProtocol(asyncio.Protocol):
+                def __init__(self):
+                    super().__init__()
+                    assert var.get() == "default", var.get()
+                def connection_made(self, transport):
+                    self.transport = transport
+                    value = var.get()
+                    var.set("in_handler")
+                    self.transport.write(value.encode())
+                    self.transport.close()
+
+            server = await asyncio.get_running_loop().create_server(
+                lambda: EchoProtocol(), '127.0.0.1', 0,
+                ssl=self.server_ssl_context)
+
+            var.set("after_server")
+
+            async def client(addr, expected):
+                reader, writer = await asyncio.open_connection(*addr,
+                                                               ssl=self.client_ssl_context)
+                data = await reader.read(100)
+                self.assertEqual(data.decode(), expected)
+                writer.close()
+                await writer.wait_closed()
+
+            async with server:
+                addr = server.sockets[0].getsockname()
+                await client(addr, "default")
+                await client(addr, "default")
+
+            self.assertEqual(var.get(), "after_server")
+
+        self.run_coro(test())
+
+    def test_gh140947(self):
+        # See https://github.com/python/cpython/issues/140947
+
+        cvar1 = contextvars.ContextVar("cvar1")
+        cvar2 = contextvars.ContextVar("cvar2")
+        cvar3 = contextvars.ContextVar("cvar3")
+        results = {}
+        is_ssl = self.server_ssl_context is not None
+
+        def capture_context(meth):
+            result = []
+            for k,v in contextvars.copy_context().items():
+                if k.name.startswith("cvar"):
+                    result.append((k.name, v))
+            results[meth] = sorted(result)
+
+        class DemoProtocol(asyncio.Protocol):
+            def __init__(self, on_conn_lost):
+                self.transport = None
+                self.on_conn_lost = on_conn_lost
+                self.tasks = set()
+
+            def connection_made(self, transport):
+                capture_context("connection_made")
+                self.transport = transport
+
+            def data_received(self, data):
+                capture_context("data_received")
+
+                task = asyncio.create_task(self.asgi())
+                self.tasks.add(task)
+                task.add_done_callback(self.tasks.discard)
+
+                self.transport.pause_reading()
+
+            def connection_lost(self, exc):
+                capture_context("connection_lost")
+                if not self.on_conn_lost.done():
+                    self.on_conn_lost.set_result(True)
+
+            async def asgi(self):
+                capture_context("asgi start")
+                cvar1.set(True)
+                # make sure that we only resume after the pause
+                # otherwise the resume does nothing
+                if is_ssl:
+                    while not self.transport._ssl_protocol._app_reading_paused:
+                        await asyncio.sleep(0.01)
+                else:
+                    while not self.transport._paused:
+                        await asyncio.sleep(0.01)
+                cvar2.set(True)
+                self.transport.resume_reading()
+                cvar3.set(True)
+                capture_context("asgi end")
+
+        async def main():
+            loop = asyncio.get_running_loop()
+            on_conn_lost = loop.create_future()
+
+            server = await loop.create_server(
+                lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0,
+                ssl=self.server_ssl_context)
+            async with server:
+                addr = server.sockets[0].getsockname()
+                reader, writer = await asyncio.open_connection(*addr,
+                                                               ssl=self.client_ssl_context)
+                writer.write(b"anything")
+                await writer.drain()
+                writer.close()
+                await writer.wait_closed()
+                await on_conn_lost
+
+        self.run_coro(main())
+        self.assertDictEqual(results, {
+            "connection_made": [],
+            "data_received": [],
+            "asgi start": [],
+            "asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)],
+            "connection_lost": [],
+        })
+
+
+class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
+    loop_factory = staticmethod(asyncio.new_event_loop)
+
+@unittest.skipUnless(ssl, "SSL not available")
+class AsyncioEventLoopSSLTests(AsyncioEventLoopTests):
+    def setUp(self):
+        super().setUp()
+        self.server_ssl_context = test_utils.simple_server_sslcontext()
+        self.client_ssl_context = test_utils.simple_client_sslcontext()
+
+if sys.platform == "win32":
+    class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase):
+        loop_factory = asyncio.ProactorEventLoop
+
+    class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase):
+        loop_factory = asyncio.SelectorEventLoop
+
+    @unittest.skipUnless(ssl, "SSL not available")
+    class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests):
+        def setUp(self):
+            super().setUp()
+            self.server_ssl_context = test_utils.simple_server_sslcontext()
+            self.client_ssl_context = test_utils.simple_client_sslcontext()
+
+    @unittest.skipUnless(ssl, "SSL not available")
+    class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests):
+        def setUp(self):
+            super().setUp()
+            self.server_ssl_context = test_utils.simple_server_sslcontext()
+            self.client_ssl_context = test_utils.simple_client_sslcontext()
+
+if __name__ == "__main__":
+    unittest.main()
index a480e16e81bb91aea01822bd7192d493547eff9f..62cfcf8ceb5f2a8bffb531aa6f89e6a1d8681b02 100644 (file)
@@ -388,8 +388,8 @@ class TestLoop(base_events.BaseEventLoop):
             else:  # pragma: no cover
                 raise AssertionError("Time generator is not finished")
 
-    def _add_reader(self, fd, callback, *args):
-        self.readers[fd] = events.Handle(callback, args, self, None)
+    def _add_reader(self, fd, callback, *args, context=None):
+        self.readers[fd] = events.Handle(callback, args, self, context)
 
     def _remove_reader(self, fd):
         self.remove_reader_count[fd] += 1
@@ -414,8 +414,8 @@ class TestLoop(base_events.BaseEventLoop):
         if fd in self.readers:
             raise AssertionError(f'fd {fd} is registered')
 
-    def _add_writer(self, fd, callback, *args):
-        self.writers[fd] = events.Handle(callback, args, self, None)
+    def _add_writer(self, fd, callback, *args, context=None):
+        self.writers[fd] = events.Handle(callback, args, self, context)
 
     def _remove_writer(self, fd):
         self.remove_writer_count[fd] += 1
diff --git a/Misc/NEWS.d/next/Library/2026-03-21-08-23-26.gh-issue-140947.owZ4r_.rst b/Misc/NEWS.d/next/Library/2026-03-21-08-23-26.gh-issue-140947.owZ4r_.rst
new file mode 100644 (file)
index 0000000..88e787e
--- /dev/null
@@ -0,0 +1 @@
+Fix incorrect contextvars handling in server tasks created by :mod:`asyncio`. Patch by Kumar Aditya.