# Copyright (C) 2020 The Psycopg Team
+import sys
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
from typing import Mapping, Optional, Sequence, Type, TYPE_CHECKING, Union
from operator import attrgetter
-from contextlib import asynccontextmanager, contextmanager
+from contextlib import contextmanager
from . import errors as e
from . import pq
from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
from .utils.queries import PostgresQuery
+if sys.version_info >= (3, 7):
+ from contextlib import asynccontextmanager
+else:
+ from .utils.context import asynccontextmanager
+
if TYPE_CHECKING:
from .proto import Transformer
from .pq.proto import PGconn, PGresult
--- /dev/null
+# type: ignore
+"""
+asynccontextmanager implementation for Python < 3.7
+"""
+
+from functools import wraps
+
+
+def asynccontextmanager(func):
+ @wraps(func)
+ def helper(*args, **kwds):
+ return _AsyncGeneratorContextManager(func, args, kwds)
+
+ return helper
+
+
+class _AsyncGeneratorContextManager:
+ """Helper for @asynccontextmanager."""
+
+ def __init__(self, func, args, kwds):
+ self.gen = func(*args, **kwds)
+ self.func, self.args, self.kwds = func, args, kwds
+ doc = getattr(func, "__doc__", None)
+ if doc is None:
+ doc = type(self).__doc__
+ self.__doc__ = doc
+
+ async def __aenter__(self):
+ try:
+ return await self.gen.__anext__()
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ async def __aexit__(self, typ, value, traceback):
+ if typ is None:
+ try:
+ await self.gen.__anext__()
+ except StopAsyncIteration:
+ return
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ value = typ()
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ raise RuntimeError("generator didn't stop after athrow()")
+ except StopAsyncIteration as exc:
+ return exc is not value
+ except RuntimeError as exc:
+ if exc is value:
+ return False
+ if isinstance(value, (StopIteration, StopAsyncIteration)):
+ if exc.__cause__ is value:
+ return False
+ raise
+ except BaseException as exc:
+ if exc is not value:
+ raise