]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): fix infinite loop with close_returns=True sqlalchemy-nullpool 1067/head
authorbash000000 <m2588953@outlook.com>
Sat, 26 Jul 2025 13:44:49 +0000 (21:44 +0800)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 6 Aug 2025 13:32:13 +0000 (15:32 +0200)
Close #1124

psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index e67a32f249b2a01b75cbdf34bad60392ef6f58bb..d8411d070ea93b51e5fa5c7b528092f84d86f336 100644 (file)
@@ -449,6 +449,7 @@ class ConnectionPool(Generic[CT], BasePool):
 
         # Close the connections that were still in the pool
         for conn in connections:
+            conn._pool = None
             conn.close()
 
         # Signal to eventual clients in the queue that business is closed.
@@ -521,6 +522,7 @@ class ConnectionPool(Generic[CT], BasePool):
             # Check for expired connections
             if conn._expire_at <= monotonic():
                 logger.info("discarding expired connection %s", conn)
+                conn._pool = None
                 conn.close()
                 self.run_task(AddConnection(self))
                 continue
@@ -700,6 +702,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if conn._expire_at <= monotonic():
             self.run_task(AddConnection(self))
             logger.info("discarding expired connection")
+            conn._pool = None
             conn.close()
             return
 
@@ -773,10 +776,12 @@ class ConnectionPool(Generic[CT], BasePool):
                     ex,
                     conn,
                 )
+                conn._pool = None
                 conn.close()
         elif status == TransactionStatus.ACTIVE:
             # Connection returned during an operation. Bad... just close it.
             logger.warning("closing returned connection: %s", conn)
+            conn._pool = None
             conn.close()
 
         if self._reset:
@@ -789,6 +794,7 @@ class ConnectionPool(Generic[CT], BasePool):
                     )
             except Exception as ex:
                 logger.warning(f"error resetting connection: {ex}")
+                conn._pool = None
                 conn.close()
 
     def _shrink_pool(self) -> None:
@@ -813,6 +819,7 @@ class ConnectionPool(Generic[CT], BasePool):
                 nconns_min,
                 self.max_idle,
             )
+            to_close._pool = None
             to_close.close()
 
     def _get_measures(self) -> dict[str, int]:
index 5731de9af1cd17a998881b1f63d8e3befbf3bc68..99e6b00bb684c55082c009453ebd8cd4cbc9d2df 100644 (file)
@@ -488,6 +488,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         # Close the connections that were still in the pool
         for conn in connections:
+            conn._pool = None
             await conn.close()
 
         # Signal to eventual clients in the queue that business is closed.
@@ -560,6 +561,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             # Check for expired connections
             if conn._expire_at <= monotonic():
                 logger.info("discarding expired connection %s", conn)
+                conn._pool = None
                 await conn.close()
                 self.run_task(AddConnection(self))
                 continue
@@ -752,6 +754,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if conn._expire_at <= monotonic():
             self.run_task(AddConnection(self))
             logger.info("discarding expired connection")
+            conn._pool = None
             await conn.close()
             return
 
@@ -827,11 +830,13 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                     ex,
                     conn,
                 )
+                conn._pool = None
                 await conn.close()
 
         elif status == TransactionStatus.ACTIVE:
             # Connection returned during an operation. Bad... just close it.
             logger.warning("closing returned connection: %s", conn)
+            conn._pool = None
             await conn.close()
 
         if self._reset:
@@ -845,6 +850,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                     )
             except Exception as ex:
                 logger.warning(f"error resetting connection: {ex}")
+                conn._pool = None
                 await conn.close()
 
     async def _shrink_pool(self) -> None:
@@ -870,6 +876,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                 nconns_min,
                 self.max_idle,
             )
+            to_close._pool = None
             await to_close.close()
 
     def _get_measures(self) -> dict[str, int]:
index d02695fc82193e7ea088dcd8c22e4caf593f4360..4c40da32407d86351fede82ddfde2a757f03b3f1 100644 (file)
@@ -1121,3 +1121,23 @@ def test_close_returns_custom_class_old(dsn):
 
     with pytest.raises(TypeError, match="close_returns=True"):
         pool.ConnectionPool(dsn, connection_class=MyConnection, close_returns=True)
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION < (3, 3), reason="psycopg >= 3.3 behaviour")
+def test_close_returns_no_loop(dsn):
+    with pool.ConnectionPool(
+        dsn, min_size=1, close_returns=True, max_lifetime=0.05
+    ) as p:
+        conn = p.getconn()
+        sleep(0.1)
+        assert len(p._pool) == 0
+        sleep(0.1)  # wait for the connection to expire
+        conn.close()
+        sleep(0.1)
+        assert len(p._pool) == 1
+        conn = p.getconn()
+        sleep(0.1)
+        assert len(p._pool) == 0
+        conn.close()
+        sleep(0.1)
+        assert len(p._pool) == 1
index 44e6397fe26054217e6f3719345a3a429373fe2b..2166e28e762d93c9f5f7d8051ee6d151efdd21e2 100644 (file)
@@ -1122,3 +1122,23 @@ async def test_close_returns_custom_class_old(dsn):
 
     with pytest.raises(TypeError, match="close_returns=True"):
         pool.AsyncConnectionPool(dsn, connection_class=MyConnection, close_returns=True)
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION < (3, 3), reason="psycopg >= 3.3 behaviour")
+async def test_close_returns_no_loop(dsn):
+    async with pool.AsyncConnectionPool(
+        dsn, min_size=1, close_returns=True, max_lifetime=0.05
+    ) as p:
+        conn = await p.getconn()
+        await asleep(0.1)
+        assert len(p._pool) == 0
+        await asleep(0.1)  # wait for the connection to expire
+        await conn.close()
+        await asleep(0.1)
+        assert len(p._pool) == 1
+        conn = await p.getconn()
+        await asleep(0.1)
+        assert len(p._pool) == 0
+        await conn.close()
+        await asleep(0.1)
+        assert len(p._pool) == 1