]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): make constructor configuration parameters type-safe
authorRan Benita <ran@unusedvar.com>
Tue, 24 Jan 2023 20:59:55 +0000 (22:59 +0200)
committerRan Benita <ran@unusedvar.com>
Sat, 28 Jan 2023 15:09:46 +0000 (17:09 +0200)
Previously, arguments to ConnectionPool and friends would be forwarded
using `**kwargs` to `BasePool`. This is however not type-safe, and
prevents code editors from auto-completing the parameters.

Drop the `**kwargs`, duplicate the parameters instead.

Fixes #493.

psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py

index 298ea6837621c3c29ad968304c45b26aa471005d..c081419e702c48e765310de920df5aca80127c91 100644 (file)
@@ -41,18 +41,18 @@ class BasePool(Generic[ConnectionType]):
         self,
         conninfo: str = "",
         *,
-        kwargs: Optional[Dict[str, Any]] = None,
-        min_size: int = 4,
-        max_size: Optional[int] = None,
-        open: bool = True,
-        name: Optional[str] = None,
-        timeout: float = 30.0,
-        max_waiting: int = 0,
-        max_lifetime: float = 60 * 60.0,
-        max_idle: float = 10 * 60.0,
-        reconnect_timeout: float = 5 * 60.0,
-        reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None,
-        num_workers: int = 3,
+        kwargs: Optional[Dict[str, Any]],
+        min_size: int,
+        max_size: Optional[int],
+        open: bool,
+        name: Optional[str],
+        timeout: float,
+        max_waiting: int,
+        max_lifetime: float,
+        max_idle: float,
+        reconnect_timeout: float,
+        reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]],
+        num_workers: int,
     ):
         min_size, max_size = self._check_size(min_size, max_size)
 
index c0a77c24672dd0830842b30342ef2912fdb2a767..20f9811b77a748b5060fb8368e90c2dd708554e7 100644 (file)
@@ -6,11 +6,12 @@ Psycopg null connection pools
 
 import logging
 import threading
-from typing import Any, Optional, Tuple
+from typing import Any, Callable, Dict, Optional, Tuple, Type
 
 from psycopg import Connection
 from psycopg.pq import TransactionStatus
 
+from .base import BasePool
 from .pool import ConnectionPool, AddConnection
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
@@ -19,13 +20,6 @@ logger = logging.getLogger("psycopg.pool")
 
 
 class _BaseNullConnectionPool:
-    def __init__(
-        self, conninfo: str = "", min_size: int = 0, *args: Any, **kwargs: Any
-    ):
-        super().__init__(  # type: ignore[call-arg]
-            conninfo, *args, min_size=min_size, **kwargs
-        )
-
     def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
         if max_size is None:
             max_size = min_size
@@ -48,6 +42,46 @@ class _BaseNullConnectionPool:
 
 
 class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
+    def __init__(
+        self,
+        conninfo: str = "",
+        *,
+        open: bool = True,
+        connection_class: Type[Connection[Any]] = Connection,
+        configure: Optional[Callable[[Connection[Any]], None]] = None,
+        reset: Optional[Callable[[Connection[Any]], None]] = None,
+        kwargs: Optional[Dict[str, Any]] = None,
+        # Note: default value changed to 0.
+        min_size: int = 0,
+        max_size: Optional[int] = None,
+        name: Optional[str] = None,
+        timeout: float = 30.0,
+        max_waiting: int = 0,
+        max_lifetime: float = 60 * 60.0,
+        max_idle: float = 10 * 60.0,
+        reconnect_timeout: float = 5 * 60.0,
+        reconnect_failed: Optional[Callable[[BasePool[Connection[Any]]], None]] = None,
+        num_workers: int = 3,
+    ):
+        super().__init__(
+            conninfo,
+            open=open,
+            connection_class=connection_class,
+            configure=configure,
+            reset=reset,
+            kwargs=kwargs,
+            min_size=min_size,
+            max_size=max_size,
+            name=name,
+            timeout=timeout,
+            max_waiting=max_waiting,
+            max_lifetime=max_lifetime,
+            max_idle=max_idle,
+            reconnect_timeout=reconnect_timeout,
+            reconnect_failed=reconnect_failed,
+            num_workers=num_workers,
+        )
+
     def wait(self, timeout: float = 30.0) -> None:
         """
         Create a connection for test.
index ae9d207bca6ef229e9bf1f0854740d3d30d2b620..9f566c66360347dcf06863bb98c83f89b6b0c14b 100644 (file)
@@ -6,11 +6,12 @@ psycopg asynchronous null connection pool
 
 import asyncio
 import logging
-from typing import Any, Optional
+from typing import Any, Awaitable, Callable, Dict, Optional, Type
 
 from psycopg import AsyncConnection
 from psycopg.pq import TransactionStatus
 
+from .base import BasePool
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
 from .null_pool import _BaseNullConnectionPool
@@ -20,6 +21,48 @@ logger = logging.getLogger("psycopg.pool")
 
 
 class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
+    def __init__(
+        self,
+        conninfo: str = "",
+        *,
+        open: bool = True,
+        connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
+        configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+        reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+        kwargs: Optional[Dict[str, Any]] = None,
+        # Note: default value changed to 0.
+        min_size: int = 0,
+        max_size: Optional[int] = None,
+        name: Optional[str] = None,
+        timeout: float = 30.0,
+        max_waiting: int = 0,
+        max_lifetime: float = 60 * 60.0,
+        max_idle: float = 10 * 60.0,
+        reconnect_timeout: float = 5 * 60.0,
+        reconnect_failed: Optional[
+            Callable[[BasePool[AsyncConnection[None]]], None]
+        ] = None,
+        num_workers: int = 3,
+    ):
+        super().__init__(
+            conninfo,
+            open=open,
+            connection_class=connection_class,
+            configure=configure,
+            reset=reset,
+            kwargs=kwargs,
+            min_size=min_size,
+            max_size=max_size,
+            name=name,
+            timeout=timeout,
+            max_waiting=max_waiting,
+            max_lifetime=max_lifetime,
+            max_idle=max_idle,
+            reconnect_timeout=reconnect_timeout,
+            reconnect_failed=reconnect_failed,
+            num_workers=num_workers,
+        )
+
     async def wait(self, timeout: float = 30.0) -> None:
         self._check_open_getconn()
 
index 609d95dfea0974598b8bb1f81a40fb72e7bc3514..05cfc8fea2ce445f757c988b4dd93a5b0e19800e 100644 (file)
@@ -36,7 +36,17 @@ class ConnectionPool(BasePool[Connection[Any]]):
         connection_class: Type[Connection[Any]] = Connection,
         configure: Optional[Callable[[Connection[Any]], None]] = None,
         reset: Optional[Callable[[Connection[Any]], None]] = None,
-        **kwargs: Any,
+        kwargs: Optional[Dict[str, Any]] = None,
+        min_size: int = 4,
+        max_size: Optional[int] = None,
+        name: Optional[str] = None,
+        timeout: float = 30.0,
+        max_waiting: int = 0,
+        max_lifetime: float = 60 * 60.0,
+        max_idle: float = 10 * 60.0,
+        reconnect_timeout: float = 5 * 60.0,
+        reconnect_failed: Optional[Callable[[BasePool[Connection[Any]]], None]] = None,
+        num_workers: int = 3,
     ):
         self.connection_class = connection_class
         self._configure = configure
@@ -53,7 +63,21 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self._tasks: "Queue[MaintenanceTask]" = Queue()
         self._workers: List[threading.Thread] = []
 
-        super().__init__(conninfo, **kwargs)
+        super().__init__(
+            conninfo,
+            kwargs=kwargs,
+            min_size=min_size,
+            max_size=max_size,
+            open=open,
+            name=name,
+            timeout=timeout,
+            max_waiting=max_waiting,
+            max_lifetime=max_lifetime,
+            max_idle=max_idle,
+            reconnect_timeout=reconnect_timeout,
+            reconnect_failed=reconnect_failed,
+            num_workers=num_workers,
+        )
 
         if open:
             self.open()
index 0ea6e9a40a5cd985201cda767daa57c8d37f9d56..1cffcce68c61e874c3d7823fe61f76c1fee70807 100644 (file)
@@ -35,7 +35,19 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
         configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
         reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
-        **kwargs: Any,
+        kwargs: Optional[Dict[str, Any]] = None,
+        min_size: int = 4,
+        max_size: Optional[int] = None,
+        name: Optional[str] = None,
+        timeout: float = 30.0,
+        max_waiting: int = 0,
+        max_lifetime: float = 60 * 60.0,
+        max_idle: float = 10 * 60.0,
+        reconnect_timeout: float = 5 * 60.0,
+        reconnect_failed: Optional[
+            Callable[[BasePool[AsyncConnection[Any]]], None]
+        ] = None,
+        num_workers: int = 3,
     ):
         self.connection_class = connection_class
         self._configure = configure
@@ -54,7 +66,21 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self._sched_runner: Optional[Task[None]] = None
         self._workers: List[Task[None]] = []
 
-        super().__init__(conninfo, **kwargs)
+        super().__init__(
+            conninfo,
+            kwargs=kwargs,
+            min_size=min_size,
+            max_size=max_size,
+            open=open,
+            name=name,
+            timeout=timeout,
+            max_waiting=max_waiting,
+            max_lifetime=max_lifetime,
+            max_idle=max_idle,
+            reconnect_timeout=reconnect_timeout,
+            reconnect_failed=reconnect_failed,
+            num_workers=num_workers,
+        )
 
         if open:
             self._open()