From: Kanav Kalucha Date: Tue, 25 Feb 2025 18:28:38 +0000 (-0800) Subject: fix(pool): reset connection transaction status after failed check X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=3d464f90f2e310709ca397680da9103e6fa3fe0c;p=thirdparty%2Fpsycopg.git fix(pool): reset connection transaction status after failed check Close #1014 --- diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 9da5890f8..f6a7139f3 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -675,6 +675,7 @@ class ConnectionPool(Generic[CT], BasePool): """ Return a connection to the pool after usage. """ + self._reset_connection(conn) if from_getconn: if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: self._stats[self._CONNECTIONS_LOST] += 1 @@ -682,14 +683,12 @@ class ConnectionPool(Generic[CT], BasePool): self.run_task(AddConnection(self)) logger.info("not serving connection found broken") return - else: - self._reset_connection(conn) - if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: - self._stats[self._RETURNS_BAD] += 1 - # Connection no more in working state: create a new one. - self.run_task(AddConnection(self)) - logger.warning("discarding closed connection: %s", conn) - return + elif conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: + self._stats[self._RETURNS_BAD] += 1 + # Connection no more in working state: create a new one. + self.run_task(AddConnection(self)) + logger.warning("discarding closed connection: %s", conn) + return # Check if the connection is past its best before date if conn._expire_at <= monotonic(): diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index c17f5f542..a8258a4da 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -728,6 +728,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): """ Return a connection to the pool after usage. """ + await self._reset_connection(conn) if from_getconn: if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: self._stats[self._CONNECTIONS_LOST] += 1 @@ -737,7 +738,6 @@ class AsyncConnectionPool(Generic[ACT], BasePool): return else: - await self._reset_connection(conn) if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: self._stats[self._RETURNS_BAD] += 1 # Connection no more in working state: create a new one. diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 3baf19ec9..076a976dc 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -1018,3 +1018,25 @@ def test_check_backoff(dsn, caplog, monkeypatch): for delta in deltas: assert delta == pytest.approx(want, 0.05), deltas want *= 2 + + +@pytest.mark.slow +@pytest.mark.parametrize("status", ["ERROR", "INTRANS"]) +def test_check_returns_an_ok_connection(dsn, status): + + def check(conn): + if status == "ERROR": + conn.execute("wat") + elif status == "INTRANS": + conn.execute("select 1") + 1 / 0 + else: + assert False + + with pool.ConnectionPool(dsn, min_size=1, check=check) as p: + p.wait(1.0) + with pytest.raises(pool.PoolTimeout): + conn = p.getconn(0.5) + + conn = list(p._pool)[0] + assert conn.info.transaction_status == TransactionStatus.IDLE diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 100e067c4..aaf780994 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -1021,3 +1021,24 @@ async def test_check_backoff(dsn, caplog, monkeypatch): for delta in deltas: assert delta == pytest.approx(want, 0.05), deltas want *= 2 + + +@pytest.mark.slow +@pytest.mark.parametrize("status", ["ERROR", "INTRANS"]) +async def test_check_returns_an_ok_connection(dsn, status): + async def check(conn): + if status == "ERROR": + await conn.execute("wat") + elif status == "INTRANS": + await conn.execute("select 1") + 1 / 0 + else: + assert False + + async with pool.AsyncConnectionPool(dsn, min_size=1, check=check) as p: + await p.wait(1.0) + with pytest.raises(pool.PoolTimeout): + conn = await p.getconn(0.5) + + conn = list(p._pool)[0] + assert conn.info.transaction_status == TransactionStatus.IDLE