From: Daniele Varrazzo Date: Thu, 12 Nov 2020 19:09:42 +0000 (+0000) Subject: Using generics to describe sync/async types X-Git-Tag: 3.0.dev0~371 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4d857acc2d72bf85d5fb4da14a9fcd475b76c004;p=thirdparty%2Fpsycopg.git Using generics to describe sync/async types --- diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 93f26a774..7055232d3 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -4,7 +4,7 @@ Entry point into the adaptation system. # Copyright (C) 2020 The Psycopg Team -from typing import Any, Callable, Optional, Type, Union +from typing import Any, cast, Callable, Optional, Type, Union from . import pq from . import proto @@ -139,7 +139,7 @@ def _connection_from_context( elif isinstance(context, BaseConnection): return context elif isinstance(context, BaseCursor): - return context.connection + return cast(BaseConnection, context.connection) elif isinstance(context, Transformer): return context.connection else: diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 7737098e2..0ebbf00c4 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -9,16 +9,16 @@ import asyncio import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple -from typing import Optional, Type, cast +from typing import Optional, Type, TYPE_CHECKING, Union from weakref import ref, ReferenceType from functools import partial from . import pq -from . import proto from . import cursor from . import errors as e from . import encodings from .pq import TransactionStatus, ExecStatus +from .proto import DumpersMap, LoadersMap, PQGen, RV from .waiting import wait, wait_async from .conninfo import make_conninfo from .generators import notifies @@ -26,8 +26,12 @@ from .generators import notifies logger = logging.getLogger(__name__) package_logger = logging.getLogger("psycopg3") -connect: Callable[[str], proto.PQGen[pq.proto.PGconn]] -execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]] +connect: Callable[[str], PQGen["PGconn"]] +execute: Callable[["PGconn"], PQGen[List["PGresult"]]] + +if TYPE_CHECKING: + from .pq.proto import PGconn, PGresult + from .cursor import Cursor, AsyncCursor if pq.__impl__ == "c": from psycopg3_c import _psycopg3 @@ -83,12 +87,13 @@ class BaseConnection: ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - def __init__(self, pgconn: pq.proto.PGconn): + cursor_factory: Union[Type["Cursor"], Type["AsyncCursor"]] + + def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn # TODO: document this - self.cursor_factory = cursor.BaseCursor self._autocommit = False - self.dumpers: proto.DumpersMap = {} - self.loaders: proto.LoadersMap = {} + self.dumpers: DumpersMap = {} + self.loaders: LoadersMap = {} self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] @@ -122,13 +127,6 @@ class BaseConnection: ) self._autocommit = value - def _cursor( - self, name: str = "", format: pq.Format = pq.Format.TEXT - ) -> cursor.BaseCursor: - if name: - raise NotImplementedError - return self.cursor_factory(self, format=format) - @property def client_encoding(self) -> str: """The Python codec name of the connection's client encoding.""" @@ -161,7 +159,7 @@ class BaseConnection: @staticmethod def _notice_handler( - wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult + wself: "ReferenceType[BaseConnection]", res: "PGresult" ) -> None: self = wself() if not (self and self._notice_handler): @@ -209,7 +207,7 @@ class Connection(BaseConnection): cursor_factory: Type[cursor.Cursor] - def __init__(self, pgconn: pq.proto.PGconn): + def __init__(self, pgconn: "PGconn"): super().__init__(pgconn) self.lock = threading.Lock() self.cursor_factory = cursor.Cursor @@ -257,8 +255,10 @@ class Connection(BaseConnection): """ Return a new `Cursor` to send commands and queries to the connection. """ - cur = self._cursor(name, format=format) - return cast(cursor.Cursor, cur) + if name: + raise NotImplementedError + + return self.cursor_factory(self, format=format) def _start_query(self) -> None: # the function is meant to be called by a cursor once the lock is taken @@ -301,9 +301,7 @@ class Connection(BaseConnection): ) @classmethod - def wait( - cls, gen: proto.PQGen[proto.RV], timeout: Optional[float] = 0.1 - ) -> proto.RV: + def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: return wait(gen, timeout=timeout) def _set_client_encoding(self, name: str) -> None: @@ -345,7 +343,7 @@ class AsyncConnection(BaseConnection): cursor_factory: Type[cursor.AsyncCursor] - def __init__(self, pgconn: pq.proto.PGconn): + def __init__(self, pgconn: "PGconn"): super().__init__(pgconn) self.lock = asyncio.Lock() self.cursor_factory = cursor.AsyncCursor @@ -386,8 +384,10 @@ class AsyncConnection(BaseConnection): """ Return a new `AsyncCursor` to send commands and queries to the connection. """ - cur = self._cursor(name, format=format) - return cast(cursor.AsyncCursor, cur) + if name: + raise NotImplementedError + + return self.cursor_factory(self, format=format) async def _start_query(self) -> None: # the function is meant to be called by a cursor once the lock is taken @@ -428,7 +428,7 @@ class AsyncConnection(BaseConnection): ) @classmethod - async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV: + async def wait(cls, gen: PQGen[RV]) -> RV: return await wait_async(gen) def _set_client_encoding(self, name: str) -> None: diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index f80900787..c62012870 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -6,36 +6,35 @@ psycopg3 copy support import re import struct -from typing import TYPE_CHECKING, AsyncIterator, Iterator +from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic from typing import Any, Dict, List, Match, Optional, Sequence, Type, Union from types import TracebackType -from . import pq -from .proto import AdaptContext +from .pq import Format +from .proto import ConnectionType, Transformer from .generators import copy_from, copy_to, copy_end if TYPE_CHECKING: - from .connection import BaseConnection, Connection, AsyncConnection + from .pq.proto import PGresult + from .connection import Connection, AsyncConnection # noqa: F401 -class BaseCopy: +class BaseCopy(Generic[ConnectionType]): def __init__( self, - context: AdaptContext, - result: Optional[pq.proto.PGresult], - format: pq.Format = pq.Format.TEXT, + connection: ConnectionType, + transformer: Transformer, + result: "PGresult", ): - from .adapt import Transformer - - self._connection: Optional["BaseConnection"] = None - self._transformer = Transformer(context) - self.format = format + self.connection = connection + self._transformer = transformer self.pgresult = result + self.format = result.binary_tuples self._first_row = True self._finished = False self._encoding: str = "" - if format == pq.Format.TEXT: + if self.format == Format.TEXT: self._format_row = self._format_row_text else: self._format_row = self._format_row_binary @@ -45,22 +44,11 @@ class BaseCopy: return self._finished @property - def connection(self) -> "BaseConnection": - if self._connection: - return self._connection - - self._connection = conn = self._transformer.connection - if conn: - return conn - - raise ValueError("no connection available") - - @property - def pgresult(self) -> Optional[pq.proto.PGresult]: + def pgresult(self) -> Optional["PGresult"]: return self._pgresult @pgresult.setter - def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: + def pgresult(self, result: Optional["PGresult"]) -> None: self._pgresult = result self._transformer.pgresult = result @@ -74,7 +62,7 @@ class BaseCopy: if ( self.pgresult is None - or self.pgresult.binary_tuples == pq.Format.BINARY + or self.pgresult.binary_tuples == Format.BINARY ): raise TypeError( "cannot copy str data in binary mode: use bytes instead" @@ -151,15 +139,7 @@ def _bsrepl_sub( _bsrepl_re = re.compile(b"[\b\t\n\v\f\r\\\\]") -class Copy(BaseCopy): - _connection: Optional["Connection"] - - @property - def connection(self) -> "Connection": - # TODO: mypy error: "Callable[[BaseCopy], BaseConnection]" has no - # attribute "fget" - return BaseCopy.connection.fget(self) # type: ignore - +class Copy(BaseCopy["Connection"]): def read(self) -> Optional[bytes]: if self._finished: return None @@ -195,7 +175,7 @@ class Copy(BaseCopy): exc_tb: Optional[TracebackType], ) -> None: if exc_val is None: - if self.format == pq.Format.BINARY and not self._first_row: + if self.format == Format.BINARY and not self._first_row: # send EOF only if we copied binary rows (_first_row is False) self.write(b"\xff\xff") self.finish() @@ -210,13 +190,7 @@ class Copy(BaseCopy): yield data -class AsyncCopy(BaseCopy): - _connection: Optional["AsyncConnection"] - - @property - def connection(self) -> "AsyncConnection": - return BaseCopy.connection.fget(self) # type: ignore - +class AsyncCopy(BaseCopy["AsyncConnection"]): async def read(self) -> Optional[bytes]: if self._finished: return None @@ -252,7 +226,7 @@ class AsyncCopy(BaseCopy): exc_tb: Optional[TracebackType], ) -> None: if exc_val is None: - if self.format == pq.Format.BINARY and not self._first_row: + if self.format == Format.BINARY and not self._first_row: # send EOF only if we copied binary rows (_first_row is False) await self.write(b"\xff\xff") await self.finish() diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 18c76afd8..66a490f78 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -5,23 +5,32 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team from types import TracebackType -from typing import Any, AsyncIterator, Callable, Iterator, List, Mapping +from typing import ( + Any, + AsyncIterator, + Callable, + Generic, + Iterator, + List, + Mapping, +) from typing import Optional, Sequence, Type, TYPE_CHECKING, Union from operator import attrgetter from . import errors as e from . import pq from . import sql -from . import proto from .oids import builtins from .copy import Copy, AsyncCopy -from .proto import Query, Params, DumpersMap, LoadersMap, PQGen +from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen from .utils.queries import PostgresQuery if TYPE_CHECKING: - from .connection import BaseConnection, Connection, AsyncConnection + from .proto import Transformer + from .pq.proto import PGconn, PGresult + from .connection import Connection, AsyncConnection # noqa: F401 -execute: Callable[[pq.proto.PGconn], PQGen[List[pq.proto.PGresult]]] +execute: Callable[["PGconn"], PQGen[List["PGresult"]]] if pq.__impl__ == "c": from psycopg3_c import _psycopg3 @@ -35,7 +44,7 @@ else: class Column(Sequence[Any]): - def __init__(self, pgresult: pq.proto.PGresult, index: int, encoding: str): + def __init__(self, pgresult: "PGresult", index: int, encoding: str): self._pgresult = pgresult self._index = index self._encoding = encoding @@ -150,13 +159,15 @@ class Column(Sequence[Any]): return None -class BaseCursor: +class BaseCursor(Generic[ConnectionType]): ExecStatus = pq.ExecStatus - _transformer: proto.Transformer + _transformer: "Transformer" def __init__( - self, connection: "BaseConnection", format: pq.Format = pq.Format.TEXT + self, + connection: ConnectionType, + format: pq.Format = pq.Format.TEXT, ): self.connection = connection self.format = format @@ -167,7 +178,7 @@ class BaseCursor: self._closed = False def _reset(self) -> None: - self._results: List[pq.proto.PGresult] = [] + self._results: List["PGresult"] = [] self.pgresult = None self._pos = 0 self._iresult = 0 @@ -185,12 +196,12 @@ class BaseCursor: return res.status if res else None @property - def pgresult(self) -> Optional[pq.proto.PGresult]: + def pgresult(self) -> Optional["PGresult"]: """The `~psycopg3.pq.PGresult` exposed by the cursor.""" return self._pgresult @pgresult.setter - def pgresult(self, result: Optional[pq.proto.PGresult]) -> None: + def pgresult(self, result: Optional["PGresult"]) -> None: self._pgresult = result if result and self._transformer: self._transformer.pgresult = result @@ -236,7 +247,7 @@ class BaseCursor: return None def _start_query(self) -> None: - from .adapt import Transformer + from . import adapt if self.closed: raise e.InterfaceError("the cursor is closed") @@ -251,7 +262,7 @@ class BaseCursor: ) self._reset() - self._transformer = Transformer(self) + self._transformer = adapt.Transformer(self) def _execute_send( self, query: Query, vars: Optional[Params], no_pqexec: bool = False @@ -275,7 +286,7 @@ class BaseCursor: # one query in one go self.connection.pgconn.send_query(pgq.query) - def _execute_results(self, results: Sequence[pq.proto.PGresult]) -> None: + def _execute_results(self, results: Sequence["PGresult"]) -> None: """ Implement part of execute() after waiting common to sync and async """ @@ -393,9 +404,7 @@ class BaseCursor: qparts.append(sql.SQL(")")) return sql.Composed(qparts) - def _check_copy_results( - self, results: Sequence[pq.proto.PGresult] - ) -> None: + def _check_copy_results(self, results: Sequence["PGresult"]) -> None: """ Check that the value returned in a copy() operation is a legit COPY. """ @@ -419,14 +428,7 @@ class BaseCursor: ) -class Cursor(BaseCursor): - connection: "Connection" - - def __init__( - self, connection: "Connection", format: pq.Format = pq.Format.TEXT - ): - super().__init__(connection, format=format) - +class Cursor(BaseCursor["Connection"]): def __enter__(self) -> "Cursor": return self @@ -563,22 +565,16 @@ class Cursor(BaseCursor): self._execute_send(statement, vars, no_pqexec=True) gen = execute(self.connection.pgconn) results = self.connection.wait(gen) - tx = self._transformer self._check_copy_results(results) return Copy( - context=tx, result=results[0], format=results[0].binary_tuples + connection=self.connection, + transformer=self._transformer, + result=results[0], ) -class AsyncCursor(BaseCursor): - connection: "AsyncConnection" - - def __init__( - self, connection: "AsyncConnection", format: pq.Format = pq.Format.TEXT - ): - super().__init__(connection, format=format) - +class AsyncCursor(BaseCursor["AsyncConnection"]): async def __aenter__(self) -> "AsyncCursor": return self @@ -700,11 +696,12 @@ class AsyncCursor(BaseCursor): self._execute_send(statement, vars, no_pqexec=True) gen = execute(self.connection.pgconn) results = await self.connection.wait(gen) - tx = self._transformer self._check_copy_results(results) return AsyncCopy( - context=tx, result=results[0], format=results[0].binary_tuples + connection=self.connection, + transformer=self._transformer, + result=results[0], ) diff --git a/psycopg3/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/psycopg3/pq/_pq_ctypes.pyi index ea443727b..8fb7a1350 100644 --- a/psycopg3/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/psycopg3/pq/_pq_ctypes.pyi @@ -4,7 +4,7 @@ types stub for ctypes functions # Copyright (C) 2020 The Psycopg Team -from typing import Any, Callable, Optional, Sequence, NewType +from typing import Any, Callable, Optional, Sequence from ctypes import Array, pointer from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index c0bdfd664..1e26e528a 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: Query = Union[str, bytes, "Composable"] Params = Union[Sequence[Any], Mapping[str, Any]] +ConnectionType = TypeVar("ConnectionType", bound="BaseConnection") # Waiting protocol types