]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
open up async greenlet for third parties
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Jun 2024 16:42:29 +0000 (12:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Jun 2024 13:36:22 +0000 (09:36 -0400)
Modified the internal representation used for adapting asyncio calls to
greenlets to allow for duck-typed compatibility with third party libraries
that implement SQLAlchemy's "greenlet-to-asyncio" pattern directly.
Running code within a greenlet that features the attribute
``__sqlalchemy_greenlet_provider__ = True`` will allow calls to
:func:`sqlalchemy.util.await_only` directly.

Change-Id: I79c67264e1a642b9a80d3b46dc64bdda80acf0aa
(cherry picked from commit c1e2d9180a14c74495b712e08d8156b92f907ac0)

doc/build/changelog/unreleased_14/greenlet_compat.rst [new file with mode: 0644]
lib/sqlalchemy/util/_concurrency_py3k.py

diff --git a/doc/build/changelog/unreleased_14/greenlet_compat.rst b/doc/build/changelog/unreleased_14/greenlet_compat.rst
new file mode 100644 (file)
index 0000000..d9eb51c
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, engine
+
+    Modified the internal representation used for adapting asyncio calls to
+    greenlets to allow for duck-typed compatibility with third party libraries
+    that implement SQLAlchemy's "greenlet-to-asyncio" pattern directly.
+    Running code within a greenlet that features the attribute
+    ``__sqlalchemy_greenlet_provider__ = True`` will allow calls to
+    :func:`sqlalchemy.util.await_only` directly.
+
index 5717d9706173e1c9292f30b2e0642fe560412602..a19607cd01c2acd62bca7e7f15c0d7e344adeb8d 100644 (file)
@@ -74,9 +74,10 @@ def is_exit_exception(e: BaseException) -> bool:
 class _AsyncIoGreenlet(greenlet):
     dead: bool
 
+    __sqlalchemy_greenlet_provider__ = True
+
     def __init__(self, fn: Callable[..., Any], driver: greenlet):
         greenlet.__init__(self, fn, driver)
-        self.driver = driver
         if _has_gr_context:
             self.gr_context = driver.gr_context
 
@@ -102,7 +103,7 @@ def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
 
 def in_greenlet() -> bool:
     current = getcurrent()
-    return isinstance(current, _AsyncIoGreenlet)
+    return getattr(current, "__sqlalchemy_greenlet_provider__", False)
 
 
 def await_only(awaitable: Awaitable[_T]) -> _T:
@@ -116,7 +117,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
     """
     # this is called in the context greenlet while running fn
     current = getcurrent()
-    if not isinstance(current, _AsyncIoGreenlet):
+    if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
         _safe_cancel_awaitable(awaitable)
 
         raise exc.MissingGreenlet(
@@ -128,7 +129,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
     # a coroutine to run. Once the awaitable is done, the driver greenlet
     # switches back to this greenlet with the result of awaitable that is
     # then returned to the caller (or raised as error)
-    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
+    return current.parent.switch(awaitable)  # type: ignore[no-any-return,attr-defined] # noqa: E501
 
 
 def await_fallback(awaitable: Awaitable[_T]) -> _T:
@@ -148,7 +149,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
 
     # this is called in the context greenlet while running fn
     current = getcurrent()
-    if not isinstance(current, _AsyncIoGreenlet):
+    if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
         loop = get_event_loop()
         if loop.is_running():
             _safe_cancel_awaitable(awaitable)
@@ -160,7 +161,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
             )
         return loop.run_until_complete(awaitable)
 
-    return current.driver.switch(awaitable)  # type: ignore[no-any-return]
+    return current.parent.switch(awaitable)  # type: ignore[no-any-return,attr-defined] # noqa: E501
 
 
 async def greenlet_spawn(
@@ -186,24 +187,21 @@ async def greenlet_spawn(
     # coroutine to wait. If the context is dead the function has
     # returned, and its result can be returned.
     switch_occurred = False
-    try:
-        result = context.switch(*args, **kwargs)
-        while not context.dead:
-            switch_occurred = True
-            try:
-                # wait for a coroutine from await_only and then return its
-                # result back to it.
-                value = await result
-            except BaseException:
-                # this allows an exception to be raised within
-                # the moderated greenlet so that it can continue
-                # its expected flow.
-                result = context.throw(*sys.exc_info())
-            else:
-                result = context.switch(value)
-    finally:
-        # clean up to avoid cycle resolution by gc
-        del context.driver
+    result = context.switch(*args, **kwargs)
+    while not context.dead:
+        switch_occurred = True
+        try:
+            # wait for a coroutine from await_only and then return its
+            # result back to it.
+            value = await result
+        except BaseException:
+            # this allows an exception to be raised within
+            # the moderated greenlet so that it can continue
+            # its expected flow.
+            result = context.throw(*sys.exc_info())
+        else:
+            result = context.switch(value)
+
     if _require_await and not switch_occurred:
         raise exc.AwaitRequired(
             "The current operation required an async execution but none was "