]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): reset connection transaction status after failed check
authorKanav Kalucha <kanav@openai.com>
Tue, 25 Feb 2025 18:28:38 +0000 (10:28 -0800)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 25 Feb 2025 20:47:31 +0000 (21:47 +0100)
Close #1014

psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 8f479a7216ed8afcd24ff798b1756f4610bace03..e739d57055563c4ccfd422f4f21a7782dc7885c1 100644 (file)
@@ -675,6 +675,7 @@ class ConnectionPool(Generic[CT], BasePool):
         """
         Return a connection to the pool after usage.
         """
+        self._reset_connection(conn)
         if from_getconn:
             if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
                 self._stats[self._CONNECTIONS_LOST] += 1
@@ -682,14 +683,12 @@ class ConnectionPool(Generic[CT], BasePool):
                 self.run_task(AddConnection(self))
                 logger.info("not serving connection found broken")
                 return
-        else:
-            self._reset_connection(conn)
-            if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
-                self._stats[self._RETURNS_BAD] += 1
-                # Connection no more in working state: create a new one.
-                self.run_task(AddConnection(self))
-                logger.warning("discarding closed connection: %s", conn)
-                return
+        elif conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+            self._stats[self._RETURNS_BAD] += 1
+            # Connection no more in working state: create a new one.
+            self.run_task(AddConnection(self))
+            logger.warning("discarding closed connection: %s", conn)
+            return
 
         # Check if the connection is past its best before date
         if conn._expire_at <= monotonic():
index f066a9223759a355ed5a23014427c026c99bb569..1b682a21b70721937757db9d4c9a33b009848ce7 100644 (file)
@@ -727,6 +727,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         """
         Return a connection to the pool after usage.
         """
+        await self._reset_connection(conn)
         if from_getconn:
             if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
                 self._stats[self._CONNECTIONS_LOST] += 1
@@ -736,7 +737,6 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                 return
 
         else:
-            await self._reset_connection(conn)
             if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
                 self._stats[self._RETURNS_BAD] += 1
                 # Connection no more in working state: create a new one.
index 9aeed8f5368b45a00ef11eba3538ee1cb5bfd1bc..fc8782e808a7d30cf9d9c6a5107a46079cc23e58 100644 (file)
@@ -1015,3 +1015,25 @@ def test_check_backoff(dsn, caplog, monkeypatch):
     for delta in deltas:
         assert delta == pytest.approx(want, 0.05), deltas
         want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("status", ["ERROR", "INTRANS"])
+def test_check_returns_an_ok_connection(dsn, status):
+
+    def check(conn):
+        if status == "ERROR":
+            conn.execute("wat")
+        elif status == "INTRANS":
+            conn.execute("select 1")
+            1 / 0
+        else:
+            assert False
+
+    with pool.ConnectionPool(dsn, min_size=1, check=check) as p:
+        p.wait(1.0)
+        with pytest.raises(pool.PoolTimeout):
+            conn = p.getconn(0.5)
+
+        conn = list(p._pool)[0]
+        assert conn.info.transaction_status == TransactionStatus.IDLE
index 525905a158b03f62820ffc6374fdf9e73e838f1d..0c38f0298adb2258c3f4ad7ba76f1f8ab0f4cac8 100644 (file)
@@ -1018,3 +1018,24 @@ async def test_check_backoff(dsn, caplog, monkeypatch):
     for delta in deltas:
         assert delta == pytest.approx(want, 0.05), deltas
         want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("status", ["ERROR", "INTRANS"])
+async def test_check_returns_an_ok_connection(dsn, status):
+    async def check(conn):
+        if status == "ERROR":
+            await conn.execute("wat")
+        elif status == "INTRANS":
+            await conn.execute("select 1")
+            1 / 0
+        else:
+            assert False
+
+    async with pool.AsyncConnectionPool(dsn, min_size=1, check=check) as p:
+        await p.wait(1.0)
+        with pytest.raises(pool.PoolTimeout):
+            conn = await p.getconn(0.5)
+
+        conn = list(p._pool)[0]
+        assert conn.info.transaction_status == TransactionStatus.IDLE