]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add AsyncConnectionPool.check()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Mar 2021 02:25:57 +0000 (03:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/async_pool.py
tests/pool/test_pool_async.py

index ca41c21ff06c6e9cce8fa7897db56339a696fce6..592e5bc6314d40f316d96a14e984a279425c2d44 100644 (file)
@@ -319,6 +319,26 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         for i in range(ngrow):
             self.run_task(AddConnection(self))
 
+    async def check(self) -> None:
+        """Verify the state of the connections currently in the pool.
+
+        Test each connection: if it works return it to the pool, otherwise
+        dispose of it and create a new one.
+        """
+        async with self._lock:
+            conns = list(self._pool)
+            self._pool.clear()
+
+        while conns:
+            conn = conns.pop()
+            try:
+                await conn.execute("select 1")
+            except Exception:
+                logger.warning("discarding broken connection: %s", conn)
+                self.run_task(AddConnection(self))
+            else:
+                await self._add_to_pool(conn)
+
     async def configure(self, conn: AsyncConnection) -> None:
         """Configure a connection after creation."""
         if self._configure:
index 24eb38e67b8ee679e3c5e212e05c0076e46397e9..090530487cd298813a6a29d79a497e40fe1e8af7 100644 (file)
@@ -722,6 +722,27 @@ async def test_max_lifetime(dsn):
     assert pids[0] == pids[1] != pids[4], pids
 
 
+async def test_check(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    async with pool.AsyncConnectionPool(dsn, minconn=4) as p:
+        await p.wait_ready(1.0)
+        async with p.connection() as conn:
+            pid = conn.pgconn.backend_pid
+
+        await p.wait_ready(1.0)
+        pids = set(conn.pgconn.backend_pid for conn in p._pool)
+        assert pid in pids
+        await conn.close()
+
+        assert len(caplog.records) == 0
+        await p.check()
+        assert len(caplog.records) == 1
+        await p.wait_ready(1.0)
+        pids2 = set(conn.pgconn.backend_pid for conn in p._pool)
+        assert len(pids & pids2) == 3
+        assert pid not in pids2
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds