]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): reset connection transaction status after failed check
authorKanav Kalucha <kanav@openai.com>
Tue, 25 Feb 2025 18:28:38 +0000 (10:28 -0800)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 25 Feb 2025 20:36:25 +0000 (21:36 +0100)
Close #1014

psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 9da5890f86eb8d4a322db1ec0838e22d03137ff3..f6a7139f347517716818f468fcf8dae853f4f805 100644 (file)
@@ -675,6 +675,7 @@ class ConnectionPool(Generic[CT], BasePool):
         """
         Return a connection to the pool after usage.
         """
+        self._reset_connection(conn)
         if from_getconn:
             if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
                 self._stats[self._CONNECTIONS_LOST] += 1
@@ -682,14 +683,12 @@ class ConnectionPool(Generic[CT], BasePool):
                 self.run_task(AddConnection(self))
                 logger.info("not serving connection found broken")
                 return
-        else:
-            self._reset_connection(conn)
-            if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
-                self._stats[self._RETURNS_BAD] += 1
-                # Connection no more in working state: create a new one.
-                self.run_task(AddConnection(self))
-                logger.warning("discarding closed connection: %s", conn)
-                return
+        elif conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+            self._stats[self._RETURNS_BAD] += 1
+            # Connection no more in working state: create a new one.
+            self.run_task(AddConnection(self))
+            logger.warning("discarding closed connection: %s", conn)
+            return
 
         # Check if the connection is past its best before date
         if conn._expire_at <= monotonic():
index c17f5f54264d83093fd32e4dd8fa9e547a5037e4..a8258a4da8524029082424889a3860d284dbabfd 100644 (file)
@@ -728,6 +728,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         """
         Return a connection to the pool after usage.
         """
+        await self._reset_connection(conn)
         if from_getconn:
             if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
                 self._stats[self._CONNECTIONS_LOST] += 1
@@ -737,7 +738,6 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                 return
 
         else:
-            await self._reset_connection(conn)
             if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
                 self._stats[self._RETURNS_BAD] += 1
                 # Connection no more in working state: create a new one.
index 3baf19ec94b1fec1633ea8312c4897f3e0ee8296..076a976dc137a28bbc71e44c698cdefd5eaaaeb4 100644 (file)
@@ -1018,3 +1018,25 @@ def test_check_backoff(dsn, caplog, monkeypatch):
     for delta in deltas:
         assert delta == pytest.approx(want, 0.05), deltas
         want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("status", ["ERROR", "INTRANS"])
+def test_check_returns_an_ok_connection(dsn, status):
+
+    def check(conn):
+        if status == "ERROR":
+            conn.execute("wat")
+        elif status == "INTRANS":
+            conn.execute("select 1")
+            1 / 0
+        else:
+            assert False
+
+    with pool.ConnectionPool(dsn, min_size=1, check=check) as p:
+        p.wait(1.0)
+        with pytest.raises(pool.PoolTimeout):
+            conn = p.getconn(0.5)
+
+        conn = list(p._pool)[0]
+        assert conn.info.transaction_status == TransactionStatus.IDLE
index 100e067c43d6d2cc60f6e93e06bbe71039d3f339..aaf780994b805c79b94f17c5f608de0b5b3f9969 100644 (file)
@@ -1021,3 +1021,24 @@ async def test_check_backoff(dsn, caplog, monkeypatch):
     for delta in deltas:
         assert delta == pytest.approx(want, 0.05), deltas
         want *= 2
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("status", ["ERROR", "INTRANS"])
+async def test_check_returns_an_ok_connection(dsn, status):
+    async def check(conn):
+        if status == "ERROR":
+            await conn.execute("wat")
+        elif status == "INTRANS":
+            await conn.execute("select 1")
+            1 / 0
+        else:
+            assert False
+
+    async with pool.AsyncConnectionPool(dsn, min_size=1, check=check) as p:
+        await p.wait(1.0)
+        with pytest.raises(pool.PoolTimeout):
+            conn = await p.getconn(0.5)
+
+        conn = list(p._pool)[0]
+        assert conn.info.transaction_status == TransactionStatus.IDLE