]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(pool): add close_returns
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 4 May 2025 19:21:19 +0000 (21:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 19 Oct 2025 01:32:08 +0000 (03:32 +0200)
Behaviour implemented via subclassing Psycopg 3.2.

Close #1046

12 files changed:
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pool/test_pool_null.py
tests/pool/test_pool_null_async.py
tools/async_to_sync.py

index b38326d832cd76254c3a286e70db71f01447a267..0d74066f329d5fb2414774946aa50cee6587d5d9 100644 (file)
@@ -183,6 +183,12 @@ class Connection(BaseConnection[Row]):
         """Close the database connection."""
         if self.closed:
             return
+
+        pool = getattr(self, "_pool", None)
+        if pool and getattr(pool, "close_returns", False):
+            pool.putconn(self)
+            return
+
         self._closed = True
 
         # TODO: maybe send a cancel on close, if the connection is ACTIVE?
index a12527ab208ebca3dfaeb42331b67cdb2373f44b..0a10ae24a183f58f7fdf167e1b0b1e71b5268942 100644 (file)
@@ -204,6 +204,12 @@ class AsyncConnection(BaseConnection[Row]):
         """Close the database connection."""
         if self.closed:
             return
+
+        pool = getattr(self, "_pool", None)
+        if pool and getattr(pool, "close_returns", False):
+            await pool.putconn(self)
+            return
+
         self._closed = True
 
         # TODO: maybe send a cancel on close, if the connection is ACTIVE?
index e708a4c6034af783e91a41a6bd5a74be7b08baa5..ad109a0d4f1ff1f8be1b4a84f540f10ad2490ef7 100644 (file)
@@ -50,6 +50,7 @@ class BasePool:
         min_size: int,
         max_size: int | None,
         name: str | None,
+        close_returns: bool,
         timeout: float,
         max_waiting: int,
         max_lifetime: float,
@@ -69,6 +70,7 @@ class BasePool:
         self.conninfo = conninfo
         self.kwargs: dict[str, Any] = kwargs or {}
         self.name = name
+        self.close_returns = close_returns
         self._min_size = min_size
         self._max_size = max_size
         self.timeout = timeout
index 884eab850997cc1a615f598af2a1be77d73a98ab..bf027e9dd4783204db1ebf0a54d9e41c27084a39 100644 (file)
@@ -40,6 +40,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
         check: ConnectionCB[CT] | None = None,
         reset: ConnectionCB[CT] | None = None,
         name: str | None = None,
+        close_returns: bool = False,
         timeout: float = 30.0,
         max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
@@ -49,6 +50,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
         num_workers: int = 3,
     ):  # Note: min_size default value changed to 0.
 
+        # close_returns=True makes no sense
         super().__init__(
             conninfo,
             open=open,
@@ -60,6 +62,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
             min_size=min_size,
             max_size=max_size,
             name=name,
+            close_returns=False,
             timeout=timeout,
             max_waiting=max_waiting,
             max_lifetime=max_lifetime,
index a037597fa045e2fecb89ef446b7a250570f279aa..a18c59c3fa8867cc1494c499556b9f350c6111c3 100644 (file)
@@ -36,6 +36,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT])
         check: AsyncConnectionCB[ACT] | None = None,
         reset: AsyncConnectionCB[ACT] | None = None,
         name: str | None = None,
+        close_returns: bool = False,
         timeout: float = 30.0,
         max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
@@ -55,6 +56,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT])
             min_size=min_size,
             max_size=max_size,
             name=name,
+            close_returns=False,  # close_returns=True makes no sense
             timeout=timeout,
             max_waiting=max_waiting,
             max_lifetime=max_lifetime,
index 8626c6c0989600621abeb117305e7972a7f4d2e3..d5fd6eb0f55f5cb3f24fa6983438a60b312a0dae 100644 (file)
@@ -28,7 +28,7 @@ from .abc import CT, ConnectFailedCB, ConnectionCB
 from .base import AttemptWithBackoff, BasePool
 from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
