]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added asynccontextmanager for Python 3.6
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 04:00:27 +0000 (04:00 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 04:00:27 +0000 (04:00 +0000)
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/utils/context.py [new file with mode: 0644]

index 4cef7efaed341d89951f9c72fca9d3997f7484ed..f3fdddbe0ca8f800c42a335475e6b2c1cb19369a 100644 (file)
@@ -4,11 +4,12 @@ psycopg3 cursor objects
 
 # Copyright (C) 2020 The Psycopg Team
 
+import sys
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
 from typing import Mapping, Optional, Sequence, Type, TYPE_CHECKING, Union
 from operator import attrgetter
-from contextlib import asynccontextmanager, contextmanager
+from contextlib import contextmanager
 
 from . import errors as e
 from . import pq
@@ -18,6 +19,11 @@ from .copy import Copy, AsyncCopy
 from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
 from .utils.queries import PostgresQuery
 
+if sys.version_info >= (3, 7):
+    from contextlib import asynccontextmanager
+else:
+    from .utils.context import asynccontextmanager
+
 if TYPE_CHECKING:
     from .proto import Transformer
     from .pq.proto import PGconn, PGresult
diff --git a/psycopg3/psycopg3/utils/context.py b/psycopg3/psycopg3/utils/context.py
new file mode 100644 (file)
index 0000000..906fa0f
--- /dev/null
@@ -0,0 +1,59 @@
+# type: ignore
+"""
+asynccontextmanager implementation for Python < 3.7
+"""
+
+from functools import wraps
+
+
+def asynccontextmanager(func):
+    @wraps(func)
+    def helper(*args, **kwds):
+        return _AsyncGeneratorContextManager(func, args, kwds)
+
+    return helper
+
+
+class _AsyncGeneratorContextManager:
+    """Helper for @asynccontextmanager."""
+
+    def __init__(self, func, args, kwds):
+        self.gen = func(*args, **kwds)
+        self.func, self.args, self.kwds = func, args, kwds
+        doc = getattr(func, "__doc__", None)
+        if doc is None:
+            doc = type(self).__doc__
+        self.__doc__ = doc
+
+    async def __aenter__(self):
+        try:
+            return await self.gen.__anext__()
+        except StopAsyncIteration:
+            raise RuntimeError("generator didn't yield") from None
+
+    async def __aexit__(self, typ, value, traceback):
+        if typ is None:
+            try:
+                await self.gen.__anext__()
+            except StopAsyncIteration:
+                return
+            else:
+                raise RuntimeError("generator didn't stop")
+        else:
+            if value is None:
+                value = typ()
+            try:
+                await self.gen.athrow(typ, value, traceback)
+                raise RuntimeError("generator didn't stop after athrow()")
+            except StopAsyncIteration as exc:
+                return exc is not value
+            except RuntimeError as exc:
+                if exc is value:
+                    return False
+                if isinstance(value, (StopIteration, StopAsyncIteration)):
+                    if exc.__cause__ is value:
+                        return False
+                raise
+            except BaseException as exc:
+                if exc is not value:
+                    raise