From: Daniele Varrazzo Date: Sun, 15 Nov 2020 04:00:27 +0000 (+0000) Subject: Added asynccontextmanager for Python 3.6 X-Git-Tag: 3.0.dev0~351^2~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e932c2e039b5838178e2280c187a62ca945c7697;p=thirdparty%2Fpsycopg.git Added asynccontextmanager for Python 3.6 --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 4cef7efae..f3fdddbe0 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -4,11 +4,12 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team +import sys from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List from typing import Mapping, Optional, Sequence, Type, TYPE_CHECKING, Union from operator import attrgetter -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from . import errors as e from . import pq @@ -18,6 +19,11 @@ from .copy import Copy, AsyncCopy from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen from .utils.queries import PostgresQuery +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager +else: + from .utils.context import asynccontextmanager + if TYPE_CHECKING: from .proto import Transformer from .pq.proto import PGconn, PGresult diff --git a/psycopg3/psycopg3/utils/context.py b/psycopg3/psycopg3/utils/context.py new file mode 100644 index 000000000..906fa0f0d --- /dev/null +++ b/psycopg3/psycopg3/utils/context.py @@ -0,0 +1,59 @@ +# type: ignore +""" +asynccontextmanager implementation for Python < 3.7 +""" + +from functools import wraps + + +def asynccontextmanager(func): + @wraps(func) + def helper(*args, **kwds): + return _AsyncGeneratorContextManager(func, args, kwds) + + return helper + + +class _AsyncGeneratorContextManager: + """Helper for @asynccontextmanager.""" + + def __init__(self, func, args, kwds): + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + doc = getattr(func, "__doc__", None) + if doc is None: + doc = type(self).__doc__ + self.__doc__ = doc + + async def __aenter__(self): + try: + return await self.gen.__anext__() + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + async def __aexit__(self, typ, value, traceback): + if typ is None: + try: + await self.gen.__anext__() + except StopAsyncIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + value = typ() + try: + await self.gen.athrow(typ, value, traceback) + raise RuntimeError("generator didn't stop after athrow()") + except StopAsyncIteration as exc: + return exc is not value + except RuntimeError as exc: + if exc is value: + return False + if isinstance(value, (StopIteration, StopAsyncIteration)): + if exc.__cause__ is value: + return False + raise + except BaseException as exc: + if exc is not value: + raise