From: Daniele Varrazzo Date: Mon, 8 Mar 2021 03:43:33 +0000 (+0100) Subject: Add reset callback to connection pool X-Git-Tag: 3.0.dev0~87^2~16 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=60b27edf2428184f6195b0d1f71ec86a36b020a0;p=thirdparty%2Fpsycopg.git Add reset callback to connection pool --- diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index ea6a38c93..7ec69ad6e 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -34,6 +34,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): configure: Optional[ Callable[[AsyncConnection], Awaitable[None]] ] = None, + reset: Optional[Callable[[AsyncConnection], Awaitable[None]]] = None, **kwargs: Any, ): # https://bugs.python.org/issue42600 @@ -43,6 +44,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): ) self._configure = configure + self._reset = reset self._lock = asyncio.Lock() self._waiting: Deque["AsyncClient"] = deque() @@ -467,9 +469,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): """ status = conn.pgconn.transaction_status if status == TransactionStatus.IDLE: - return + pass - if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): + elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR): # Connection returned with an active transaction logger.warning("rolling back returned connection: %s", conn) try: @@ -488,6 +490,20 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): logger.warning("closing returned connection: %s", conn) await conn.close() + if not conn.closed and self._reset: + try: + await self._reset(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + nstatus = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {nstatus} by reset function" + f" {self._reset}: discarded" + ) + except Exception as ex: + logger.warning(f"error resetting connection: {ex}") + await conn.close() + async def _shrink_pool(self) -> None: to_close: Optional[AsyncConnection] = None diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index 5f0a8e0a5..65ad83857 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -31,9 +31,11 @@ class ConnectionPool(BasePool[Connection]): self, conninfo: str = "", configure: Optional[Callable[[Connection], None]] = None, + reset: Optional[Callable[[Connection], None]] = None, **kwargs: Any, ): self._configure = configure + self._reset = reset self._lock = threading.RLock() self._waiting: Deque["WaitingClient"] = deque() @@ -503,6 +505,20 @@ class ConnectionPool(BasePool[Connection]): logger.warning("closing returned connection: %s", conn) conn.close() + if not conn.closed and self._reset: + try: + self._reset(conn) + status = conn.pgconn.transaction_status + if status != TransactionStatus.IDLE: + nstatus = TransactionStatus(status).name + raise e.ProgrammingError( + f"connection left in status {nstatus} by reset function" + f" {self._reset}: discarded" + ) + except Exception as ex: + logger.warning(f"error resetting connection: {ex}") + conn.close() + def _shrink_pool(self) -> None: to_close: Optional[Connection] = None diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index e0786fed3..212776d47 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -185,6 +185,76 @@ def test_configure_broken(dsn, caplog): assert "WAT" in caplog.records[0].message +def test_reset(dsn): + resets = 0 + + def setup(conn): + with conn.transaction(): + conn.execute("set timezone to '+1:00'") + + def reset(conn): + nonlocal resets + resets += 1 + with conn.transaction(): + conn.execute("set timezone to utc") + + with pool.ConnectionPool(minconn=1, reset=reset) as p: + with p.connection() as conn: + assert resets == 0 + conn.execute("set timezone to '+2:00'") + + p.wait() + assert resets == 1 + + with p.connection() as conn: + with conn.execute("show timezone") as cur: + assert cur.fetchone() == ("UTC",) + + p.wait() + assert resets == 2 + + +def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3.pool") + + def reset(conn): + conn.execute("reset all") + + with pool.ConnectionPool(minconn=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.pgconn.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.pgconn.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3.pool") + + def reset(conn): + with conn.transaction(): + conn.execute("WAT") + + with pool.ConnectionPool(minconn=1, reset=reset) as p: + with p.connection() as conn: + conn.execute("select 1") + pid1 = conn.pgconn.backend_pid + + with p.connection() as conn: + conn.execute("select 1") + pid2 = conn.pgconn.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + @pytest.mark.slow def test_queue(dsn, retries): def worker(n): diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 829f723ed..d9436c937 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -201,6 +201,76 @@ async def test_configure_broken(dsn, caplog): assert "WAT" in caplog.records[0].message +async def test_reset(dsn): + resets = 0 + + async def setup(conn): + async with conn.transaction(): + await conn.execute("set timezone to '+1:00'") + + async def reset(conn): + nonlocal resets + resets += 1 + async with conn.transaction(): + await conn.execute("set timezone to utc") + + async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p: + async with p.connection() as conn: + assert resets == 0 + await conn.execute("set timezone to '+2:00'") + + await p.wait() + assert resets == 1 + + async with p.connection() as conn: + cur = await conn.execute("show timezone") + assert (await cur.fetchone()) == ("UTC",) + + await p.wait() + assert resets == 2 + + +async def test_reset_badstate(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3.pool") + + async def reset(conn): + await conn.execute("reset all") + + async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.pgconn.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.pgconn.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "INTRANS" in caplog.records[0].message + + +async def test_reset_broken(dsn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3.pool") + + async def reset(conn): + async with conn.transaction(): + await conn.execute("WAT") + + async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p: + async with p.connection() as conn: + await conn.execute("select 1") + pid1 = conn.pgconn.backend_pid + + async with p.connection() as conn: + await conn.execute("select 1") + pid2 = conn.pgconn.backend_pid + + assert pid1 != pid2 + assert caplog.records + assert "WAT" in caplog.records[0].message + + @pytest.mark.slow async def test_queue(dsn, retries): async def worker(n):