]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
asyncio: Preserve contextvars across SelectorThread on Windows (#3479)
authorbestcondition <admin@bestcondition.cn>
Wed, 14 May 2025 17:33:06 +0000 (01:33 +0800)
committerGitHub <noreply@github.com>
Wed, 14 May 2025 17:33:06 +0000 (13:33 -0400)
contextvars that were set on the main thread at event loop creation need to be preserved across callbacks that pass through the SelectorThread.

tornado/platform/asyncio.py
tornado/test/asyncio_test.py

index 4635fecb26fa015abed492a2ac0dd0aaa86cd0c2..a7a2e700c1538e915943c8469e696cd13036da90 100644 (file)
@@ -25,6 +25,7 @@ the same event loop.
 import asyncio
 import atexit
 import concurrent.futures
+import contextvars
 import errno
 import functools
 import select
@@ -472,6 +473,8 @@ class SelectorThread:
     _closed = False
 
     def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None:
+        self._main_thread_ctx = contextvars.copy_context()
+
         self._real_loop = real_loop
 
         self._select_cond = threading.Condition()
@@ -491,7 +494,8 @@ class SelectorThread:
         # clean up if we get to this point but the event loop is closed without
         # starting.
         self._real_loop.call_soon(
-            lambda: self._real_loop.create_task(thread_manager_anext())
+            lambda: self._real_loop.create_task(thread_manager_anext()),
+            context=self._main_thread_ctx,
         )
 
         self._readers: Dict[_FileDescriptorLike, Callable] = {}
@@ -618,7 +622,9 @@ class SelectorThread:
                     raise
 
             try:
-                self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws)
+                self._real_loop.call_soon_threadsafe(
+                    self._handle_select, rs, ws, context=self._main_thread_ctx
+                )
             except RuntimeError:
                 # "Event loop is closed". Swallow the exception for
                 # consistency with PollIOLoop (and logical consistency
index 6c355c04fe08cb306ca9cb98ec57fb2eac3e337b..f33c5f53c301deebd22f4a0d90c12890c36f9f45 100644 (file)
@@ -11,6 +11,7 @@
 # under the License.
 
 import asyncio
+import contextvars
 import threading
 import time
 import unittest
@@ -25,8 +26,14 @@ from tornado.platform.asyncio import (
     to_asyncio_future,
     AddThreadSelectorEventLoop,
 )
-from tornado.testing import AsyncTestCase, gen_test, setup_with_context_manager
+from tornado.testing import (
+    AsyncTestCase,
+    gen_test,
+    setup_with_context_manager,
+    AsyncHTTPTestCase,
+)
 from tornado.test.util import ignore_deprecation
+from tornado.web import Application, RequestHandler
 
 
 class AsyncIOLoopTest(AsyncTestCase):
@@ -261,3 +268,31 @@ class AnyThreadEventLoopPolicyTest(unittest.TestCase):
             asyncio.set_event_loop_policy(self.AnyThreadEventLoopPolicy())
             self.assertIsInstance(self.executor.submit(IOLoop.current).result(), IOLoop)
             self.executor.submit(lambda: asyncio.get_event_loop().close()).result()  # type: ignore
+
+
+class SelectorThreadContextvarsTest(AsyncHTTPTestCase):
+    ctx_value = "foo"
+    test_endpoint = "/"
+    tornado_test_ctx = contextvars.ContextVar("tornado_test_ctx", default="default")
+    tornado_test_ctx.set(ctx_value)
+
+    def get_app(self) -> Application:
+        tornado_test_ctx = self.tornado_test_ctx
+
+        class Handler(RequestHandler):
+            async def get(self):
+                # On the Windows platform,
+                # when a asyncio.events.Handle is created
+                # in the SelectorThread without providing a context,
+                # it will copy the current thread's context,
+                # which can lead to the loss of the main thread's context
+                # when executing the handle.
+                # Therefore, it is necessary to
+                # save a copy of the main thread's context in the SelectorThread
+                # for creating the handle.
+                self.write(tornado_test_ctx.get())
+
+        return Application([(self.test_endpoint, Handler)])
+
+    def test_context_vars(self):
+        self.assertEqual(self.ctx_value, self.fetch(self.test_endpoint).body.decode())