]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add max_lifetime to pool connections
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 28 Feb 2021 03:35:56 +0000 (04:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/base.py
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 3997061e354998be6d5c7eba5b173b7d1b62a7bb..d8ca8975017c44785f04a94e29aeda42f388acf7 100644 (file)
@@ -129,6 +129,9 @@ class BaseConnection(AdaptContext):
         # apart a connection in the pool too (when _pool = None)
         self._pool: Optional["BasePool[Any]"]
 
+        # Time after which the connection should be closed
+        self._expire_at: float
+
     def __del__(self) -> None:
         # If fails on connection we might not have this attribute yet
         if not hasattr(self, "pgconn"):
index c024214d4b2c0db85da6adb4bb1a8c3cc83b0f80..c3494dd4154230985208b432fd2ce5f32063141f 100644 (file)
@@ -283,6 +283,10 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         conn = await AsyncConnection.connect(self.conninfo, **self.kwargs)
         await self.configure(conn)
         conn._pool = self
+        # Set an expiry date, with some randomness to avoid mass reconnection
+        conn._expire_at = monotonic() + self._jitter(
+            self.max_lifetime, -0.05, 0.0
+        )
         return conn
 
     async def _add_connection(
@@ -329,10 +333,18 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         await self._reset_connection(conn)
         if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
             # Connection no more in working state: create a new one.
+            self.run_task(tasks.AddConnection(self))
             logger.warning("discarding closed connection: %s", conn)
+            return
+
+        # Check if the connection is past its best before date
+        if conn._expire_at <= monotonic():
             self.run_task(tasks.AddConnection(self))
-        else:
-            await self._add_to_pool(conn)
+            logger.info("discarding expired connection")
+            await conn.close()
+            return
+
+        await self._add_to_pool(conn)
 
     async def _add_to_pool(self, conn: AsyncConnection) -> None:
         """
index c56b116c58e35e3aaf2e975aa5884f44a1db1d66..c342e69d353411966bcf59945e9ed975c30f9e3e 100644 (file)
@@ -35,6 +35,7 @@ class BasePool(Generic[ConnectionType]):
         maxconn: Optional[int] = None,
         name: Optional[str] = None,
         timeout: float = 30.0,
+        max_lifetime: float = 60 * 60.0,
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
         reconnect_failed: Optional[
@@ -62,6 +63,7 @@ class BasePool(Generic[ConnectionType]):
         self._maxconn = maxconn
         self.timeout = timeout
         self.reconnect_timeout = reconnect_timeout
+        self.max_lifetime = max_lifetime
         self.max_idle = max_idle
         self.num_workers = num_workers
 
index 71c07ebc7d1c117c1eb5bce4fbe7913e0edcc3ea..5154f2c845ba6f7df747fdeb2ad0329a0dab7783 100644 (file)
@@ -261,6 +261,10 @@ class ConnectionPool(BasePool[Connection]):
         conn = Connection.connect(self.conninfo, **self.kwargs)
         self.configure(conn)
         conn._pool = self
+        # Set an expiry date, with some randomness to avoid mass reconnection
+        conn._expire_at = monotonic() + self._jitter(
+            self.max_lifetime, -0.05, 0.0
+        )
         return conn
 
     def _add_connection(self, attempt: Optional[ConnectionAttempt]) -> None:
@@ -305,10 +309,18 @@ class ConnectionPool(BasePool[Connection]):
         self._reset_connection(conn)
         if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
             # Connection no more in working state: create a new one.
+            self.run_task(tasks.AddConnection(self))
             logger.warning("discarding closed connection: %s", conn)
+            return
+
+        # Check if the connection is past its best before date
+        if conn._expire_at <= monotonic():
             self.run_task(tasks.AddConnection(self))
-        else:
-            self._add_to_pool(conn)
+            logger.info("discarding expired connection")
+            conn.close()
+            return
+
+        self._add_to_pool(conn)
 
     def _add_to_pool(self, conn: Connection) -> None:
         """
index 49792724a47247708909827090dc99c9c94b0a19..7cfc9435dcdb48590ee2ab40ae770dc6345c083f 100644 (file)
@@ -15,7 +15,8 @@ def test_defaults(dsn):
     with pool.ConnectionPool(dsn) as p:
         assert p.minconn == p.maxconn == 4
         assert p.timeout == 30
-        assert p.max_idle == 600
+        assert p.max_idle == 10 * 60
+        assert p.max_lifetime == 60 * 60
         assert p.num_workers == 3
 
 
@@ -643,6 +644,19 @@ def test_jitter():
     assert 35 < max(rnds) < 36
 
 
+@pytest.mark.slow
+def test_max_lifetime(dsn):
+    with pool.ConnectionPool(dsn, minconn=1, max_lifetime=0.2) as p:
+        sleep(0.1)
+        pids = []
+        for i in range(5):
+            with p.connection() as conn:
+                pids.append(conn.pgconn.backend_pid)
+            sleep(0.2)
+
+    assert pids[0] == pids[1] != pids[2] == pids[3] != pids[4], pids
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds
index 7a19e036b314880fbad74d647cf939aa23096c38..60c91c126cbd936d53e19902ff403a78218d7473 100644 (file)
@@ -24,7 +24,8 @@ async def test_defaults(dsn):
     async with pool.AsyncConnectionPool(dsn) as p:
         assert p.minconn == p.maxconn == 4
         assert p.timeout == 30
-        assert p.max_idle == 600
+        assert p.max_idle == 10 * 60
+        assert p.max_lifetime == 60 * 60
         assert p.num_workers == 3
 
 
@@ -677,6 +678,19 @@ def test_jitter():
     assert 35 < max(rnds) < 36
 
 
+@pytest.mark.slow
+async def test_max_lifetime(dsn):
+    async with pool.AsyncConnectionPool(dsn, minconn=1, max_lifetime=0.2) as p:
+        await asyncio.sleep(0.1)
+        pids = []
+        for i in range(5):
+            async with p.connection() as conn:
+                pids.append(conn.pgconn.backend_pid)
+            await asyncio.sleep(0.2)
+
+    assert pids[0] == pids[1] != pids[2] == pids[3] != pids[4], pids
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds