]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow contextvars to be set in events when using asyncio
authorFederico Caselli <cfederico87@gmail.com>
Thu, 14 Apr 2022 22:29:01 +0000 (00:29 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sun, 17 Apr 2022 08:44:42 +0000 (10:44 +0200)
Allow setting contextvar values inside async adapted event handlers.
Previously the value set to the contextvar would not be properly
propagated.

Fixes: #7937
Change-Id: I787aa869f8d057579e13e32c749f05f184ffd02a

doc/build/changelog/unreleased_14/7937.rst [new file with mode: 0644]
lib/sqlalchemy/util/_concurrency_py3k.py
test/base/test_concurrency_py3k.py

diff --git a/doc/build/changelog/unreleased_14/7937.rst b/doc/build/changelog/unreleased_14/7937.rst
new file mode 100644 (file)
index 0000000..96d80d6
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 7937
+
+    Allow setting contextvar values inside async adapted event handlers.
+    Previously the value set to the contextvar would not be properly
+    propagated.
+
index 6ad099eefcb6d19b971663ebcfddd4282cebc162..167c4214016cda4e864439e11c881bc195269721 100644 (file)
@@ -7,26 +7,28 @@
 from __future__ import annotations
 
 import asyncio
-from contextvars import copy_context as _copy_context
+from contextvars import Context
 import sys
 import typing
 from typing import Any
 from typing import Awaitable
 from typing import Callable
 from typing import Coroutine
+from typing import Optional
 from typing import TypeVar
 
 from .langhelpers import memoized_property
 from .. import exc
 from ..util.typing import Protocol
 
-_T = TypeVar("_T", bound=Any)
+_T = TypeVar("_T")
 
 if typing.TYPE_CHECKING:
 
     class greenlet(Protocol):
 
         dead: bool
+        gr_context: Optional[Context]
 
         def __init__(self, fn: Callable[..., Any], driver: "greenlet"):
             ...
@@ -45,15 +47,10 @@ else:
     from greenlet import greenlet
 
 
-if not typing.TYPE_CHECKING:
-    try:
-
-        # If greenlet.gr_context is present in current version of greenlet,
-        # it will be set with a copy of the current context on creation.
-        # Refs: https://github.com/python-greenlet/greenlet/pull/198
-        getattr(greenlet, "gr_context")
-    except (ImportError, AttributeError):
-        _copy_context = None  # noqa
+# If greenlet.gr_context is present in current version of greenlet,
+# it will be set with the current context on creation.
+# Refs: https://github.com/python-greenlet/greenlet/pull/198
+_has_gr_context = hasattr(getcurrent(), "gr_context")
 
 
 def is_exit_exception(e: BaseException) -> bool:
@@ -75,15 +72,15 @@ class _AsyncIoGreenlet(greenlet):  # type: ignore
     def __init__(self, fn: Callable[..., Any], driver: greenlet):
         greenlet.__init__(self, fn, driver)
         self.driver = driver
-        if _copy_context is not None:
-            self.gr_context = _copy_context()
+        if _has_gr_context:
+            self.gr_context = driver.gr_context
 
 
 def await_only(awaitable: Awaitable[_T]) -> _T:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_` calls cannot be nested.
+    :func:`await_only` calls cannot be nested.
 
     :param awaitable: The coroutine to call.
 
@@ -92,8 +89,8 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
     current = getcurrent()
     if not isinstance(current, _AsyncIoGreenlet):
         raise exc.MissingGreenlet(
-            "greenlet_spawn has not been called; can't call await_() here. "
-            "Was IO attempted in an unexpected place?"
+            "greenlet_spawn has not been called; can't call await_only() "
+            "here. Was IO attempted in an unexpected place?"
         )
 
     # returns the control to the driver greenlet passing it
@@ -107,7 +104,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_` calls cannot be nested.
+    :func:`await_fallback` calls cannot be nested.
 
     :param awaitable: The coroutine to call.
 
@@ -120,7 +117,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
         if loop.is_running():
             raise exc.MissingGreenlet(
                 "greenlet_spawn has not been called and asyncio event "
-                "loop is already running; can't call await_() here. "
+                "loop is already running; can't call await_fallback() here. "
                 "Was IO attempted in an unexpected place?"
             )
         return loop.run_until_complete(awaitable)  # type: ignore[no-any-return]  # noqa: E501
@@ -136,7 +133,7 @@ async def greenlet_spawn(
 ) -> _T:
     """Runs a sync function ``fn`` in a new greenlet.
 
-    The sync function can then use :func:`await_` to wait for async
+    The sync function can then use :func:`await_only` to wait for async
     functions.
 
     :param fn: The sync callable to call.
@@ -144,10 +141,10 @@ async def greenlet_spawn(
     :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
     """
 
-    result: _T
+    result: Any
     context = _AsyncIoGreenlet(fn, getcurrent())
     # runs the function synchronously in gl greenlet. If the execution
-    # is interrupted by await_, context is not dead and result is a
+    # is interrupted by await_only, context is not dead and result is a
     # coroutine to wait. If the context is dead the function has
     # returned, and its result can be returned.
     switch_occurred = False
@@ -156,7 +153,7 @@ async def greenlet_spawn(
         while not context.dead:
             switch_occurred = True
             try:
-                # wait for a coroutine from await_ and then return its
+                # wait for a coroutine from await_only and then return its
                 # result back to it.
                 value = await result
             except BaseException:
@@ -175,7 +172,7 @@ async def greenlet_spawn(
             "detected. This will usually happen when using a non compatible "
             "DBAPI driver. Please ensure that an async DBAPI is used."
         )
-    return result
+    return result  # type: ignore[no-any-return]
 
 
 class AsyncAdaptedLock:
index 79601019e850485b7cdafd514a1d20ef3d617717..6a3098a6a36705bd4bac91d02a6bfe21143d7964 100644 (file)
@@ -1,4 +1,6 @@
 import asyncio
+import contextvars
+import random
 import threading
 
 from sqlalchemy import exc
@@ -88,7 +90,8 @@ class TestAsyncioCompat(fixtures.TestBase):
         to_await = run1()
         with expect_raises_message(
             exc.MissingGreenlet,
-            r"greenlet_spawn has not been called; can't call await_\(\) here.",
+            "greenlet_spawn has not been called; "
+            r"can't call await_only\(\) here.",
         ):
             await_only(to_await)
 
@@ -133,7 +136,8 @@ class TestAsyncioCompat(fixtures.TestBase):
 
         with expect_raises_message(
             exc.InvalidRequestError,
-            r"greenlet_spawn has not been called; can't call await_\(\) here.",
+            "greenlet_spawn has not been called; "
+            r"can't call await_only\(\) here.",
         ):
             await greenlet_spawn(go)
 
@@ -141,24 +145,44 @@ class TestAsyncioCompat(fixtures.TestBase):
 
     @async_test
     async def test_contextvars(self):
-        import asyncio
-        import contextvars
-
         var = contextvars.ContextVar("var")
-        concurrency = 5
+        concurrency = 500
 
+        # NOTE: sleep here is not necessary. It's used to simulate IO
+        # ensuring that task are not run sequentially
         async def async_inner(val):
+            await asyncio.sleep(random.uniform(0.005, 0.015))
             eq_(val, var.get())
             return var.get()
 
+        async def async_set(val):
+            await asyncio.sleep(random.uniform(0.005, 0.015))
+            var.set(val)
+
         def inner(val):
             retval = await_only(async_inner(val))
             eq_(val, var.get())
             eq_(retval, val)
+
+            # set the value in a sync function
+            newval = val + concurrency
+            var.set(newval)
+            syncset = await_only(async_inner(newval))
+            eq_(newval, var.get())
+            eq_(syncset, newval)
+
+            # set the value in an async function
+            retval = val + 2 * concurrency
+            await_only(async_set(retval))
+            eq_(var.get(), retval)
+            eq_(await_only(async_inner(retval)), retval)
+
             return retval
 
         async def task(val):
+            await asyncio.sleep(random.uniform(0.005, 0.015))
             var.set(val)
+            await asyncio.sleep(random.uniform(0.005, 0.015))
             return await greenlet_spawn(inner, val)
 
         values = {
@@ -167,7 +191,7 @@ class TestAsyncioCompat(fixtures.TestBase):
                 [task(i) for i in range(concurrency)]
             )
         }
-        eq_(values, set(range(concurrency)))
+        eq_(values, set(range(concurrency * 2, concurrency * 3)))
 
     @async_test
     async def test_require_await(self):