]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
gen: Expliclty track contextvars, fixing contextvars.reset 2938/head
authorBen Darnell <ben@cockroachlabs.com>
Fri, 16 Oct 2020 19:29:20 +0000 (15:29 -0400)
committerBen Darnell <ben@cockroachlabs.com>
Sun, 18 Oct 2020 00:11:34 +0000 (20:11 -0400)
The asyncio event loop provides enough contextvars support out of the
box for basic contextvars functionality to work in tornado coroutines,
but not `contextvars.reset`. Prior to this change, each yield created
a new "level" of context, when an entire coroutine should be on the
same level. This is necessary for the reset method to work.

Fixes #2731

tornado/gen.py
tornado/test/gen_test.py

index 7f41931a3955d2e6b45767cd8d5b7873b528fb25..cab9689375043d5671d38d60b169186665a31d12 100644 (file)
@@ -90,6 +90,11 @@ from tornado.ioloop import IOLoop
 from tornado.log import app_log
 from tornado.util import TimeoutError
 
+try:
+    import contextvars
+except ImportError:
+    contextvars = None  # type: ignore
+
 import typing
 from typing import Union, Any, Callable, List, Type, Tuple, Awaitable, Dict, overload
 
@@ -153,6 +158,10 @@ def _create_future() -> Future:
     return future
 
 
+def _fake_ctx_run(f: Callable[..., _T], *args: Any, **kw: Any) -> _T:
+    return f(*args, **kw)
+
+
 @overload
 def coroutine(
     func: Callable[..., "Generator[Any, Any, _T]"]
@@ -199,8 +208,12 @@ def coroutine(
         # This function is type-annotated with a comment to work around
         # https://bitbucket.org/pypy/pypy/issues/2868/segfault-with-args-type-annotation-in
         future = _create_future()
+        if contextvars is not None:
+            ctx_run = contextvars.copy_context().run  # type: Callable
+        else:
+            ctx_run = _fake_ctx_run
         try:
-            result = func(*args, **kwargs)
+            result = ctx_run(func, *args, **kwargs)
         except (Return, StopIteration) as e:
             result = _value_from_stopiteration(e)
         except Exception:
@@ -218,7 +231,7 @@ def coroutine(
                 # use "optional" coroutines in critical path code without
                 # performance penalty for the synchronous case.
                 try:
-                    yielded = next(result)
+                    yielded = ctx_run(next, result)
                 except (StopIteration, Return) as e:
                     future_set_result_unless_cancelled(
                         future, _value_from_stopiteration(e)
@@ -234,7 +247,7 @@ def coroutine(
                     # add_done_callback() instead of putting a private
                     # attribute on the Future.
                     # (GitHub issues #1769, #2229).
-                    runner = Runner(result, future, yielded)
+                    runner = Runner(ctx_run, result, future, yielded)
                     future.add_done_callback(lambda _: runner)
                 yielded = None
                 try:
@@ -711,10 +724,12 @@ class Runner(object):
 
     def __init__(
         self,
+        ctx_run: Callable,
         gen: "Generator[_Yieldable, Any, _T]",
         result_future: "Future[_T]",
         first_yielded: _Yieldable,
     ) -> None:
+        self.ctx_run = ctx_run
         self.gen = gen
         self.result_future = result_future
         self.future = _null_future  # type: Union[None, Future]
@@ -723,7 +738,7 @@ class Runner(object):
         self.io_loop = IOLoop.current()
         if self.handle_yield(first_yielded):
             gen = result_future = first_yielded = None  # type: ignore
-            self.run()
+            self.ctx_run(self.run)
 
     def run(self) -> None:
         """Starts or resumes the generator, running until it reaches a
@@ -787,7 +802,7 @@ class Runner(object):
             future_set_exc_info(self.future, sys.exc_info())
 
         if self.future is moment:
-            self.io_loop.add_callback(self.run)
+            self.io_loop.add_callback(self.ctx_run, self.run)
             return False
         elif self.future is None:
             raise Exception("no pending future")
@@ -796,7 +811,7 @@ class Runner(object):
             def inner(f: Any) -> None:
                 # Break a reference cycle to speed GC.
                 f = None  # noqa: F841
-                self.run()
+                self.ctx_run(self.run)
 
             self.io_loop.add_future(self.future, inner)
             return False
@@ -808,7 +823,7 @@ class Runner(object):
         if not self.running and not self.finished:
             self.future = Future()
             future_set_exc_info(self.future, (typ, value, tb))
-            self.run()
+            self.ctx_run(self.run)
             return True
         else:
             return False
index 0d091c2bb6a44c8c9c350a1302b68232c51ea0f0..73c33878030bd2d63b9003402e842d38b3afc54b 100644 (file)
@@ -1106,6 +1106,14 @@ class ContextVarsTest(AsyncTestCase):
             self.gen_root(4),
         ]
 
+    @gen_test
+    def test_reset(self):
+        token = ctx_var.set(1)
+        yield
+        # reset asserts that we are still at the same level of the context tree,
+        # so we must make sure that we maintain that property across yield.
+        ctx_var.reset(token)
+
 
 if __name__ == "__main__":
     unittest.main()