]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make sure that the pool config function leaves connections in idle state
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Mar 2021 03:18:25 +0000 (04:18 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 579f5ed31fd8ab33546bfa9b7887c19ab1b0f442..ea6a38c93caedc4eacae6d03c9f5997dbd220b36 100644 (file)
@@ -103,9 +103,12 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             # Run the task. Make sure don't die in the attempt.
             try:
                 await task.run()
-            except Exception as e:
+            except Exception as ex:
                 logger.warning(
-                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
+                    "task run %s failed: %s: %s",
+                    task,
+                    ex.__class__.__name__,
+                    ex,
                 )
 
     async def wait(self, timeout: float = 30.0) -> None:
@@ -340,11 +343,6 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             else:
                 await self._add_to_pool(conn)
 
-    async def configure(self, conn: AsyncConnection) -> None:
-        """Configure a connection after creation."""
-        if self._configure:
-            await self._configure(conn)
-
     def reconnect_failed(self) -> None:
         """
         Called when reconnection failed for longer than `reconnect_timeout`.
@@ -354,8 +352,18 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
     async def _connect(self) -> AsyncConnection:
         """Return a new connection configured for the pool."""
         conn = await AsyncConnection.connect(self.conninfo, **self.kwargs)
-        await self.configure(conn)
         conn._pool = self
+
+        if self._configure:
+            await self._configure(conn)
+            status = conn.pgconn.transaction_status
+            if status != TransactionStatus.IDLE:
+                nstatus = TransactionStatus(status).name
+                raise e.ProgrammingError(
+                    f"connection left in status {nstatus} by configure function"
+                    f" {self._configure}: discarded"
+                )
+
         # Set an expiry date, with some randomness to avoid mass reconnection
         conn._expire_at = monotonic() + self._jitter(
             self.max_lifetime, -0.05, 0.0
@@ -380,8 +388,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
 
         try:
             conn = await self._connect()
-        except Exception as e:
-            logger.warning(f"error connecting in {self.name!r}: {e}")
+        except Exception as ex:
+            logger.warning(f"error connecting in {self.name!r}: {ex}")
             if attempt.time_to_give_up(now):
                 logger.warning(
                     "reconnection attempt in pool %r failed after %s sec",
@@ -466,11 +474,11 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             logger.warning("rolling back returned connection: %s", conn)
             try:
                 await conn.rollback()
-            except Exception as e:
+            except Exception as ex:
                 logger.warning(
                     "rollback failed: %s: %s. Discarding connection %s",
-                    e.__class__.__name__,
-                    e,
+                    ex.__class__.__name__,
+                    ex,
                     conn,
                 )
                 await conn.close()
index ece4a7c32805dcf741691cce0cf13e5dff790132..5f0a8e0a54353dc98e8b3d449bcff7b127adc540 100644 (file)
@@ -15,6 +15,7 @@ from weakref import ref
 from contextlib import contextmanager
 from collections import deque
 
+from .. import errors as e
 from ..pq import TransactionStatus
 from ..connection import Connection
 
@@ -32,8 +33,7 @@ class ConnectionPool(BasePool[Connection]):
         configure: Optional[Callable[[Connection], None]] = None,
         **kwargs: Any,
     ):
-        self._configure: Callable[[Connection], None]
-        self._configure = configure or (lambda conn: None)
+        self._configure = configure
 
         self._lock = threading.RLock()
         self._waiting: Deque["WaitingClient"] = deque()
@@ -316,10 +316,6 @@ class ConnectionPool(BasePool[Connection]):
             else:
                 self._add_to_pool(conn)
 
-    def configure(self, conn: Connection) -> None:
-        """Configure a connection after creation."""
-        self._configure(conn)
-
     def reconnect_failed(self) -> None:
         """
         Called when reconnection failed for longer than `reconnect_timeout`.
@@ -364,16 +360,29 @@ class ConnectionPool(BasePool[Connection]):
             # Run the task. Make sure don't die in the attempt.
             try:
                 task.run()
-            except Exception as e:
+            except Exception as ex:
                 logger.warning(
-                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
+                    "task run %s failed: %s: %s",
+                    task,
+                    ex.__class__.__name__,
+                    ex,
                 )
 
     def _connect(self) -> Connection:
         """Return a new connection configured for the pool."""
         conn = Connection.connect(self.conninfo, **self.kwargs)
-        self.configure(conn)
         conn._pool = self
+
+        if self._configure:
+            self._configure(conn)
+            status = conn.pgconn.transaction_status
+            if status != TransactionStatus.IDLE:
+                nstatus = TransactionStatus(status).name
+                raise e.ProgrammingError(
+                    f"connection left in status {nstatus} by configure function"
+                    f" {self._configure}: discarded"
+                )
+
         # Set an expiry date, with some randomness to avoid mass reconnection
         conn._expire_at = monotonic() + self._jitter(
             self.max_lifetime, -0.05, 0.0
@@ -396,8 +405,8 @@ class ConnectionPool(BasePool[Connection]):
 
         try:
             conn = self._connect()
-        except Exception as e:
-            logger.warning(f"error connecting in {self.name!r}: {e}")
+        except Exception as ex:
+            logger.warning(f"error connecting in {self.name!r}: {ex}")
             if attempt.time_to_give_up(now):
                 logger.warning(
                     "reconnection attempt in pool %r failed after %s sec",
@@ -473,18 +482,18 @@ class ConnectionPool(BasePool[Connection]):
         """
         status = conn.pgconn.transaction_status
         if status == TransactionStatus.IDLE:
-            return
+            pass
 
-        if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+        elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
             # Connection returned with an active transaction
             logger.warning("rolling back returned connection: %s", conn)
             try:
                 conn.rollback()
-            except Exception as e:
+            except Exception as ex:
                 logger.warning(
                     "rollback failed: %s: %s. Discarding connection %s",
-                    e.__class__.__name__,
-                    e,
+                    ex.__class__.__name__,
+                    ex,
                     conn,
                 )
                 conn.close()
index a1fe95301383d296af0e10048e314862d109aae5..e0786fed3a30f583307aaeb3472f3ddc3ff3a220 100644 (file)
@@ -132,9 +132,11 @@ def test_configure(dsn):
     def configure(conn):
         nonlocal inits
         inits += 1
-        conn.execute("set default_transaction_read_only to on")
+        with conn.transaction():
+            conn.execute("set default_transaction_read_only to on")
 
     with pool.ConnectionPool(minconn=1, configure=configure) as p:
+        p.wait(timeout=1.0)
         with p.connection() as conn:
             assert inits == 1
             res = conn.execute("show default_transaction_read_only")
@@ -152,12 +154,28 @@ def test_configure(dsn):
             assert res.fetchone()[0] == "on"
 
 
+@pytest.mark.slow
+def test_configure_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    def configure(conn):
+        conn.execute("select 1")
+
+    with pool.ConnectionPool(minconn=1, configure=configure) as p:
+        with pytest.raises(pool.PoolTimeout):
+            p.wait(timeout=0.5)
+
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
 @pytest.mark.slow
 def test_configure_broken(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
 
     def configure(conn):
-        conn.execute("WAT")
+        with conn.transaction():
+            conn.execute("WAT")
 
     with pool.ConnectionPool(minconn=1, configure=configure) as p:
         with pytest.raises(pool.PoolTimeout):
index 501f99da29e80a52d881cf47add9dd3273e7890d..829f723ed965e74a40fbb6a5b6524f54b1065990 100644 (file)
@@ -148,9 +148,11 @@ async def test_configure(dsn):
     async def configure(conn):
         nonlocal inits
         inits += 1
-        await conn.execute("set default_transaction_read_only to on")
+        async with conn.transaction():
+            await conn.execute("set default_transaction_read_only to on")
 
     async with pool.AsyncConnectionPool(minconn=1, configure=configure) as p:
+        await p.wait(timeout=1.0)
         async with p.connection() as conn:
             assert inits == 1
             res = await conn.execute("show default_transaction_read_only")
@@ -168,12 +170,28 @@ async def test_configure(dsn):
             assert (await res.fetchone())[0] == "on"
 
 
+@pytest.mark.slow
+async def test_configure_badstate(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    async def configure(conn):
+        await conn.execute("select 1")
+
+    async with pool.AsyncConnectionPool(minconn=1, configure=configure) as p:
+        with pytest.raises(pool.PoolTimeout):
+            await p.wait(timeout=0.5)
+
+    assert caplog.records
+    assert "INTRANS" in caplog.records[0].message
+
+
 @pytest.mark.slow
 async def test_configure_broken(dsn, caplog):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
 
     async def configure(conn):
-        await conn.execute("WAT")
+        async with conn.transaction():
+            await conn.execute("WAT")
 
     async with pool.AsyncConnectionPool(minconn=1, configure=configure) as p:
         with pytest.raises(pool.PoolTimeout):