]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add reset callback to connection pool
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Mar 2021 03:43:33 +0000 (04:43 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index ea6a38c93caedc4eacae6d03c9f5997dbd220b36..7ec69ad6e5bb08c6cbff9cd0e63b90799eeb26be 100644 (file)
@@ -34,6 +34,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         configure: Optional[
             Callable[[AsyncConnection], Awaitable[None]]
         ] = None,
+        reset: Optional[Callable[[AsyncConnection], Awaitable[None]]] = None,
         **kwargs: Any,
     ):
         # https://bugs.python.org/issue42600
@@ -43,6 +44,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             )
 
         self._configure = configure
+        self._reset = reset
 
         self._lock = asyncio.Lock()
         self._waiting: Deque["AsyncClient"] = deque()
@@ -467,9 +469,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         """
         status = conn.pgconn.transaction_status
         if status == TransactionStatus.IDLE:
-            return
+            pass
 
-        if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+        elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
             # Connection returned with an active transaction
             logger.warning("rolling back returned connection: %s", conn)
             try:
@@ -488,6 +490,20 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             logger.warning("closing returned connection: %s", conn)
             await conn.close()
 
+        if not conn.closed and self._reset:
+            try:
+                await self._reset(conn)
+                status = conn.pgconn.transaction_status
+                if status != TransactionStatus.IDLE:
+                    nstatus = TransactionStatus(status).name
+                    raise e.ProgrammingError(
+                        f"connection left in status {nstatus} by reset function"
+                        f" {self._reset}: discarded"
+                    )
+            except Exception as ex:
+                logger.warning(f"error resetting connection: {ex}")
+                await conn.close()
+
     async def _shrink_pool(self) -> None:
         to_close: Optional[AsyncConnection] = None
 
index 5f0a8e0a54353dc98e8b3d449bcff7b127adc540..65ad838572a0ba0b58de6ae483b22e89e33b3a6a 100644 (file)
@@ -31,9 +31,11 @@ class ConnectionPool(BasePool[Connection]):
         self,
         conninfo: str = "",
         configure: Optional[Callable[[Connection], None]] = None,
+        reset: Optional[Callable[[Connection], None]] = None,
         **kwargs: Any,
     ):
         self._configure = configure
+        self._reset = reset
 
         self._lock = threading.RLock()
         self._waiting: Deque["WaitingClient"] = deque()
@@ -503,6 +505,20 @@ class ConnectionPool(BasePool[Connection]):
             logger.warning("closing returned connection: %s", conn)
             conn.close()
 
+        if not conn.closed and self._reset:
+            try:
+                self._reset(conn)
+                status = conn.pgconn.transaction_status
+                if status != TransactionStatus.IDLE:
+                    nstatus = TransactionStatus(status).name
+                    raise e.ProgrammingError(
+                        f"connection left in status {nstatus} by reset function"
+                        f" {self._reset}: discarded"
+                    )
+            except Exception as ex:
+                logger.warning(f"error resetting connection: {ex}")
+                conn.close()
+
     def _shrink_pool(self) -> None:
         to_close: Optional[Connection] = None
 
index e0786fed3a30f583307aaeb3472f3ddc3ff3a220..212776d471ce19b216510d2aa4b3edc2d2b03bbf 100644 (file)
@@ -185,6 +185,76 @@ def test_configure_broken(dsn, caplog):
     assert "WAT" in caplog.records[0].message
 
 
+def test_reset(dsn):
+    resets = 0
+
+    def setup(conn):
+        with conn.transaction():
+            conn.execute("set timezone to '+1:00'")
+
+    def reset(conn):
+        nonlocal resets
+        resets += 1
+        with conn.transaction():
+            conn.execute("set timezone to utc")
+
+    with pool.ConnectionPool(minconn=1, reset=reset) as p:
+        with p.connection() as conn:
+            assert resets == 0
+            conn.execute("set timezone to '+2:00'")
+
+        p.wait()
+        assert resets == 1
+
+        with p.connection() as conn:
+            with conn.execute("show timezone") as cur:
+                assert cur.fetchone() == ("UTC",)
+
+        p.wait()
+        assert resets == 2
+
+
+def test_reset_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    def reset(conn):
+        conn.execute("reset all")
+
+    with pool.ConnectionPool(minconn=1, reset=reset) as p:
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pid1 = conn.pgconn.backend_pid
+
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pid2 = conn.pgconn.backend_pid
+
+    assert pid1 != pid2
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+def test_reset_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    def reset(conn):
+        with conn.transaction():
+            conn.execute("WAT")
+
+    with pool.ConnectionPool(minconn=1, reset=reset) as p:
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pid1 = conn.pgconn.backend_pid
+
+        with p.connection() as conn:
+            conn.execute("select 1")
+            pid2 = conn.pgconn.backend_pid
+
+    assert pid1 != pid2
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
 @pytest.mark.slow
 def test_queue(dsn, retries):
     def worker(n):
index 829f723ed965e74a40fbb6a5b6524f54b1065990..d9436c9375d60b5a0d6186b77212a8819a850e09 100644 (file)
@@ -201,6 +201,76 @@ async def test_configure_broken(dsn, caplog):
     assert "WAT" in caplog.records[0].message
 
 
+async def test_reset(dsn):
+    resets = 0
+
+    async def setup(conn):
+        async with conn.transaction():
+            await conn.execute("set timezone to '+1:00'")
+
+    async def reset(conn):
+        nonlocal resets
+        resets += 1
+        async with conn.transaction():
+            await conn.execute("set timezone to utc")
+
+    async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p:
+        async with p.connection() as conn:
+            assert resets == 0
+            await conn.execute("set timezone to '+2:00'")
+
+        await p.wait()
+        assert resets == 1
+
+        async with p.connection() as conn:
+            cur = await conn.execute("show timezone")
+            assert (await cur.fetchone()) == ("UTC",)
+
+        await p.wait()
+        assert resets == 2
+
+
+async def test_reset_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    async def reset(conn):
+        await conn.execute("reset all")
+
+    async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p:
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+            pid1 = conn.pgconn.backend_pid
+
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+            pid2 = conn.pgconn.backend_pid
+
+    assert pid1 != pid2
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
+async def test_reset_broken(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    async def reset(conn):
+        async with conn.transaction():
+            await conn.execute("WAT")
+
+    async with pool.AsyncConnectionPool(minconn=1, reset=reset) as p:
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+            pid1 = conn.pgconn.backend_pid
+
+        async with p.connection() as conn:
+            await conn.execute("select 1")
+            pid2 = conn.pgconn.backend_pid
+
+    assert pid1 != pid2
+    assert caplog.records
+    assert "WAT" in caplog.records[0].message
+
+
 @pytest.mark.slow
 async def test_queue(dsn, retries):
     async def worker(n):