- Fix the return type annotation of `!NullConnectionPool.__enter__()`
(:ticket:`#540`).
+- Make connection pool classes generic on the connection type (:ticket:`#559`).
Current release
# Attribute is only set if the connection is from a pool so we can tell
# apart a connection in the pool too (when _pool = None)
- self._pool: Optional["BasePool[Any]"]
+ self._pool: Optional["BasePool"]
self._pipeline: Optional[BasePipeline] = None
from time import monotonic
from random import random
-from typing import Any, Dict, Generic, Optional, Tuple
+from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from psycopg import errors as e
-from psycopg.abc import ConnectionType
from .errors import PoolClosed
from ._compat import Counter, Deque
+if TYPE_CHECKING:
+ from psycopg.connection import BaseConnection
-class BasePool(Generic[ConnectionType]):
+
+class BasePool:
# Used to generate pool names
_num_pool = 0
_CONNECTIONS_ERRORS = "connections_errors"
_CONNECTIONS_LOST = "connections_lost"
+ _pool: Deque["Any"]
+
def __init__(
self,
conninfo: str = "",
self.num_workers = num_workers
self._nconns = min_size # currently in the pool, out, being prepared
- self._pool = Deque[ConnectionType]()
+ self._pool = Deque()
self._stats = Counter[str]()
# Min number of connections in the pool in a max_idle unit of time.
else:
raise PoolClosed(f"the pool {self.name!r} is not open yet")
- def _check_pool_putconn(self, conn: ConnectionType) -> None:
+ def _check_pool_putconn(self, conn: "BaseConnection[Any]") -> None:
pool = getattr(conn, "_pool", None)
if pool is self:
return
"""
return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc)
- def _set_connection_expiry_date(self, conn: ConnectionType) -> None:
+ def _set_connection_expiry_date(self, conn: "BaseConnection[Any]") -> None:
"""Set an expiry date on a connection.
Add some randomness to avoid mass reconnection.
import logging
import threading
-from typing import Any, Callable, Dict, Optional, Tuple, Type
+from typing import Any, Callable, cast, Dict, Optional, overload, Tuple, Type
from psycopg import Connection
from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
-from .pool import ConnectionPool, AddConnection, ConnectFailedCB
+from .pool import ConnectionPool, CT, AddConnection, ConnectFailedCB
from .errors import PoolTimeout, TooManyRequests
from ._compat import ConnectionTimeout
pass
-class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
+ @overload
+ def __init__(
+ self: "NullConnectionPool[Connection[TupleRow]]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ configure: Optional[Callable[[CT], None]] = ...,
+ reset: Optional[Callable[[CT], None]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[ConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "NullConnectionPool[CT]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ connection_class: Type[CT],
+ configure: Optional[Callable[[CT], None]] = ...,
+ reset: Optional[Callable[[CT], None]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[ConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
+
def __init__(
self,
conninfo: str = "",
*,
open: bool = True,
- connection_class: Type[Connection[Any]] = Connection,
- configure: Optional[Callable[[Connection[Any]], None]] = None,
- reset: Optional[Callable[[Connection[Any]], None]] = None,
+ connection_class: Type[CT] = cast(Type[CT], Connection),
+ configure: Optional[Callable[[CT], None]] = None,
+ reset: Optional[Callable[[CT], None]] = None,
kwargs: Optional[Dict[str, Any]] = None,
# Note: default value changed to 0.
min_size: int = 0,
logger.info("pool %r is ready to use", self.name)
- def _get_ready_connection(
- self, timeout: Optional[float]
- ) -> Optional[Connection[Any]]:
- conn: Optional[Connection[Any]] = None
+ def _get_ready_connection(self, timeout: Optional[float]) -> Optional[CT]:
+ conn: Optional[CT] = None
if self.max_size == 0 or self._nconns < self.max_size:
# Create a new connection for the client
try:
)
return conn
- def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+ def _maybe_close_connection(self, conn: CT) -> bool:
with self._lock:
if not self._closed and self._waiting:
return False
"""No-op, as the pool doesn't have connections in its state."""
pass
- def _add_to_pool(self, conn: Connection[Any]) -> None:
+ def _add_to_pool(self, conn: CT) -> None:
# Remove the pool reference from the connection before returning it
# to the state, to avoid to create a reference loop.
# Also disable the warning for open connection in conn.__del__
import asyncio
import logging
-from typing import Any, Awaitable, Callable, Dict, Optional, Type
+from typing import Any, Awaitable, Callable, cast, Dict, Optional, overload, Type
from psycopg import AsyncConnection
from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
from .errors import PoolTimeout, TooManyRequests
from ._compat import ConnectionTimeout
from .null_pool import _BaseNullConnectionPool
-from .pool_async import AsyncConnectionPool, AddConnection, AsyncConnectFailedCB
+from .pool_async import AsyncConnectionPool, ACT, AddConnection, AsyncConnectFailedCB
logger = logging.getLogger("psycopg.pool")
-class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]):
+ @overload
+ def __init__(
+ self: "AsyncNullConnectionPool[AsyncConnection[TupleRow]]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncNullConnectionPool[ACT]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ connection_class: Type[ACT],
+ configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
+
def __init__(
self,
conninfo: str = "",
*,
open: bool = True,
- connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
- configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
- reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ connection_class: Type[ACT] = cast(Type[ACT], AsyncConnection),
+ configure: Optional[Callable[[ACT], Awaitable[None]]] = None,
+ reset: Optional[Callable[[ACT], Awaitable[None]]] = None,
kwargs: Optional[Dict[str, Any]] = None,
# Note: default value changed to 0.
min_size: int = 0,
logger.info("pool %r is ready to use", self.name)
- async def _get_ready_connection(
- self, timeout: Optional[float]
- ) -> Optional[AsyncConnection[Any]]:
- conn: Optional[AsyncConnection[Any]] = None
+ async def _get_ready_connection(self, timeout: Optional[float]) -> Optional[ACT]:
+ conn: Optional[ACT] = None
if self.max_size == 0 or self._nconns < self.max_size:
# Create a new connection for the client
try:
)
return conn
- async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+ async def _maybe_close_connection(self, conn: ACT) -> bool:
# Close the connection if no client is waiting for it, or if the pool
# is closed. For extra refcare remove the pool reference from it.
# Maintain the stats.
async def check(self) -> None:
pass
- async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ async def _add_to_pool(self, conn: ACT) -> None:
# Remove the pool reference from the connection before returning it
# to the state, to avoid to create a reference loop.
# Also disable the warning for open connection in conn.__del__
from time import monotonic
from queue import Queue, Empty
from types import TracebackType
-from typing import Any, Callable, Dict, Iterator, List
-from typing import Optional, Sequence, Type, TypeVar
+from typing import Any, Callable, cast, Dict, Generic, Iterator, List
+from typing import Optional, overload, Sequence, Type, TypeVar
from typing_extensions import TypeAlias
from weakref import ref
from contextlib import contextmanager
from psycopg import errors as e
from psycopg import Connection
from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
from .base import ConnectionAttempt, BasePool
from .sched import Scheduler
ConnectFailedCB: TypeAlias = Callable[["ConnectionPool"], None]
+CT = TypeVar("CT", bound="Connection[Any]")
-class ConnectionPool(BasePool[Connection[Any]]):
- _Self = TypeVar("_Self", bound="ConnectionPool")
+
+class ConnectionPool(Generic[CT], BasePool):
+ _Self = TypeVar("_Self", bound="ConnectionPool[CT]")
+ _pool: Deque[CT]
+
+ @overload
+ def __init__(
+ self: "ConnectionPool[Connection[TupleRow]]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ configure: Optional[Callable[[CT], None]] = ...,
+ reset: Optional[Callable[[CT], None]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[ConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "ConnectionPool[CT]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ connection_class: Type[CT],
+ configure: Optional[Callable[[CT], None]] = ...,
+ reset: Optional[Callable[[CT], None]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[ConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
def __init__(
self,
conninfo: str = "",
*,
open: bool = True,
- connection_class: Type[Connection[Any]] = Connection,
- configure: Optional[Callable[[Connection[Any]], None]] = None,
- reset: Optional[Callable[[Connection[Any]], None]] = None,
+ connection_class: Type[CT] = cast(Type[CT], Connection[TupleRow]),
+ configure: Optional[Callable[[CT], None]] = None,
+ reset: Optional[Callable[[CT], None]] = None,
kwargs: Optional[Dict[str, Any]] = None,
min_size: int = 4,
max_size: Optional[int] = None,
self._reconnect_failed = reconnect_failed or (lambda pool: None)
self._lock = threading.RLock()
- self._waiting = Deque["WaitingClient"]()
+ self._waiting = Deque["WaitingClient[CT]"]()
# to notify that the pool is full
self._pool_full_event: Optional[threading.Event] = None
logger.info("pool %r is ready to use", self.name)
@contextmanager
- def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]:
+ def connection(self, timeout: Optional[float] = None) -> Iterator[CT]:
"""Context manager to obtain a connection from the pool.
Return the connection immediately if available, otherwise wait up to
self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
self.putconn(conn)
- def getconn(self, timeout: Optional[float] = None) -> Connection[Any]:
+ def getconn(self, timeout: Optional[float] = None) -> CT:
"""Obtain a connection from the pool.
You should preferably use `connection()`. Use this function only if
if not conn:
# No connection available: put the client in the waiting queue
t0 = monotonic()
- pos = WaitingClient()
+ pos: WaitingClient[CT] = WaitingClient()
self._waiting.append(pos)
self._stats[self._REQUESTS_QUEUED] += 1
logger.info("connection given by %r", self.name)
return conn
- def _get_ready_connection(
- self, timeout: Optional[float]
- ) -> Optional[Connection[Any]]:
+ def _get_ready_connection(self, timeout: Optional[float]) -> Optional[CT]:
"""Return a connection, if the client deserves one."""
- conn: Optional[Connection[Any]] = None
+ conn: Optional[CT] = None
if self._pool:
# Take a connection ready out of the pool
conn = self._pool.popleft()
self._growing = True
self.run_task(AddConnection(self, growing=True))
- def putconn(self, conn: Connection[Any]) -> None:
+ def putconn(self, conn: CT) -> None:
"""Return a connection to the loving hands of its pool.
Use this function only paired with a `getconn()`. You don't need to use
else:
self._return_connection(conn)
- def _maybe_close_connection(self, conn: Connection[Any]) -> bool:
+ def _maybe_close_connection(self, conn: CT) -> bool:
"""Close a returned connection if necessary.
Return `!True if the connection was closed.
def _stop_workers(
self,
- waiting_clients: Sequence["WaitingClient"] = (),
- connections: Sequence[Connection[Any]] = (),
+ waiting_clients: Sequence["WaitingClient[CT]"] = (),
+ connections: Sequence[CT] = (),
timeout: float = 0.0,
) -> None:
# Stop the scheduler
ex,
)
- def _connect(self, timeout: Optional[float] = None) -> Connection[Any]:
+ def _connect(self, timeout: Optional[float] = None) -> CT:
"""Return a new connection configured for the pool."""
self._stats[self._CONNECTIONS_NUM] += 1
kwargs = self.kwargs
kwargs["connect_timeout"] = max(round(timeout), 1)
t0 = monotonic()
try:
- conn: Connection[Any]
- conn = self.connection_class.connect(self.conninfo, **kwargs)
+ conn: CT = self.connection_class.connect( # type: ignore
+ self.conninfo, **kwargs
+ )
except Exception:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
else:
self._growing = False
- def _return_connection(self, conn: Connection[Any]) -> None:
+ def _return_connection(self, conn: CT) -> None:
"""
Return a connection to the pool after usage.
"""
self._add_to_pool(conn)
- def _add_to_pool(self, conn: Connection[Any]) -> None:
+ def _add_to_pool(self, conn: CT) -> None:
"""
Add a connection to the pool.
if self._pool_full_event and len(self._pool) >= self._min_size:
self._pool_full_event.set()
- def _reset_connection(self, conn: Connection[Any]) -> None:
+ def _reset_connection(self, conn: CT) -> None:
"""
Bring a connection to IDLE state or close it.
"""
conn.close()
def _shrink_pool(self) -> None:
- to_close: Optional[Connection[Any]] = None
+ to_close: Optional[CT] = None
with self._lock:
# Reset the min number of connections used
return rv
-class WaitingClient:
+class WaitingClient(Generic[CT]):
"""A position in a queue for a client waiting for a connection."""
__slots__ = ("conn", "error", "_cond")
def __init__(self) -> None:
- self.conn: Optional[Connection[Any]] = None
+ self.conn: Optional[CT] = None
self.error: Optional[BaseException] = None
# The WaitingClient behaves in a way similar to an Event, but we need
# will be lost.
self._cond = threading.Condition()
- def wait(self, timeout: float) -> Connection[Any]:
+ def wait(self, timeout: float) -> CT:
"""Wait for a connection to be set and return it.
Raise an exception if the wait times out or if fail() is called.
assert self.error
raise self.error
- def set(self, conn: Connection[Any]) -> bool:
+ def set(self, conn: CT) -> bool:
"""Signal the client waiting that a connection is ready.
Return True if the client has "accepted" the connection, False
class MaintenanceTask(ABC):
"""A task to run asynchronously to maintain the pool state."""
- def __init__(self, pool: "ConnectionPool"):
+ def __init__(self, pool: "ConnectionPool[Any]"):
self.pool = ref(pool)
def __repr__(self) -> str:
pool.run_task(self)
@abstractmethod
- def _run(self, pool: "ConnectionPool") -> None:
+ def _run(self, pool: "ConnectionPool[Any]") -> None:
...
class StopWorker(MaintenanceTask):
"""Signal the maintenance thread to terminate."""
- def _run(self, pool: "ConnectionPool") -> None:
+ def _run(self, pool: "ConnectionPool[Any]") -> None:
pass
class AddConnection(MaintenanceTask):
def __init__(
self,
- pool: "ConnectionPool",
+ pool: "ConnectionPool[Any]",
attempt: Optional["ConnectionAttempt"] = None,
growing: bool = False,
):
self.attempt = attempt
self.growing = growing
- def _run(self, pool: "ConnectionPool") -> None:
+ def _run(self, pool: "ConnectionPool[Any]") -> None:
pool._add_connection(self.attempt, growing=self.growing)
class ReturnConnection(MaintenanceTask):
"""Clean up and return a connection to the pool."""
- def __init__(self, pool: "ConnectionPool", conn: "Connection[Any]"):
+ def __init__(self, pool: "ConnectionPool[Any]", conn: CT):
super().__init__(pool)
self.conn = conn
- def _run(self, pool: "ConnectionPool") -> None:
+ def _run(self, pool: "ConnectionPool[Any]") -> None:
pool._return_connection(self.conn)
in the pool.
"""
- def _run(self, pool: "ConnectionPool") -> None:
+ def _run(self, pool: "ConnectionPool[Any]") -> None:
# Reschedule the task now so that in case of any error we don't lose
# the periodic run.
pool.schedule_task(self, pool.max_idle)
from abc import ABC, abstractmethod
from time import monotonic
from types import TracebackType
-from typing import Any, AsyncIterator, Awaitable, Callable
-from typing import Dict, List, Optional, Sequence, Type, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, cast, Generic
+from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar, Union
from typing_extensions import TypeAlias
from weakref import ref
from contextlib import asynccontextmanager
from psycopg import errors as e
from psycopg import AsyncConnection
from psycopg.pq import TransactionStatus
+from psycopg.rows import TupleRow
from .base import ConnectionAttempt, BasePool
from .sched import AsyncScheduler
Callable[["AsyncConnectionPool"], Awaitable[None]],
]
+ACT = TypeVar("ACT", bound="AsyncConnection[Any]")
-class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
- _Self = TypeVar("_Self", bound="AsyncConnectionPool")
+
+class AsyncConnectionPool(Generic[ACT], BasePool):
+ _Self = TypeVar("_Self", bound="AsyncConnectionPool[ACT]")
+ _pool: Deque[ACT]
+
+ @overload
+ def __init__(
+ self: "AsyncConnectionPool[AsyncConnection[TupleRow]]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self: "AsyncConnectionPool[ACT]",
+ conninfo: str = "",
+ *,
+ open: bool = ...,
+ connection_class: Type[ACT],
+ configure: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ reset: Optional[Callable[[ACT], Awaitable[None]]] = ...,
+ kwargs: Optional[Dict[str, Any]] = ...,
+ min_size: int = ...,
+ max_size: Optional[int] = ...,
+ name: Optional[str] = ...,
+ timeout: float = ...,
+ max_waiting: int = ...,
+ max_lifetime: float = ...,
+ max_idle: float = ...,
+ reconnect_timeout: float = ...,
+ reconnect_failed: Optional[AsyncConnectFailedCB] = ...,
+ num_workers: int = ...,
+ ):
+ ...
def __init__(
self,
conninfo: str = "",
*,
open: bool = True,
- connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
- configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
- reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ connection_class: Type[ACT] = cast(Type[ACT], AsyncConnection),
+ configure: Optional[Callable[[ACT], Awaitable[None]]] = None,
+ reset: Optional[Callable[[ACT], Awaitable[None]]] = None,
kwargs: Optional[Dict[str, Any]] = None,
min_size: int = 4,
max_size: Optional[int] = None,
self._sched: AsyncScheduler
self._tasks: "asyncio.Queue[MaintenanceTask]"
- self._waiting = Deque["AsyncClient"]()
+ self._waiting = Deque["AsyncClient[ACT]"]()
# to notify that the pool is full
self._pool_full_event: Optional[asyncio.Event] = None
logger.info("pool %r is ready to use", self.name)
@asynccontextmanager
- async def connection(
- self, timeout: Optional[float] = None
- ) -> AsyncIterator[AsyncConnection[Any]]:
+ async def connection(self, timeout: Optional[float] = None) -> AsyncIterator[ACT]:
conn = await self.getconn(timeout=timeout)
t0 = monotonic()
try:
self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
await self.putconn(conn)
- async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+ async def getconn(self, timeout: Optional[float] = None) -> ACT:
logger.info("connection requested from %r", self.name)
self._stats[self._REQUESTS_NUM] += 1
if not conn:
# No connection available: put the client in the waiting queue
t0 = monotonic()
- pos = AsyncClient()
+ pos: AsyncClient[ACT] = AsyncClient()
self._waiting.append(pos)
self._stats[self._REQUESTS_QUEUED] += 1
logger.info("connection given by %r", self.name)
return conn
- async def _get_ready_connection(
- self, timeout: Optional[float]
- ) -> Optional[AsyncConnection[Any]]:
- conn: Optional[AsyncConnection[Any]] = None
+ async def _get_ready_connection(self, timeout: Optional[float]) -> Optional[ACT]:
+ conn: Optional[ACT] = None
if self._pool:
# Take a connection ready out of the pool
conn = self._pool.popleft()
self._growing = True
self.run_task(AddConnection(self, growing=True))
- async def putconn(self, conn: AsyncConnection[Any]) -> None:
+ async def putconn(self, conn: ACT) -> None:
self._check_pool_putconn(conn)
logger.info("returning connection to %r", self.name)
else:
await self._return_connection(conn)
- async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
+ async def _maybe_close_connection(self, conn: ACT) -> bool:
# If the pool is closed just close the connection instead of returning
# it to the pool. For extra refcare remove the pool reference from it.
if not self._closed:
async def _stop_workers(
self,
- waiting_clients: Sequence["AsyncClient"] = (),
- connections: Sequence[AsyncConnection[Any]] = (),
+ waiting_clients: Sequence["AsyncClient[ACT]"] = (),
+ connections: Sequence[ACT] = (),
timeout: float = 0.0,
) -> None:
# Stop the scheduler
ex,
)
- async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
+ async def _connect(self, timeout: Optional[float] = None) -> ACT:
self._stats[self._CONNECTIONS_NUM] += 1
kwargs = self.kwargs
if timeout:
kwargs["connect_timeout"] = max(round(timeout), 1)
t0 = monotonic()
try:
- conn: AsyncConnection[Any]
- conn = await self.connection_class.connect(self.conninfo, **kwargs)
+ conn: ACT = await self.connection_class.connect(self.conninfo, **kwargs)
except Exception:
self._stats[self._CONNECTIONS_ERRORS] += 1
raise
else:
self._growing = False
- async def _return_connection(self, conn: AsyncConnection[Any]) -> None:
+ async def _return_connection(self, conn: ACT) -> None:
"""
Return a connection to the pool after usage.
"""
await self._add_to_pool(conn)
- async def _add_to_pool(self, conn: AsyncConnection[Any]) -> None:
+ async def _add_to_pool(self, conn: ACT) -> None:
"""
Add a connection to the pool.
if self._pool_full_event and len(self._pool) >= self._min_size:
self._pool_full_event.set()
- async def _reset_connection(self, conn: AsyncConnection[Any]) -> None:
+ async def _reset_connection(self, conn: ACT) -> None:
"""
Bring a connection to IDLE state or close it.
"""
await conn.close()
async def _shrink_pool(self) -> None:
- to_close: Optional[AsyncConnection[Any]] = None
+ to_close: Optional[ACT] = None
async with self._lock:
# Reset the min number of connections used
return rv
-class AsyncClient:
+class AsyncClient(Generic[ACT]):
"""A position in a queue for a client waiting for a connection."""
__slots__ = ("conn", "error", "_cond")
def __init__(self) -> None:
- self.conn: Optional[AsyncConnection[Any]] = None
+ self.conn: Optional[ACT] = None
self.error: Optional[BaseException] = None
# The AsyncClient behaves in a way similar to an Event, but we need
# will be lost.
self._cond = asyncio.Condition()
- async def wait(self, timeout: float) -> AsyncConnection[Any]:
+ async def wait(self, timeout: float) -> ACT:
"""Wait for a connection to be set and return it.
Raise an exception if the wait times out or if fail() is called.
assert self.error
raise self.error
- async def set(self, conn: AsyncConnection[Any]) -> bool:
+ async def set(self, conn: ACT) -> bool:
"""Signal the client waiting that a connection is ready.
Return True if the client has "accepted" the connection, False
class MaintenanceTask(ABC):
"""A task to run asynchronously to maintain the pool state."""
- def __init__(self, pool: "AsyncConnectionPool"):
+ def __init__(self, pool: "AsyncConnectionPool[Any]"):
self.pool = ref(pool)
def __repr__(self) -> str:
pool.run_task(self)
@abstractmethod
- async def _run(self, pool: "AsyncConnectionPool") -> None:
+ async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
...
class StopWorker(MaintenanceTask):
"""Signal the maintenance worker to terminate."""
- async def _run(self, pool: "AsyncConnectionPool") -> None:
+ async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
pass
class AddConnection(MaintenanceTask):
def __init__(
self,
- pool: "AsyncConnectionPool",
+ pool: "AsyncConnectionPool[Any]",
attempt: Optional["ConnectionAttempt"] = None,
growing: bool = False,
):
self.attempt = attempt
self.growing = growing
- async def _run(self, pool: "AsyncConnectionPool") -> None:
+ async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
await pool._add_connection(self.attempt, growing=self.growing)
class ReturnConnection(MaintenanceTask):
"""Clean up and return a connection to the pool."""
- def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"):
+ def __init__(self, pool: "AsyncConnectionPool[Any]", conn: ACT):
super().__init__(pool)
self.conn = conn
- async def _run(self, pool: "AsyncConnectionPool") -> None:
+ async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
await pool._return_connection(self.conn)
in the pool.
"""
- async def _run(self, pool: "AsyncConnectionPool") -> None:
+ async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
# Reschedule the task now so that in case of any error we don't lose
# the periodic run.
await pool.schedule_task(self, pool.max_idle)
def __init__(
self,
- pool: "AsyncConnectionPool",
+ pool: "AsyncConnectionPool[Any]",
task: MaintenanceTask,
delay: float,
):
self.task = task
self.delay = delay
- async def _run(self, pool: "AsyncConnectionPool") -> None:
+ async def _run(self, pool: "AsyncConnectionPool[Any]") -> None:
await pool.schedule_task(self.task, self.delay)
import logging
from time import sleep, time
from threading import Thread, Event
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
import pytest
from packaging.version import parse as ver # noqa: F401 # used in skipif
import psycopg
from psycopg.pq import TransactionStatus
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type
from .test_pool import delay_connection, ensure_waiting
assert conn.autocommit
+class MyRow(Dict[str, Any]):
+ ...
+
+
+def test_generic_connection_type(dsn):
+ def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+ conn.autocommit = True
+
+ class MyConnection(psycopg.Connection[Row]):
+ pass
+
+ with NullConnectionPool(
+ dsn,
+ connection_class=MyConnection[MyRow],
+ kwargs={"row_factory": class_row(MyRow)},
+ configure=set_autocommit,
+ ) as p1:
+ with p1.connection() as conn1:
+ cur1 = conn1.execute("select 1 as x")
+ (row1,) = cur1.fetchall()
+
+ assert_type(p1, NullConnectionPool[MyConnection[MyRow]])
+ assert_type(conn1, MyConnection[MyRow])
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+ with NullConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
+ with p2.connection() as conn2:
+ (row2,) = conn2.execute("select 2 as y").fetchall()
+ assert_type(p2, NullConnectionPool[MyConnection[TupleRow]])
+ assert_type(conn2, MyConnection[TupleRow])
+ assert_type(row2, TupleRow)
+ assert row2 == (2,)
+
+
+def test_non_generic_connection_type(dsn):
+ def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+ conn.autocommit = True
+
+ class MyConnection(psycopg.Connection[MyRow]):
+ def __init__(self, *args: Any, **kwargs: Any):
+ kwargs["row_factory"] = class_row(MyRow)
+ super().__init__(*args, **kwargs)
+
+ with NullConnectionPool(
+ dsn, connection_class=MyConnection, configure=set_autocommit
+ ) as p1:
+ with p1.connection() as conn1:
+ (row1,) = conn1.execute("select 1 as x").fetchall()
+ assert_type(p1, NullConnectionPool[MyConnection])
+ assert_type(conn1, MyConnection)
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+
@pytest.mark.crdb_skip("backend pid")
def test_its_no_pool_at_all(dsn):
with NullConnectionPool(dsn, max_size=2) as p:
import asyncio
import logging
from time import time
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
import pytest
from packaging.version import parse as ver # noqa: F401 # used in skipif
import psycopg
from psycopg.pq import TransactionStatus
-from psycopg._compat import create_task
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type, create_task
from .test_pool_async import delay_connection, ensure_waiting
pytestmark = [pytest.mark.anyio]
assert conn.autocommit
+class MyRow(Dict[str, Any]):
+ ...
+
+
+async def test_generic_connection_type(dsn):
+ async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+ await conn.set_autocommit(True)
+
+ class MyConnection(psycopg.AsyncConnection[Row]):
+ pass
+
+ async with AsyncNullConnectionPool(
+ dsn,
+ connection_class=MyConnection[MyRow],
+ kwargs={"row_factory": class_row(MyRow)},
+ configure=set_autocommit,
+ ) as p1:
+ async with p1.connection() as conn1:
+ cur1 = await conn1.execute("select 1 as x")
+ (row1,) = await cur1.fetchall()
+ assert_type(p1, AsyncNullConnectionPool[MyConnection[MyRow]])
+ assert_type(conn1, MyConnection[MyRow])
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+ async with AsyncNullConnectionPool(
+ dsn, connection_class=MyConnection[TupleRow]
+ ) as p2:
+ async with p2.connection() as conn2:
+ cur2 = await conn2.execute("select 2 as y")
+ (row2,) = await cur2.fetchall()
+ assert_type(p2, AsyncNullConnectionPool[MyConnection[TupleRow]])
+ assert_type(conn2, MyConnection[TupleRow])
+ assert_type(row2, TupleRow)
+ assert row2 == (2,)
+
+
+async def test_non_generic_connection_type(dsn):
+ async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+ await conn.set_autocommit(True)
+
+ class MyConnection(psycopg.AsyncConnection[MyRow]):
+ def __init__(self, *args: Any, **kwargs: Any):
+ kwargs["row_factory"] = class_row(MyRow)
+ super().__init__(*args, **kwargs)
+
+ async with AsyncNullConnectionPool(
+ dsn, connection_class=MyConnection, configure=set_autocommit
+ ) as p1:
+ async with p1.connection() as conn1:
+ (row1,) = await (await conn1.execute("select 1 as x")).fetchall()
+ assert_type(p1, AsyncNullConnectionPool[MyConnection])
+ assert_type(conn1, MyConnection)
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+
@pytest.mark.crdb_skip("backend pid")
async def test_its_no_pool_at_all(dsn):
async with AsyncNullConnectionPool(dsn, max_size=2) as p:
import weakref
from time import sleep, time
from threading import Thread, Event
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
import pytest
import psycopg
from psycopg.pq import TransactionStatus
-from psycopg._compat import Counter
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import Counter, assert_type
try:
import psycopg_pool as pool
assert conn.autocommit
+class MyRow(Dict[str, Any]):
+ ...
+
+
+def test_generic_connection_type(dsn):
+ def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+ conn.autocommit = True
+
+ class MyConnection(psycopg.Connection[Row]):
+ pass
+
+ with pool.ConnectionPool(
+ dsn,
+ connection_class=MyConnection[MyRow],
+ kwargs=dict(row_factory=class_row(MyRow)),
+ configure=set_autocommit,
+ ) as p1, p1.connection() as conn1:
+ (row1,) = conn1.execute("select 1 as x").fetchall()
+ assert_type(p1, pool.ConnectionPool[MyConnection[MyRow]])
+ assert_type(conn1, MyConnection[MyRow])
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+ with pool.ConnectionPool(dsn, connection_class=MyConnection[TupleRow]) as p2:
+ with p2.connection() as conn2:
+ (row2,) = conn2.execute("select 2 as y").fetchall()
+ assert_type(p2, pool.ConnectionPool[MyConnection[TupleRow]])
+ assert_type(conn2, MyConnection[TupleRow])
+ assert_type(row2, TupleRow)
+ assert row2 == (2,)
+
+
+def test_non_generic_connection_type(dsn):
+ def set_autocommit(conn: psycopg.Connection[Any]) -> None:
+ conn.autocommit = True
+
+ class MyConnection(psycopg.Connection[MyRow]):
+ def __init__(self, *args: Any, **kwargs: Any):
+ kwargs["row_factory"] = class_row(MyRow)
+ super().__init__(*args, **kwargs)
+
+ with pool.ConnectionPool(
+ dsn, connection_class=MyConnection, configure=set_autocommit
+ ) as p1:
+ with p1.connection() as conn1:
+ (row1,) = conn1.execute("select 1 as x").fetchall()
+ assert_type(p1, pool.ConnectionPool[MyConnection])
+ assert_type(conn1, MyConnection)
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+
@pytest.mark.crdb_skip("backend pid")
def test_its_really_a_pool(dsn):
with pool.ConnectionPool(dsn, min_size=2) as p:
import asyncio
import logging
from time import time
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
import pytest
import psycopg
from psycopg.pq import TransactionStatus
-from psycopg._compat import create_task, Counter
+from psycopg.rows import class_row, Row, TupleRow
+from psycopg._compat import assert_type, create_task, Counter
try:
import psycopg_pool as pool
assert conn.autocommit
+class MyRow(Dict[str, Any]):
+ ...
+
+
+async def test_generic_connection_type(dsn):
+ async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+ await conn.set_autocommit(True)
+
+ class MyConnection(psycopg.AsyncConnection[Row]):
+ pass
+
+ async with pool.AsyncConnectionPool(
+ dsn,
+ connection_class=MyConnection[MyRow],
+ kwargs=dict(row_factory=class_row(MyRow)),
+ configure=set_autocommit,
+ ) as p1:
+ async with p1.connection() as conn1:
+ cur1 = await conn1.execute("select 1 as x")
+ (row1,) = await cur1.fetchall()
+ assert_type(p1, pool.AsyncConnectionPool[MyConnection[MyRow]])
+ assert_type(conn1, MyConnection[MyRow])
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+ async with pool.AsyncConnectionPool(
+ dsn, connection_class=MyConnection[TupleRow]
+ ) as p2:
+ async with p2.connection() as conn2:
+ cur2 = await conn2.execute("select 2 as y")
+ (row2,) = await cur2.fetchall()
+ assert_type(p2, pool.AsyncConnectionPool[MyConnection[TupleRow]])
+ assert_type(conn2, MyConnection[TupleRow])
+ assert_type(row2, TupleRow)
+ assert row2 == (2,)
+
+
+async def test_non_generic_connection_type(dsn):
+ async def set_autocommit(conn: psycopg.AsyncConnection[Any]) -> None:
+ await conn.set_autocommit(True)
+
+ class MyConnection(psycopg.AsyncConnection[MyRow]):
+ def __init__(self, *args: Any, **kwargs: Any):
+ kwargs["row_factory"] = class_row(MyRow)
+ super().__init__(*args, **kwargs)
+
+ async with pool.AsyncConnectionPool(
+ dsn,
+ connection_class=MyConnection,
+ configure=set_autocommit,
+ ) as p1:
+ async with p1.connection() as conn1:
+ cur1 = await conn1.execute("select 1 as x")
+ (row1,) = await cur1.fetchall()
+ assert_type(p1, pool.AsyncConnectionPool[MyConnection])
+ assert_type(conn1, MyConnection)
+ assert_type(row1, MyRow)
+ assert conn1.autocommit
+ assert row1 == {"x": 1}
+
+
@pytest.mark.crdb_skip("backend pid")
async def test_its_really_a_pool(dsn):
async with pool.AsyncConnectionPool(dsn, min_size=2) as p: