From 9a7201231b543fad9f732c0125fec031cb50c200 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 8 Mar 2021 03:25:57 +0100 Subject: [PATCH] Add AsyncConnectionPool.check() --- psycopg3/psycopg3/pool/async_pool.py | 20 ++++++++++++++++++++ tests/pool/test_pool_async.py | 21 +++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index ca41c21ff..592e5bc63 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -319,6 +319,26 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): for i in range(ngrow): self.run_task(AddConnection(self)) + async def check(self) -> None: + """Verify the state of the connections currently in the pool. + + Test each connection: if it works return it to the pool, otherwise + dispose of it and create a new one. + """ + async with self._lock: + conns = list(self._pool) + self._pool.clear() + + while conns: + conn = conns.pop() + try: + await conn.execute("select 1") + except Exception: + logger.warning("discarding broken connection: %s", conn) + self.run_task(AddConnection(self)) + else: + await self._add_to_pool(conn) + async def configure(self, conn: AsyncConnection) -> None: """Configure a connection after creation.""" if self._configure: diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 24eb38e67..090530487 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -722,6 +722,27 @@ async def test_max_lifetime(dsn): assert pids[0] == pids[1] != pids[4], pids +async def test_check(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3.pool") + async with pool.AsyncConnectionPool(dsn, minconn=4) as p: + await p.wait_ready(1.0) + async with p.connection() as conn: + pid = conn.pgconn.backend_pid + + await p.wait_ready(1.0) + pids = set(conn.pgconn.backend_pid for conn in p._pool) + assert pid in pids + await conn.close() + + assert len(caplog.records) == 0 + await p.check() + assert len(caplog.records) == 1 + await p.wait_ready(1.0) + pids2 = set(conn.pgconn.backend_pid for conn in p._pool) + assert len(pids & pids2) == 3 + assert pid not in pids2 + + def delay_connection(monkeypatch, sec): """ Return a _connect_gen function delayed by the amount of seconds -- 2.47.2