]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make Connection generic on Row
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 20 Apr 2021 08:30:03 +0000 (10:30 +0200)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 28 Apr 2021 13:08:48 +0000 (15:08 +0200)
We use a type variable 'RowConn' for Connection that is distinct from
'Row' that is used on Cursor side in order to reflect the possibility to
have distinct row factories on connection and cursor sides.

In order to avoid "propagation" of the Row type variable of Connection
classes, we use Any everywhere that variable is not used (namely,
everywhere outside the connection module).

The typing_example.py almost works: connect(row_factory=...) returns a
typed Connection, but only connect() still returns a Connection[Any].

13 files changed:
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/_typeinfo.py
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/pool.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/server_cursor.py
psycopg3/psycopg3/transaction.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
tests/typing_example.py

index ec0119790f673b616acad4d154f2ee3bf58857c8..b6ed255c7a3a78a33cc0faf37f7f6edf41d9d097 100644 (file)
@@ -66,7 +66,7 @@ class Transformer(AdaptContext):
         self._row_loaders: List[LoadFunc] = []
 
     @property
-    def connection(self) -> Optional["BaseConnection"]:
+    def connection(self) -> Optional["BaseConnection[Any]"]:
         return self._conn
 
     @property
index 338b9d57df94aa84317e7cc3f5938e3e13859a6d..8743d90f5879388a6566a31f9f4f734e95fa86c5 100644 (file)
@@ -55,7 +55,7 @@ class TypeInfo:
 
     @classmethod
     def fetch(
-        cls: Type[T], conn: "Connection", name: Union[str, "Identifier"]
+        cls: Type[T], conn: "Connection[Any]", name: Union[str, "Identifier"]
     ) -> Optional[T]:
         """
         Query a system catalog to read information about a type.
@@ -77,7 +77,9 @@ class TypeInfo:
 
     @classmethod
     async def fetch_async(
-        cls: Type[T], conn: "AsyncConnection", name: Union[str, "Identifier"]
+        cls: Type[T],
+        conn: "AsyncConnection[Any]",
+        name: Union[str, "Identifier"],
     ) -> Optional[T]:
         """
         Query a system catalog to read information about a type.
index 62f08e2637ea224280df76dcf47f5d94637f294d..6d6f5c189833d86b2d1453c5d75121c8c5d0f5d6 100644 (file)
@@ -35,7 +35,7 @@ class Dumper(ABC):
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         self.cls = cls
-        self.connection: Optional["BaseConnection"] = (
+        self.connection: Optional["BaseConnection[Any]"] = (
             context.connection if context else None
         )
 
@@ -109,7 +109,7 @@ class Loader(ABC):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         self.oid = oid
-        self.connection: Optional["BaseConnection"] = (
+        self.connection: Optional["BaseConnection[Any]"] = (
             context.connection if context else None
         )
 
@@ -170,7 +170,7 @@ class AdaptersMap(AdaptContext):
         return self
 
     @property
-    def connection(self) -> Optional["BaseConnection"]:
+    def connection(self) -> Optional["BaseConnection[Any]"]:
         return None
 
     def register_dumper(
index 96440bc7180f8177a0b601e671aa9d9d01224a71..7c3135fc065da9c198ddb3833a676929fb82b870 100644 (file)
@@ -9,8 +9,8 @@ import logging
 import warnings
 import threading
 from types import TracebackType
-from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, overload, Type, Union, TYPE_CHECKING
+from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
+from typing import NamedTuple, Optional, Type, Union, TYPE_CHECKING, overload
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -23,8 +23,8 @@ from . import encodings
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .sql import Composable
 from .rows import tuple_row
-from .proto import PQGen, PQGenConn, RV, Row, RowFactory, Query, Params
-from .proto import AdaptContext, ConnectionType
+from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn
+from .proto import Query, Row, RowConn, RowFactory, RV
 from .cursor import Cursor, AsyncCursor
 from .conninfo import make_conninfo, ConnectionInfo
 from .generators import notifies
@@ -74,7 +74,7 @@ NoticeHandler = Callable[[e.Diagnostic], None]
 NotifyHandler = Callable[[Notify], None]
 
 
-class BaseConnection(AdaptContext):
+class BaseConnection(AdaptContext, Generic[RowConn]):
     """
     Base class for different types of connections.
 
@@ -98,7 +98,7 @@ class BaseConnection(AdaptContext):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Any]):
+    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
         self.pgconn = pgconn  # TODO: document this
         self._row_factory = row_factory
         self._autocommit = False
@@ -224,17 +224,17 @@ class BaseConnection(AdaptContext):
         return self._adapters
 
     @property
-    def connection(self) -> "BaseConnection":
+    def connection(self) -> "BaseConnection[RowConn]":
         # implement the AdaptContext protocol
         return self
 
     @property
-    def row_factory(self) -> RowFactory[Any]:
+    def row_factory(self) -> RowFactory[RowConn]:
         """Writable attribute to control how result rows are formed."""
         return self._row_factory
 
     @row_factory.setter
-    def row_factory(self, row_factory: RowFactory[Any]) -> None:
+    def row_factory(self, row_factory: RowFactory[RowConn]) -> None:
         self._row_factory = row_factory
 
     def fileno(self) -> int:
@@ -265,7 +265,7 @@ class BaseConnection(AdaptContext):
 
     @staticmethod
     def _notice_handler(
-        wself: "ReferenceType[BaseConnection]", res: "PGresult"
+        wself: "ReferenceType[BaseConnection[RowConn]]", res: "PGresult"
     ) -> None:
         self = wself()
         if not (self and self._notice_handler):
@@ -294,7 +294,7 @@ class BaseConnection(AdaptContext):
 
     @staticmethod
     def _notify_handler(
-        wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify
+        wself: "ReferenceType[BaseConnection[RowConn]]", pgn: pq.PGnotify
     ) -> None:
         self = wself()
         if not (self and self._notify_handlers):
@@ -435,14 +435,14 @@ class BaseConnection(AdaptContext):
         yield from self._exec_command(b"rollback")
 
 
-class Connection(BaseConnection):
+class Connection(BaseConnection[RowConn]):
     """
     Wrapper for a connection to the database.
     """
 
     __module__ = "psycopg3"
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Any]):
+    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
         super().__init__(pgconn, row_factory)
         self.lock = threading.Lock()
 
@@ -452,9 +452,9 @@ class Connection(BaseConnection):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: Optional[RowFactory[Any]] = None,
+        row_factory: Optional[RowFactory[RowConn]] = None,
         **kwargs: Any,
-    ) -> "Connection":
+    ) -> "Connection[RowConn]":
         """
         Connect to a database server and return a new `Connection` instance.
 
@@ -469,7 +469,7 @@ class Connection(BaseConnection):
             )
         )
 
-    def __enter__(self) -> "Connection":
+    def __enter__(self) -> "Connection[RowConn]":
         return self
 
     def __exit__(
@@ -506,7 +506,7 @@ class Connection(BaseConnection):
         self.pgconn.finish()
 
     @overload
-    def cursor(self, *, binary: bool = False) -> Cursor[Any]:
+    def cursor(self, *, binary: bool = False) -> Cursor[RowConn]:
         ...
 
     @overload
@@ -516,7 +516,9 @@ class Connection(BaseConnection):
         ...
 
     @overload
-    def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Any]:
+    def cursor(
+        self, name: str, *, binary: bool = False
+    ) -> ServerCursor[RowConn]:
         ...
 
     @overload
@@ -551,9 +553,9 @@ class Connection(BaseConnection):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> Cursor[Any]:
+    ) -> Cursor[RowConn]:
         """Execute a query and return a cursor to read its results."""
-        cur: Cursor[Any] = self.cursor()
+        cur = self.cursor()
         try:
             return cur.execute(query, params, prepare=prepare)
         except e.Error as ex:
@@ -626,14 +628,14 @@ class Connection(BaseConnection):
             self.wait(self._set_client_encoding_gen(name))
 
 
-class AsyncConnection(BaseConnection):
+class AsyncConnection(BaseConnection[RowConn]):
     """
     Asynchronous wrapper for a connection to the database.
     """
 
     __module__ = "psycopg3"
 
-    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[Any]):
+    def __init__(self, pgconn: "PGconn", row_factory: RowFactory[RowConn]):
         super().__init__(pgconn, row_factory)
         self.lock = asyncio.Lock()
 
@@ -643,9 +645,9 @@ class AsyncConnection(BaseConnection):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: Optional[RowFactory[Any]] = None,
+        row_factory: Optional[RowFactory[RowConn]] = None,
         **kwargs: Any,
