]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move some common checks to the pool base class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 18:40:11 +0000 (19:40 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 5 Jan 2022 22:09:05 +0000 (23:09 +0100)
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py

index 160221507353cea71a987770139115bee4f920e6..bdfc326123cf34cd9ae024cb3f9b2275120cb47e 100644 (file)
@@ -4,11 +4,13 @@ psycopg connection pool base class and functionalities.
 
 # Copyright (C) 2021 The Psycopg Team
 
+from time import monotonic
 from random import random
-from typing import Any, Callable, Dict, Generic, Optional
+from typing import Any, Callable, Dict, Generic, Optional, Tuple
 
 from psycopg.abc import ConnectionType
 
+from .errors import PoolClosed
 from ._compat import Counter, Deque
 
 
@@ -52,10 +54,8 @@ class BasePool(Generic[ConnectionType]):
         ] = None,
         num_workers: int = 3,
     ):
-        if max_size is None:
-            max_size = min_size
-        if max_size < min_size:
-            raise ValueError("max_size must be greater or equal than min_size")
+        min_size, max_size = self._check_size(min_size, max_size)
+
         if not name:
             num = BasePool._num_pool = BasePool._num_pool + 1
             name = f"pool-{num}"
@@ -117,6 +117,33 @@ class BasePool(Generic[ConnectionType]):
         """`!True` if the pool is closed."""
         return self._closed
 
+    def _check_size(
+        self, min_size: int, max_size: Optional[int]
+    ) -> Tuple[int, int]:
+        if max_size is None:
+            max_size = min_size
+        if max_size < min_size:
+            raise ValueError("max_size must be greater or equal than min_size")
+
+        return min_size, max_size
+
+    def _check_open_getconn(self) -> None:
+        if self._closed:
+            raise PoolClosed(f"the pool {self.name!r} is closed")
+
+    def _check_pool_putconn(self, conn: ConnectionType) -> None:
+        pool = getattr(conn, "_pool", None)
+        if pool is self:
+            return
+
+        if pool:
+            msg = f"it comes from pool {pool.name!r}"
+        else:
+            msg = "it doesn't come from any pool"
+        raise ValueError(
+            f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+        )
+
     def get_stats(self) -> Dict[str, int]:
         """
         Return current stats about the pool usage.
@@ -154,6 +181,15 @@ class BasePool(Generic[ConnectionType]):
         """
         return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc)
 
+    def _set_connection_expiry_date(self, conn: ConnectionType) -> None:
+        """Set an expiry date on a connection.
+
+        Add some randomness to avoid mass reconnection.
+        """
+        conn._expire_at = monotonic() + self._jitter(
+            self.max_lifetime, -0.05, 0.0
+        )
+
 
 class ConnectionAttempt:
     """Keep the state of a connection attempt."""
index dcbcde02ac547a3af78ec60753cf8a02b572d814..70d9e7d67557f6f492320f597583816c0364cabe 100644 (file)
@@ -164,11 +164,11 @@ class ConnectionPool(BasePool[Connection[Any]]):
         """
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
+
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         with self._lock:
-            if self._closed:
-                raise PoolClosed(f"the pool {self.name!r} is closed")
+            self._check_open_getconn()
 
             pos: Optional[WaitingClient] = None
             if self._pool:
@@ -229,15 +229,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         it if you use the much more comfortable `connection()` context manager.
         """
         # Quick check to discard the wrong connection
-        pool = getattr(conn, "_pool", None)
-        if pool is not self:
-            if pool:
-                msg = f"it comes from pool {pool.name!r}"
-            else:
-                msg = "it doesn't come from any pool"
-            raise ValueError(
-                f"can't return connection to pool {self.name!r}, {msg}: {conn}"
-            )
+        self._check_pool_putconn(conn)
 
         logger.info("returning connection to %r", self.name)
 
@@ -323,10 +315,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
     def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
         """Change the size of the pool during runtime."""
-        if max_size is None:
-            max_size = min_size
-        if max_size < min_size:
-            raise ValueError("max_size must be greater or equal than min_size")
+        min_size, max_size = self._check_size(min_size, max_size)
 
         ngrow = max(0, min_size - self._min_size)
 
@@ -439,16 +428,14 @@ class ConnectionPool(BasePool[Connection[Any]]):
             self._configure(conn)
             status = conn.pgconn.transaction_status
             if status != TransactionStatus.IDLE:
-                nstatus = TransactionStatus(status).name
+                sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
-                    f"connection left in status {nstatus} by configure function"
+                    f"connection left in status {sname} by configure function"
                     f" {self._configure}: discarded"
                 )
 
         # Set an expiry date, with some randomness to avoid mass reconnection
-        conn._expire_at = monotonic() + self._jitter(
-            self.max_lifetime, -0.05, 0.0
-        )
+        self._set_connection_expiry_date(conn)
         return conn
 
     def _add_connection(
@@ -586,9 +573,9 @@ class ConnectionPool(BasePool[Connection[Any]]):
                 self._reset(conn)
                 status = conn.pgconn.transaction_status
                 if status != TransactionStatus.IDLE:
-                    nstatus = TransactionStatus(status).name
+                    sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
-                        f"connection left in status {nstatus} by reset function"
+                        f"connection left in status {sname} by reset function"
                         f" {self._reset}: discarded"
                     )
             except Exception as ex:
index 597f306aa13ee4c4b841ceb3ef74bfa3268d928f..787a0e7e33c1283483ccb09915f7dcdc8b616610 100644 (file)
@@ -121,11 +121,11 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
     ) -> AsyncConnection[Any]:
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
+
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         async with self._lock:
-            if self._closed:
-                raise PoolClosed(f"the pool {self.name!r} is closed")
+            self._check_open_getconn()
 
             pos: Optional[AsyncClient] = None
             if self._pool:
@@ -179,16 +179,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         return conn
 
     async def putconn(self, conn: AsyncConnection[Any]) -> None:
-        # Quick check to discard the wrong connection
-        pool = getattr(conn, "_pool", None)
-        if pool is not self:
-            if pool:
-                msg = f"it comes from pool {pool.name!r}"
-            else:
-                msg = "it doesn't come from any pool"
-            raise ValueError(
-                f"can't return connection to pool {self.name!r}, {msg}: {conn}"
-            )
+        self._check_pool_putconn(conn)
 
         logger.info("returning connection to %r", self.name)
 
@@ -265,10 +256,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
     async def resize(
         self, min_size: int, max_size: Optional[int] = None
     ) -> None:
-        if max_size is None:
-            max_size = min_size
-        if max_size < min_size:
-            raise ValueError("max_size must be greater or equal than min_size")
+        min_size, max_size = self._check_size(min_size, max_size)
 
         ngrow = max(0, min_size - self._min_size)
 
@@ -369,16 +357,14 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             await self._configure(conn)
             status = conn.pgconn.transaction_status
             if status != TransactionStatus.IDLE:
-                nstatus = TransactionStatus(status).name
+                sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
-                    f"connection left in status {nstatus} by configure function"
+                    f"connection left in status {sname} by configure function"
                     f" {self._configure}: discarded"
                 )
 
         # Set an expiry date, with some randomness to avoid mass reconnection
-        conn._expire_at = monotonic() + self._jitter(
-            self.max_lifetime, -0.05, 0.0
-        )
+        self._set_connection_expiry_date(conn)
         return conn
 
     async def _add_connection(
@@ -516,9 +502,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 await self._reset(conn)
                 status = conn.pgconn.transaction_status
                 if status != TransactionStatus.IDLE:
-                    nstatus = TransactionStatus(status).name
+                    sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
-                        f"connection left in status {nstatus} by reset function"
+                        f"connection left in status {sname} by reset function"
                         f" {self._reset}: discarded"
                     )
             except Exception as ex: