]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(pool): add close_returns
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 4 May 2025 19:21:19 +0000 (21:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 6 Aug 2025 13:13:12 +0000 (15:13 +0200)
Behaviour implemented via subclassing Psycopg 3.2.

Close #1046

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tools/async_to_sync.py

index bb39c54608a3005d7a92fdb4f85a5e118d608e7d..66672be30baa158f2ca494b8b0b13a9178e25484 100644 (file)
@@ -170,6 +170,12 @@ class Connection(BaseConnection[Row]):
         """Close the database connection."""
         if self.closed:
             return
+
+        pool = getattr(self, "_pool", None)
+        if pool and getattr(pool, "close_returns", False):
+            pool.putconn(self)
+            return
+
         self._closed = True
 
         # TODO: maybe send a cancel on close, if the connection is ACTIVE?
index 8755648a5dd2acdc94fd2b69e26928ec1ac317c7..06898905e1dd20b3dbc3ff42674d2e3a4783b6a6 100644 (file)
@@ -186,6 +186,12 @@ class AsyncConnection(BaseConnection[Row]):
         """Close the database connection."""
         if self.closed:
             return
+
+        pool = getattr(self, "_pool", None)
+        if pool and getattr(pool, "close_returns", False):
+            await pool.putconn(self)
+            return
+
         self._closed = True
 
         # TODO: maybe send a cancel on close, if the connection is ACTIVE?
index e708a4c6034af783e91a41a6bd5a74be7b08baa5..ad109a0d4f1ff1f8be1b4a84f540f10ad2490ef7 100644 (file)
@@ -50,6 +50,7 @@ class BasePool:
         min_size: int,
         max_size: int | None,
         name: str | None,
+        close_returns: bool,
         timeout: float,
         max_waiting: int,
         max_lifetime: float,
@@ -69,6 +70,7 @@ class BasePool:
         self.conninfo = conninfo
         self.kwargs: dict[str, Any] = kwargs or {}
         self.name = name
+        self.close_returns = close_returns
         self._min_size = min_size
         self._max_size = max_size
         self.timeout = timeout
index 8626c6c0989600621abeb117305e7972a7f4d2e3..d5fd6eb0f55f5cb3f24fa6983438a60b312a0dae 100644 (file)
@@ -28,7 +28,7 @@ from .abc import CT, ConnectFailedCB, ConnectionCB
 from .base import AttemptWithBackoff, BasePool
 from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
-from ._compat import Self
+from ._compat import PSYCOPG_VERSION, PoolConnection, Self
 from ._acompat import Condition, Event, Lock, Queue, Worker, current_thread_name
 from ._acompat import gather, sleep, spawn
 
@@ -51,6 +51,7 @@ class ConnectionPool(Generic[CT], BasePool):
         check: ConnectionCB[CT] | None = None,
         reset: ConnectionCB[CT] | None = None,
         name: str | None = None,
+        close_returns: bool = False,
         timeout: float = 30.0,
         max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
@@ -59,6 +60,14 @@ class ConnectionPool(Generic[CT], BasePool):
         reconnect_failed: ConnectFailedCB | None = None,
         num_workers: int = 3,
     ):
+        if close_returns and PSYCOPG_VERSION < (3, 3):
+            if connection_class is Connection:
+                connection_class = cast(type[CT], PoolConnection)
+            else:
+                raise TypeError(
+                    "Using 'close_returns=True' and a non-standard 'connection_class' requires psycopg 3.3 or newer."
+                )
+
         self.connection_class = connection_class
         self._check = check
         self._configure = configure
@@ -86,6 +95,7 @@ class ConnectionPool(Generic[CT], BasePool):
             min_size=min_size,
             max_size=max_size,
             name=name,
+            close_returns=close_returns,
             timeout=timeout,
             max_waiting=max_waiting,
             max_lifetime=max_lifetime,
index 5f882d4eea436d06be8c8ed35587c268fea550d2..ac4dfe83c7801bece5f77f67906e8bfc44840495 100644 (file)
@@ -24,7 +24,7 @@ from psycopg.pq import TransactionStatus
 from .abc import ACT, AsyncConnectFailedCB, AsyncConnectionCB
 from .base import AttemptWithBackoff, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
-from ._compat import Self
+from ._compat import PSYCOPG_VERSION, AsyncPoolConnection, Self
 from ._acompat import ACondition, AEvent, ALock, AQueue, AWorker, agather, asleep
 from ._acompat import aspawn, current_task_name, ensure_async
 from .sched_async import AsyncScheduler
@@ -51,6 +51,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         check: AsyncConnectionCB[ACT] | None = None,
         reset: AsyncConnectionCB[ACT] | None = None,
         name: str | None = None,
+        close_returns: bool = False,
         timeout: float = 30.0,
         max_waiting: int = 0,
         max_lifetime: float = 60 * 60.0,
@@ -59,6 +60,15 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         reconnect_failed: AsyncConnectFailedCB | None = None,
         num_workers: int = 3,
     ):
+        if close_returns and PSYCOPG_VERSION < (3, 3):
+            if connection_class is AsyncConnection:
+                connection_class = cast(type[ACT], AsyncPoolConnection)
+            else:
+                raise TypeError(
+                    "Using 'close_returns=True' and a non-standard 'connection_class'"
+                    " requires psycopg 3.3 or newer."
+                )
+
         self.connection_class = connection_class
         self._check = check
         self._configure = configure
@@ -86,6 +96,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             min_size=min_size,
             max_size=max_size,
             name=name,
+            close_returns=close_returns,
             timeout=timeout,
             max_waiting=max_waiting,
             max_lifetime=max_lifetime,
index c5a9ae3469a371876a6a8d5ee420762d915c3186..d02695fc82193e7ea088dcd8c22e4caf593f4360 100644 (file)
@@ -26,6 +26,9 @@ except ImportError:
     pass
 
 
+PSYCOPG_VERSION = tuple(map(int, psycopg.__version__.split(".", 2)[:2]))
+
+
 def test_default_sizes(dsn):
     with pool.ConnectionPool(dsn) as p:
         assert p.min_size == p.max_size == 4
@@ -1074,3 +1077,47 @@ def test_override_close(dsn):
         assert len(p._pool) == 2
 
     assert conn.closed
+
+
+def test_close_returns(dsn):
+    with pool.ConnectionPool(dsn, min_size=2, close_returns=True) as p:
+        p.wait()
+        assert len(p._pool) == 2
+        conn = p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION < (3, 3), reason="psycopg >= 3.3 behaviour")
+def test_close_returns_custom_class(dsn):
+
+    class MyConnection(psycopg.Connection):
+        pass
+
+    with pool.ConnectionPool(
+        dsn, min_size=2, connection_class=MyConnection, close_returns=True
+    ) as p:
+        p.wait()
+        conn = p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION >= (3, 3), reason="psycopg < 3.3 behaviour")
+def test_close_returns_custom_class_old(dsn):
+
+    class MyConnection(psycopg.Connection):
+        pass
+
+    with pytest.raises(TypeError, match="close_returns=True"):
+        pool.ConnectionPool(dsn, connection_class=MyConnection, close_returns=True)
index e288a27e96ec8efd0bc2be1173899900151b8dab..44e6397fe26054217e6f3719345a3a429373fe2b 100644 (file)
@@ -25,6 +25,8 @@ except ImportError:
 if True:  # ASYNC
     pytestmark = [pytest.mark.anyio]
 
+PSYCOPG_VERSION = tuple(map(int, psycopg.__version__.split(".", 2)[:2]))
+
 
 async def test_default_sizes(dsn):
     async with pool.AsyncConnectionPool(dsn) as p:
@@ -1077,3 +1079,46 @@ async def test_override_close(dsn):
         assert len(p._pool) == 2
 
     assert conn.closed
+
+
+async def test_close_returns(dsn):
+
+    async with pool.AsyncConnectionPool(dsn, min_size=2, close_returns=True) as p:
+        await p.wait()
+        assert len(p._pool) == 2
+        conn = await p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        await conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION < (3, 3), reason="psycopg >= 3.3 behaviour")
+async def test_close_returns_custom_class(dsn):
+    class MyConnection(psycopg.AsyncConnection):
+        pass
+
+    async with pool.AsyncConnectionPool(
+        dsn, min_size=2, connection_class=MyConnection, close_returns=True
+    ) as p:
+        await p.wait()
+        conn = await p.getconn()
+        assert not conn.closed
+        assert len(p._pool) == 1
+        await conn.close()
+        assert not conn.closed
+        assert len(p._pool) == 2
+
+    assert conn.closed
+
+
+@pytest.mark.skipif(PSYCOPG_VERSION >= (3, 3), reason="psycopg < 3.3 behaviour")
+async def test_close_returns_custom_class_old(dsn):
+    class MyConnection(psycopg.AsyncConnection):
+        pass
+
+    with pytest.raises(TypeError, match="close_returns=True"):
+        pool.AsyncConnectionPool(dsn, connection_class=MyConnection, close_returns=True)
index 888417abe8aa19f17a073395bcb9d074367cfe95..05c95472f8685fa338a4d87c49691b9ca4b4a497 100755 (executable)
@@ -298,6 +298,7 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
         "AsyncLibpqWriter": "LibpqWriter",
         "AsyncNullConnectionPool": "NullConnectionPool",
         "AsyncPipeline": "Pipeline",
+        "AsyncPoolConnection": "PoolConnection",
         "AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
         "AsyncRawCursor": "RawCursor",
         "AsyncRawServerCursor": "RawServerCursor",