-    ) -> "AsyncConnection":
+    ) -> "AsyncConnection[RowConn]":
         return await cls._wait_conn(
             cls._connect_gen(
                 conninfo,
@@ -655,7 +657,7 @@ class AsyncConnection(BaseConnection):
             )
         )
 
-    async def __aenter__(self) -> "AsyncConnection":
+    async def __aenter__(self) -> "AsyncConnection[RowConn]":
         return self
 
     async def __aexit__(
@@ -691,7 +693,7 @@ class AsyncConnection(BaseConnection):
         self.pgconn.finish()
 
     @overload
-    def cursor(self, *, binary: bool = False) -> AsyncCursor[Any]:
+    def cursor(self, *, binary: bool = False) -> AsyncCursor[RowConn]:
         ...
 
     @overload
@@ -703,7 +705,7 @@ class AsyncConnection(BaseConnection):
     @overload
     def cursor(
         self, name: str, *, binary: bool = False
-    ) -> AsyncServerCursor[Any]:
+    ) -> AsyncServerCursor[RowConn]:
         ...
 
     @overload
@@ -738,8 +740,8 @@ class AsyncConnection(BaseConnection):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> AsyncCursor[Any]:
-        cur: AsyncCursor[Any] = self.cursor()
+    ) -> AsyncCursor[RowConn]:
+        cur = self.cursor()
         try:
             return await cur.execute(query, params, prepare=prepare)
         except e.Error as ex:
index ed9fede9eab8dbe66670a8cb53d2b58593ee1e05..69bf06e8a372b22f41ed11cd1f346574185c564d 100644 (file)
@@ -148,7 +148,7 @@ class BaseCopy(Generic[ConnectionType]):
         self._finished = True
 
 
-class Copy(BaseCopy["Connection"]):
+class Copy(BaseCopy["Connection[Any]"]):
     """Manage a :sql:`COPY` operation."""
 
     __module__ = "psycopg3"
@@ -280,7 +280,7 @@ class Copy(BaseCopy["Connection"]):
             self._worker = None  # break the loop
 
 
-class AsyncCopy(BaseCopy["AsyncConnection"]):
+class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     """Manage an asynchronous :sql:`COPY` operation."""
 
     __module__ = "psycopg3"
index 1e200ce68ef742cf5a6707969340f152258b44fa..ac5cccac7a5a3229e77016f587fc85587b31904c 100644 (file)
@@ -471,7 +471,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._pgq = pgq
 
 
-class Cursor(BaseCursor["Connection", Row]):
+class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg3"
     __slots__ = ()
 
@@ -622,7 +622,7 @@ class Cursor(BaseCursor["Connection", Row]):
             yield copy
 
 
-class AsyncCursor(BaseCursor["AsyncConnection", Row]):
+class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg3"
     __slots__ = ()
 
index 7c950c1c21a878bafa0bd28dbb6045735d10ef00..6c877b46d381ff9574f808b19d7ffa4a79d3d3df 100644 (file)
@@ -27,16 +27,18 @@ from .errors import PoolClosed, PoolTimeout, TooManyRequests
 logger = logging.getLogger("psycopg3.pool")
 
 
-class AsyncConnectionPool(BasePool[AsyncConnection]):
+class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
     def __init__(
         self,
         conninfo: str = "",
         *,
-        connection_class: Type[AsyncConnection] = AsyncConnection,
+        connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
         configure: Optional[
-            Callable[[AsyncConnection], Awaitable[None]]
+            Callable[[AsyncConnection[Any]], Awaitable[None]]
+        ] = None,
+        reset: Optional[
+            Callable[[AsyncConnection[Any]], Awaitable[None]]
         ] = None,
-        reset: Optional[Callable[[AsyncConnection], Awaitable[None]]] = None,
         **kwargs: Any,
     ):
         # https://bugs.python.org/issue42600
@@ -104,7 +106,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
     @asynccontextmanager
     async def connection(
         self, timeout: Optional[float] = None
-    ) -> AsyncIterator[AsyncConnection]:
+    ) -> AsyncIterator[AsyncConnection[Any]]:
         conn = await self.getconn(timeout=timeout)
         t0 = monotonic()
         try:
@@ -117,7 +119,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
 
     async def getconn(
         self, timeout: Optional[float] = None
-    ) -> AsyncConnection:
+    ) -> AsyncConnection[Any]:
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
         # Critical section: decide here if there's a connection ready
@@ -177,7 +179,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         logger.info("connection given by %r", self.name)
         return conn
 
-    async def putconn(self, conn: AsyncConnection) -> None:
+    async def putconn(self, conn: AsyncConnection[Any]) -> None:
         # Quick check to discard the wrong connection
         pool = getattr(conn, "_pool", None)
         if pool is not self:
@@ -343,7 +345,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                     ex,
                 )
 
-    async def _connect(self) -> AsyncConnection:
+    async def _connect(self) -> AsyncConnection[Any]:
         """Return a new connection configured for the pool."""
         self._stats[self._CONNECTIONS_NUM] += 1
         t0 = monotonic()
@@ -426,7 +428,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                 else:
                     self._growing = False
 
-    async def _return_connection(self, conn: AsyncConnection) -> None:
+    async def _return_connection(self, conn: AsyncConnection[Any]) -> None:
         """
         Return a connection to the pool after usage.
         """
@@ -447,7 +449,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
 
         await self._add_to_pool(conn)
 
-    async def _add_to_pool(self, conn: AsyncConnection) -> None:
+    async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
         """
         Add a connection to the pool.
 
@@ -481,7 +483,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                 if self._pool_full_event and len(self._pool) >= self._nconns:
                     self._pool_full_event.set()
 
-    async def _reset_connection(self, conn: AsyncConnection) -> None:
+    async def _reset_connection(self, conn: AsyncConnection[Any]) -> None:
         """
         Bring a connection to IDLE state or close it.
         """
@@ -523,7 +525,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                 await conn.close()
 
     async def _shrink_pool(self) -> None:
-        to_close: Optional[AsyncConnection] = None
+        to_close: Optional[AsyncConnection[Any]] = None
 
         async with self._lock:
             # Reset the min number of connections used
@@ -559,7 +561,7 @@ class AsyncClient:
     __slots__ = ("conn", "error", "_cond")
 
     def __init__(self) -> None:
-        self.conn: Optional[AsyncConnection] = None
+        self.conn: Optional[AsyncConnection[Any]] = None
         self.error: Optional[Exception] = None
 
         # The AsyncClient behaves in a way similar to an Event, but we need
@@ -569,7 +571,7 @@ class AsyncClient:
         # will be lost.
         self._cond = asyncio.Condition()
 
-    async def wait(self, timeout: float) -> AsyncConnection:
+    async def wait(self, timeout: float) -> AsyncConnection[Any]:
         """Wait for a connection to be set and return it.
 
         Raise an exception if the wait times out or if fail() is called.
@@ -589,7 +591,7 @@ class AsyncClient:
             assert self.error
             raise self.error
 
-    async def set(self, conn: AsyncConnection) -> bool:
+    async def set(self, conn: AsyncConnection[Any]) -> bool:
         """Signal the client waiting that a connection is ready.
 
         Return True if the client has "accepted" the connection, False
@@ -685,7 +687,9 @@ class AddConnection(MaintenanceTask):
 class ReturnConnection(MaintenanceTask):
     """Clean up and return a connection to the pool."""
 
-    def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection"):
+    def __init__(
+        self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"
+    ):
         super().__init__(pool)
         self.conn = conn
 
@@ -715,7 +719,10 @@ class Schedule(MaintenanceTask):
     """
 
     def __init__(
-        self, pool: "AsyncConnectionPool", task: MaintenanceTask, delay: float
+        self,
+        pool: "AsyncConnectionPool",
+        task: MaintenanceTask,
+        delay: float,
     ):
         super().__init__(pool)
         self.task = task
index c75b8edab8d405938e2ba8823f1d6c934a9107df..2f311a6de865a2f450feddfd3d80e387713c53c4 100644 (file)
@@ -10,7 +10,8 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from queue import Queue, Empty
 from types import TracebackType
-from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Type
+from typing import Any, Callable, Deque, Dict, Iterator, List
+from typing import Optional, Type
 from weakref import ref
 from contextlib import contextmanager
 from collections import deque
@@ -26,14 +27,14 @@ from .errors import PoolClosed, PoolTimeout, TooManyRequests
 logger = logging.getLogger("psycopg3.pool")
 
 
-class ConnectionPool(BasePool[Connection]):
+class ConnectionPool(BasePool[Connection[Any]]):
     def __init__(
         self,
         conninfo: str = "",
         *,
-        connection_class: Type[Connection] = Connection,
-        configure: Optional[Callable[[Connection], None]] = None,
-        reset: Optional[Callable[[Connection], None]] = None,
+        connection_class: Type[Connection[Any]] = Connection,
+        configure: Optional[Callable[[Connection[Any]], None]] = None,
+        reset: Optional[Callable[[Connection[Any]], None]] = None,
         **kwargs: Any,
     ):
         self.connection_class = connection_class
@@ -128,7 +129,7 @@ class ConnectionPool(BasePool[Connection]):
     @contextmanager
     def connection(
         self, timeout: Optional[float] = None
-    ) -> Iterator[Connection]:
+    ) -> Iterator[Connection[Any]]:
         """Context manager to obtain a connection from the pool.
 
         Returned the connection immediately if available, otherwise wait up to
@@ -151,7 +152,7 @@ class ConnectionPool(BasePool[Connection]):
             self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
             self.putconn(conn)
 
-    def getconn(self, timeout: Optional[float] = None) -> Connection:
+    def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
         """Obtain a contection from the pool.
 
         You should preferrably use `connection()`. Use this function only if
@@ -221,7 +222,7 @@ class ConnectionPool(BasePool[Connection]):
         logger.info("connection given by %r", self.name)
         return conn
 
-    def putconn(self, conn: Connection) -> None:
+    def putconn(self, conn: Connection[Any]) -> None:
         """Return a connection to the loving hands of its pool.
 
         Use this function only paired with a `getconn()`. You don't need to use
@@ -416,7 +417,7 @@ class ConnectionPool(BasePool[Connection]):
                     ex,
                 )
 
-    def _connect(self) -> Connection:
+    def _connect(self) -> Connection[Any]:
         """Return a new connection configured for the pool."""
         self._stats[self._CONNECTIONS_NUM] += 1
         t0 = monotonic()
@@ -497,7 +498,7 @@ class ConnectionPool(BasePool[Connection]):
                 else:
                     self._growing = False
 
-    def _return_connection(self, conn: Connection) -> None:
+    def _return_connection(self, conn: Connection[Any]) -> None:
         """
         Return a connection to the pool after usage.
         """
@@ -518,7 +519,7 @@ class ConnectionPool(BasePool[Connection]):
 
         self._add_to_pool(conn)
 
-    def _add_to_pool(self, conn: Connection) -> None:
+    def _add_to_pool(self, conn: Connection[Any]) -> None:
         """
         Add a connection to the pool.
 
@@ -552,7 +553,7 @@ class ConnectionPool(BasePool[Connection]):
                 if self._pool_full_event and len(self._pool) >= self._nconns:
                     self._pool_full_event.set()
 
-    def _reset_connection(self, conn: Connection) -> None:
+    def _reset_connection(self, conn: Connection[Any]) -> None:
         """
         Bring a connection to IDLE state or close it.
         """
@@ -594,7 +595,7 @@ class ConnectionPool(BasePool[Connection]):
                 conn.close()
 
     def _shrink_pool(self) -> None:
-        to_close: Optional[Connection] = None
+        to_close: Optional[Connection[Any]] = None
 
         with self._lock:
             # Reset the min number of connections used
@@ -630,7 +631,7 @@ class WaitingClient:
     __slots__ = ("conn", "error", "_cond")
 
     def __init__(self) -> None:
-        self.conn: Optional[Connection] = None
+        self.conn: Optional[Connection[Any]] = None
         self.error: Optional[Exception] = None
 
         # The WaitingClient behaves in a way similar to an Event, but we need
@@ -640,7 +641,7 @@ class WaitingClient:
         # will be lost.
         self._cond = threading.Condition()
 
-    def wait(self, timeout: float) -> Connection:
+    def wait(self, timeout: float) -> Connection[Any]:
         """Wait for a connection to be set and return it.
 
         Raise an exception if the wait times out or if fail() is called.
@@ -658,7 +659,7 @@ class WaitingClient:
             assert self.error
             raise self.error
 
-    def set(self, conn: Connection) -> bool:
+    def set(self, conn: Connection[Any]) -> bool:
         """Signal the client waiting that a connection is ready.
 
         Return True if the client has "accepted" the connection, False