-from ._compat import Self
+from ._compat import PSYCOPG_VERSION, PoolConnection, Self
 from ._acompat import Condition, Event, Lock, Queue, Worker, current_thread_name
 from ._acompat import gather, sleep, spawn
 
@@ -51,6 +51,7 @@ class ConnectionPool(Generic[CT], BasePool):
         check: ConnectionCB[CT] | None = None,
         reset: ConnectionCB[CT] | None = None,
         name: str | None = None,
+        close_returns: bool = False,
         timeout: float = 30.0,
         max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
@@ -59,6 +60,14 @@ class ConnectionPool(Generic[CT], BasePool):
         reconnect_failed: ConnectFailedCB | None = None,
         num_workers: int = 3,
     ):
+        if close_returns and PSYCOPG_VERSION < (3, 3):
+            if connection_class is Connection:
+                connection_class = cast(type[CT], PoolConnection)
+            else:
+                raise TypeError(
+                    "Using 'close_returns=True' and a non-standard 'connection_class' requires psycopg 3.3 or newer."
+                )
+
         self.connection_class = connection_class
         self._check = check
         self._configure = configure
@@ -86,6 +95,7 @@ class ConnectionPool(Generic[CT], BasePool):
             min_size=min_size,
             max_size=max_size,
             name=name,
+            close_returns=close_returns,
             timeout=timeout,
             max_waiting=max_waiting,
             max_lifetime=max_lifetime,
index 5f882d4eea436d06be8c8ed35587c268fea550d2..ac4dfe83c7801bece5f77f67906e8bfc44840495 100644 (file)
@@ -24,7 +24,7 @@ from psycopg.pq import TransactionStatus
 from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB
 from .base import AttemptWithBackoff, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
-from ._compat import Self
+from ._compat import PSYCOPG_VERSION, AsyncPoolConnection, Self
 from ._acompat import ACondition, AEvent, ALock, AQueue, AWorker, agather, asleep
 from ._acompat import aspawn, current_task_name, ensure_async
 from .sched_async import AsyncScheduler
@@ -51,6 +51,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         check: AsyncConnectionCB[ACT] | None = None,
         reset: AsyncConnectionCB[ACT] | None = None,
         name: str | None = None,
+        close_returns: bool = False,
         timeout: float = 30.0,
         max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
@@ -59,6 +60,15 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         reconnect_failed: AsyncConnectFailedCB | None = None,
         num_workers: int = 3,
     ):
+        if close_returns and PSYCOPG_VERSION < (3, 3):
+            if connection_class is AsyncConnection:
+                connection_class = cast(type[ACT], AsyncPoolConnection)
+            else:
+                raise TypeError(
+                    "Using 'close_returns=True' and a non-standard 'connection_class'"
+                    " requires psycopg 3.3 or newer."
+                )
+
         self.connection_class = connection_class
         self._check = check
         self._configure = configure
@@ -86,6 +96,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             min_size=min_size,
             max_size=max_size,
             name=name,
+            close_returns=close_returns,
             timeout=timeout,
             max_waiting=max_waiting,
             max_lifetime=max_lifetime,
index 68064a014ab60471c3c9152d7eec10ee42eebf3b..df90c1c3e7c08d11c67dc5dfc4ffaa4e50431c3a 100644 (file)
@@ -26,6 +26,9 @@ except ImportError:
     pass
 
 
+PSYCOPG_VERSION = tuple(map(int, psycopg.__version__.split(".", 2)[:2]))
+
+
 def test_default_sizes(dsn):
     with pool.ConnectionPool(dsn) as p:
         assert p.min_size == p.max_size == 4
@@ -1074,3 +1077,47 @@ def test_override_close(dsn):
         assert len(p._pool) == 2
 
     assert conn.closed
