]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(pool): make pool generic on connection type
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Sep 2023 23:16:55 +0000 (01:16 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 27 Sep 2023 00:41:00 +0000 (02:41 +0200)
docs/news_pool.rst
psycopg/psycopg/connection.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_null_pool.py
tests/pool/test_null_pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 1f2aeb8fd32b847e42c0ece22b2d581505236d87..043181d5da5c52a945fc8c63e359b5f0fea4d4e5 100644 (file)
@@ -22,6 +22,7 @@ psycopg_pool 3.1.9 (unreleased)
 
 - Fix the return type annotation of `!NullConnectionPool.__enter__()`
   (:ticket:`#540`).
+- Make connection pool classes generic on the connection type (:ticket:`#559`).
 
 
 Current release
index ca9305394cd7fb94e079fb1b15ea561b0368b08d..a6571d5d7b683321ede9c5c98e3f90c08374a470 100644 (file)
@@ -129,7 +129,7 @@ class BaseConnection(Generic[Row]):
 
         # Attribute is only set if the connection is from a pool so we can tell
         # apart a connection in the pool too (when _pool = None)
-        self._pool: Optional["BasePool[Any]"]
+        self._pool: Optional["BasePool"]
 
         self._pipeline: Optional[BasePipeline] = None
 
index 6d66f5206ec07bc30287e416b059f41dac71b031..13823bccd2ac9ae8da4965b6503bafb461a6c262 100644 (file)
@@ -6,16 +6,18 @@ psycopg connection pool base class and functionalities.
 
 from time import monotonic
 from random import random
-from typing import Any, Dict, Generic, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
 
 from psycopg import errors as e
-from psycopg.abc import ConnectionType
 
 from .errors import PoolClosed
 from ._compat import Counter, Deque
 
+if TYPE_CHECKING:
+    from psycopg.connection import BaseConnection
 
-class BasePool(Generic[ConnectionType]):
+
+class BasePool:
     # Used to generate pool names
     _num_pool = 0
 
@@ -36,6 +38,8 @@ class BasePool(Generic[ConnectionType]):
     _CONNECTIONS_ERRORS = "connections_errors"
     _CONNECTIONS_LOST = "connections_lost"
 
+    _pool: Deque["Any"]
+
     def __init__(
         self,
         conninfo: str = "",
@@ -74,7 +78,7 @@ class BasePool(Generic[ConnectionType]):
         self.num_workers = num_workers
 
         self._nconns = min_size  # currently in the pool, out, being prepared
-        self._pool = Deque[ConnectionType]()
+        self._pool = Deque()
         self._stats = Counter[str]()
 
         # Min number of connections in the pool in a max_idle unit of time.
@@ -138,7 +142,7 @@ class BasePool(Generic[ConnectionType]):
             else:
                 raise PoolClosed(f"the pool {self.name!r} is not open yet")
 
-    def _check_pool_putconn(self, conn: ConnectionType) -> None:
+    def _check_pool_putconn(self, conn: "BaseConnection[Any]") -> None:
         pool = getattr(conn, "_pool", None)
         if pool is self:
             return
@@ -188,7 +192,7 @@ class BasePool(Generic[ConnectionType]):
         """
         return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc)
 
-    def _set_connection_expiry_date(self, conn: ConnectionType) -> None:
+    def _set_connection_expiry_date(self, conn: "BaseConnection[Any]") -> None:
         """Set an expiry date on a connection.
 
         Add some randomness to avoid mass reconnection.
index 9c796ea71d9e46eec8a3e63cbe4950b2ed3a8859..b4e640b74afe45ea9e78cd8a4e827ec13e43063f 100644 (file)
@@ -6,12 +6,13 @@ Psycopg null connection pools
 
 import logging
 import threading
-from typing import Any, Callable, Dict, Optional, Tuple, Type
+from typing import Any, Callable, cast, Dict, Optional, overload, Tuple, Type
 
 from psycopg import Connection
 from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
 
-from .pool import ConnectionPool, AddConnection, ConnectFailedCB
+from .pool import ConnectionPool, CT, AddConnection, ConnectFailedCB
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
 
@@ -40,15 +41,60 @@ class _BaseNullConnectionPool:
         pass
 
 
-class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
+    @overload
+    def __init__(
+        self: "NullConnectionPool[Connection[TupleRow]]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        configure: Optional[Callable[[CT], None]] = ...,
+        reset: Optional[Callable[[CT], None]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[ConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: "NullConnectionPool[CT]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        connection_class: Type[CT],
+        configure: Optional[Callable[[CT], None]] = ...,
+        reset: Optional[Callable[[CT], None]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[ConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
+
     def __init__(
         self,
         conninfo: str = "",
         *,
         open: bool = True,
-        connection_class: Type[Connection[Any]] = Connection,
-        configure: Optional[Callable[[Connection[Any]], None]] = None,
-        reset: Optional[Callable[[Connection[Any]], None]] = None,
+        connection_class: Type[CT] = cast(Type[CT], Connection),
+        configure: Optional[Callable[[CT], None]] = None,
+        reset: Optional[Callable[[CT], None]] = None,
         kwargs: Optional[Dict[str, Any]] = None,
         # Note: default value changed to 0.
         min_size: int = 0,
@@ -109,10 +155,8 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
 
         logger.info("pool %r is ready to use", self.name)
 
-    def _get_ready_connection(
-        self, timeout: Optional[float]
-    ) -> Optional[Connection[Any]]:
-        conn: Optional[Connection[Any]] = None
+    def _get_ready_connection(self, timeout: Optional[float]) -> Optional[CT]:
+        conn: Optional[CT] = None
         if self.max_size == 0 or self._nconns < self.max_size:
             # Create a new connection for the client
             try:
@@ -129,7 +173,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
             )
         return conn
 
-    def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+    def _maybe_close_connection(self, conn: CT) -> bool:
         with self._lock:
             if not self._closed and self._waiting:
                 return False
@@ -162,7 +206,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
         """No-op, as the pool doesn't have connections in its state."""
         pass
 
-    def _add_to_pool(self, conn: Connection[Any]) -> None:
+    def _add_to_pool(self, conn: CT) -> None:
         # Remove the pool reference from the connection before returning it
         # to the state, to avoid to create a reference loop.
         # Also disable the warning for open connection in conn.__del__
index 83254cf38d4e68861d91a9076ffe93688ccb1508..ca9db8bcdbf67f3dd3bfdc100928082d937a6fd8 100644 (file)
@@ -6,28 +6,74 @@ psycopg asynchronous null connection pool
 
 import asyncio
 import logging
-from typing import Any, Awaitable, Callable, Dict, Optional, Type
+from typing import Any, Awaitable, Callable, cast, Dict, Optional, overload, Type
 
 from psycopg import AsyncConnection
 from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
 
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
 from .null_pool import _BaseNullConnectionPool
-from .pool_async import AsyncConnectionPool, AddConnection, AsyncConnectFailedCB
+from .pool_async import AsyncConnectionPool, ACT, AddConnection, AsyncConnectFailedCB
 
 logger = logging.getLogger("psycopg.pool")
 
 
-class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]):
+    @overload
+    def __init__(
+        self: "AsyncNullConnectionPool[AsyncConnection[TupleRow]]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: "AsyncNullConnectionPool[ACT]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        connection_class: Type[ACT],
+        configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
+
     def __init__(
         self,
         conninfo: str = "",
         *,
         open: bool = True,
-        connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
-        configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
-        reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+        connection_class: Type[ACT] = cast(Type[ACT], AsyncConnection),
+        configure: Optional[Callable[[ACT], Awaitable[None]]] = None,
+        reset: Optional[Callable[[ACT], Awaitable[None]]] = None,
         kwargs: Optional[Dict[str, Any]] = None,
         # Note: default value changed to 0.
         min_size: int = 0,
@@ -82,10 +128,8 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
 
         logger.info("pool %r is ready to use", self.name)
 
-    async def _get_ready_connection(
-        self, timeout: Optional[float]
-    ) -> Optional[AsyncConnection[Any]]:
-        conn: Optional[AsyncConnection[Any]] = None
+    async def _get_ready_connection(self, timeout: Optional[float]) -> Optional[ACT]:
+        conn: Optional[ACT] = None
         if self.max_size == 0 or self._nconns < self.max_size:
             # Create a new connection for the client
             try:
@@ -101,7 +145,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
             )
         return conn
 
-    async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+    async def _maybe_close_connection(self, conn: ACT) -> bool:
         # Close the connection if no client is waiting for it, or if the pool
         # is closed. For extra refcare remove the pool reference from it.
         # Maintain the stats.
@@ -132,7 +176,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
     async def check(self) -> None:
         pass
 
-    async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+    async def _add_to_pool(self, conn: ACT) -> None:
         # Remove the pool reference from the connection before returning it
         # to the state, to avoid to create a reference loop.
         # Also disable the warning for open connection in conn.__del__
index 47ecf084eb12b6447faf9d308331d50517877cd2..a7bc5b0de5d8911220b789221dae7b192fe2f2bd 100644 (file)
@@ -10,8 +10,8 @@ from abc import ABC, abstractmethod
 from time import monotonic
 from queue import Queue, Empty
 from types import TracebackType
-from typing import Any, Callable, Dict, Iterator, List
-from typing import Optional, Sequence, Type, TypeVar
+from typing import Any, Callable, cast, Dict, Generic, Iterator, List
+from typing import Optional, overload, Sequence, Type, TypeVar
 from typing_extensions import TypeAlias
 from weakref import ref
 from contextlib import contextmanager
@@ -19,6 +19,7 @@ from contextlib import contextmanager
 from psycopg import errors as e
 from psycopg import Connection
 from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
 
 from .base import ConnectionAttempt, BasePool
 from .sched import Scheduler
@@ -29,18 +30,66 @@ logger = logging.getLogger("psycopg.pool")
 
 ConnectFailedCB: TypeAlias = Callable[["ConnectionPool"], None]
 
+CT = TypeVar("CT", bound="Connection[Any]")
 
-class ConnectionPool(BasePool[Connection[Any]]):
-    _Self = TypeVar("_Self", bound="ConnectionPool")
+
+class ConnectionPool(Generic[CT], BasePool):
+    _Self = TypeVar("_Self", bound="ConnectionPool[CT]")
+    _pool: Deque[CT]
+
+    @overload
+    def __init__(
+        self: "ConnectionPool[Connection[TupleRow]]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        configure: Optional[Callable[[CT], None]] = ...,
+        reset: Optional[Callable[[CT], None]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[ConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: "ConnectionPool[CT]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        connection_class: Type[CT],
+        configure: Optional[Callable[[CT], None]] = ...,
+        reset: Optional[Callable[[CT], None]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[ConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
 
     def __init__(
         self,
         conninfo: str = "",
         *,
         open: bool = True,
-        connection_class: Type[Connection[Any]] = Connection,
-        configure: Optional[Callable[[Connection[Any]], None]] = None,
-        reset: Optional[Callable[[Connection[Any]], None]] = None,
+        connection_class: Type[CT] = cast(Type[CT], Connection[TupleRow]),
+        configure: Optional[Callable[[CT], None]] = None,
+        reset: Optional[Callable[[CT], None]] = None,
         kwargs: Optional[Dict[str, Any]] = None,
         min_size: int = 4,
         max_size: Optional[int] = None,
@@ -61,7 +110,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self._reconnect_failed = reconnect_failed or (lambda pool: None)
 
         self._lock = threading.RLock()
-        self._waiting = Deque["WaitingClient"]()
+        self._waiting = Deque["WaitingClient[CT]"]()
 
         # to notify that the pool is full
         self._pool_full_event: Optional[threading.Event] = None
@@ -130,7 +179,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         logger.info("pool %r is ready to use", self.name)
 
     @contextmanager
-    def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]:
+    def connection(self, timeout: Optional[float] = None) -> Iterator[CT]:
         """Context manager to obtain a connection from the pool.
 
         Return the connection immediately if available, otherwise wait up to
@@ -152,7 +201,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
             self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
             self.putconn(conn)
 
-    def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
+    def getconn(self, timeout: Optional[float] = None) -> CT:
         """Obtain a connection from the pool.
 
         You should preferably use `connection()`. Use this function only if
@@ -173,7 +222,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
             if not conn:
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
-                pos = WaitingClient()
+                pos: WaitingClient[CT] = WaitingClient()
                 self._waiting.append(pos)
                 self._stats[self._REQUESTS_QUEUED] += 1
 
@@ -201,11 +250,9 @@ class ConnectionPool(BasePool[Connection[Any]]):
         logger.info("connection given by %r", self.name)
         return conn
 
-    def _get_ready_connection(
-        self, timeout: Optional[float]
-    ) -> Optional[Connection[Any]]:
+    def _get_ready_connection(self, timeout: Optional[float]) -> Optional[CT]:
         """Return a connection, if the client deserves one."""
-        conn: Optional[Connection[Any]] = None
+        conn: Optional[CT] = None
         if self._pool:
             # Take a connection ready out of the pool
             conn = self._pool.popleft()
@@ -229,7 +276,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self._growing = True
         self.run_task(AddConnection(self, growing=True))
 
-    def putconn(self, conn: Connection[Any]) -> None:
+    def putconn(self, conn: CT) -> None:
         """Return a connection to the loving hands of its pool.
 
         Use this function only paired with a `getconn()`. You don't need to use
@@ -249,7 +296,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         else:
             self._return_connection(conn)
 
-    def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+    def _maybe_close_connection(self, conn: CT) -> bool:
         """Close a returned connection if necessary.
 
         Return `!True if the connection was closed.
@@ -354,8 +401,8 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
     def _stop_workers(
         self,
-        waiting_clients: Sequence["WaitingClient"] = (),
-        connections: Sequence[Connection[Any]] = (),
+        waiting_clients: Sequence["WaitingClient[CT]"] = (),
+        connections: Sequence[CT] = (),
         timeout: float = 0.0,
     ) -> None:
         # Stop the scheduler
@@ -511,7 +558,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
                     ex,
                 )
 
-    def _connect(self, timeout: Optional[float] = None) -> Connection[Any]:
+    def _connect(self, timeout: Optional[float] = None) -> CT:
         """Return a new connection configured for the pool."""
         self._stats[self._CONNECTIONS_NUM] += 1
         kwargs = self.kwargs
@@ -520,8 +567,9 @@ class ConnectionPool(BasePool[Connection[Any]]):
             kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
-            conn: Connection[Any]
-            conn = self.connection_class.connect(self.conninfo, **kwargs)
+            conn: CT = self.connection_class.connect(  # type: ignore
+                self.conninfo, **kwargs
+            )
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
@@ -598,7 +646,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
                 else:
                     self._growing = False
 
-    def _return_connection(self, conn: Connection[Any]) -> None:
+    def _return_connection(self, conn: CT) -> None:
         """
         Return a connection to the pool after usage.
         """
@@ -619,7 +667,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
         self._add_to_pool(conn)
 
-    def _add_to_pool(self, conn: Connection[Any]) -> None:
+    def _add_to_pool(self, conn: CT) -> None:
         """
         Add a connection to the pool.
 
@@ -651,7 +699,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
                 if self._pool_full_event and len(self._pool) >= self._min_size:
                     self._pool_full_event.set()
 
-    def _reset_connection(self, conn: Connection[Any]) -> None:
+    def _reset_connection(self, conn: CT) -> None:
         """
         Bring a connection to IDLE state or close it.
         """
@@ -693,7 +741,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
                 conn.close()
 
     def _shrink_pool(self) -> None:
-        to_close: Optional[Connection[Any]] = None
+        to_close: Optional[CT] = None
 
         with self._lock:
             # Reset the min number of connections used
@@ -723,13 +771,13 @@ class ConnectionPool(BasePool[Connection[Any]]):
         return rv
 
 
-class WaitingClient:
+class WaitingClient(Generic[CT]):
     """A position in a queue for a client waiting for a connection."""
 
     __slots__ = ("conn", "error", "_cond")
 
     def __init__(self) -> None:
-        self.conn: Optional[Connection[Any]] = None
+        self.conn: Optional[CT] = None
         self.error: Optional[BaseException] = None
 
         # The WaitingClient behaves in a way similar to an Event, but we need
@@ -739,7 +787,7 @@ class WaitingClient:
         # will be lost.
         self._cond = threading.Condition()
 
-    def wait(self, timeout: float) -> Connection[Any]:
+    def wait(self, timeout: float) -> CT:
         """Wait for a connection to be set and return it.
 
         Raise an exception if the wait times out or if fail() is called.
@@ -760,7 +808,7 @@ class WaitingClient:
             assert self.error
             raise self.error
 
-    def set(self, conn: Connection[Any]) -> bool:
+    def set(self, conn: CT) -> bool:
         """Signal the client waiting that a connection is ready.
 
         Return True if the client has "accepted" the connection, False
@@ -792,7 +840,7 @@ class WaitingClient:
 class MaintenanceTask(ABC):
     """A task to run asynchronously to maintain the pool state."""
 
-    def __init__(self, pool: "ConnectionPool"):
+    def __init__(self, pool: "ConnectionPool[Any]"):
         self.pool = ref(pool)
 
     def __repr__(self) -> str:
@@ -830,21 +878,21 @@ class MaintenanceTask(ABC):
         pool.run_task(self)
 
     @abstractmethod
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "ConnectionPool[Any]") -> None:
         ...
 
 
 class StopWorker(MaintenanceTask):
     """Signal the maintenance thread to terminate."""
 
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "ConnectionPool[Any]") -> None:
         pass
 
 
 class AddConnection(MaintenanceTask):
     def __init__(
         self,
-        pool: "ConnectionPool",
+        pool: "ConnectionPool[Any]",
         attempt: Optional["ConnectionAttempt"] = None,
         growing: bool = False,
     ):
@@ -852,18 +900,18 @@ class AddConnection(MaintenanceTask):
         self.attempt = attempt
         self.growing = growing
 
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "ConnectionPool[Any]") -> None:
         pool._add_connection(self.attempt, growing=self.growing)
 
 
 class ReturnConnection(MaintenanceTask):
     """Clean up and return a connection to the pool."""
 
-    def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"):
+    def __init__(self, pool: "ConnectionPool[Any]", conn: CT):
         super().__init__(pool)
         self.conn = conn
 
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "ConnectionPool[Any]") -> None:
         pool._return_connection(self.conn)
 
 
@@ -874,7 +922,7 @@ class ShrinkPool(MaintenanceTask):
     in the pool.
     """
 
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "ConnectionPool[Any]") -> None:
         # Reschedule the task now so that in case of any error we don't lose
         # the periodic run.
         pool.schedule_task(self, pool.max_idle)
index b57098e4e5baaf4bf6bc86aa72b82e4a43c36259..9ee8e85541d42662bdf97edf52a439ca730fa6e6 100644 (file)
@@ -9,8 +9,8 @@ import logging
 from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
-from typing import Any, AsyncIterator, Awaitable, Callable
-from typing import Dict, List, Optional, Sequence, Type, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, cast, Generic
+from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar, Union
 from typing_extensions import TypeAlias
 from weakref import ref
 from contextlib import asynccontextmanager
@@ -18,6 +18,7 @@ from contextlib import asynccontextmanager
 from psycopg import errors as e
 from psycopg import AsyncConnection
 from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
 
 from .base import ConnectionAttempt, BasePool
 from .sched import AsyncScheduler
@@ -31,18 +32,66 @@ AsyncConnectFailedCB: TypeAlias = Union[
     Callable[["AsyncConnectionPool"], Awaitable[None]],
 ]
 
+ACT = TypeVar("ACT", bound="AsyncConnection[Any]")
 
-class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
-    _Self = TypeVar("_Self", bound="AsyncConnectionPool")
+
+class AsyncConnectionPool(Generic[ACT], BasePool):
+    _Self = TypeVar("_Self", bound="AsyncConnectionPool[ACT]")
+    _pool: Deque[ACT]
+
+    @overload
+    def __init__(
+        self: "AsyncConnectionPool[AsyncConnection[TupleRow]]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: "AsyncConnectionPool[ACT]",
+        conninfo: str = "",
+        *,
+        open: bool = ...,
+        connection_class: Type[ACT],
+        configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+        kwargs: Optional[Dict[str, Any]] = ...,
+        min_size: int = ...,
+        max_size: Optional[int] = ...,
+        name: Optional[str] = ...,
+        timeout: float = ...,
+        max_waiting: int = ...,
+        max_lifetime: float = ...,
+        max_idle: float = ...,
+        reconnect_timeout: float = ...,
+        reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+        num_workers: int = ...,
+    ):
+        ...
 
     def __init__(
         self,
         conninfo: str = "",
         *,
         open: bool = True,
-        connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
-        configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
-        reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+        connection_class: Type[ACT] = cast(Type[ACT], AsyncConnection),
+        configure: Optional[Callable[[ACT], Awaitable[None]]] = None,
+        reset: Optional[Callable[[ACT], Awaitable[None]]] = None,
         kwargs: Optional[Dict[str, Any]] = None,
         min_size: int = 4,
         max_size: Optional[int] = None,
@@ -67,7 +116,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self._sched: AsyncScheduler
         self._tasks: "asyncio.Queue[MaintenanceTask]"
 
-        self._waiting = Deque["AsyncClient"]()
+        self._waiting = Deque["AsyncClient[ACT]"]()
 
         # to notify that the pool is full
         self._pool_full_event: Optional[asyncio.Event] = None
@@ -118,9 +167,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         logger.info("pool %r is ready to use", self.name)
 
     @asynccontextmanager
-    async def connection(
-        self, timeout: Optional[float] = None
-    ) -> AsyncIterator[AsyncConnection[Any]]:
+    async def connection(self, timeout: Optional[float] = None) -> AsyncIterator[ACT]:
         conn = await self.getconn(timeout=timeout)
         t0 = monotonic()
         try:
@@ -131,7 +178,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
             await self.putconn(conn)
 
-    async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+    async def getconn(self, timeout: Optional[float] = None) -> ACT:
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
 
@@ -144,7 +191,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             if not conn:
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
-                pos = AsyncClient()
+                pos: AsyncClient[ACT] = AsyncClient()
                 self._waiting.append(pos)
                 self._stats[self._REQUESTS_QUEUED] += 1
 
@@ -172,10 +219,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         logger.info("connection given by %r", self.name)
         return conn
 
-    async def _get_ready_connection(
-        self, timeout: Optional[float]
-    ) -> Optional[AsyncConnection[Any]]:
-        conn: Optional[AsyncConnection[Any]] = None
+    async def _get_ready_connection(self, timeout: Optional[float]) -> Optional[ACT]:
+        conn: Optional[ACT] = None
         if self._pool:
             # Take a connection ready out of the pool
             conn = self._pool.popleft()
@@ -198,7 +243,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             self._growing = True
             self.run_task(AddConnection(self, growing=True))
 
-    async def putconn(self, conn: AsyncConnection[Any]) -> None:
+    async def putconn(self, conn: ACT) -> None:
         self._check_pool_putconn(conn)
 
         logger.info("returning connection to %r", self.name)
@@ -211,7 +256,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         else:
             await self._return_connection(conn)
 
-    async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+    async def _maybe_close_connection(self, conn: ACT) -> bool:
         # If the pool is closed just close the connection instead of returning
         # it to the pool. For extra refcare remove the pool reference from it.
         if not self._closed:
@@ -299,8 +344,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
 
     async def _stop_workers(
         self,
-        waiting_clients: Sequence["AsyncClient"] = (),
-        connections: Sequence[AsyncConnection[Any]] = (),
+        waiting_clients: Sequence["AsyncClient[ACT]"] = (),
+        connections: Sequence[ACT] = (),
         timeout: float = 0.0,
     ) -> None:
         # Stop the scheduler
@@ -442,7 +487,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                     ex,
                 )
 
-    async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+    async def _connect(self, timeout: Optional[float] = None) -> ACT:
         self._stats[self._CONNECTIONS_NUM] += 1
         kwargs = self.kwargs
         if timeout:
@@ -450,8 +495,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
-            conn: AsyncConnection[Any]
-            conn = await self.connection_class.connect(self.conninfo, **kwargs)
+            conn: ACT = await self.connection_class.connect(self.conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
@@ -528,7 +572,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 else:
                     self._growing = False
 
-    async def _return_connection(self, conn: AsyncConnection[Any]) -> None:
+    async def _return_connection(self, conn: ACT) -> None:
         """
         Return a connection to the pool after usage.
         """
@@ -549,7 +593,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
 
         await self._add_to_pool(conn)
 
-    async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+    async def _add_to_pool(self, conn: ACT) -> None:
         """
         Add a connection to the pool.
 
@@ -581,7 +625,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 if self._pool_full_event and len(self._pool) >= self._min_size:
                     self._pool_full_event.set()
 
-    async def _reset_connection(self, conn: AsyncConnection[Any]) -> None:
+    async def _reset_connection(self, conn: ACT) -> None:
         """
         Bring a connection to IDLE state or close it.
         """
@@ -623,7 +667,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 await conn.close()
 
     async def _shrink_pool(self) -> None:
-        to_close: Optional[AsyncConnection[Any]] = None
+        to_close: Optional[ACT] = None
 
         async with self._lock:
             # Reset the min number of connections used
@@ -653,13 +697,13 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         return rv
 
 
-class AsyncClient:
+class AsyncClient(Generic[ACT]):
     """A position in a queue for a client waiting for a connection."""
 
     __slots__ = ("conn", "error", "_cond")
 
     def __init__(self) -> None:
-        self.conn: Optional[AsyncConnection[Any]] = None
+        self.conn: Optional[ACT] = None
         self.error: Optional[BaseException] = None
 
         # The AsyncClient behaves in a way similar to an Event, but we need
@@ -669,7 +713,7 @@ class AsyncClient:
         # will be lost.
         self._cond = asyncio.Condition()
 
-    async def wait(self, timeout: float) -> AsyncConnection[Any]:
+    async def wait(self, timeout: float) -> ACT:
         """Wait for a connection to be set and return it.
 
         Raise an exception if the wait times out or if fail() is called.
@@ -691,7 +735,7 @@ class AsyncClient:
             assert self.error
             raise self.error
 
-    async def set(self, conn: AsyncConnection[Any]) -> bool:
+    async def set(self, conn: ACT) -> bool:
         """Signal the client waiting that a connection is ready.
 
         Return True if the client has "accepted" the connection, False
@@ -723,7 +767,7 @@ class AsyncClient:
 class MaintenanceTask(ABC):
     """A task to run asynchronously to maintain the pool state."""
 
-    def __init__(self, pool: "AsyncConnectionPool"):
+    def __init__(self, pool: "AsyncConnectionPool[Any]"):
         self.pool = ref(pool)
 
     def __repr__(self) -> str:
@@ -760,21 +804,21 @@ class MaintenanceTask(ABC):
         pool.run_task(self)
 
     @abstractmethod
-    async def _run(self, pool: "AsyncConnectionPool") -> None:
+    async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
         ...
 
 
 class StopWorker(MaintenanceTask):
     """Signal the maintenance worker to terminate."""
 
-    async def _run(self, pool: "AsyncConnectionPool") -> None:
+    async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
         pass
 
 
 class AddConnection(MaintenanceTask):
     def __init__(
         self,
-        pool: "AsyncConnectionPool",
+        pool: "AsyncConnectionPool[Any]",
         attempt: Optional["ConnectionAttempt"] = None,
         growing: bool = False,
     ):
@@ -782,18 +826,18 @@ class AddConnection(MaintenanceTask):
         self.attempt = attempt
         self.growing = growing
 
-    async def _run(self, pool: "AsyncConnectionPool") -> None:
+    async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
         await pool._add_connection(self.attempt, growing=self.growing)
 
 
 class ReturnConnection(MaintenanceTask):
     """Clean up and return a connection to the pool."""
 
-    def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"):
+    def __init__(self, pool: "AsyncConnectionPool[Any]", conn: ACT):
         super().__init__(pool)
         self.conn = conn
 
-    async def _run(self, pool: "AsyncConnectionPool") -> None:
+    async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
         await pool._return_connection(self.conn)
 
 
@@ -804,7 +848,7 @@ class ShrinkPool(MaintenanceTask):
     in the pool.
     """
 
-    async def _run(self, pool: "AsyncConnectionPool") -> None:
+    async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
         # Reschedule the task now so that in case of any error we don't lose
         # the periodic run.
         await pool.schedule_task(self, pool.max_idle)
@@ -820,7 +864,7 @@ class Schedule(MaintenanceTask):
 
     def __init__(
         self,
-        pool: "AsyncConnectionPool",
+        pool: "AsyncConnectionPool[Any]",
         task: MaintenanceTask,
         delay: float,
     ):
@@ -828,5 +872,5 @@ class Schedule(MaintenanceTask):
         self.task = task
         self.delay = delay
 
-    async def _run(self, pool: "AsyncConnectionPool") -> None:
+    async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
         await pool.schedule_task(self.task, self.delay)
index 51cc67ad9924233036c92c4d6aae1a8294ec037a..a1b271529f761f7d9b2d6f69d1e53e52c2f912bc 100644 (file)
@@ -1,13 +1,15 @@
 import logging
 from time import sleep, time
 from threading import Thread, Event
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
 
 import pytest
 from packaging.version import parse as ver  # noqa: F401  # used in skipif
 
 import psycopg
 from psycopg.pq import TransactionStatus
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type
 
 from .test_pool import delay_connection, ensure_waiting
 
@@ -54,6 +56,63 @@ def test_kwargs(dsn):
             assert conn.autocommit
 
 
+class MyRow(Dict[str, Any]):
+    ...
+
+
+def test_generic_connection_type(dsn):
+    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+        conn.autocommit = True
+
+    class MyConnection(psycopg.Connection[Row]):
+        pass
+
+    with NullConnectionPool(
+        dsn,
+        connection_class=MyConnection[MyRow],
+        kwargs={"row_factory": class_row(MyRow)},
+        configure=set_autocommit,
+    ) as p1:
+        with p1.connection() as conn1:
+            cur1 = conn1.execute("select 1 as x")
+            (row1,) = cur1.fetchall()
+
+    assert_type(p1, NullConnectionPool[MyConnection[MyRow]])
+    assert_type(conn1, MyConnection[MyRow])
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+    with NullConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
+        with p2.connection() as conn2:
+            (row2,) = conn2.execute("select 2 as y").fetchall()
+    assert_type(p2, NullConnectionPool[MyConnection[TupleRow]])
+    assert_type(conn2, MyConnection[TupleRow])
+    assert_type(row2, TupleRow)
+    assert row2 == (2,)
+
+
+def test_non_generic_connection_type(dsn):
+    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+        conn.autocommit = True
+
+    class MyConnection(psycopg.Connection[MyRow]):
+        def __init__(self, *args: Any, **kwargs: Any):
+            kwargs["row_factory"] = class_row(MyRow)
+            super().__init__(*args, **kwargs)
+
+    with NullConnectionPool(
+        dsn, connection_class=MyConnection, configure=set_autocommit
+    ) as p1:
+        with p1.connection() as conn1:
+            (row1,) = conn1.execute("select 1 as x").fetchall()
+    assert_type(p1, NullConnectionPool[MyConnection])
+    assert_type(conn1, MyConnection)
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+
 @pytest.mark.crdb_skip("backend pid")
 def test_its_no_pool_at_all(dsn):
     with NullConnectionPool(dsn, max_size=2) as p:
index be29b3dcb61b14b7e9c5e4ab881330a4e82353f5..47b0c88209b53715893fae18cba6cfd78a93ad8e 100644 (file)
@@ -1,14 +1,15 @@
 import asyncio
 import logging
 from time import time
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
 
 import pytest
 from packaging.version import parse as ver  # noqa: F401  # used in skipif
 
 import psycopg
 from psycopg.pq import TransactionStatus
-from psycopg._compat import create_task
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type, create_task
 from .test_pool_async import delay_connection, ensure_waiting
 
 pytestmark = [pytest.mark.anyio]
@@ -56,6 +57,65 @@ async def test_kwargs(dsn):
             assert conn.autocommit
 
 
+class MyRow(Dict[str, Any]):
+    ...
+
+
+async def test_generic_connection_type(dsn):
+    async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+        await conn.set_autocommit(True)
+
+    class MyConnection(psycopg.AsyncConnection[Row]):
+        pass
+
+    async with AsyncNullConnectionPool(
+        dsn,
+        connection_class=MyConnection[MyRow],
+        kwargs={"row_factory": class_row(MyRow)},
+        configure=set_autocommit,
+    ) as p1:
+        async with p1.connection() as conn1:
+            cur1 = await conn1.execute("select 1 as x")
+            (row1,) = await cur1.fetchall()
+    assert_type(p1, AsyncNullConnectionPool[MyConnection[MyRow]])
+    assert_type(conn1, MyConnection[MyRow])
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+    async with AsyncNullConnectionPool(
+        dsn, connection_class=MyConnection[TupleRow]
+    ) as p2:
+        async with p2.connection() as conn2:
+            cur2 = await conn2.execute("select 2 as y")
+            (row2,) = await cur2.fetchall()
+    assert_type(p2, AsyncNullConnectionPool[MyConnection[TupleRow]])
+    assert_type(conn2, MyConnection[TupleRow])
+    assert_type(row2, TupleRow)
+    assert row2 == (2,)
+
+
+async def test_non_generic_connection_type(dsn):
+    async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+        await conn.set_autocommit(True)
+
+    class MyConnection(psycopg.AsyncConnection[MyRow]):
+        def __init__(self, *args: Any, **kwargs: Any):
+            kwargs["row_factory"] = class_row(MyRow)
+            super().__init__(*args, **kwargs)
+
+    async with AsyncNullConnectionPool(
+        dsn, connection_class=MyConnection, configure=set_autocommit
+    ) as p1:
+        async with p1.connection() as conn1:
+            (row1,) = await (await conn1.execute("select 1 as x")).fetchall()
+    assert_type(p1, AsyncNullConnectionPool[MyConnection])
+    assert_type(conn1, MyConnection)
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+
 @pytest.mark.crdb_skip("backend pid")
 async def test_its_no_pool_at_all(dsn):
     async with AsyncNullConnectionPool(dsn, max_size=2) as p:
index 5b4c435d720ce9c7177e47a05e4b5c7f1e68d53a..7234a3c84c7da753e53927e1f2a26a9e72a817f0 100644 (file)
@@ -2,13 +2,14 @@ import logging
 import weakref
 from time import sleep, time
 from threading import Thread, Event
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
 
 import pytest
 
 import psycopg
 from psycopg.pq import TransactionStatus
-from psycopg._compat import Counter
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import Counter, assert_type
 
 try:
     import psycopg_pool as pool
@@ -64,6 +65,60 @@ def test_kwargs(dsn):
             assert conn.autocommit
 
 
+class MyRow(Dict[str, Any]):
+    ...
+
+
+def test_generic_connection_type(dsn):
+    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+        conn.autocommit = True
+
+    class MyConnection(psycopg.Connection[Row]):
+        pass
+
+    with pool.ConnectionPool(
+        dsn,
+        connection_class=MyConnection[MyRow],
+        kwargs=dict(row_factory=class_row(MyRow)),
+        configure=set_autocommit,
+    ) as p1, p1.connection() as conn1:
+        (row1,) = conn1.execute("select 1 as x").fetchall()
+    assert_type(p1, pool.ConnectionPool[MyConnection[MyRow]])
+    assert_type(conn1, MyConnection[MyRow])
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+    with pool.ConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
+        with p2.connection() as conn2:
+            (row2,) = conn2.execute("select 2 as y").fetchall()
+    assert_type(p2, pool.ConnectionPool[MyConnection[TupleRow]])
+    assert_type(conn2, MyConnection[TupleRow])
+    assert_type(row2, TupleRow)
+    assert row2 == (2,)
+
+
+def test_non_generic_connection_type(dsn):
+    def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+        conn.autocommit = True
+
+    class MyConnection(psycopg.Connection[MyRow]):
+        def __init__(self, *args: Any, **kwargs: Any):
+            kwargs["row_factory"] = class_row(MyRow)
+            super().__init__(*args, **kwargs)
+
+    with pool.ConnectionPool(
+        dsn, connection_class=MyConnection, configure=set_autocommit
+    ) as p1:
+        with p1.connection() as conn1:
+            (row1,) = conn1.execute("select 1 as x").fetchall()
+    assert_type(p1, pool.ConnectionPool[MyConnection])
+    assert_type(conn1, MyConnection)
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+
 @pytest.mark.crdb_skip("backend pid")
 def test_its_really_a_pool(dsn):
     with pool.ConnectionPool(dsn, min_size=2) as p:
index 67382535e26a5e9a8f0e948838579773d14ed2ce..31861df91570ff7e6d0ada9febd0c556723de957 100644 (file)
@@ -1,13 +1,14 @@
 import asyncio
 import logging
 from time import time
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
 
 import pytest
 
 import psycopg
 from psycopg.pq import TransactionStatus
-from psycopg._compat import create_task, Counter
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type, create_task, Counter
 
 try:
     import psycopg_pool as pool
@@ -57,6 +58,68 @@ async def test_kwargs(dsn):
             assert conn.autocommit
 
 
+class MyRow(Dict[str, Any]):
+    ...
+
+
+async def test_generic_connection_type(dsn):
+    async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+        await conn.set_autocommit(True)
+
+    class MyConnection(psycopg.AsyncConnection[Row]):
+        pass
+
+    async with pool.AsyncConnectionPool(
+        dsn,
+        connection_class=MyConnection[MyRow],
+        kwargs=dict(row_factory=class_row(MyRow)),
+        configure=set_autocommit,
+    ) as p1:
+        async with p1.connection() as conn1:
+            cur1 = await conn1.execute("select 1 as x")
+            (row1,) = await cur1.fetchall()
+    assert_type(p1, pool.AsyncConnectionPool[MyConnection[MyRow]])
+    assert_type(conn1, MyConnection[MyRow])
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+    async with pool.AsyncConnectionPool(
+        dsn, connection_class=MyConnection[TupleRow]
+    ) as p2:
+        async with p2.connection() as conn2:
+            cur2 = await conn2.execute("select 2 as y")
+            (row2,) = await cur2.fetchall()
+    assert_type(p2, pool.AsyncConnectionPool[MyConnection[TupleRow]])
+    assert_type(conn2, MyConnection[TupleRow])
+    assert_type(row2, TupleRow)
+    assert row2 == (2,)
+
+
+async def test_non_generic_connection_type(dsn):
+    async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+        await conn.set_autocommit(True)
+
+    class MyConnection(psycopg.AsyncConnection[MyRow]):
+        def __init__(self, *args: Any, **kwargs: Any):
+            kwargs["row_factory"] = class_row(MyRow)
+            super().__init__(*args, **kwargs)
+
+    async with pool.AsyncConnectionPool(
+        dsn,
+        connection_class=MyConnection,
+        configure=set_autocommit,
+    ) as p1:
+        async with p1.connection() as conn1:
+            cur1 = await conn1.execute("select 1 as x")
+            (row1,) = await cur1.fetchall()
+    assert_type(p1, pool.AsyncConnectionPool[MyConnection])
+    assert_type(conn1, MyConnection)
+    assert_type(row1, MyRow)
+    assert conn1.autocommit
+    assert row1 == {"x": 1}
+
+
 @pytest.mark.crdb_skip("backend pid")
 async def test_its_really_a_pool(dsn):
     async with pool.AsyncConnectionPool(dsn, min_size=2) as p: