]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(pool): make sure that check() fills an empty pool
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Dec 2022 03:21:38 +0000 (03:21 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Dec 2022 15:29:41 +0000 (15:29 +0000)
Close #438.

docs/news_pool.rst
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 866b04548847f68cd6b6bce3c62030c7629a9038..1477a6199fa45c2360e9c7dc607d77ee5acd9fca 100644 (file)
@@ -13,7 +13,9 @@ Future releases
 psycopg_pool 3.1.5 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- Avoid error in pyright caused by aliasing TypeAlias (:ticket:`#439`).
+- Make sure that `!ConnectionPool.check()` refills an empty pool
+  (:ticket:`#438`).
+- Avoid error in Pyright caused by aliasing `!TypeAlias` (:ticket:`#439`).
 
 
 Current release
index 766c0c64a9174a7c3541bf8442c4c9a8e986a765..609d95dfea0974598b8bb1f81a40fb72e7bc3514 100644 (file)
@@ -402,6 +402,11 @@ class ConnectionPool(BasePool[Connection[Any]]):
             conns = list(self._pool)
             self._pool.clear()
 
+            # Give a chance to the pool to grow if it has no connection.
+            # In case there are enough connection, or the pool is already
+            # growing, this is a no-op.
+            self._maybe_grow_pool()
+
         while conns:
             conn = conns.pop()
             try:
index d411d8bc091a5f783925fe33aaf7bd2273338787..0ea6e9a40a5cd985201cda767daa57c8d37f9d56 100644 (file)
@@ -337,6 +337,11 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             conns = list(self._pool)
             self._pool.clear()
 
+            # Give a chance to the pool to grow if it has no connection.
+            # In case there are enough connection, or the pool is already
+            # growing, this is a no-op.
+            self._maybe_grow_pool()
+
         while conns:
             conn = conns.pop()
             try:
index b83ca08e919a042ad0030d66002b759ffd2d9da8..30c790b10963c240b26b101f874ab6b9e29fe2f4 100644 (file)
@@ -953,6 +953,37 @@ def test_reconnect_after_grow_failed(proxy):
         assert len(p._pool) == p.min_size == 4
 
 
+@pytest.mark.slow
+def test_refill_on_check(proxy):
+    proxy.start()
+    ev = Event()
+
+    def failed(pool):
+        ev.set()
+
+    with pool.ConnectionPool(
+        proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+    ) as p:
+        # The pool is full
+        p.wait(timeout=2)
+
+        # Break all the connection
+        proxy.stop()
+
+        # Checking the pool will empty it
+        p.check()
+        assert ev.wait(timeout=2)
+        assert len(p._pool) == 0
+
+        # Allow to connect again
+        proxy.start()
+
+        # Make sure that check has refilled the pool
+        p.check()
+        p.wait(timeout=2)
+        assert len(p._pool) == 4
+
+
 @pytest.mark.slow
 def test_uniform_use(dsn):
     with pool.ConnectionPool(dsn, min_size=4) as p:
index 7d6ca6d7942dd8c4479c41d4b85d72f963b311d9..286a77524dc4e7dd4e07a13c041f53f96d6118e3 100644 (file)
@@ -905,6 +905,37 @@ async def test_reconnect_after_grow_failed(proxy):
         assert len(p._pool) == p.min_size == 4
 
 
+@pytest.mark.slow
+async def test_refill_on_check(proxy):
+    proxy.start()
+    ev = asyncio.Event()
+
+    def failed(pool):
+        ev.set()
+
+    async with pool.AsyncConnectionPool(
+        proxy.client_dsn, min_size=4, reconnect_timeout=1.0, reconnect_failed=failed
+    ) as p:
+        # The pool is full
+        await p.wait(timeout=2)
+
+        # Break all the connection
+        proxy.stop()
+
+        # Checking the pool will empty it
+        await p.check()
+        await asyncio.wait_for(ev.wait(), 2.0)
+        assert len(p._pool) == 0
+
+        # Allow to connect again
+        proxy.start()
+
+        # Make sure that check has refilled the pool
+        await p.check()
+        await p.wait(timeout=2)
+        assert len(p._pool) == 4
+
+
 @pytest.mark.slow
 async def test_uniform_use(dsn):
     async with pool.AsyncConnectionPool(dsn, min_size=4) as p: