]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(pool): add 'drain()' method pool-drain
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 11 Jul 2025 20:02:11 +0000 (22:02 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 11 Jul 2025 20:02:11 +0000 (22:02 +0200)
This changeset introduces a mapping from connection ids to connections
for the connections currently out of the pool. It might be useful for
further refactoring.

psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py

index e708a4c6034af783e91a41a6bd5a74be7b08baa5..fb222e44a930a598decba2a68766f4ef79088629 100644 (file)
@@ -41,6 +41,7 @@ class BasePool:
     _CONNECTIONS_LOST = "connections_lost"
 
     _pool: deque[Any]
+    _given: dict[int, Any]
 
     def __init__(
         self,
@@ -80,6 +81,7 @@ class BasePool:
 
         self._nconns = min_size  # currently in the pool, out, being prepared
         self._pool = deque()
+        self._given = {}
         self._stats = Counter[str]()
 
         # Min number of connections in the pool in a max_idle unit of time.
index 8626c6c0989600621abeb117305e7972a7f4d2e3..a1ccf04c2d186a9bc2dd989280a00823691f5b43 100644 (file)
@@ -37,6 +37,7 @@ logger = logging.getLogger("psycopg.pool")
 
 class ConnectionPool(Generic[CT], BasePool):
     _pool: deque[CT]
+    _given: dict[int, CT]
 
     def __init__(
         self,
@@ -271,6 +272,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if self._pool:
             # Take a connection ready out of the pool
             conn = self._pool.popleft()
+            self._given[id(conn)] = conn
             if len(self._pool) < self._nconns_min:
                 self._nconns_min = len(self._pool)
         elif self.max_waiting and len(self._waiting) >= self.max_waiting:
@@ -314,6 +316,31 @@ class ConnectionPool(Generic[CT], BasePool):
 
         self._putconn(conn, from_getconn=False)
 
+    def drain(self) -> None:
+        """
+        Remove all the connections from the pool and create new ones.
+
+        If a connection is currently out of the pool it will be closed when
+        returned to the pool and replaced with a new one.
+
+        This method is useful to force a connection re-configuration, for
+        example when the adapters map changes after the pool was created.
+        """
+        with self._lock:
+            conns = list(self._pool)
+            self._pool.clear()
+
+            # Mark the currently given connections as already expired,
+            # so they will be closed as soon as returned.
+            earlier = monotonic() - 1.0
+            for conn in self._given.values():
+                conn._expire_at = earlier
+
+        # Close the connection already in the pool, open new ones.
+        for conn in conns:
+            conn.close()
+            self.run_task(AddConnection(self))
+
     def _putconn(self, conn: CT, from_getconn: bool) -> None:
         # Use a worker to perform eventual maintenance work in a separate task
         if self._reset:
@@ -736,6 +763,7 @@ class ConnectionPool(Generic[CT], BasePool):
                     break
             else:
                 # No client waiting for a connection: put it back into the pool
+                self._given.pop(id(conn), None)
                 self._pool.append(conn)
                 # If we have been asked to wait for pool init, notify the
                 # waiter if the pool is full.
index 5f882d4eea436d06be8c8ed35587c268fea550d2..11ed143fb6df5f71b193d5670dea763bdf9a5cc6 100644 (file)
@@ -37,6 +37,7 @@ logger = logging.getLogger("psycopg.pool")
 
 class AsyncConnectionPool(Generic[ACT], BasePool):
     _pool: deque[ACT]
+    _given: dict[int, ACT]
 
     def __init__(
         self,
@@ -299,6 +300,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if self._pool:
             # Take a connection ready out of the pool
             conn = self._pool.popleft()
+            self._given[id(conn)] = conn
             if len(self._pool) < self._nconns_min:
                 self._nconns_min = len(self._pool)
         elif self.max_waiting and len(self._waiting) >= self.max_waiting:
@@ -343,6 +345,31 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         await self._putconn(conn, from_getconn=False)
 
+    async def drain(self) -> None:
+        """
+        Remove all the connections from the pool and create new ones.
+
+        If a connection is currently out of the pool it will be closed when
+        returned to the pool and replaced with a new one.
+
+        This method is useful to force a connection re-configuration, for
+        example when the adapters map changes after the pool was created.
+        """
+        async with self._lock:
+            conns = list(self._pool)
+            self._pool.clear()
+
+            # Mark the currently given connections as already expired,
+            # so they will be closed as soon as returned.
+            earlier = monotonic() - 1.0
+            for conn in self._given.values():
+                conn._expire_at = earlier
+
+        # Close the connection already in the pool, open new ones.
+        for conn in conns:
+            await conn.close()
+            self.run_task(AddConnection(self))
+
     async def _putconn(self, conn: ACT, from_getconn: bool) -> None:
         # Use a worker to perform eventual maintenance work in a separate task
         if self._reset:
@@ -786,6 +813,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                     break
             else:
                 # No client waiting for a connection: put it back into the pool
+                self._given.pop(id(conn), None)
                 self._pool.append(conn)
                 # If we have been asked to wait for pool init, notify the
                 # waiter if the pool is full.
index 3c21c098009e5a3c2cd698588d23871768cad2a9..ddfbb8162ac4d5edfe783327a21ef55eea48cce6 100644 (file)
@@ -638,6 +638,30 @@ def test_check_timeout(pool_cls, dsn):
     assert time() - t0 <= 1.5
 
 
+@pytest.mark.crdb_skip("backend pid")
+def test_drain(pool_cls, dsn):
+    pids1 = set()
+    pids2 = set()
+    with pool_cls(dsn, min_size=min_size(pool_cls, 2)) as p:
+        p.wait()
+
+        with p.connection() as conn:
+            pids1.add(conn.info.backend_pid)
+            with p.connection() as conn2:
+                pids1.add(conn2.info.backend_pid)
+                p.drain()
+        assert len(pids1) == 2
+
+        with p.connection() as conn:
+            pids2.add(conn.info.backend_pid)
+            with p.connection() as conn2:
+                pids2.add(conn2.info.backend_pid)
+
+        assert len(pids2) == 2
+
+    assert not pids1 & pids2
+
+
 @skip_sync
 def test_cancellation_in_queue(pool_cls, dsn):
     # https://github.com/psycopg/psycopg/issues/509
index 068064d684e9bb5bd012bffc65dd5834bd768df4..872054830a8970f4ebc6ebe51bd0a545ee1bac3d 100644 (file)
@@ -648,6 +648,30 @@ async def test_check_timeout(pool_cls, dsn):
     assert time() - t0 <= 1.5
 
 
+@pytest.mark.crdb_skip("backend pid")
+async def test_drain(pool_cls, dsn):
+    pids1 = set()
+    pids2 = set()
+    async with pool_cls(dsn, min_size=min_size(pool_cls, 2)) as p:
+        await p.wait()
+
+        async with p.connection() as conn:
+            pids1.add(conn.info.backend_pid)
+            async with p.connection() as conn2:
+                pids1.add(conn2.info.backend_pid)
+                await p.drain()
+        assert len(pids1) == 2
+
+        async with p.connection() as conn:
+            pids2.add(conn.info.backend_pid)
+            async with p.connection() as conn2:
+                pids2.add(conn2.info.backend_pid)
+
+        assert len(pids2) == 2
+
+    assert not pids1 & pids2
+
+
 @skip_sync
 async def test_cancellation_in_queue(pool_cls, dsn):
     # https://github.com/psycopg/psycopg/issues/509