]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow contextvars to be set in events when using asyncio
authorFederico Caselli <cfederico87@gmail.com>
Thu, 14 Apr 2022 22:29:01 +0000 (00:29 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sun, 17 Apr 2022 18:23:38 +0000 (20:23 +0200)
Allow setting contextvar values inside async adapted event handlers.
Previously the value set to the contextvar would not be properly
propagated.

Fixes: #7937
Change-Id: I787aa869f8d057579e13e32c749f05f184ffd02a
(cherry picked from commit 640d163bd8bf61e87790255558b6f704a0d06174)

doc/build/changelog/unreleased_14/7937.rst [new file with mode: 0644]
lib/sqlalchemy/util/_concurrency_py3k.py
test/base/test_concurrency_py3k.py

diff --git a/doc/build/changelog/unreleased_14/7937.rst b/doc/build/changelog/unreleased_14/7937.rst
new file mode 100644 (file)
index 0000000..96d80d6
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 7937
+
+    Allow setting contextvar values inside async adapted event handlers.
+    Previously the value set to the contextvar would not be properly
+    propagated.
+
index e3c5dac580093eb6a85a298bd927a90a4a0d744a..0b128344d645c666c2ad40f78f3ad700ddc979e8 100644 (file)
@@ -17,18 +17,10 @@ from . import compat
 from .langhelpers import memoized_property
 from .. import exc
 
-if compat.py37:
-    try:
-        from contextvars import copy_context as _copy_context
-
-        # If greenlet.gr_context is present in current version of greenlet,
-        # it will be set with a copy of the current context on creation.
-        # Refs: https://github.com/python-greenlet/greenlet/pull/198
-        getattr(greenlet.greenlet, "gr_context")
-    except (ImportError, AttributeError):
-        _copy_context = None
-else:
-    _copy_context = None
+# If greenlet.gr_context is present in current version of greenlet,
+# it will be set with the current context on creation.
+# Refs: https://github.com/python-greenlet/greenlet/pull/198
+_has_gr_context = hasattr(greenlet.getcurrent(), "gr_context")
 
 
 def is_exit_exception(e):
@@ -48,15 +40,15 @@ class _AsyncIoGreenlet(greenlet.greenlet):
     def __init__(self, fn, driver):
         greenlet.greenlet.__init__(self, fn, driver)
         self.driver = driver
-        if _copy_context is not None:
-            self.gr_context = _copy_context()
+        if _has_gr_context:
+            self.gr_context = driver.gr_context
 
 
 def await_only(awaitable: Coroutine) -> Any:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_` calls cannot be nested.
+    :func:`await_only` calls cannot be nested.
 
     :param awaitable: The coroutine to call.
 
@@ -65,8 +57,8 @@ def await_only(awaitable: Coroutine) -> Any:
     current = greenlet.getcurrent()
     if not isinstance(current, _AsyncIoGreenlet):
         raise exc.MissingGreenlet(
-            "greenlet_spawn has not been called; can't call await_() here. "
-            "Was IO attempted in an unexpected place?"
+            "greenlet_spawn has not been called; can't call await_only() "
+            "here. Was IO attempted in an unexpected place?"
         )
 
     # returns the control to the driver greenlet passing it
@@ -80,7 +72,7 @@ def await_fallback(awaitable: Coroutine) -> Any:
     """Awaits an async function in a sync method.
 
     The sync method must be inside a :func:`greenlet_spawn` context.
-    :func:`await_` calls cannot be nested.
+    :func:`await_fallback` calls cannot be nested.
 
     :param awaitable: The coroutine to call.
 
@@ -92,7 +84,7 @@ def await_fallback(awaitable: Coroutine) -> Any:
         if loop.is_running():
             raise exc.MissingGreenlet(
                 "greenlet_spawn has not been called and asyncio event "
-                "loop is already running; can't call await_() here. "
+                "loop is already running; can't call await_fallback() here. "
                 "Was IO attempted in an unexpected place?"
             )
         return loop.run_until_complete(awaitable)
@@ -105,7 +97,7 @@ async def greenlet_spawn(
 ) -> Any:
     """Runs a sync function ``fn`` in a new greenlet.
 
-    The sync function can then use :func:`await_` to wait for async
+    The sync function can then use :func:`await_only` to wait for async
     functions.
 
     :param fn: The sync callable to call.
@@ -115,7 +107,7 @@ async def greenlet_spawn(
 
     context = _AsyncIoGreenlet(fn, greenlet.getcurrent())
     # runs the function synchronously in gl greenlet. If the execution
-    # is interrupted by await_, context is not dead and result is a
+    # is interrupted by await_only, context is not dead and result is a
     # coroutine to wait. If the context is dead the function has
     # returned, and its result can be returned.
     switch_occurred = False
@@ -124,7 +116,7 @@ async def greenlet_spawn(
         while not context.dead:
             switch_occurred = True
             try:
-                # wait for a coroutine from await_ and then return its
+                # wait for a coroutine from await_only and then return its
                 # result back to it.
                 value = await result
             except BaseException:
index 0b648aa30bdd8ba741351db57ba9f92a1fead99e..de7157c7889b750cac0f6abb39af8f3befb5d794 100644 (file)
@@ -1,3 +1,5 @@
+import asyncio
+import random
 import threading
 
 from sqlalchemy import exc
@@ -8,7 +10,6 @@ from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_true
-from sqlalchemy.util import asyncio
 from sqlalchemy.util import await_fallback
 from sqlalchemy.util import await_only
 from sqlalchemy.util import greenlet_spawn
@@ -89,7 +90,8 @@ class TestAsyncioCompat(fixtures.TestBase):
         to_await = run1()
         with expect_raises_message(
             exc.MissingGreenlet,
-            r"greenlet_spawn has not been called; can't call await_\(\) here.",
+            "greenlet_spawn has not been called; "
+            r"can't call await_only\(\) here.",
         ):
             await_only(to_await)
 
@@ -134,7 +136,8 @@ class TestAsyncioCompat(fixtures.TestBase):
 
         with expect_raises_message(
             exc.InvalidRequestError,
-            r"greenlet_spawn has not been called; can't call await_\(\) here.",
+            "greenlet_spawn has not been called; "
+            r"can't call await_only\(\) here.",
         ):
             await greenlet_spawn(go)
 
@@ -147,20 +150,43 @@ class TestAsyncioCompat(fixtures.TestBase):
         import contextvars
 
         var = contextvars.ContextVar("var")
-        concurrency = 5
+        concurrency = 500
 
+        # NOTE: sleep here is not necessary. It's used to simulate IO
+        # ensuring that task are not run sequentially
         async def async_inner(val):
+            await asyncio.sleep(random.uniform(0.005, 0.015))
             eq_(val, var.get())
             return var.get()
 
+        async def async_set(val):
+            await asyncio.sleep(random.uniform(0.005, 0.015))
+            var.set(val)
+
         def inner(val):
             retval = await_only(async_inner(val))
             eq_(val, var.get())
             eq_(retval, val)
+
+            # set the value in a sync function
+            newval = val + concurrency
+            var.set(newval)
+            syncset = await_only(async_inner(newval))
+            eq_(newval, var.get())
+            eq_(syncset, newval)
+
+            # set the value in an async function
+            retval = val + 2 * concurrency
+            await_only(async_set(retval))
+            eq_(var.get(), retval)
+            eq_(await_only(async_inner(retval)), retval)
+
             return retval
 
         async def task(val):
+            await asyncio.sleep(random.uniform(0.005, 0.015))
             var.set(val)
+            await asyncio.sleep(random.uniform(0.005, 0.015))
             return await greenlet_spawn(inner, val)
 
         values = {
@@ -169,7 +195,7 @@ class TestAsyncioCompat(fixtures.TestBase):
                 [task(i) for i in range(concurrency)]
             )
         }
-        eq_(values, set(range(concurrency)))
+        eq_(values, set(range(concurrency * 2, concurrency * 3)))
 
     @async_test
     async def test_require_await(self):