From 9e4a835449ae40d9ae9b08cfd3584262ccf576ea Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Tue, 20 Apr 2021 10:30:03 +0200 Subject: [PATCH] Make Connection generic on Row We use a type variable 'RowConn' for Connection that is distinct from 'Row' that is used on Cursor side in order to reflect the possibility to have distinct row factories on connection and cursor sides. In order to avoid "propagation" of the Row type variable of Connection classes, we use Any everywhere that variable is not used (namely, everywhere outside the connection module). The typing_example.py almost works: connect(row_factory=...) returns a typed Connection, but only connect() still returns a Connection[Any]. --- psycopg3/psycopg3/_transform.py | 2 +- psycopg3/psycopg3/_typeinfo.py | 6 ++- psycopg3/psycopg3/adapt.py | 6 +-- psycopg3/psycopg3/connection.py | 60 ++++++++++++++-------------- psycopg3/psycopg3/copy.py | 4 +- psycopg3/psycopg3/cursor.py | 4 +- psycopg3/psycopg3/pool/async_pool.py | 41 +++++++++++-------- psycopg3/psycopg3/pool/pool.py | 35 ++++++++-------- psycopg3/psycopg3/proto.py | 8 ++-- psycopg3/psycopg3/server_cursor.py | 13 +++--- psycopg3/psycopg3/transaction.py | 9 +++-- psycopg3_c/psycopg3_c/_psycopg3.pyi | 2 +- tests/typing_example.py | 25 ++++++------ 13 files changed, 115 insertions(+), 100 deletions(-) diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index ec0119790..b6ed255c7 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -66,7 +66,7 @@ class Transformer(AdaptContext): self._row_loaders: List[LoadFunc] = [] @property - def connection(self) -> Optional["BaseConnection"]: + def connection(self) -> Optional["BaseConnection[Any]"]: return self._conn @property diff --git a/psycopg3/psycopg3/_typeinfo.py b/psycopg3/psycopg3/_typeinfo.py index 338b9d57d..8743d90f5 100644 --- a/psycopg3/psycopg3/_typeinfo.py +++ b/psycopg3/psycopg3/_typeinfo.py @@ -55,7 +55,7 @@ class TypeInfo: @classmethod def fetch( - cls: Type[T], conn: "Connection", name: Union[str, "Identifier"] + cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"] ) -> Optional[T]: """ Query a system catalog to read information about a type. @@ -77,7 +77,9 @@ class TypeInfo: @classmethod async def fetch_async( - cls: Type[T], conn: "AsyncConnection", name: Union[str, "Identifier"] + cls: Type[T], + conn: "AsyncConnection[Any]", + name: Union[str, "Identifier"], ) -> Optional[T]: """ Query a system catalog to read information about a type. diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 62f08e263..6d6f5c189 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -35,7 +35,7 @@ class Dumper(ABC): def __init__(self, cls: type, context: Optional[AdaptContext] = None): self.cls = cls - self.connection: Optional["BaseConnection"] = ( + self.connection: Optional["BaseConnection[Any]"] = ( context.connection if context else None ) @@ -109,7 +109,7 @@ class Loader(ABC): def __init__(self, oid: int, context: Optional[AdaptContext] = None): self.oid = oid - self.connection: Optional["BaseConnection"] = ( + self.connection: Optional["BaseConnection[Any]"] = ( context.connection if context else None ) @@ -170,7 +170,7 @@ class AdaptersMap(AdaptContext): return self @property - def connection(self) -> Optional["BaseConnection"]: + def connection(self) -> Optional["BaseConnection[Any]"]: return None def register_dumper( diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 96440bc71..7c3135fc0 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -9,8 +9,8 @@ import logging import warnings import threading from types import TracebackType -from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple -from typing import Optional, overload, Type, Union, TYPE_CHECKING +from typing import Any, AsyncIterator, Callable, Generic, Iterator, List +from typing import NamedTuple, Optional, Type, Union, TYPE_CHECKING, overload from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager @@ -23,8 +23,8 @@ from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable from .rows import tuple_row -from .proto import PQGen, PQGenConn, RV, Row, RowFactory, Query, Params -from .proto import AdaptContext, ConnectionType +from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn +from .proto import Query, Row, RowConn, RowFactory, RV from .cursor import Cursor, AsyncCursor from .conninfo import make_conninfo, ConnectionInfo from .generators import notifies @@ -74,7 +74,7 @@ NoticeHandler = Callable[[e.Diagnostic], None] NotifyHandler = Callable[[Notify], None] -class BaseConnection(AdaptContext): +class BaseConnection(AdaptContext, Generic[RowConn]): """ Base class for different types of connections. @@ -98,7 +98,7 @@ class BaseConnection(AdaptContext): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Any]): + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]): self.pgconn = pgconn # TODO: document this self._row_factory = row_factory self._autocommit = False @@ -224,17 +224,17 @@ class BaseConnection(AdaptContext): return self._adapters @property - def connection(self) -> "BaseConnection": + def connection(self) -> "BaseConnection[RowConn]": # implement the AdaptContext protocol return self @property - def row_factory(self) -> RowFactory[Any]: + def row_factory(self) -> RowFactory[RowConn]: """Writable attribute to control how result rows are formed.""" return self._row_factory @row_factory.setter - def row_factory(self, row_factory: RowFactory[Any]) -> None: + def row_factory(self, row_factory: RowFactory[RowConn]) -> None: self._row_factory = row_factory def fileno(self) -> int: @@ -265,7 +265,7 @@ class BaseConnection(AdaptContext): @staticmethod def _notice_handler( - wself: "ReferenceType[BaseConnection]", res: "PGresult" + wself: "ReferenceType[BaseConnection[RowConn]]", res: "PGresult" ) -> None: self = wself() if not (self and self._notice_handler): @@ -294,7 +294,7 @@ class BaseConnection(AdaptContext): @staticmethod def _notify_handler( - wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify + wself: "ReferenceType[BaseConnection[RowConn]]", pgn: pq.PGnotify ) -> None: self = wself() if not (self and self._notify_handlers): @@ -435,14 +435,14 @@ class BaseConnection(AdaptContext): yield from self._exec_command(b"rollback") -class Connection(BaseConnection): +class Connection(BaseConnection[RowConn]): """ Wrapper for a connection to the database. """ __module__ = "psycopg3" - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Any]): + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]): super().__init__(pgconn, row_factory) self.lock = threading.Lock() @@ -452,9 +452,9 @@ class Connection(BaseConnection): conninfo: str = "", *, autocommit: bool = False, - row_factory: Optional[RowFactory[Any]] = None, + row_factory: Optional[RowFactory[RowConn]] = None, **kwargs: Any, - ) -> "Connection": + ) -> "Connection[RowConn]": """ Connect to a database server and return a new `Connection` instance. @@ -469,7 +469,7 @@ class Connection(BaseConnection): ) ) - def __enter__(self) -> "Connection": + def __enter__(self) -> "Connection[RowConn]": return self def __exit__( @@ -506,7 +506,7 @@ class Connection(BaseConnection): self.pgconn.finish() @overload - def cursor(self, *, binary: bool = False) -> Cursor[Any]: + def cursor(self, *, binary: bool = False) -> Cursor[RowConn]: ... @overload @@ -516,7 +516,9 @@ class Connection(BaseConnection): ... @overload - def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Any]: + def cursor( + self, name: str, *, binary: bool = False + ) -> ServerCursor[RowConn]: ... @overload @@ -551,9 +553,9 @@ class Connection(BaseConnection): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> Cursor[Any]: + ) -> Cursor[RowConn]: """Execute a query and return a cursor to read its results.""" - cur: Cursor[Any] = self.cursor() + cur = self.cursor() try: return cur.execute(query, params, prepare=prepare) except e.Error as ex: @@ -626,14 +628,14 @@ class Connection(BaseConnection): self.wait(self._set_client_encoding_gen(name)) -class AsyncConnection(BaseConnection): +class AsyncConnection(BaseConnection[RowConn]): """ Asynchronous wrapper for a connection to the database. """ __module__ = "psycopg3" - def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Any]): + def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]): super().__init__(pgconn, row_factory) self.lock = asyncio.Lock() @@ -643,9 +645,9 @@ class AsyncConnection(BaseConnection): conninfo: str = "", *, autocommit: bool = False, - row_factory: Optional[RowFactory[Any]] = None, + row_factory: Optional[RowFactory[RowConn]] = None, **kwargs: Any, - ) -> "AsyncConnection": + ) -> "AsyncConnection[RowConn]": return await cls._wait_conn( cls._connect_gen( conninfo, @@ -655,7 +657,7 @@ class AsyncConnection(BaseConnection): ) ) - async def __aenter__(self) -> "AsyncConnection": + async def __aenter__(self) -> "AsyncConnection[RowConn]": return self async def __aexit__( @@ -691,7 +693,7 @@ class AsyncConnection(BaseConnection): self.pgconn.finish() @overload - def cursor(self, *, binary: bool = False) -> AsyncCursor[Any]: + def cursor(self, *, binary: bool = False) -> AsyncCursor[RowConn]: ... @overload @@ -703,7 +705,7 @@ class AsyncConnection(BaseConnection): @overload def cursor( self, name: str, *, binary: bool = False - ) -> AsyncServerCursor[Any]: + ) -> AsyncServerCursor[RowConn]: ... @overload @@ -738,8 +740,8 @@ class AsyncConnection(BaseConnection): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> AsyncCursor[Any]: - cur: AsyncCursor[Any] = self.cursor() + ) -> AsyncCursor[RowConn]: + cur = self.cursor() try: return await cur.execute(query, params, prepare=prepare) except e.Error as ex: diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index ed9fede9e..69bf06e8a 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -148,7 +148,7 @@ class BaseCopy(Generic[ConnectionType]): self._finished = True -class Copy(BaseCopy["Connection"]): +class Copy(BaseCopy["Connection[Any]"]): """Manage a :sql:`COPY` operation.""" __module__ = "psycopg3" @@ -280,7 +280,7 @@ class Copy(BaseCopy["Connection"]): self._worker = None # break the loop -class AsyncCopy(BaseCopy["AsyncConnection"]): +class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" __module__ = "psycopg3" diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 1e200ce68..ac5cccac7 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -471,7 +471,7 @@ class BaseCursor(Generic[ConnectionType, Row]): self._pgq = pgq -class Cursor(BaseCursor["Connection", Row]): +class Cursor(BaseCursor["Connection[Any]", Row]): __module__ = "psycopg3" __slots__ = () @@ -622,7 +622,7 @@ class Cursor(BaseCursor["Connection", Row]): yield copy -class AsyncCursor(BaseCursor["AsyncConnection", Row]): +class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg3" __slots__ = () diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index 7c950c1c2..6c877b46d 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -27,16 +27,18 @@ from .errors import PoolClosed, PoolTimeout, TooManyRequests logger = logging.getLogger("psycopg3.pool") -class AsyncConnectionPool(BasePool[AsyncConnection]): +class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): def __init__( self, conninfo: str = "", *, - connection_class: Type[AsyncConnection] = AsyncConnection, + connection_class: Type[AsyncConnection[Any]] = AsyncConnection, configure: Optional[ - Callable[[AsyncConnection], Awaitable[None]] + Callable[[AsyncConnection[Any]], Awaitable[None]] + ] = None, + reset: Optional[ + Callable[[AsyncConnection[Any]], Awaitable[None]] ] = None, - reset: Optional[Callable[[AsyncConnection], Awaitable[None]]] = None, **kwargs: Any, ): # https://bugs.python.org/issue42600 @@ -104,7 +106,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): @asynccontextmanager async def connection( self, timeout: Optional[float] = None - ) -> AsyncIterator[AsyncConnection]: + ) -> AsyncIterator[AsyncConnection[Any]]: conn = await self.getconn(timeout=timeout) t0 = monotonic() try: @@ -117,7 +119,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): async def getconn( self, timeout: Optional[float] = None - ) -> AsyncConnection: + ) -> AsyncConnection[Any]: logger.info("connection requested from %r", self.name) self._stats[self._REQUESTS_NUM] += 1 # Critical section: decide here if there's a connection ready @@ -177,7 +179,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): logger.info("connection given by %r", self.name) return conn - async def putconn(self, conn: AsyncConnection) -> None: + async def putconn(self, conn: AsyncConnection[Any]) -> None: # Quick check to discard the wrong connection pool = getattr(conn, "_pool", None) if pool is not self: @@ -343,7 +345,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): ex, ) - async def _connect(self) -> AsyncConnection: + async def _connect(self) -> AsyncConnection[Any]: """Return a new connection configured for the pool.""" self._stats[self._CONNECTIONS_NUM] += 1 t0 = monotonic() @@ -426,7 +428,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): else: self._growing = False - async def _return_connection(self, conn: AsyncConnection) -> None: + async def _return_connection(self, conn: AsyncConnection[Any]) -> None: """ Return a connection to the pool after usage. """ @@ -447,7 +449,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): await self._add_to_pool(conn) - async def _add_to_pool(self, conn: AsyncConnection) -> None: + async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: """ Add a connection to the pool. @@ -481,7 +483,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): if self._pool_full_event and len(self._pool) >= self._nconns: self._pool_full_event.set() - async def _reset_connection(self, conn: AsyncConnection) -> None: + async def _reset_connection(self, conn: AsyncConnection[Any]) -> None: """ Bring a connection to IDLE state or close it. """ @@ -523,7 +525,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): await conn.close() async def _shrink_pool(self) -> None: - to_close: Optional[AsyncConnection] = None + to_close: Optional[AsyncConnection[Any]] = None async with self._lock: # Reset the min number of connections used @@ -559,7 +561,7 @@ class AsyncClient: __slots__ = ("conn", "error", "_cond") def __init__(self) -> None: - self.conn: Optional[AsyncConnection] = None + self.conn: Optional[AsyncConnection[Any]] = None self.error: Optional[Exception] = None # The AsyncClient behaves in a way similar to an Event, but we need @@ -569,7 +571,7 @@ class AsyncClient: # will be lost. self._cond = asyncio.Condition() - async def wait(self, timeout: float) -> AsyncConnection: + async def wait(self, timeout: float) -> AsyncConnection[Any]: """Wait for a connection to be set and return it. Raise an exception if the wait times out or if fail() is called. @@ -589,7 +591,7 @@ class AsyncClient: assert self.error raise self.error - async def set(self, conn: AsyncConnection) -> bool: + async def set(self, conn: AsyncConnection[Any]) -> bool: """Signal the client waiting that a connection is ready. Return True if the client has "accepted" the connection, False @@ -685,7 +687,9 @@ class AddConnection(MaintenanceTask): class ReturnConnection(MaintenanceTask): """Clean up and return a connection to the pool.""" - def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection"): + def __init__( + self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]" + ): super().__init__(pool) self.conn = conn @@ -715,7 +719,10 @@ class Schedule(MaintenanceTask): """ def __init__( - self, pool: "AsyncConnectionPool", task: MaintenanceTask, delay: float + self, + pool: "AsyncConnectionPool", + task: MaintenanceTask, + delay: float, ): super().__init__(pool) self.task = task diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index c75b8edab..2f311a6de 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -10,7 +10,8 @@ from abc import ABC, abstractmethod from time import monotonic from queue import Queue, Empty from types import TracebackType -from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Type +from typing import Any, Callable, Deque, Dict, Iterator, List +from typing import Optional, Type from weakref import ref from contextlib import contextmanager from collections import deque @@ -26,14 +27,14 @@ from .errors import PoolClosed, PoolTimeout, TooManyRequests logger = logging.getLogger("psycopg3.pool") -class ConnectionPool(BasePool[Connection]): +class ConnectionPool(BasePool[Connection[Any]]): def __init__( self, conninfo: str = "", *, - connection_class: Type[Connection] = Connection, - configure: Optional[Callable[[Connection], None]] = None, - reset: Optional[Callable[[Connection], None]] = None, + connection_class: Type[Connection[Any]] = Connection, + configure: Optional[Callable[[Connection[Any]], None]] = None, + reset: Optional[Callable[[Connection[Any]], None]] = None, **kwargs: Any, ): self.connection_class = connection_class @@ -128,7 +129,7 @@ class ConnectionPool(BasePool[Connection]): @contextmanager def connection( self, timeout: Optional[float] = None - ) -> Iterator[Connection]: + ) -> Iterator[Connection[Any]]: """Context manager to obtain a connection from the pool. Returned the connection immediately if available, otherwise wait up to @@ -151,7 +152,7 @@ class ConnectionPool(BasePool[Connection]): self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) self.putconn(conn) - def getconn(self, timeout: Optional[float] = None) -> Connection: + def getconn(self, timeout: Optional[float] = None) -> Connection[Any]: """Obtain a contection from the pool. You should preferrably use `connection()`. Use this function only if @@ -221,7 +222,7 @@ class ConnectionPool(BasePool[Connection]): logger.info("connection given by %r", self.name) return conn - def putconn(self, conn: Connection) -> None: + def putconn(self, conn: Connection[Any]) -> None: """Return a connection to the loving hands of its pool. Use this function only paired with a `getconn()`. You don't need to use @@ -416,7 +417,7 @@ class ConnectionPool(BasePool[Connection]): ex, ) - def _connect(self) -> Connection: + def _connect(self) -> Connection[Any]: """Return a new connection configured for the pool.""" self._stats[self._CONNECTIONS_NUM] += 1 t0 = monotonic() @@ -497,7 +498,7 @@ class ConnectionPool(BasePool[Connection]): else: self._growing = False - def _return_connection(self, conn: Connection) -> None: + def _return_connection(self, conn: Connection[Any]) -> None: """ Return a connection to the pool after usage. """ @@ -518,7 +519,7 @@ class ConnectionPool(BasePool[Connection]): self._add_to_pool(conn) - def _add_to_pool(self, conn: Connection) -> None: + def _add_to_pool(self, conn: Connection[Any]) -> None: """ Add a connection to the pool. @@ -552,7 +553,7 @@ class ConnectionPool(BasePool[Connection]): if self._pool_full_event and len(self._pool) >= self._nconns: self._pool_full_event.set() - def _reset_connection(self, conn: Connection) -> None: + def _reset_connection(self, conn: Connection[Any]) -> None: """ Bring a connection to IDLE state or close it. """ @@ -594,7 +595,7 @@ class ConnectionPool(BasePool[Connection]): conn.close() def _shrink_pool(self) -> None: - to_close: Optional[Connection] = None + to_close: Optional[Connection[Any]] = None with self._lock: # Reset the min number of connections used @@ -630,7 +631,7 @@ class WaitingClient: __slots__ = ("conn", "error", "_cond") def __init__(self) -> None: - self.conn: Optional[Connection] = None + self.conn: Optional[Connection[Any]] = None self.error: Optional[Exception] = None # The WaitingClient behaves in a way similar to an Event, but we need @@ -640,7 +641,7 @@ class WaitingClient: # will be lost. self._cond = threading.Condition() - def wait(self, timeout: float) -> Connection: + def wait(self, timeout: float) -> Connection[Any]: """Wait for a connection to be set and return it. Raise an exception if the wait times out or if fail() is called. @@ -658,7 +659,7 @@ class WaitingClient: assert self.error raise self.error - def set(self, conn: Connection) -> bool: + def set(self, conn: Connection[Any]) -> bool: """Signal the client waiting that a connection is ready. Return True if the client has "accepted" the connection, False @@ -760,7 +761,7 @@ class AddConnection(MaintenanceTask): class ReturnConnection(MaintenanceTask): """Clean up and return a connection to the pool.""" - def __init__(self, pool: "ConnectionPool", conn: "Connection"): + def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"): super().__init__(pool) self.conn = conn diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index a00f20d7e..f83982f65 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -24,7 +24,7 @@ Buffer = Union[bytes, bytearray, memoryview] Query = Union[str, bytes, "Composable"] Params = Union[Sequence[Any], Mapping[str, Any]] -ConnectionType = TypeVar("ConnectionType", bound="BaseConnection") +ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") # Waiting protocol types @@ -49,6 +49,8 @@ Wait states. Row = TypeVar("Row") Row_co = TypeVar("Row_co", covariant=True) +# Type variable for Connection (other are for Cursor). +RowConn = TypeVar("RowConn") class RowMaker(Protocol[Row_co]): @@ -82,7 +84,7 @@ class AdaptContext(Protocol): ... @property - def connection(self) -> Optional["BaseConnection"]: + def connection(self) -> Optional["BaseConnection[Any]"]: ... @@ -91,7 +93,7 @@ class Transformer(Protocol): ... @property - def connection(self) -> Optional["BaseConnection"]: + def connection(self) -> Optional["BaseConnection[Any]"]: ... @property diff --git a/psycopg3/psycopg3/server_cursor.py b/psycopg3/psycopg3/server_cursor.py index 9b4665431..0723b991a 100644 --- a/psycopg3/psycopg3/server_cursor.py +++ b/psycopg3/psycopg3/server_cursor.py @@ -16,6 +16,7 @@ from .cursor import BaseCursor, execute from .proto import ConnectionType, Query, Params, PQGen, Row, RowFactory if TYPE_CHECKING: + from typing import Any # noqa: F401 from .connection import BaseConnection # noqa: F401 from .connection import Connection, AsyncConnection # noqa: F401 @@ -165,20 +166,20 @@ class ServerCursorHelper(Generic[ConnectionType, Row]): return sql.SQL(" ").join(parts) -class ServerCursor(BaseCursor["Connection", Row]): +class ServerCursor(BaseCursor["Connection[Any]", Row]): __module__ = "psycopg3" __slots__ = ("_helper", "itersize") def __init__( self, - connection: "Connection", + connection: "Connection[Any]", name: str, *, format: pq.Format = pq.Format.TEXT, row_factory: RowFactory[Row], ): super().__init__(connection, format=format, row_factory=row_factory) - self._helper: ServerCursorHelper["Connection", Row] + self._helper: ServerCursorHelper["Connection[Any]", Row] self._helper = ServerCursorHelper(name) self.itersize: int = DEFAULT_ITERSIZE @@ -286,20 +287,20 @@ class ServerCursor(BaseCursor["Connection", Row]): self._pos = value -class AsyncServerCursor(BaseCursor["AsyncConnection", Row]): +class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg3" __slots__ = ("_helper", "itersize") def __init__( self, - connection: "AsyncConnection", + connection: "AsyncConnection[Any]", name: str, *, format: pq.Format = pq.Format.TEXT, row_factory: RowFactory[Row], ): super().__init__(connection, format=format, row_factory=row_factory) - self._helper: ServerCursorHelper["AsyncConnection", Row] + self._helper: ServerCursorHelper["AsyncConnection[Any]", Row] self._helper = ServerCursorHelper(name) self.itersize: int = DEFAULT_ITERSIZE diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index ef05a921a..47a8f6da0 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -16,6 +16,7 @@ from .proto import ConnectionType, PQGen from .pq.proto import PGresult if TYPE_CHECKING: + from typing import Any # noqa: F401 from .connection import Connection, AsyncConnection # noqa: F401 logger = logging.getLogger(__name__) @@ -171,7 +172,7 @@ class BaseTransaction(Generic[ConnectionType]): return False -class Transaction(BaseTransaction["Connection"]): +class Transaction(BaseTransaction["Connection[Any]"]): """ Returned by `Connection.transaction()` to handle a transaction block. """ @@ -179,7 +180,7 @@ class Transaction(BaseTransaction["Connection"]): __module__ = "psycopg3" @property - def connection(self) -> "Connection": + def connection(self) -> "Connection[Any]": """The connection the object is managing.""" return self._conn @@ -198,7 +199,7 @@ class Transaction(BaseTransaction["Connection"]): return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) -class AsyncTransaction(BaseTransaction["AsyncConnection"]): +class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): """ Returned by `AsyncConnection.transaction()` to handle a transaction block. """ @@ -206,7 +207,7 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): __module__ = "psycopg3" @property - def connection(self) -> "AsyncConnection": + def connection(self) -> "AsyncConnection[Any]": return self._conn async def __aenter__(self) -> "AsyncTransaction": diff --git a/psycopg3_c/psycopg3_c/_psycopg3.pyi b/psycopg3_c/psycopg3_c/_psycopg3.pyi index 66e3b4a4c..899feadfd 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3.pyi +++ b/psycopg3_c/psycopg3_c/_psycopg3.pyi @@ -18,7 +18,7 @@ from psycopg3.pq.proto import PGconn, PGresult class Transformer(proto.AdaptContext): def __init__(self, context: Optional[proto.AdaptContext] = None): ... @property - def connection(self) -> Optional[BaseConnection]: ... + def connection(self) -> Optional[BaseConnection[Any]]: ... @property def adapters(self) -> AdaptersMap: ... @property diff --git a/tests/typing_example.py b/tests/typing_example.py index 63cba24bc..fa59a2152 100644 --- a/tests/typing_example.py +++ b/tests/typing_example.py @@ -3,9 +3,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Tuple -from psycopg3 import BaseCursor, Cursor, ServerCursor, connect +from psycopg3 import BaseCursor, Connection, Cursor, ServerCursor, connect def int_row_factory( @@ -32,7 +32,7 @@ class Person: def check_row_factory_cursor() -> None: """Type-check connection.cursor(..., row_factory=) case.""" - conn = connect() + conn = connect() # type: ignore[var-annotated] # Connection[Any] cur1: Cursor[Any] cur1 = conn.cursor() @@ -58,12 +58,10 @@ def check_row_factory_cursor() -> None: def check_row_factory_connection() -> None: """Type-check connect(..., row_factory=) or Connection.row_factory cases. - - This example is incomplete because Connection is not generic on Row, hence - all the Any, which we aim at getting rid of. """ - cur1: Cursor[Any] - r1: Any + conn1: Connection[int] + cur1: Cursor[int] + r1: Optional[int] conn1 = connect(row_factory=int_row_factory) cur1 = conn1.execute("select 1") r1 = cur1.fetchone() @@ -71,8 +69,9 @@ def check_row_factory_connection() -> None: with conn1.cursor() as cur1: cur1.execute("select 2") - cur2: Cursor[Any] - r2: Any + conn2: Connection[Person] + cur2: Cursor[Person] + r2: Optional[Person] conn2 = connect(row_factory=Person.row_factory) cur2 = conn2.execute("select * from persons") r2 = cur2.fetchone() @@ -80,9 +79,9 @@ def check_row_factory_connection() -> None: with conn2.cursor() as cur2: cur2.execute("select 2") - cur3: Cursor[Any] - r3: Optional[Any] - conn3 = connect() + cur3: Cursor[Tuple[Any, ...]] + r3: Optional[Tuple[Any, ...]] + conn3 = connect() # type: ignore[var-annotated] cur3 = conn3.execute("select 3") with conn3.cursor() as cur3: cur3.execute("select 42") -- 2.47.2