From 70e1e9509a43e5d4b0c9943cbde30201b963d36b Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 27 Sep 2023 01:16:55 +0200 Subject: [PATCH] feat(pool): make pool generic on connection type --- docs/news_pool.rst | 1 + psycopg/psycopg/connection.py | 2 +- psycopg_pool/psycopg_pool/base.py | 16 ++- psycopg_pool/psycopg_pool/null_pool.py | 68 ++++++++-- psycopg_pool/psycopg_pool/null_pool_async.py | 68 ++++++++-- psycopg_pool/psycopg_pool/pool.py | 124 ++++++++++++------ psycopg_pool/psycopg_pool/pool_async.py | 128 +++++++++++++------ tests/pool/test_null_pool.py | 61 ++++++++- tests/pool/test_null_pool_async.py | 64 +++++++++- tests/pool/test_pool.py | 59 ++++++++- tests/pool/test_pool_async.py | 67 +++++++++- 11 files changed, 540 insertions(+), 118 deletions(-) diff --git a/docs/news_pool.rst b/docs/news_pool.rst index 1f2aeb8fd..043181d5d 100644 --- a/docs/news_pool.rst +++ b/docs/news_pool.rst @@ -22,6 +22,7 @@ psycopg_pool 3.1.9 (unreleased) - Fix the return type annotation of `!NullConnectionPool.__enter__()` (:ticket:`#540`). +- Make connection pool classes generic on the connection type (:ticket:`#559`). Current release diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index ca9305394..a6571d5d7 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -129,7 +129,7 @@ class BaseConnection(Generic[Row]): # Attribute is only set if the connection is from a pool so we can tell # apart a connection in the pool too (when _pool = None) - self._pool: Optional["BasePool[Any]"] + self._pool: Optional["BasePool"] self._pipeline: Optional[BasePipeline] = None diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index 6d66f5206..13823bccd 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -6,16 +6,18 @@ psycopg connection pool base class and functionalities. from time import monotonic from random import random -from typing import Any, Dict, Generic, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING from psycopg import errors as e -from psycopg.abc import ConnectionType from .errors import PoolClosed from ._compat import Counter, Deque +if TYPE_CHECKING: + from psycopg.connection import BaseConnection -class BasePool(Generic[ConnectionType]): + +class BasePool: # Used to generate pool names _num_pool = 0 @@ -36,6 +38,8 @@ class BasePool(Generic[ConnectionType]): _CONNECTIONS_ERRORS = "connections_errors" _CONNECTIONS_LOST = "connections_lost" + _pool: Deque["Any"] + def __init__( self, conninfo: str = "", @@ -74,7 +78,7 @@ class BasePool(Generic[ConnectionType]): self.num_workers = num_workers self._nconns = min_size # currently in the pool, out, being prepared - self._pool = Deque[ConnectionType]() + self._pool = Deque() self._stats = Counter[str]() # Min number of connections in the pool in a max_idle unit of time. @@ -138,7 +142,7 @@ class BasePool(Generic[ConnectionType]): else: raise PoolClosed(f"the pool {self.name!r} is not open yet") - def _check_pool_putconn(self, conn: ConnectionType) -> None: + def _check_pool_putconn(self, conn: "BaseConnection[Any]") -> None: pool = getattr(conn, "_pool", None) if pool is self: return @@ -188,7 +192,7 @@ class BasePool(Generic[ConnectionType]): """ return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc) - def _set_connection_expiry_date(self, conn: ConnectionType) -> None: + def _set_connection_expiry_date(self, conn: "BaseConnection[Any]") -> None: """Set an expiry date on a connection. Add some randomness to avoid mass reconnection. diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py index 9c796ea71..b4e640b74 100644 --- a/psycopg_pool/psycopg_pool/null_pool.py +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -6,12 +6,13 @@ Psycopg null connection pools import logging import threading -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Callable, cast, Dict, Optional, overload, Tuple, Type from psycopg import Connection from psycopg.pq import TransactionStatus +from psycopg.rows import TupleRow -from .pool import ConnectionPool, AddConnection, ConnectFailedCB +from .pool import ConnectionPool, CT, AddConnection, ConnectFailedCB from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout @@ -40,15 +41,60 @@ class _BaseNullConnectionPool: pass -class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): +class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]): + @overload + def __init__( + self: "NullConnectionPool[Connection[TupleRow]]", + conninfo: str = "", + *, + open: bool = ..., + configure: Optional[Callable[[CT], None]] = ..., + reset: Optional[Callable[[CT], None]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[ConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... + + @overload + def __init__( + self: "NullConnectionPool[CT]", + conninfo: str = "", + *, + open: bool = ..., + connection_class: Type[CT], + configure: Optional[Callable[[CT], None]] = ..., + reset: Optional[Callable[[CT], None]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[ConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... + def __init__( self, conninfo: str = "", *, open: bool = True, - connection_class: Type[Connection[Any]] = Connection, - configure: Optional[Callable[[Connection[Any]], None]] = None, - reset: Optional[Callable[[Connection[Any]], None]] = None, + connection_class: Type[CT] = cast(Type[CT], Connection), + configure: Optional[Callable[[CT], None]] = None, + reset: Optional[Callable[[CT], None]] = None, kwargs: Optional[Dict[str, Any]] = None, # Note: default value changed to 0. min_size: int = 0, @@ -109,10 +155,8 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): logger.info("pool %r is ready to use", self.name) - def _get_ready_connection( - self, timeout: Optional[float] - ) -> Optional[Connection[Any]]: - conn: Optional[Connection[Any]] = None + def _get_ready_connection(self, timeout: Optional[float]) -> Optional[CT]: + conn: Optional[CT] = None if self.max_size == 0 or self._nconns < self.max_size: # Create a new connection for the client try: @@ -129,7 +173,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): ) return conn - def _maybe_close_connection(self, conn: Connection[Any]) -> bool: + def _maybe_close_connection(self, conn: CT) -> bool: with self._lock: if not self._closed and self._waiting: return False @@ -162,7 +206,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool): """No-op, as the pool doesn't have connections in its state.""" pass - def _add_to_pool(self, conn: Connection[Any]) -> None: + def _add_to_pool(self, conn: CT) -> None: # Remove the pool reference from the connection before returning it # to the state, to avoid to create a reference loop. # Also disable the warning for open connection in conn.__del__ diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py index 83254cf38..ca9db8bcd 100644 --- a/psycopg_pool/psycopg_pool/null_pool_async.py +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -6,28 +6,74 @@ psycopg asynchronous null connection pool import asyncio import logging -from typing import Any, Awaitable, Callable, Dict, Optional, Type +from typing import Any, Awaitable, Callable, cast, Dict, Optional, overload, Type from psycopg import AsyncConnection from psycopg.pq import TransactionStatus +from psycopg.rows import TupleRow from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout from .null_pool import _BaseNullConnectionPool -from .pool_async import AsyncConnectionPool, AddConnection, AsyncConnectFailedCB +from .pool_async import AsyncConnectionPool, ACT, AddConnection, AsyncConnectFailedCB logger = logging.getLogger("psycopg.pool") -class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): +class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]): + @overload + def __init__( + self: "AsyncNullConnectionPool[AsyncConnection[TupleRow]]", + conninfo: str = "", + *, + open: bool = ..., + configure: Optional[Callable[[ACT], Awaitable[None]]] = ..., + reset: Optional[Callable[[ACT], Awaitable[None]]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[AsyncConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... + + @overload + def __init__( + self: "AsyncNullConnectionPool[ACT]", + conninfo: str = "", + *, + open: bool = ..., + connection_class: Type[ACT], + configure: Optional[Callable[[ACT], Awaitable[None]]] = ..., + reset: Optional[Callable[[ACT], Awaitable[None]]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[AsyncConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... + def __init__( self, conninfo: str = "", *, open: bool = True, - connection_class: Type[AsyncConnection[Any]] = AsyncConnection, - configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, - reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, + connection_class: Type[ACT] = cast(Type[ACT], AsyncConnection), + configure: Optional[Callable[[ACT], Awaitable[None]]] = None, + reset: Optional[Callable[[ACT], Awaitable[None]]] = None, kwargs: Optional[Dict[str, Any]] = None, # Note: default value changed to 0. min_size: int = 0, @@ -82,10 +128,8 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): logger.info("pool %r is ready to use", self.name) - async def _get_ready_connection( - self, timeout: Optional[float] - ) -> Optional[AsyncConnection[Any]]: - conn: Optional[AsyncConnection[Any]] = None + async def _get_ready_connection(self, timeout: Optional[float]) -> Optional[ACT]: + conn: Optional[ACT] = None if self.max_size == 0 or self._nconns < self.max_size: # Create a new connection for the client try: @@ -101,7 +145,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): ) return conn - async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool: + async def _maybe_close_connection(self, conn: ACT) -> bool: # Close the connection if no client is waiting for it, or if the pool # is closed. For extra refcare remove the pool reference from it. # Maintain the stats. @@ -132,7 +176,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool): async def check(self) -> None: pass - async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + async def _add_to_pool(self, conn: ACT) -> None: # Remove the pool reference from the connection before returning it # to the state, to avoid to create a reference loop. # Also disable the warning for open connection in conn.__del__ diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 47ecf084e..a7bc5b0de 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -10,8 +10,8 @@ from abc import ABC, abstractmethod from time import monotonic from queue import Queue, Empty from types import TracebackType -from typing import Any, Callable, Dict, Iterator, List -from typing import Optional, Sequence, Type, TypeVar +from typing import Any, Callable, cast, Dict, Generic, Iterator, List +from typing import Optional, overload, Sequence, Type, TypeVar from typing_extensions import TypeAlias from weakref import ref from contextlib import contextmanager @@ -19,6 +19,7 @@ from contextlib import contextmanager from psycopg import errors as e from psycopg import Connection from psycopg.pq import TransactionStatus +from psycopg.rows import TupleRow from .base import ConnectionAttempt, BasePool from .sched import Scheduler @@ -29,18 +30,66 @@ logger = logging.getLogger("psycopg.pool") ConnectFailedCB: TypeAlias = Callable[["ConnectionPool"], None] +CT = TypeVar("CT", bound="Connection[Any]") -class ConnectionPool(BasePool[Connection[Any]]): - _Self = TypeVar("_Self", bound="ConnectionPool") + +class ConnectionPool(Generic[CT], BasePool): + _Self = TypeVar("_Self", bound="ConnectionPool[CT]") + _pool: Deque[CT] + + @overload + def __init__( + self: "ConnectionPool[Connection[TupleRow]]", + conninfo: str = "", + *, + open: bool = ..., + configure: Optional[Callable[[CT], None]] = ..., + reset: Optional[Callable[[CT], None]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[ConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... + + @overload + def __init__( + self: "ConnectionPool[CT]", + conninfo: str = "", + *, + open: bool = ..., + connection_class: Type[CT], + configure: Optional[Callable[[CT], None]] = ..., + reset: Optional[Callable[[CT], None]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[ConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... def __init__( self, conninfo: str = "", *, open: bool = True, - connection_class: Type[Connection[Any]] = Connection, - configure: Optional[Callable[[Connection[Any]], None]] = None, - reset: Optional[Callable[[Connection[Any]], None]] = None, + connection_class: Type[CT] = cast(Type[CT], Connection[TupleRow]), + configure: Optional[Callable[[CT], None]] = None, + reset: Optional[Callable[[CT], None]] = None, kwargs: Optional[Dict[str, Any]] = None, min_size: int = 4, max_size: Optional[int] = None, @@ -61,7 +110,7 @@ class ConnectionPool(BasePool[Connection[Any]]): self._reconnect_failed = reconnect_failed or (lambda pool: None) self._lock = threading.RLock() - self._waiting = Deque["WaitingClient"]() + self._waiting = Deque["WaitingClient[CT]"]() # to notify that the pool is full self._pool_full_event: Optional[threading.Event] = None @@ -130,7 +179,7 @@ class ConnectionPool(BasePool[Connection[Any]]): logger.info("pool %r is ready to use", self.name) @contextmanager - def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]: + def connection(self, timeout: Optional[float] = None) -> Iterator[CT]: """Context manager to obtain a connection from the pool. Return the connection immediately if available, otherwise wait up to @@ -152,7 +201,7 @@ class ConnectionPool(BasePool[Connection[Any]]): self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) self.putconn(conn) - def getconn(self, timeout: Optional[float] = None) -> Connection[Any]: + def getconn(self, timeout: Optional[float] = None) -> CT: """Obtain a connection from the pool. You should preferably use `connection()`. Use this function only if @@ -173,7 +222,7 @@ class ConnectionPool(BasePool[Connection[Any]]): if not conn: # No connection available: put the client in the waiting queue t0 = monotonic() - pos = WaitingClient() + pos: WaitingClient[CT] = WaitingClient() self._waiting.append(pos) self._stats[self._REQUESTS_QUEUED] += 1 @@ -201,11 +250,9 @@ class ConnectionPool(BasePool[Connection[Any]]): logger.info("connection given by %r", self.name) return conn - def _get_ready_connection( - self, timeout: Optional[float] - ) -> Optional[Connection[Any]]: + def _get_ready_connection(self, timeout: Optional[float]) -> Optional[CT]: """Return a connection, if the client deserves one.""" - conn: Optional[Connection[Any]] = None + conn: Optional[CT] = None if self._pool: # Take a connection ready out of the pool conn = self._pool.popleft() @@ -229,7 +276,7 @@ class ConnectionPool(BasePool[Connection[Any]]): self._growing = True self.run_task(AddConnection(self, growing=True)) - def putconn(self, conn: Connection[Any]) -> None: + def putconn(self, conn: CT) -> None: """Return a connection to the loving hands of its pool. Use this function only paired with a `getconn()`. You don't need to use @@ -249,7 +296,7 @@ class ConnectionPool(BasePool[Connection[Any]]): else: self._return_connection(conn) - def _maybe_close_connection(self, conn: Connection[Any]) -> bool: + def _maybe_close_connection(self, conn: CT) -> bool: """Close a returned connection if necessary. Return `!True if the connection was closed. @@ -354,8 +401,8 @@ class ConnectionPool(BasePool[Connection[Any]]): def _stop_workers( self, - waiting_clients: Sequence["WaitingClient"] = (), - connections: Sequence[Connection[Any]] = (), + waiting_clients: Sequence["WaitingClient[CT]"] = (), + connections: Sequence[CT] = (), timeout: float = 0.0, ) -> None: # Stop the scheduler @@ -511,7 +558,7 @@ class ConnectionPool(BasePool[Connection[Any]]): ex, ) - def _connect(self, timeout: Optional[float] = None) -> Connection[Any]: + def _connect(self, timeout: Optional[float] = None) -> CT: """Return a new connection configured for the pool.""" self._stats[self._CONNECTIONS_NUM] += 1 kwargs = self.kwargs @@ -520,8 +567,9 @@ class ConnectionPool(BasePool[Connection[Any]]): kwargs["connect_timeout"] = max(round(timeout), 1) t0 = monotonic() try: - conn: Connection[Any] - conn = self.connection_class.connect(self.conninfo, **kwargs) + conn: CT = self.connection_class.connect( # type: ignore + self.conninfo, **kwargs + ) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise @@ -598,7 +646,7 @@ class ConnectionPool(BasePool[Connection[Any]]): else: self._growing = False - def _return_connection(self, conn: Connection[Any]) -> None: + def _return_connection(self, conn: CT) -> None: """ Return a connection to the pool after usage. """ @@ -619,7 +667,7 @@ class ConnectionPool(BasePool[Connection[Any]]): self._add_to_pool(conn) - def _add_to_pool(self, conn: Connection[Any]) -> None: + def _add_to_pool(self, conn: CT) -> None: """ Add a connection to the pool. @@ -651,7 +699,7 @@ class ConnectionPool(BasePool[Connection[Any]]): if self._pool_full_event and len(self._pool) >= self._min_size: self._pool_full_event.set() - def _reset_connection(self, conn: Connection[Any]) -> None: + def _reset_connection(self, conn: CT) -> None: """ Bring a connection to IDLE state or close it. """ @@ -693,7 +741,7 @@ class ConnectionPool(BasePool[Connection[Any]]): conn.close() def _shrink_pool(self) -> None: - to_close: Optional[Connection[Any]] = None + to_close: Optional[CT] = None with self._lock: # Reset the min number of connections used @@ -723,13 +771,13 @@ class ConnectionPool(BasePool[Connection[Any]]): return rv -class WaitingClient: +class WaitingClient(Generic[CT]): """A position in a queue for a client waiting for a connection.""" __slots__ = ("conn", "error", "_cond") def __init__(self) -> None: - self.conn: Optional[Connection[Any]] = None + self.conn: Optional[CT] = None self.error: Optional[BaseException] = None # The WaitingClient behaves in a way similar to an Event, but we need @@ -739,7 +787,7 @@ class WaitingClient: # will be lost. self._cond = threading.Condition() - def wait(self, timeout: float) -> Connection[Any]: + def wait(self, timeout: float) -> CT: """Wait for a connection to be set and return it. Raise an exception if the wait times out or if fail() is called. @@ -760,7 +808,7 @@ class WaitingClient: assert self.error raise self.error - def set(self, conn: Connection[Any]) -> bool: + def set(self, conn: CT) -> bool: """Signal the client waiting that a connection is ready. Return True if the client has "accepted" the connection, False @@ -792,7 +840,7 @@ class WaitingClient: class MaintenanceTask(ABC): """A task to run asynchronously to maintain the pool state.""" - def __init__(self, pool: "ConnectionPool"): + def __init__(self, pool: "ConnectionPool[Any]"): self.pool = ref(pool) def __repr__(self) -> str: @@ -830,21 +878,21 @@ class MaintenanceTask(ABC): pool.run_task(self) @abstractmethod - def _run(self, pool: "ConnectionPool") -> None: + def _run(self, pool: "ConnectionPool[Any]") -> None: ... class StopWorker(MaintenanceTask): """Signal the maintenance thread to terminate.""" - def _run(self, pool: "ConnectionPool") -> None: + def _run(self, pool: "ConnectionPool[Any]") -> None: pass class AddConnection(MaintenanceTask): def __init__( self, - pool: "ConnectionPool", + pool: "ConnectionPool[Any]", attempt: Optional["ConnectionAttempt"] = None, growing: bool = False, ): @@ -852,18 +900,18 @@ class AddConnection(MaintenanceTask): self.attempt = attempt self.growing = growing - def _run(self, pool: "ConnectionPool") -> None: + def _run(self, pool: "ConnectionPool[Any]") -> None: pool._add_connection(self.attempt, growing=self.growing) class ReturnConnection(MaintenanceTask): """Clean up and return a connection to the pool.""" - def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"): + def __init__(self, pool: "ConnectionPool[Any]", conn: CT): super().__init__(pool) self.conn = conn - def _run(self, pool: "ConnectionPool") -> None: + def _run(self, pool: "ConnectionPool[Any]") -> None: pool._return_connection(self.conn) @@ -874,7 +922,7 @@ class ShrinkPool(MaintenanceTask): in the pool. """ - def _run(self, pool: "ConnectionPool") -> None: + def _run(self, pool: "ConnectionPool[Any]") -> None: # Reschedule the task now so that in case of any error we don't lose # the periodic run. pool.schedule_task(self, pool.max_idle) diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index b57098e4e..9ee8e8554 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -9,8 +9,8 @@ import logging from abc import ABC, abstractmethod from time import monotonic from types import TracebackType -from typing import Any, AsyncIterator, Awaitable, Callable -from typing import Dict, List, Optional, Sequence, Type, TypeVar, Union +from typing import Any, AsyncIterator, Awaitable, Callable, cast, Generic +from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar, Union from typing_extensions import TypeAlias from weakref import ref from contextlib import asynccontextmanager @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager from psycopg import errors as e from psycopg import AsyncConnection from psycopg.pq import TransactionStatus +from psycopg.rows import TupleRow from .base import ConnectionAttempt, BasePool from .sched import AsyncScheduler @@ -31,18 +32,66 @@ AsyncConnectFailedCB: TypeAlias = Union[ Callable[["AsyncConnectionPool"], Awaitable[None]], ] +ACT = TypeVar("ACT", bound="AsyncConnection[Any]") -class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): - _Self = TypeVar("_Self", bound="AsyncConnectionPool") + +class AsyncConnectionPool(Generic[ACT], BasePool): + _Self = TypeVar("_Self", bound="AsyncConnectionPool[ACT]") + _pool: Deque[ACT] + + @overload + def __init__( + self: "AsyncConnectionPool[AsyncConnection[TupleRow]]", + conninfo: str = "", + *, + open: bool = ..., + configure: Optional[Callable[[ACT], Awaitable[None]]] = ..., + reset: Optional[Callable[[ACT], Awaitable[None]]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[AsyncConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... + + @overload + def __init__( + self: "AsyncConnectionPool[ACT]", + conninfo: str = "", + *, + open: bool = ..., + connection_class: Type[ACT], + configure: Optional[Callable[[ACT], Awaitable[None]]] = ..., + reset: Optional[Callable[[ACT], Awaitable[None]]] = ..., + kwargs: Optional[Dict[str, Any]] = ..., + min_size: int = ..., + max_size: Optional[int] = ..., + name: Optional[str] = ..., + timeout: float = ..., + max_waiting: int = ..., + max_lifetime: float = ..., + max_idle: float = ..., + reconnect_timeout: float = ..., + reconnect_failed: Optional[AsyncConnectFailedCB] = ..., + num_workers: int = ..., + ): + ... def __init__( self, conninfo: str = "", *, open: bool = True, - connection_class: Type[AsyncConnection[Any]] = AsyncConnection, - configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, - reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None, + connection_class: Type[ACT] = cast(Type[ACT], AsyncConnection), + configure: Optional[Callable[[ACT], Awaitable[None]]] = None, + reset: Optional[Callable[[ACT], Awaitable[None]]] = None, kwargs: Optional[Dict[str, Any]] = None, min_size: int = 4, max_size: Optional[int] = None, @@ -67,7 +116,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): self._sched: AsyncScheduler self._tasks: "asyncio.Queue[MaintenanceTask]" - self._waiting = Deque["AsyncClient"]() + self._waiting = Deque["AsyncClient[ACT]"]() # to notify that the pool is full self._pool_full_event: Optional[asyncio.Event] = None @@ -118,9 +167,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): logger.info("pool %r is ready to use", self.name) @asynccontextmanager - async def connection( - self, timeout: Optional[float] = None - ) -> AsyncIterator[AsyncConnection[Any]]: + async def connection(self, timeout: Optional[float] = None) -> AsyncIterator[ACT]: conn = await self.getconn(timeout=timeout) t0 = monotonic() try: @@ -131,7 +178,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0)) await self.putconn(conn) - async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]: + async def getconn(self, timeout: Optional[float] = None) -> ACT: logger.info("connection requested from %r", self.name) self._stats[self._REQUESTS_NUM] += 1 @@ -144,7 +191,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): if not conn: # No connection available: put the client in the waiting queue t0 = monotonic() - pos = AsyncClient() + pos: AsyncClient[ACT] = AsyncClient() self._waiting.append(pos) self._stats[self._REQUESTS_QUEUED] += 1 @@ -172,10 +219,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): logger.info("connection given by %r", self.name) return conn - async def _get_ready_connection( - self, timeout: Optional[float] - ) -> Optional[AsyncConnection[Any]]: - conn: Optional[AsyncConnection[Any]] = None + async def _get_ready_connection(self, timeout: Optional[float]) -> Optional[ACT]: + conn: Optional[ACT] = None if self._pool: # Take a connection ready out of the pool conn = self._pool.popleft() @@ -198,7 +243,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): self._growing = True self.run_task(AddConnection(self, growing=True)) - async def putconn(self, conn: AsyncConnection[Any]) -> None: + async def putconn(self, conn: ACT) -> None: self._check_pool_putconn(conn) logger.info("returning connection to %r", self.name) @@ -211,7 +256,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): else: await self._return_connection(conn) - async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool: + async def _maybe_close_connection(self, conn: ACT) -> bool: # If the pool is closed just close the connection instead of returning # it to the pool. For extra refcare remove the pool reference from it. if not self._closed: @@ -299,8 +344,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): async def _stop_workers( self, - waiting_clients: Sequence["AsyncClient"] = (), - connections: Sequence[AsyncConnection[Any]] = (), + waiting_clients: Sequence["AsyncClient[ACT]"] = (), + connections: Sequence[ACT] = (), timeout: float = 0.0, ) -> None: # Stop the scheduler @@ -442,7 +487,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): ex, ) - async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]: + async def _connect(self, timeout: Optional[float] = None) -> ACT: self._stats[self._CONNECTIONS_NUM] += 1 kwargs = self.kwargs if timeout: @@ -450,8 +495,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): kwargs["connect_timeout"] = max(round(timeout), 1) t0 = monotonic() try: - conn: AsyncConnection[Any] - conn = await self.connection_class.connect(self.conninfo, **kwargs) + conn: ACT = await self.connection_class.connect(self.conninfo, **kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 raise @@ -528,7 +572,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): else: self._growing = False - async def _return_connection(self, conn: AsyncConnection[Any]) -> None: + async def _return_connection(self, conn: ACT) -> None: """ Return a connection to the pool after usage. """ @@ -549,7 +593,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): await self._add_to_pool(conn) - async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None: + async def _add_to_pool(self, conn: ACT) -> None: """ Add a connection to the pool. @@ -581,7 +625,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): if self._pool_full_event and len(self._pool) >= self._min_size: self._pool_full_event.set() - async def _reset_connection(self, conn: AsyncConnection[Any]) -> None: + async def _reset_connection(self, conn: ACT) -> None: """ Bring a connection to IDLE state or close it. """ @@ -623,7 +667,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): await conn.close() async def _shrink_pool(self) -> None: - to_close: Optional[AsyncConnection[Any]] = None + to_close: Optional[ACT] = None async with self._lock: # Reset the min number of connections used @@ -653,13 +697,13 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): return rv -class AsyncClient: +class AsyncClient(Generic[ACT]): """A position in a queue for a client waiting for a connection.""" __slots__ = ("conn", "error", "_cond") def __init__(self) -> None: - self.conn: Optional[AsyncConnection[Any]] = None + self.conn: Optional[ACT] = None self.error: Optional[BaseException] = None # The AsyncClient behaves in a way similar to an Event, but we need @@ -669,7 +713,7 @@ class AsyncClient: # will be lost. self._cond = asyncio.Condition() - async def wait(self, timeout: float) -> AsyncConnection[Any]: + async def wait(self, timeout: float) -> ACT: """Wait for a connection to be set and return it. Raise an exception if the wait times out or if fail() is called. @@ -691,7 +735,7 @@ class AsyncClient: assert self.error raise self.error - async def set(self, conn: AsyncConnection[Any]) -> bool: + async def set(self, conn: ACT) -> bool: """Signal the client waiting that a connection is ready. Return True if the client has "accepted" the connection, False @@ -723,7 +767,7 @@ class AsyncClient: class MaintenanceTask(ABC): """A task to run asynchronously to maintain the pool state.""" - def __init__(self, pool: "AsyncConnectionPool"): + def __init__(self, pool: "AsyncConnectionPool[Any]"): self.pool = ref(pool) def __repr__(self) -> str: @@ -760,21 +804,21 @@ class MaintenanceTask(ABC): pool.run_task(self) @abstractmethod - async def _run(self, pool: "AsyncConnectionPool") -> None: + async def _run(self, pool: "AsyncConnectionPool[Any]") -> None: ... class StopWorker(MaintenanceTask): """Signal the maintenance worker to terminate.""" - async def _run(self, pool: "AsyncConnectionPool") -> None: + async def _run(self, pool: "AsyncConnectionPool[Any]") -> None: pass class AddConnection(MaintenanceTask): def __init__( self, - pool: "AsyncConnectionPool", + pool: "AsyncConnectionPool[Any]", attempt: Optional["ConnectionAttempt"] = None, growing: bool = False, ): @@ -782,18 +826,18 @@ class AddConnection(MaintenanceTask): self.attempt = attempt self.growing = growing - async def _run(self, pool: "AsyncConnectionPool") -> None: + async def _run(self, pool: "AsyncConnectionPool[Any]") -> None: await pool._add_connection(self.attempt, growing=self.growing) class ReturnConnection(MaintenanceTask): """Clean up and return a connection to the pool.""" - def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"): + def __init__(self, pool: "AsyncConnectionPool[Any]", conn: ACT): super().__init__(pool) self.conn = conn - async def _run(self, pool: "AsyncConnectionPool") -> None: + async def _run(self, pool: "AsyncConnectionPool[Any]") -> None: await pool._return_connection(self.conn) @@ -804,7 +848,7 @@ class ShrinkPool(MaintenanceTask): in the pool. """ - async def _run(self, pool: "AsyncConnectionPool") -> None: + async def _run(self, pool: "AsyncConnectionPool[Any]") -> None: # Reschedule the task now so that in case of any error we don't lose # the periodic run. await pool.schedule_task(self, pool.max_idle) @@ -820,7 +864,7 @@ class Schedule(MaintenanceTask): def __init__( self, - pool: "AsyncConnectionPool", + pool: "AsyncConnectionPool[Any]", task: MaintenanceTask, delay: float, ): @@ -828,5 +872,5 @@ class Schedule(MaintenanceTask): self.task = task self.delay = delay - async def _run(self, pool: "AsyncConnectionPool") -> None: + async def _run(self, pool: "AsyncConnectionPool[Any]") -> None: await pool.schedule_task(self.task, self.delay) diff --git a/tests/pool/test_null_pool.py b/tests/pool/test_null_pool.py index 51cc67ad9..a1b271529 100644 --- a/tests/pool/test_null_pool.py +++ b/tests/pool/test_null_pool.py @@ -1,13 +1,15 @@ import logging from time import sleep, time from threading import Thread, Event -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import pytest from packaging.version import parse as ver # noqa: F401 # used in skipif import psycopg from psycopg.pq import TransactionStatus +from psycopg.rows import class_row, Row, TupleRow +from psycopg._compat import assert_type from .test_pool import delay_connection, ensure_waiting @@ -54,6 +56,63 @@ def test_kwargs(dsn): assert conn.autocommit +class MyRow(Dict[str, Any]): + ... + + +def test_generic_connection_type(dsn): + def set_autocommit(conn: psycopg.Connection[Any]) -> None: + conn.autocommit = True + + class MyConnection(psycopg.Connection[Row]): + pass + + with NullConnectionPool( + dsn, + connection_class=MyConnection[MyRow], + kwargs={"row_factory": class_row(MyRow)}, + configure=set_autocommit, + ) as p1: + with p1.connection() as conn1: + cur1 = conn1.execute("select 1 as x") + (row1,) = cur1.fetchall() + + assert_type(p1, NullConnectionPool[MyConnection[MyRow]]) + assert_type(conn1, MyConnection[MyRow]) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + with NullConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2: + with p2.connection() as conn2: + (row2,) = conn2.execute("select 2 as y").fetchall() + assert_type(p2, NullConnectionPool[MyConnection[TupleRow]]) + assert_type(conn2, MyConnection[TupleRow]) + assert_type(row2, TupleRow) + assert row2 == (2,) + + +def test_non_generic_connection_type(dsn): + def set_autocommit(conn: psycopg.Connection[Any]) -> None: + conn.autocommit = True + + class MyConnection(psycopg.Connection[MyRow]): + def __init__(self, *args: Any, **kwargs: Any): + kwargs["row_factory"] = class_row(MyRow) + super().__init__(*args, **kwargs) + + with NullConnectionPool( + dsn, connection_class=MyConnection, configure=set_autocommit + ) as p1: + with p1.connection() as conn1: + (row1,) = conn1.execute("select 1 as x").fetchall() + assert_type(p1, NullConnectionPool[MyConnection]) + assert_type(conn1, MyConnection) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + @pytest.mark.crdb_skip("backend pid") def test_its_no_pool_at_all(dsn): with NullConnectionPool(dsn, max_size=2) as p: diff --git a/tests/pool/test_null_pool_async.py b/tests/pool/test_null_pool_async.py index be29b3dcb..47b0c8820 100644 --- a/tests/pool/test_null_pool_async.py +++ b/tests/pool/test_null_pool_async.py @@ -1,14 +1,15 @@ import asyncio import logging from time import time -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import pytest from packaging.version import parse as ver # noqa: F401 # used in skipif import psycopg from psycopg.pq import TransactionStatus -from psycopg._compat import create_task +from psycopg.rows import class_row, Row, TupleRow +from psycopg._compat import assert_type, create_task from .test_pool_async import delay_connection, ensure_waiting pytestmark = [pytest.mark.anyio] @@ -56,6 +57,65 @@ async def test_kwargs(dsn): assert conn.autocommit +class MyRow(Dict[str, Any]): + ... + + +async def test_generic_connection_type(dsn): + async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None: + await conn.set_autocommit(True) + + class MyConnection(psycopg.AsyncConnection[Row]): + pass + + async with AsyncNullConnectionPool( + dsn, + connection_class=MyConnection[MyRow], + kwargs={"row_factory": class_row(MyRow)}, + configure=set_autocommit, + ) as p1: + async with p1.connection() as conn1: + cur1 = await conn1.execute("select 1 as x") + (row1,) = await cur1.fetchall() + assert_type(p1, AsyncNullConnectionPool[MyConnection[MyRow]]) + assert_type(conn1, MyConnection[MyRow]) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + async with AsyncNullConnectionPool( + dsn, connection_class=MyConnection[TupleRow] + ) as p2: + async with p2.connection() as conn2: + cur2 = await conn2.execute("select 2 as y") + (row2,) = await cur2.fetchall() + assert_type(p2, AsyncNullConnectionPool[MyConnection[TupleRow]]) + assert_type(conn2, MyConnection[TupleRow]) + assert_type(row2, TupleRow) + assert row2 == (2,) + + +async def test_non_generic_connection_type(dsn): + async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None: + await conn.set_autocommit(True) + + class MyConnection(psycopg.AsyncConnection[MyRow]): + def __init__(self, *args: Any, **kwargs: Any): + kwargs["row_factory"] = class_row(MyRow) + super().__init__(*args, **kwargs) + + async with AsyncNullConnectionPool( + dsn, connection_class=MyConnection, configure=set_autocommit + ) as p1: + async with p1.connection() as conn1: + (row1,) = await (await conn1.execute("select 1 as x")).fetchall() + assert_type(p1, AsyncNullConnectionPool[MyConnection]) + assert_type(conn1, MyConnection) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + @pytest.mark.crdb_skip("backend pid") async def test_its_no_pool_at_all(dsn): async with AsyncNullConnectionPool(dsn, max_size=2) as p: diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 5b4c435d7..7234a3c84 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -2,13 +2,14 @@ import logging import weakref from time import sleep, time from threading import Thread, Event -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import pytest import psycopg from psycopg.pq import TransactionStatus -from psycopg._compat import Counter +from psycopg.rows import class_row, Row, TupleRow +from psycopg._compat import Counter, assert_type try: import psycopg_pool as pool @@ -64,6 +65,60 @@ def test_kwargs(dsn): assert conn.autocommit +class MyRow(Dict[str, Any]): + ... + + +def test_generic_connection_type(dsn): + def set_autocommit(conn: psycopg.Connection[Any]) -> None: + conn.autocommit = True + + class MyConnection(psycopg.Connection[Row]): + pass + + with pool.ConnectionPool( + dsn, + connection_class=MyConnection[MyRow], + kwargs=dict(row_factory=class_row(MyRow)), + configure=set_autocommit, + ) as p1, p1.connection() as conn1: + (row1,) = conn1.execute("select 1 as x").fetchall() + assert_type(p1, pool.ConnectionPool[MyConnection[MyRow]]) + assert_type(conn1, MyConnection[MyRow]) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + with pool.ConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2: + with p2.connection() as conn2: + (row2,) = conn2.execute("select 2 as y").fetchall() + assert_type(p2, pool.ConnectionPool[MyConnection[TupleRow]]) + assert_type(conn2, MyConnection[TupleRow]) + assert_type(row2, TupleRow) + assert row2 == (2,) + + +def test_non_generic_connection_type(dsn): + def set_autocommit(conn: psycopg.Connection[Any]) -> None: + conn.autocommit = True + + class MyConnection(psycopg.Connection[MyRow]): + def __init__(self, *args: Any, **kwargs: Any): + kwargs["row_factory"] = class_row(MyRow) + super().__init__(*args, **kwargs) + + with pool.ConnectionPool( + dsn, connection_class=MyConnection, configure=set_autocommit + ) as p1: + with p1.connection() as conn1: + (row1,) = conn1.execute("select 1 as x").fetchall() + assert_type(p1, pool.ConnectionPool[MyConnection]) + assert_type(conn1, MyConnection) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + @pytest.mark.crdb_skip("backend pid") def test_its_really_a_pool(dsn): with pool.ConnectionPool(dsn, min_size=2) as p: diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 67382535e..31861df91 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -1,13 +1,14 @@ import asyncio import logging from time import time -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import pytest import psycopg from psycopg.pq import TransactionStatus -from psycopg._compat import create_task, Counter +from psycopg.rows import class_row, Row, TupleRow +from psycopg._compat import assert_type, create_task, Counter try: import psycopg_pool as pool @@ -57,6 +58,68 @@ async def test_kwargs(dsn): assert conn.autocommit +class MyRow(Dict[str, Any]): + ... + + +async def test_generic_connection_type(dsn): + async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None: + await conn.set_autocommit(True) + + class MyConnection(psycopg.AsyncConnection[Row]): + pass + + async with pool.AsyncConnectionPool( + dsn, + connection_class=MyConnection[MyRow], + kwargs=dict(row_factory=class_row(MyRow)), + configure=set_autocommit, + ) as p1: + async with p1.connection() as conn1: + cur1 = await conn1.execute("select 1 as x") + (row1,) = await cur1.fetchall() + assert_type(p1, pool.AsyncConnectionPool[MyConnection[MyRow]]) + assert_type(conn1, MyConnection[MyRow]) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + async with pool.AsyncConnectionPool( + dsn, connection_class=MyConnection[TupleRow] + ) as p2: + async with p2.connection() as conn2: + cur2 = await conn2.execute("select 2 as y") + (row2,) = await cur2.fetchall() + assert_type(p2, pool.AsyncConnectionPool[MyConnection[TupleRow]]) + assert_type(conn2, MyConnection[TupleRow]) + assert_type(row2, TupleRow) + assert row2 == (2,) + + +async def test_non_generic_connection_type(dsn): + async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None: + await conn.set_autocommit(True) + + class MyConnection(psycopg.AsyncConnection[MyRow]): + def __init__(self, *args: Any, **kwargs: Any): + kwargs["row_factory"] = class_row(MyRow) + super().__init__(*args, **kwargs) + + async with pool.AsyncConnectionPool( + dsn, + connection_class=MyConnection, + configure=set_autocommit, + ) as p1: + async with p1.connection() as conn1: + cur1 = await conn1.execute("select 1 as x") + (row1,) = await cur1.fetchall() + assert_type(p1, pool.AsyncConnectionPool[MyConnection]) + assert_type(conn1, MyConnection) + assert_type(row1, MyRow) + assert conn1.autocommit + assert row1 == {"x": 1} + + @pytest.mark.crdb_skip("backend pid") async def test_its_really_a_pool(dsn): async with pool.AsyncConnectionPool(dsn, min_size=2) as p: -- 2.47.3