@@ -760,7 +761,7 @@ class AddConnection(MaintenanceTask):
 class ReturnConnection(MaintenanceTask):
     """Clean up and return a connection to the pool."""
 
-    def __init__(self, pool: "ConnectionPool", conn: "Connection"):
+    def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"):
         super().__init__(pool)
         self.conn = conn
 
index a00f20d7ea58cf68159a73e0551ceedd7ae2a2df..f83982f65dc3a43f6b4574f2b374d2ea0070178c 100644 (file)
@@ -24,7 +24,7 @@ Buffer = Union[bytes, bytearray, memoryview]
 
 Query = Union[str, bytes, "Composable"]
 Params = Union[Sequence[Any], Mapping[str, Any]]
-ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
+ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
 
 
 # Waiting protocol types
@@ -49,6 +49,8 @@ Wait states.
 
 Row = TypeVar("Row")
 Row_co = TypeVar("Row_co", covariant=True)
+# Type variable for Connection (other are for Cursor).
+RowConn = TypeVar("RowConn")
 
 
 class RowMaker(Protocol[Row_co]):
@@ -82,7 +84,7 @@ class AdaptContext(Protocol):
         ...
 
     @property
-    def connection(self) -> Optional["BaseConnection"]:
+    def connection(self) -> Optional["BaseConnection[Any]"]:
         ...
 
 
@@ -91,7 +93,7 @@ class Transformer(Protocol):
         ...
 
     @property
-    def connection(self) -> Optional["BaseConnection"]:
+    def connection(self) -> Optional["BaseConnection[Any]"]:
         ...
 
     @property
index 9b46654312c67173270e19e75c41af48bcc4505b..0723b991a8b62d05f1e9beb0e2d25c4dd2513d52 100644 (file)
@@ -16,6 +16,7 @@ from .cursor import BaseCursor, execute
 from .proto import ConnectionType, Query, Params, PQGen, Row, RowFactory
 
 if TYPE_CHECKING:
+    from typing import Any  # noqa: F401
     from .connection import BaseConnection  # noqa: F401
     from .connection import Connection, AsyncConnection  # noqa: F401
 
@@ -165,20 +166,20 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         return sql.SQL(" ").join(parts)
 
 
-class ServerCursor(BaseCursor["Connection", Row]):
+class ServerCursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg3"
     __slots__ = ("_helper", "itersize")
 
     def __init__(
         self,
-        connection: "Connection",
+        connection: "Connection[Any]",
         name: str,
         *,
         format: pq.Format = pq.Format.TEXT,
         row_factory: RowFactory[Row],
     ):
         super().__init__(connection, format=format, row_factory=row_factory)
-        self._helper: ServerCursorHelper["Connection", Row]
+        self._helper: ServerCursorHelper["Connection[Any]", Row]
         self._helper = ServerCursorHelper(name)
         self.itersize: int = DEFAULT_ITERSIZE
 
@@ -286,20 +287,20 @@ class ServerCursor(BaseCursor["Connection", Row]):
             self._pos = value
 
 
-class AsyncServerCursor(BaseCursor["AsyncConnection", Row]):
+class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg3"
     __slots__ = ("_helper", "itersize")
 
     def __init__(
         self,
-        connection: "AsyncConnection",
+        connection: "AsyncConnection[Any]",
         name: str,
         *,
         format: pq.Format = pq.Format.TEXT,
         row_factory: RowFactory[Row],
     ):
         super().__init__(connection, format=format, row_factory=row_factory)
-        self._helper: ServerCursorHelper["AsyncConnection", Row]
+        self._helper: ServerCursorHelper["AsyncConnection[Any]", Row]
         self._helper = ServerCursorHelper(name)
         self.itersize: int = DEFAULT_ITERSIZE
 
index ef05a921abc83af635a1c8c37cfb3ea5892987b4..47a8f6da0965b73931a18106a4724009d678bf0c 100644 (file)
@@ -16,6 +16,7 @@ from .proto import ConnectionType, PQGen
 from .pq.proto import PGresult
 
 if TYPE_CHECKING:
+    from typing import Any  # noqa: F401
     from .connection import Connection, AsyncConnection  # noqa: F401
 
 logger = logging.getLogger(__name__)
