]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(pool): add `drain()` method 1215/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 11 Jul 2025 20:02:11 +0000 (22:02 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 24 Nov 2025 21:02:13 +0000 (22:02 +0100)
Add the creation timestamp to the connection to verify that it should be
immediately discarded on return.

docs/api/pool.rst
docs/news_pool.rst
psycopg/psycopg/_connection_base.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py

index f06fa12e2c5a3499c6c5ff9d35f7d6a2b864014b..ca07b1e87ec0030b15e1adf5b09b2ca7d5af2d08 100644 (file)
@@ -248,6 +248,10 @@ The `!ConnectionPool` class
 
       .. versionadded:: 3.2
 
+   .. automethod:: drain
+
+      .. versionadded:: 3.3
+
    .. automethod:: get_stats
    .. automethod:: pop_stats
 
index 9136e6d65491929c130694ebf11d5627bdd95213..657cf1b34bc9556034e4f741984fe519970a1134 100644 (file)
@@ -13,10 +13,11 @@ Future releases
 psycopg_pool 3.3.0 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- Add `!close_returns` for :ref:`integration with SQLAlchemy <pool-sqlalchemy>`
-  (:ticket:`#1046`).
-- Allow `!conninfo` and `!kwargs` to be callable to allow connection
-  parameters# update (:ticket:`#851`).
+- Add `~ConnectionPool.drain()` method (:ticket:`#1215`).
+- Allow the `!conninfo` and `!kwargs` `ConnectionPool` parameters to be callable
+  to allow connection parameters update (:ticket:`#851`).
+- Add `!close_returns` `ConnectionPool` parameter for :ref:`integration with
+  SQLAlchemy <pool-sqlalchemy>` (:ticket:`#1046`).
 
 
 Current release
index 07b7e523adfe1e3f9c0c3f979725265b98d62413..566ee4ab1f2231b820fe2c96e1d3ac67171c54c6 100644 (file)
@@ -127,6 +127,8 @@ class BaseConnection(Generic[Row]):
 
         self._pipeline: BasePipeline | None = None
 
+        # Time when the connection was created (currently only used by the pool)
+        self._created_at: float
         # Time after which the connection should be closed
         self._expire_at: float
 
index 745c995c78136c57b9c7c32115d4538fae737835..9f68098fb0292e206353c980e4ba8a5861ed054c 100644 (file)
@@ -79,6 +79,7 @@ class BasePool:
         self._nconns = min_size  # currently in the pool, out, being prepared
         self._pool = deque()
         self._stats = Counter[str]()
+        self._drained_at = 0.0
 
         # Min number of connections in the pool in a max_idle unit of time.
         # It is reset periodically by the ShrinkPool scheduled task.
@@ -196,7 +197,8 @@ class BasePool:
 
         Add some randomness to avoid mass reconnection.
         """
-        conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0)
+        conn._created_at = t = monotonic()
+        conn._expire_at = t + self._jitter(self.max_lifetime, -0.05, 0.0)
 
 
 class AttemptWithBackoff:
index a5f062fdaf384bde451e2d3cf6f8a47c58f9357e..245907f67a05450978b99fe3981ea4e6024e6cd0 100644 (file)
@@ -325,6 +325,26 @@ class ConnectionPool(Generic[CT], BasePool):
 
         self._putconn(conn, from_getconn=False)
 
+    def drain(self) -> None:
+        """
+        Remove all the connections from the pool and create new ones.
+
+        If a connection is currently out of the pool it will be closed when
+        returned to the pool and replaced with a new one.
+
+        This method is useful to force a connection re-configuration, for
+        example when the adapters map changes after the pool was created.
+        """
+        with self._lock:
+            conns = list(self._pool)
+            self._pool.clear()
+            self._drained_at = monotonic()
+
+        # Close the connection already in the pool, open new ones.
+        for conn in conns:
+            self._close_connection(conn)
+            self.run_task(AddConnection(self))
+
     def _putconn(self, conn: CT, from_getconn: bool) -> None:
         # Use a worker to perform eventual maintenance work in a separate task
         if self._reset:
@@ -715,7 +735,7 @@ class ConnectionPool(Generic[CT], BasePool):
             return
 
         # Check if the connection is past its best before date
-        if conn._expire_at <= monotonic():
+        if conn._created_at <= self._drained_at or conn._expire_at <= monotonic():
             logger.info("discarding expired connection")
             self._close_connection(conn)
             self.run_task(AddConnection(self))
index 6ea1b3c8850533635d21bfb52be4136250d55d74..254e1367d17eb95c38a7f1c81f5ff97c5818b857 100644 (file)
@@ -363,6 +363,26 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         await self._putconn(conn, from_getconn=False)
 
+    async def drain(self) -> None:
+        """
+        Remove all the connections from the pool and create new ones.
+
+        If a connection is currently out of the pool it will be closed when
+        returned to the pool and replaced with a new one.
+
+        This method is useful to force a connection re-configuration, for
+        example when the adapters map changes after the pool was created.
+        """
+        async with self._lock:
+            conns = list(self._pool)
+            self._pool.clear()
+            self._drained_at = monotonic()
+
+        # Close the connection already in the pool, open new ones.
+        for conn in conns:
+            await self._close_connection(conn)
+            self.run_task(AddConnection(self))
+
     async def _putconn(self, conn: ACT, from_getconn: bool) -> None:
         # Use a worker to perform eventual maintenance work in a separate task
         if self._reset:
@@ -773,7 +793,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                 return
 
         # Check if the connection is past its best before date
-        if conn._expire_at <= monotonic():
+        if conn._created_at <= self._drained_at or conn._expire_at <= monotonic():
             logger.info("discarding expired connection")
             await self._close_connection(conn)
             self.run_task(AddConnection(self))
index 259e7d3ac6dc76ad7e23d0696b2aa5afe1db5f96..1d0732427bb81b9a5c2bc1b3d3e336f3c3d4e00e 100644 (file)
@@ -746,6 +746,40 @@ def test_cancel_on_rollback(pool_cls, dsn, monkeypatch):
             assert cur.fetchone() == (3,)
 
 
+@pytest.mark.crdb_skip("backend pid")
+def test_drain(pool_cls, dsn):
+    pids1 = set()
+    pids2 = set()
+    pids3 = set()
+    with pool_cls(dsn, min_size=min_size(pool_cls, 2)) as p:
+        p.wait()
+
+        with p.connection() as conn:
+            pids1.add(conn.info.backend_pid)
+            with p.connection() as conn2:
+                pids1.add(conn2.info.backend_pid)
+                p.drain()
+        assert len(pids1) == 2
+
+        with p.connection() as conn:
+            pids2.add(conn.info.backend_pid)
+            with p.connection() as conn2:
+                pids2.add(conn2.info.backend_pid)
+
+        assert len(pids2) == 2
+
+        assert not pids1 & pids2
+
+        with p.connection() as conn:
+            pids3.add(conn.info.backend_pid)
+            with p.connection() as conn2:
+                pids3.add(conn2.info.backend_pid)
+
+        assert len(pids3) == 2
+        if pool_cls is not pool.NullConnectionPool:
+            assert pids2 == pids3
+
+
 def min_size(pool_cls, num=1):
     """Return the minimum min_size supported by the pool class."""
     if pool_cls is pool.ConnectionPool:
index 3dc01a238010e30e1bad1c69ffb232224062ebd8..22e56ab19fa312872d30c9a92b773627795e3b5b 100644 (file)
@@ -759,6 +759,40 @@ async def test_cancel_on_rollback(pool_cls, dsn, monkeypatch):
             assert (await cur.fetchone()) == (3,)
 
 
+@pytest.mark.crdb_skip("backend pid")
+async def test_drain(pool_cls, dsn):
+    pids1 = set()
+    pids2 = set()
+    pids3 = set()
+    async with pool_cls(dsn, min_size=min_size(pool_cls, 2)) as p:
+        await p.wait()
+
+        async with p.connection() as conn:
+            pids1.add(conn.info.backend_pid)
+            async with p.connection() as conn2:
+                pids1.add(conn2.info.backend_pid)
+                await p.drain()
+        assert len(pids1) == 2
+
+        async with p.connection() as conn:
+            pids2.add(conn.info.backend_pid)
+            async with p.connection() as conn2:
+                pids2.add(conn2.info.backend_pid)
+
+        assert len(pids2) == 2
+
+        assert not pids1 & pids2
+
+        async with p.connection() as conn:
+            pids3.add(conn.info.backend_pid)
+            async with p.connection() as conn2:
+                pids3.add(conn2.info.backend_pid)
+
+        assert len(pids3) == 2
+        if pool_cls is not pool.AsyncNullConnectionPool:
+            assert pids2 == pids3
+
+
 def min_size(pool_cls, num=1):
     """Return the minimum min_size supported by the pool class."""
     if pool_cls is pool.AsyncConnectionPool: