]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Reset workers state in *ConnectionPool.close()
authorDenis Laxalde <denis.laxalde@dalibo.com>
Mon, 15 Nov 2021 08:28:36 +0000 (09:28 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 15:41:10 +0000 (16:41 +0100)
This way, the pools may be re-opened after a close, even if we'll
disallow this later on.

psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index a6ec95601016646272b630043b74e070e60a1467..0211424ae594c2400d6eb3ac46dfac16da896ae3 100644 (file)
@@ -299,7 +299,8 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self._sched.enter(0, None)
 
         # Stop the worker threads
-        for i in range(len(self._workers)):
+        workers, self._workers = self._workers[:], []
+        for i in range(len(workers)):
             self.run_task(StopWorker(self))
 
         # Signal to eventual clients in the queue that business is closed.
@@ -312,8 +313,9 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
         # Wait for the worker threads to terminate
         assert self._sched_runner is not None
+        sched_runner, self._sched_runner = self._sched_runner, None
         if timeout > 0:
-            for t in [self._sched_runner] + self._workers:
+            for t in [sched_runner] + workers:
                 if not t.is_alive():
                     continue
                 t.join(timeout)
@@ -324,7 +326,6 @@ class ConnectionPool(BasePool[Connection[Any]]):
                         self.name,
                         timeout,
                     )
-        self._sched_runner = None
 
     def __enter__(self) -> "ConnectionPool":
         return self
index cf6887ef4eafcc251937d20de49429cfeb15eab9..6c140c09c3f617841e3453e5a4db76a3723b4d18 100644 (file)
@@ -239,7 +239,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         await self._sched.enter(0, None)
 
         # Stop the worker tasks
-        for w in self._workers:
+        workers, self._workers = self._workers[:], []
+        for w in workers:
             self.run_task(StopWorker(self))
 
         # Signal to eventual clients in the queue that business is closed.
@@ -252,7 +253,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
 
         # Wait for the worker tasks to terminate
         assert self._sched_runner is not None
-        wait = asyncio.gather(self._sched_runner, *self._workers)
+        sched_runner, self._sched_runner = self._sched_runner, None
+        wait = asyncio.gather(sched_runner, *workers)
         try:
             if timeout > 0:
                 await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
@@ -264,7 +266,6 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 self.name,
                 timeout,
             )
-        self._sched_runner = None
 
     async def __aenter__(self) -> "AsyncConnectionPool":
         return self
index e96e0518e8e0b13d0a4f446f5c5196a5b609fb66..47f90ee5dc39a67e871f4ab7dd8ee52dd532bd8f 100644 (file)
@@ -562,12 +562,15 @@ def test_fail_rollback_close(dsn, caplog, monkeypatch):
 def test_close_no_threads(dsn):
     p = pool.ConnectionPool(dsn)
     assert p._sched_runner and p._sched_runner.is_alive()
-    for t in p._workers:
+    workers = p._workers[:]
+    assert workers
+    for t in workers:
         assert t.is_alive()
 
     p.close()
     assert p._sched_runner is None
-    for t in p._workers:
+    assert not p._workers
+    for t in workers:
         assert not t.is_alive()
 
 
@@ -680,6 +683,21 @@ def test_closed_queue(dsn):
     assert len(success) == 2
 
 
+def test_reopen(dsn):
+    p = pool.ConnectionPool(dsn)
+    with p.connection() as conn:
+        conn.execute("select 1")
+    p.close()
+    assert p._sched_runner is None
+    assert not p._workers
+    p.open()
+    assert p._sched_runner is not None
+    assert p._workers
+    with p.connection() as conn:
+        conn.execute("select 1")
+    p.close()
+
+
 @pytest.mark.slow
 @pytest.mark.timing
 def test_grow(dsn, monkeypatch, retries):
index 2efab352f85b176d67a2eaef3d175f1b32ec4bba..fb12ecbf8f125e1c713de43eb7d870f4a3aafbc9 100644 (file)
@@ -577,12 +577,15 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
 async def test_close_no_tasks(dsn):
     p = pool.AsyncConnectionPool(dsn)
     assert p._sched_runner and not p._sched_runner.done()
-    for t in p._workers:
+    assert p._workers
+    workers = p._workers[:]
+    for t in workers:
         assert not t.done()
 
     await p.close()
     assert p._sched_runner is None
-    for t in p._workers:
+    assert not p._workers
+    for t in workers:
         assert t.done()
 
 
@@ -669,6 +672,19 @@ async def test_closed_queue(dsn):
     assert len(success) == 2
 
 
+async def test_reopen(dsn):
+    p = pool.AsyncConnectionPool(dsn)
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+    await p.close()
+    assert p._sched_runner is None
+    p.open()
+    assert p._sched_runner is not None
+    async with p.connection() as conn:
+        await conn.execute("select 1")
+    await p.close()
+
+
 @pytest.mark.slow
 @pytest.mark.timing
 async def test_grow(dsn, monkeypatch, retries):