@@ -171,7 +172,7 @@ class BaseTransaction(Generic[ConnectionType]):
         return False
 
 
-class Transaction(BaseTransaction["Connection"]):
+class Transaction(BaseTransaction["Connection[Any]"]):
     """
     Returned by `Connection.transaction()` to handle a transaction block.
     """
@@ -179,7 +180,7 @@ class Transaction(BaseTransaction["Connection"]):
     __module__ = "psycopg3"
 
     @property
-    def connection(self) -> "Connection":
+    def connection(self) -> "Connection[Any]":
         """The connection the object is managing."""
         return self._conn
 
@@ -198,7 +199,7 @@ class Transaction(BaseTransaction["Connection"]):
             return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
 
 
-class AsyncTransaction(BaseTransaction["AsyncConnection"]):
+class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
     """
     Returned by `AsyncConnection.transaction()` to handle a transaction block.
     """
@@ -206,7 +207,7 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
     __module__ = "psycopg3"
 
     @property
-    def connection(self) -> "AsyncConnection":
+    def connection(self) -> "AsyncConnection[Any]":
         return self._conn
 
     async def __aenter__(self) -> "AsyncTransaction":
index 66e3b4a4c70354ca7e3a25f95080052ec8b2cd33..899feadfde6eb3d0325e359880e4697b8cf57ddb 100644 (file)
@@ -18,7 +18,7 @@ from psycopg3.pq.proto import PGconn, PGresult
 class Transformer(proto.AdaptContext):
     def __init__(self, context: Optional[proto.AdaptContext] = None): ...
     @property
-    def connection(self) -> Optional[BaseConnection]: ...
+    def connection(self) -> Optional[BaseConnection[Any]]: ...
     @property
     def adapters(self) -> AdaptersMap: ...
     @property
index 63cba24bc27554f6c527944fb84116f2077a0796..fa59a2152577ffaabf0030937821ed2f5d3e6c3e 100644 (file)
@@ -3,9 +3,9 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
-from typing import Any, Callable, Optional, Sequence
+from typing import Any, Callable, Optional, Sequence, Tuple
 
-from psycopg3 import BaseCursor, Cursor, ServerCursor, connect
+from psycopg3 import BaseCursor, Connection, Cursor, ServerCursor, connect
 
 
 def int_row_factory(
@@ -32,7 +32,7 @@ class Person:
 
 def check_row_factory_cursor() -> None:
     """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
-    conn = connect()
+    conn = connect()  # type: ignore[var-annotated] # Connection[Any]
 
     cur1: Cursor[Any]
     cur1 = conn.cursor()
@@ -58,12 +58,10 @@ def check_row_factory_cursor() -> None:
 def check_row_factory_connection() -> None:
     """Type-check connect(..., row_factory=<MyRowFactory>) or
     Connection.row_factory cases.
-
-    This example is incomplete because Connection is not generic on Row, hence
-    all the Any, which we aim at getting rid of.
     """
-    cur1: Cursor[Any]
-    r1: Any
+    conn1: Connection[int]
+    cur1: Cursor[int]
+    r1: Optional[int]
     conn1 = connect(row_factory=int_row_factory)
     cur1 = conn1.execute("select 1")
     r1 = cur1.fetchone()
@@ -71,8 +69,9 @@ def check_row_factory_connection() -> None:
     with conn1.cursor() as cur1:
         cur1.execute("select 2")
 
-    cur2: Cursor[Any]
-    r2: Any
+    conn2: Connection[Person]
+    cur2: Cursor[Person]
+    r2: Optional[Person]
     conn2 = connect(row_factory=Person.row_factory)
     cur2 = conn2.execute("select * from persons")
     r2 = cur2.fetchone()
@@ -80,9 +79,9 @@ def check_row_factory_connection() -> None:
     with conn2.cursor() as cur2:
         cur2.execute("select 2")
 
-    cur3: Cursor[Any]
-    r3: Optional[Any]
-    conn3 = connect()
+    cur3: Cursor[Tuple[Any, ...]]
+    r3: Optional[Tuple[Any, ...]]
+    conn3 = connect()  # type: ignore[var-annotated]
     cur3 = conn3.execute("select 3")
     with conn3.cursor() as cur3:
         cur3.execute("select 42")