]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
open up async greenlet for third parties
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Jun 2024 16:42:29 +0000 (12:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Jun 2024 13:29:52 +0000 (09:29 -0400)
Modified the internal representation used for adapting asyncio calls to
greenlets to allow for duck-typed compatibility with third party libraries
that implement SQLAlchemy's "greenlet-to-asyncio" pattern directly.
Running code within a greenlet that features the attribute
``__sqlalchemy_greenlet_provider__ = True`` will allow calls to
:func:`sqlalchemy.util.await_only` directly.

Change-Id: I79c67264e1a642b9a80d3b46dc64bdda80acf0aa

doc/build/changelog/unreleased_14/greenlet_compat.rst [new file with mode: 0644]
lib/sqlalchemy/util/concurrency.py

diff --git a/doc/build/changelog/unreleased_14/greenlet_compat.rst b/doc/build/changelog/unreleased_14/greenlet_compat.rst
new file mode 100644 (file)
index 0000000..d9eb51c
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, engine
+
+    Modified the internal representation used for adapting asyncio calls to
+    greenlets to allow for duck-typed compatibility with third party libraries
+    that implement SQLAlchemy's "greenlet-to-asyncio" pattern directly.
+    Running code within a greenlet that features the attribute
+    ``__sqlalchemy_greenlet_provider__ = True`` will allow calls to
+    :func:`sqlalchemy.util.await_only` directly.
+
index 25ea27ea8c4480920b69c9ab33b0ab51238dc30a..aa3eb45139b5f99a06cb55da631765920d43782a 100644 (file)
@@ -93,9 +93,10 @@ class _concurrency_shim_cls:
             class _AsyncIoGreenlet(greenlet):
                 dead: bool
 
+                __sqlalchemy_greenlet_provider__ = True
+
                 def __init__(self, fn: Callable[..., Any], driver: greenlet):
                     greenlet.__init__(self, fn, driver)
-                    self.driver = driver
                     if _has_gr_context:
                         self.gr_context = driver.gr_context
 
@@ -138,7 +139,7 @@ def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
 
 def in_greenlet() -> bool:
     current = _concurrency_shim.getcurrent()
-    return isinstance(current, _concurrency_shim._AsyncIoGreenlet)
+    return getattr(current, "__sqlalchemy_greenlet_provider__", False)
 
 
 def await_(awaitable: Awaitable[_T]) -> _T:
@@ -152,7 +153,7 @@ def await_(awaitable: Awaitable[_T]) -> _T:
     """
     # this is called in the context greenlet while running fn
     current = _concurrency_shim.getcurrent()
-    if not isinstance(current, _concurrency_shim._AsyncIoGreenlet):
+    if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
         _safe_cancel_awaitable(awaitable)
 
         raise exc.MissingGreenlet(
@@ -164,7 +165,8 @@ def await_(awaitable: Awaitable[_T]) -> _T:
     # a coroutine to run. Once the awaitable is done, the driver greenlet
     # switches back to this greenlet with the result of awaitable that is
     # then returned to the caller (or raised as error)
-    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
+    assert current.parent
+    return current.parent.switch(awaitable)  # type: ignore[no-any-return]
 
 
 await_only = await_  # old name. deprecated on 2.2
@@ -195,24 +197,22 @@ async def greenlet_spawn(
     # coroutine to wait. If the context is dead the function has
     # returned, and its result can be returned.
     switch_occurred = False
-    try:
-        result = context.switch(*args, **kwargs)
-        while not context.dead:
-            switch_occurred = True
-            try:
-                # wait for a coroutine from await_ and then return its
-                # result back to it.
-                value = await result
-            except BaseException:
-                # this allows an exception to be raised within
-                # the moderated greenlet so that it can continue
-                # its expected flow.
-                result = context.throw(*sys.exc_info())
-            else:
-                result = context.switch(value)
-    finally:
-        # clean up to avoid cycle resolution by gc
-        del context.driver
+
+    result = context.switch(*args, **kwargs)
+    while not context.dead:
+        switch_occurred = True
+        try:
+            # wait for a coroutine from await_ and then return its
+            # result back to it.
+            value = await result
+        except BaseException:
+            # this allows an exception to be raised within
+            # the moderated greenlet so that it can continue
+            # its expected flow.
+            result = context.throw(*sys.exc_info())
+        else:
+            result = context.switch(value)
+
     if _require_await and not switch_occurred:
         raise exc.AwaitRequired(
             "The current operation required an async execution but none was "
@@ -309,10 +309,7 @@ class _AsyncUtil:
         if _concurrency_shim._has_greenlet:
             if self.runner.get_loop().is_running():
                 # allow for a wrapped test function to call another
-                assert isinstance(
-                    _concurrency_shim.getcurrent(),
-                    _concurrency_shim._AsyncIoGreenlet,
-                )
+                assert in_greenlet()
                 return fn(*args, **kwargs)
             else:
                 return self.runner.run(greenlet_spawn(fn, *args, **kwargs))