+
+
+def test_close_returns(dsn):
+    with pool.ConnectionPool(dsn, min_size=2, close_returns=True) as p:
+        p.wait()
+        assert len(p._pool) == 2
+        conn = p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION < (3, 3), reason="psycopg >= 3.3 behaviour")
+def test_close_returns_custom_class(dsn):
+
+    class MyConnection(psycopg.Connection):
+        pass
+
+    with pool.ConnectionPool(
+        dsn, min_size=2, connection_class=MyConnection, close_returns=True
+    ) as p:
+        p.wait()
+        conn = p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION >= (3, 3), reason="psycopg < 3.3 behaviour")
+def test_close_returns_custom_class_old(dsn):
+
+    class MyConnection(psycopg.Connection):
+        pass
+
+    with pytest.raises(TypeError, match="close_returns=True"):
+        pool.ConnectionPool(dsn, connection_class=MyConnection, close_returns=True)
index d4a2d277af771c56036c23e903596d1bca57bbfd..6cc20e67217aacade168fd311b022987514ec474 100644 (file)
@@ -25,6 +25,8 @@ except ImportError:
 if True:  # ASYNC
     pytestmark = [pytest.mark.anyio]
 
+PSYCOPG_VERSION = tuple(map(int, psycopg.__version__.split(".", 2)[:2]))
+
 
 async def test_default_sizes(dsn):
     async with pool.AsyncConnectionPool(dsn) as p:
@@ -1077,3 +1079,46 @@ async def test_override_close(dsn):
         assert len(p._pool) == 2
 
     assert conn.closed
+
+
+async def test_close_returns(dsn):
+
+    async with pool.AsyncConnectionPool(dsn, min_size=2, close_returns=True) as p:
+        await p.wait()
+        assert len(p._pool) == 2
+        conn = await p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        await conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION < (3, 3), reason="psycopg >= 3.3 behaviour")
+async def test_close_returns_custom_class(dsn):
+    class MyConnection(psycopg.AsyncConnection):
+        pass
+
+    async with pool.AsyncConnectionPool(
+        dsn, min_size=2, connection_class=MyConnection, close_returns=True
+    ) as p:
+        await p.wait()
+        conn = await p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        await conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION >= (3, 3), reason="psycopg < 3.3 behaviour")
+async def test_close_returns_custom_class_old(dsn):
+    class MyConnection(psycopg.AsyncConnection):
+        pass
+
+    with pytest.raises(TypeError, match="close_returns=True"):
+        pool.AsyncConnectionPool(dsn, connection_class=MyConnection, close_returns=True)
index ac0ab05f861526c9f9e4de39475fb22c825834b1..332dfd2de4ff1984cd16ed92911481d8e36a63bd 100644 (file)
@@ -497,3 +497,12 @@ def test_cancellation_in_queue(dsn):
         with p.connection() as conn:
             cur = conn.execute("select 1")
             assert cur.fetchone() == (1,)
+
+
+def test_close_returns(dsn):
+    # Mostly test the interface; close is close even if it goes via putconn().
+    with pool.NullConnectionPool(dsn, close_returns=True) as p:
+        conn = p.getconn()
+        assert not conn.closed
+        conn.close()
+        assert conn.closed
index 0a7fbc7c405a88a0d3a51a3782b3d25e37415527..0e71112249c413bcbd11196afd7a69ad07c87a0f 100644 (file)
@@ -499,3 +499,12 @@ async def test_cancellation_in_queue(dsn):
         async with p.connection() as conn:
             cur = await conn.execute("select 1")
             assert await cur.fetchone() == (1,)
+
+
+async def test_close_returns(dsn):
+    # Mostly test the interface; close is close even if it goes via putconn().
+    async with pool.AsyncNullConnectionPool(dsn, close_returns=True) as p:
+        conn = await p.getconn()
+        assert not conn.closed
+        await conn.close()
+        assert conn.closed
index 7a78b2551c1a9c2f17357c7fdc64bb021ba9dd46..7aee1a81e40873d85cf8d683bf9d539648eb7030 100755 (executable)
@@ -300,6 +300,7 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
         "AsyncLibpqWriter": "LibpqWriter",
         "AsyncNullConnectionPool": "NullConnectionPool",
         "AsyncPipeline": "Pipeline",
+        "AsyncPoolConnection": "PoolConnection",
         "AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
         "AsyncRawCursor": "RawCursor",
         "AsyncRawServerCursor": "RawServerCursor",