]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add pool.check()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Mar 2021 02:38:45 +0000 (03:38 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py

index 5154f2c845ba6f7df747fdeb2ad0329a0dab7783..bbc89776a4f699bf6b6efd1db0415dd4777d1fc6 100644 (file)
@@ -246,6 +246,26 @@ class ConnectionPool(BasePool[Connection]):
         for i in range(ngrow):
             self.run_task(tasks.AddConnection(self))
 
+    def check(self) -> None:
+        """Verify the state of the connections currently in the pool.
+
+        Test each connection: if it works return it to the pool, otherwise
+        dispose of it and create a new one.
+        """
+        with self._lock:
+            conns = list(self._pool)
+            self._pool.clear()
+
+        while conns:
+            conn = conns.pop()
+            try:
+                conn.execute("select 1")
+            except Exception:
+                logger.warning("discarding broken connection: %s", conn)
+                self.run_task(tasks.AddConnection(self))
+            else:
+                self._add_to_pool(conn)
+
     def configure(self, conn: Connection) -> None:
         """Configure a connection after creation."""
         self._configure(conn)
index b043b86c1784176b4657402d7bf2d4678f76367f..6c400ac2220cc522dfa7a36eea4fac679b10bef7 100644 (file)
@@ -659,6 +659,27 @@ def test_max_lifetime(dsn, retries):
             assert pids[0] == pids[1] != pids[2] == pids[3] != pids[4], pids
 
 
+def test_check(dsn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+    with pool.ConnectionPool(dsn, minconn=4) as p:
+        p.wait_ready(1.0)
+        with p.connection() as conn:
+            pid = conn.pgconn.backend_pid
+
+        p.wait_ready(1.0)
+        pids = set(conn.pgconn.backend_pid for conn in p._pool)
+        assert pid in pids
+        conn.close()
+
+        assert len(caplog.records) == 0
+        p.check()
+        assert len(caplog.records) == 1
+        p.wait_ready(1.0)
+        pids2 = set(conn.pgconn.backend_pid for conn in p._pool)
+        assert len(pids & pids2) == 3
+        assert pid not in pids2
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds