]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move some common checks to the pool base class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 18:40:11 +0000 (19:40 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 5 Jan 2022 22:12:49 +0000 (23:12 +0100)
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py

index 380508b20fb0f7615c11256e13b4c2cfa84e5cea..70ecd0ed1ea1251a21c292ece4b694bdbf7cb183 100644 (file)
@@ -4,8 +4,9 @@ psycopg connection pool base class and functionalities.
 
 # Copyright (C) 2021 The Psycopg Team
 
+from time import monotonic
 from random import random
-from typing import Any, Callable, Dict, Generic, Optional
+from typing import Any, Callable, Dict, Generic, Optional, Tuple
 
 from psycopg.abc import ConnectionType
 from psycopg import errors as e
@@ -55,10 +56,8 @@ class BasePool(Generic[ConnectionType]):
         ] = None,
         num_workers: int = 3,
     ):
-        if max_size is None:
-            max_size = min_size
-        if max_size < min_size:
-            raise ValueError("max_size must be greater or equal than min_size")
+        min_size, max_size = self._check_size(min_size, max_size)
+
         if not name:
             num = BasePool._num_pool = BasePool._num_pool + 1
             name = f"pool-{num}"
@@ -119,6 +118,16 @@ class BasePool(Generic[ConnectionType]):
         """`!True` if the pool is closed."""
         return self._closed
 
+    def _check_size(
+        self, min_size: int, max_size: Optional[int]
+    ) -> Tuple[int, int]:
+        if max_size is None:
+            max_size = min_size
+        if max_size < min_size:
+            raise ValueError("max_size must be greater or equal than min_size")
+
+        return min_size, max_size
+
     def _check_open(self) -> None:
         if self._closed and self._opened:
             raise e.OperationalError(
@@ -132,6 +141,19 @@ class BasePool(Generic[ConnectionType]):
             else:
                 raise PoolClosed(f"the pool {self.name!r} is not open yet")
 
+    def _check_pool_putconn(self, conn: ConnectionType) -> None:
+        pool = getattr(conn, "_pool", None)
+        if pool is self:
+            return
+
+        if pool:
+            msg = f"it comes from pool {pool.name!r}"
+        else:
+            msg = "it doesn't come from any pool"
+        raise ValueError(
+            f"can't return connection to pool {self.name!r}, {msg}: {conn}"
+        )
+
     def get_stats(self) -> Dict[str, int]:
         """
         Return current stats about the pool usage.
@@ -169,6 +191,15 @@ class BasePool(Generic[ConnectionType]):
         """
         return value * (1.0 + ((max_pc - min_pc) * random()) + min_pc)
 
+    def _set_connection_expiry_date(self, conn: ConnectionType) -> None:
+        """Set an expiry date on a connection.
+
+        Add some randomness to avoid mass reconnection.
+        """
+        conn._expire_at = monotonic() + self._jitter(
+            self.max_lifetime, -0.05, 0.0
+        )
+
 
 class ConnectionAttempt:
     """Keep the state of a connection attempt."""
index beacc0fb43681f649c866bf26db2246f83b724f3..895067c56982b13783ee81d7882669626c69076b 100644 (file)
@@ -134,6 +134,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         """
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
+
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         with self._lock:
@@ -198,15 +199,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         it if you use the much more comfortable `connection()` context manager.
         """
         # Quick check to discard the wrong connection
-        pool = getattr(conn, "_pool", None)
-        if pool is not self:
-            if pool:
-                msg = f"it comes from pool {pool.name!r}"
-            else:
-                msg = "it doesn't come from any pool"
-            raise ValueError(
-                f"can't return connection to pool {self.name!r}, {msg}: {conn}"
-            )
+        self._check_pool_putconn(conn)
 
         logger.info("returning connection to %r", self.name)
 
@@ -237,6 +230,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
             self._check_open()
 
             self._start_workers()
+            self._start_initial_tasks()
 
             self._closed = False
             self._opened = True
@@ -262,6 +256,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         for t in self._workers:
             t.start()
 
+    def _start_initial_tasks(self) -> None:
         # populate the pool with initial min_size connections in background
         for i in range(self._nconns):
             self.run_task(AddConnection(self))
@@ -302,7 +297,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self,
         waiting_clients: Sequence["WaitingClient"] = (),
         connections: Sequence[Connection[Any]] = (),
-        timeout: float = 0,
+        timeout: float = 0.0,
     ) -> None:
 
         # Stop the scheduler
@@ -351,10 +346,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
     def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
         """Change the size of the pool during runtime."""
-        if max_size is None:
-            max_size = min_size
-        if max_size < min_size:
-            raise ValueError("max_size must be greater or equal than min_size")
+        min_size, max_size = self._check_size(min_size, max_size)
 
         ngrow = max(0, min_size - self._min_size)
 
@@ -467,16 +459,14 @@ class ConnectionPool(BasePool[Connection[Any]]):
             self._configure(conn)
             status = conn.pgconn.transaction_status
             if status != TransactionStatus.IDLE:
-                nstatus = TransactionStatus(status).name
+                sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
-                    f"connection left in status {nstatus} by configure function"
+                    f"connection left in status {sname} by configure function"
                     f" {self._configure}: discarded"
                 )
 
         # Set an expiry date, with some randomness to avoid mass reconnection
-        conn._expire_at = monotonic() + self._jitter(
-            self.max_lifetime, -0.05, 0.0
-        )
+        self._set_connection_expiry_date(conn)
         return conn
 
     def _add_connection(
@@ -614,9 +604,9 @@ class ConnectionPool(BasePool[Connection[Any]]):
                 self._reset(conn)
                 status = conn.pgconn.transaction_status
                 if status != TransactionStatus.IDLE:
-                    nstatus = TransactionStatus(status).name
+                    sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
-                        f"connection left in status {nstatus} by reset function"
+                        f"connection left in status {sname} by reset function"
                         f" {self._reset}: discarded"
                     )
             except Exception as ex:
index d52551b3caa475da43f044b27de99ad4bdffef42..297505e057d54b7d28472bfc255c247d06da5aae 100644 (file)
@@ -108,6 +108,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
     ) -> AsyncConnection[Any]:
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
+
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         async with self._lock:
@@ -165,16 +166,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         return conn
 
     async def putconn(self, conn: AsyncConnection[Any]) -> None:
-        # Quick check to discard the wrong connection
-        pool = getattr(conn, "_pool", None)
-        if pool is not self:
-            if pool:
-                msg = f"it comes from pool {pool.name!r}"
-            else:
-                msg = "it doesn't come from any pool"
-            raise ValueError(
-                f"can't return connection to pool {self.name!r}, {msg}: {conn}"
-            )
+        self._check_pool_putconn(conn)
 
         logger.info("returning connection to %r", self.name)
 
@@ -202,6 +194,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self._check_open()
 
         self._start_workers()
+        self._start_initial_tasks()
 
         self._closed = False
         self._opened = True
@@ -217,6 +210,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             )
             self._workers.append(t)
 
+    def _start_initial_tasks(self) -> None:
         # populate the pool with initial min_size connections in background
         for i in range(self._nconns):
             self.run_task(AddConnection(self))
@@ -247,7 +241,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self,
         waiting_clients: Sequence["AsyncClient"] = (),
         connections: Sequence[AsyncConnection[Any]] = (),
-        timeout: float = 0,
+        timeout: float = 0.0,
     ) -> None:
         # Stop the scheduler
         await self._sched.enter(0, None)
@@ -296,10 +290,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
     async def resize(
         self, min_size: int, max_size: Optional[int] = None
     ) -> None:
-        if max_size is None:
-            max_size = min_size
-        if max_size < min_size:
-            raise ValueError("max_size must be greater or equal than min_size")
+        min_size, max_size = self._check_size(min_size, max_size)
 
         ngrow = max(0, min_size - self._min_size)
 
@@ -400,16 +391,14 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             await self._configure(conn)
             status = conn.pgconn.transaction_status
             if status != TransactionStatus.IDLE:
-                nstatus = TransactionStatus(status).name
+                sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
-                    f"connection left in status {nstatus} by configure function"
+                    f"connection left in status {sname} by configure function"
                     f" {self._configure}: discarded"
                 )
 
         # Set an expiry date, with some randomness to avoid mass reconnection
-        conn._expire_at = monotonic() + self._jitter(
-            self.max_lifetime, -0.05, 0.0
-        )
+        self._set_connection_expiry_date(conn)
         return conn
 
     async def _add_connection(
@@ -547,9 +536,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 await self._reset(conn)
                 status = conn.pgconn.transaction_status
                 if status != TransactionStatus.IDLE:
-                    nstatus = TransactionStatus(status).name
+                    sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
-                        f"connection left in status {nstatus} by reset function"
+                        f"connection left in status {sname} by reset function"
                         f" {self._reset}: discarded"
                     )
             except Exception as ex: