]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
vendor asynccontextmanager
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Sep 2021 18:34:51 +0000 (14:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Sep 2021 18:34:51 +0000 (14:34 -0400)
while we still support python 3.6 vendor a simple version
of this for now in the one place we currently use it.

Change-Id: Ibcfc8b004b17e2ac79f9123ccb76c5eb25243f90

lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/util/_concurrency_py3k.py
lib/sqlalchemy/util/concurrency.py

index f312cf79bd427f838dff4df6398c4ec3cc9c2b72..cde43398d2ae8c475dcf35711a0d1462cde810e7 100644 (file)
@@ -28,11 +28,10 @@ This dialect should normally be used only with the
 
 """  # noqa
 
-import contextlib
-
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
 from ... import util
+from ...util.concurrency import asynccontextmanager
 from ...util.concurrency import asyncio
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
@@ -188,7 +187,7 @@ class AsyncAdapt_asyncmy_connection:
         self._execute_mutex = asyncio.Lock()
         self._ss_cursors = set()
 
-    @contextlib.asynccontextmanager
+    @asynccontextmanager
     async def _mutex_and_adapt_errors(self):
         async with self._execute_mutex:
             try:
index c6d1fa5d3a878a0629d69cf58df8192711772b1e..7a4a2c361b71ccbc8fda71a6aa058c29e7ce4a01 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+from functools import wraps
 import sys
 from typing import Any
 from typing import Callable
@@ -193,3 +194,63 @@ def get_event_loop():
             return asyncio.get_event_loop_policy().get_event_loop()
     else:
         return asyncio.get_event_loop()
+
+
+# vendored from py3.7
+
+
+class _AsyncGeneratorContextManager:
+    """Helper for @asynccontextmanager."""
+
+    def __init__(self, func, args, kwds):
+        self.gen = func(*args, **kwds)
+        self.func, self.args, self.kwds = func, args, kwds
+        doc = getattr(func, "__doc__", None)
+        if doc is None:
+            doc = type(self).__doc__
+        self.__doc__ = doc
+
+    async def __aenter__(self):
+        try:
+            return await self.gen.__anext__()
+        except StopAsyncIteration:
+            raise RuntimeError("generator didn't yield") from None
+
+    async def __aexit__(self, typ, value, traceback):
+        if typ is None:
+            try:
+                await self.gen.__anext__()
+            except StopAsyncIteration:
+                return
+            else:
+                raise RuntimeError("generator didn't stop")
+        else:
+            if value is None:
+                value = typ()
+            # See _GeneratorContextManager.__exit__ for comments on subtleties
+            # in this implementation
+            try:
+                await self.gen.athrow(typ, value, traceback)
+                raise RuntimeError("generator didn't stop after athrow()")
+            except StopAsyncIteration as exc:
+                return exc is not value
+            except RuntimeError as exc:
+                if exc is value:
+                    return False
+                if isinstance(value, (StopIteration, StopAsyncIteration)):
+                    if exc.__cause__ is value:
+                        return False
+                raise
+            except BaseException as exc:
+                if exc is not value:
+                    raise
+
+
+# using the vendored version in all cases at the moment to establish
+# full test coverage
+def asynccontextmanager(func):
+    @wraps(func)
+    def helper(*args, **kwds):
+        return _AsyncGeneratorContextManager(func, args, kwds)
+
+    return helper
index 4635473191fe8cb25b16bc2ce7695c7719761792..5d15e06dec0ce3bc088257a379898cb0910aead0 100644 (file)
@@ -19,6 +19,7 @@ if compat.py3k:
             _util_async_run_coroutine_function,
         )  # noqa F401, E501
         from ._concurrency_py3k import asyncio  # noqa F401
+        from ._concurrency_py3k import asynccontextmanager
 
 if not have_greenlet:
 
@@ -57,3 +58,6 @@ if not have_greenlet:
 
     def _util_async_run_coroutine_function(fn, *arg, **kw):  # noqa F81
         _not_implemented()
+
+    def asynccontextmanager(fn, *arg, **kw):  # noqa F81
+        _not_implemented()