From: Daniele Varrazzo Date: Fri, 30 Apr 2021 00:36:33 +0000 (+0200) Subject: Drop RowConn from proto X-Git-Tag: 3.0.dev0~63^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7d20123cc4e2bfc612a1d5998eea664d053564dd;p=thirdparty%2Fpsycopg.git Drop RowConn from proto Use a more local CursorRow definition. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 5d77e69d0..fff67965c 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -10,7 +10,8 @@ import warnings import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List -from typing import NamedTuple, Optional, Type, Union, TYPE_CHECKING, overload +from typing import NamedTuple, Optional, Type, TypeVar, Union +from typing import overload, TYPE_CHECKING from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager @@ -24,7 +25,7 @@ from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable from .rows import tuple_row, TupleRow from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn -from .proto import Query, Row, RowConn, RowFactory, RV +from .proto import Query, Row, RowFactory, RV from .cursor import Cursor, AsyncCursor from .conninfo import make_conninfo, ConnectionInfo from .generators import notifies @@ -38,6 +39,10 @@ logger = logging.getLogger("psycopg3") connect: Callable[[str], PQGenConn["PGconn"]] execute: Callable[["PGconn"], PQGen[List["PGresult"]]] +# Row Type variable for Cursor (when it needs to be distinguished from the +# connection's one) +CursorRow = TypeVar("CursorRow") + if TYPE_CHECKING: from .pq.proto import PGconn, PGresult from .pool.base import BasePool @@ -74,7 +79,7 @@ NoticeHandler = Callable[[e.Diagnostic], None] NotifyHandler = Callable[[Notify], None] -class BaseConnection(AdaptContext, Generic[RowConn]): +class BaseConnection(AdaptContext, Generic[Row]): """ Base class for different types of connections. @@ -98,7 +103,7 @@ class BaseConnection(AdaptContext, Generic[RowConn]): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]): + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): self.pgconn = pgconn # TODO: document this self._row_factory = row_factory self._autocommit = False @@ -224,17 +229,17 @@ class BaseConnection(AdaptContext, Generic[RowConn]): return self._adapters @property - def connection(self) -> "BaseConnection[RowConn]": + def connection(self) -> "BaseConnection[Row]": # implement the AdaptContext protocol return self @property - def row_factory(self) -> RowFactory[RowConn]: + def row_factory(self) -> RowFactory[Row]: """Writable attribute to control how result rows are formed.""" return self._row_factory @row_factory.setter - def row_factory(self, row_factory: RowFactory[RowConn]) -> None: + def row_factory(self, row_factory: RowFactory[Row]) -> None: self._row_factory = row_factory def fileno(self) -> int: @@ -265,7 +270,7 @@ class BaseConnection(AdaptContext, Generic[RowConn]): @staticmethod def _notice_handler( - wself: "ReferenceType[BaseConnection[RowConn]]", res: "PGresult" + wself: "ReferenceType[BaseConnection[Row]]", res: "PGresult" ) -> None: self = wself() if not (self and self._notice_handler): @@ -294,7 +299,7 @@ class BaseConnection(AdaptContext, Generic[RowConn]): @staticmethod def _notify_handler( - wself: "ReferenceType[BaseConnection[RowConn]]", pgn: pq.PGnotify + wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify ) -> None: self = wself() if not (self and self._notify_handlers): @@ -435,14 +440,14 @@ class BaseConnection(AdaptContext, Generic[RowConn]): yield from self._exec_command(b"rollback") -class Connection(BaseConnection[RowConn]): +class Connection(BaseConnection[Row]): """ Wrapper for a connection to the database. """ __module__ = "psycopg3" - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]): + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): super().__init__(pgconn, row_factory) self.lock = threading.Lock() @@ -453,9 +458,9 @@ class Connection(BaseConnection[RowConn]): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory[RowConn], + row_factory: RowFactory[Row], **kwargs: Union[None, int, str], - ) -> "Connection[RowConn]": + ) -> "Connection[Row]": ... @overload @@ -475,7 +480,7 @@ class Connection(BaseConnection[RowConn]): conninfo: str = "", *, autocommit: bool = False, - row_factory: Optional[RowFactory[RowConn]] = None, + row_factory: Optional[RowFactory[Row]] = None, **kwargs: Any, ) -> "Connection[Any]": """ @@ -492,7 +497,7 @@ class Connection(BaseConnection[RowConn]): ) ) - def __enter__(self) -> "Connection[RowConn]": + def __enter__(self) -> "Connection[Row]": return self def __exit__( @@ -529,25 +534,27 @@ class Connection(BaseConnection[RowConn]): self.pgconn.finish() @overload - def cursor(self, *, binary: bool = False) -> Cursor[RowConn]: + def cursor(self, *, binary: bool = False) -> Cursor[Row]: ... @overload def cursor( - self, *, binary: bool = False, row_factory: RowFactory[Row] - ) -> Cursor[Row]: + self, *, binary: bool = False, row_factory: RowFactory[CursorRow] + ) -> Cursor[CursorRow]: ... @overload - def cursor( - self, name: str, *, binary: bool = False - ) -> ServerCursor[RowConn]: + def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Row]: ... @overload def cursor( - self, name: str, *, binary: bool = False, row_factory: RowFactory[Row] - ) -> ServerCursor[Row]: + self, + name: str, + *, + binary: bool = False, + row_factory: RowFactory[CursorRow], + ) -> ServerCursor[CursorRow]: ... def cursor( @@ -576,7 +583,7 @@ class Connection(BaseConnection[RowConn]): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> Cursor[RowConn]: + ) -> Cursor[Row]: """Execute a query and return a cursor to read its results.""" cur = self.cursor() try: @@ -651,14 +658,14 @@ class Connection(BaseConnection[RowConn]): self.wait(self._set_client_encoding_gen(name)) -class AsyncConnection(BaseConnection[RowConn]): +class AsyncConnection(BaseConnection[Row]): """ Asynchronous wrapper for a connection to the database. """ __module__ = "psycopg3" - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]): + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Row]): super().__init__(pgconn, row_factory) self.lock = asyncio.Lock() @@ -669,9 +676,9 @@ class AsyncConnection(BaseConnection[RowConn]): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory[RowConn], + row_factory: RowFactory[Row], **kwargs: Union[None, int, str], - ) -> "AsyncConnection[RowConn]": + ) -> "AsyncConnection[Row]": ... @overload @@ -691,7 +698,7 @@ class AsyncConnection(BaseConnection[RowConn]): conninfo: str = "", *, autocommit: bool = False, - row_factory: Optional[RowFactory[RowConn]] = None, + row_factory: Optional[RowFactory[Row]] = None, **kwargs: Any, ) -> "AsyncConnection[Any]": return await cls._wait_conn( @@ -703,7 +710,7 @@ class AsyncConnection(BaseConnection[RowConn]): ) ) - async def __aenter__(self) -> "AsyncConnection[RowConn]": + async def __aenter__(self) -> "AsyncConnection[Row]": return self async def __aexit__( @@ -739,25 +746,29 @@ class AsyncConnection(BaseConnection[RowConn]): self.pgconn.finish() @overload - def cursor(self, *, binary: bool = False) -> AsyncCursor[RowConn]: + def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ... @overload def cursor( - self, *, binary: bool = False, row_factory: RowFactory[Row] - ) -> AsyncCursor[Row]: + self, *, binary: bool = False, row_factory: RowFactory[CursorRow] + ) -> AsyncCursor[CursorRow]: ... @overload def cursor( self, name: str, *, binary: bool = False - ) -> AsyncServerCursor[RowConn]: + ) -> AsyncServerCursor[Row]: ... @overload def cursor( - self, name: str, *, binary: bool = False, row_factory: RowFactory[Row] - ) -> AsyncServerCursor[Row]: + self, + name: str, + *, + binary: bool = False, + row_factory: RowFactory[CursorRow], + ) -> AsyncServerCursor[CursorRow]: ... def cursor( @@ -786,7 +797,7 @@ class AsyncConnection(BaseConnection[RowConn]): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> AsyncCursor[RowConn]: + ) -> AsyncCursor[Row]: cur = self.cursor() try: return await cur.execute(query, params, prepare=prepare) diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index 657780967..6f6652a98 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -49,8 +49,6 @@ Wait states. Row = TypeVar("Row") Row_co = TypeVar("Row_co", covariant=True) -# Type variable for Connection (other are for Cursor). -RowConn = TypeVar("RowConn") class RowMaker(Protocol[Row_co]):