From 3f1ddc8986461828fb4864fae3737b3c77c864f6 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Mon, 15 Nov 2021 09:28:36 +0100 Subject: [PATCH] Reset workers state in *ConnectionPool.close() This way, the pools may be re-opened after a close, even if we'll disallow this later on. --- psycopg_pool/psycopg_pool/pool.py | 7 ++++--- psycopg_pool/psycopg_pool/pool_async.py | 7 ++++--- tests/pool/test_pool.py | 22 ++++++++++++++++++++-- tests/pool/test_pool_async.py | 20 ++++++++++++++++++-- 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index a6ec95601..0211424ae 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -299,7 +299,8 @@ class ConnectionPool(BasePool[Connection[Any]]): self._sched.enter(0, None) # Stop the worker threads - for i in range(len(self._workers)): + workers, self._workers = self._workers[:], [] + for i in range(len(workers)): self.run_task(StopWorker(self)) # Signal to eventual clients in the queue that business is closed. @@ -312,8 +313,9 @@ class ConnectionPool(BasePool[Connection[Any]]): # Wait for the worker threads to terminate assert self._sched_runner is not None + sched_runner, self._sched_runner = self._sched_runner, None if timeout > 0: - for t in [self._sched_runner] + self._workers: + for t in [sched_runner] + workers: if not t.is_alive(): continue t.join(timeout) @@ -324,7 +326,6 @@ class ConnectionPool(BasePool[Connection[Any]]): self.name, timeout, ) - self._sched_runner = None def __enter__(self) -> "ConnectionPool": return self diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index cf6887ef4..6c140c09c 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -239,7 +239,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): await self._sched.enter(0, None) # Stop the worker tasks - for w in self._workers: + workers, self._workers = self._workers[:], [] + for w in workers: self.run_task(StopWorker(self)) # Signal to eventual clients in the queue that business is closed. @@ -252,7 +253,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): # Wait for the worker tasks to terminate assert self._sched_runner is not None - wait = asyncio.gather(self._sched_runner, *self._workers) + sched_runner, self._sched_runner = self._sched_runner, None + wait = asyncio.gather(sched_runner, *workers) try: if timeout > 0: await asyncio.wait_for(asyncio.shield(wait), timeout=timeout) @@ -264,7 +266,6 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): self.name, timeout, ) - self._sched_runner = None async def __aenter__(self) -> "AsyncConnectionPool": return self diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index e96e0518e..47f90ee5d 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -562,12 +562,15 @@ def test_fail_rollback_close(dsn, caplog, monkeypatch): def test_close_no_threads(dsn): p = pool.ConnectionPool(dsn) assert p._sched_runner and p._sched_runner.is_alive() - for t in p._workers: + workers = p._workers[:] + assert workers + for t in workers: assert t.is_alive() p.close() assert p._sched_runner is None - for t in p._workers: + assert not p._workers + for t in workers: assert not t.is_alive() @@ -680,6 +683,21 @@ def test_closed_queue(dsn): assert len(success) == 2 +def test_reopen(dsn): + p = pool.ConnectionPool(dsn) + with p.connection() as conn: + conn.execute("select 1") + p.close() + assert p._sched_runner is None + assert not p._workers + p.open() + assert p._sched_runner is not None + assert p._workers + with p.connection() as conn: + conn.execute("select 1") + p.close() + + @pytest.mark.slow @pytest.mark.timing def test_grow(dsn, monkeypatch, retries): diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 2efab352f..fb12ecbf8 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -577,12 +577,15 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch): async def test_close_no_tasks(dsn): p = pool.AsyncConnectionPool(dsn) assert p._sched_runner and not p._sched_runner.done() - for t in p._workers: + assert p._workers + workers = p._workers[:] + for t in workers: assert not t.done() await p.close() assert p._sched_runner is None - for t in p._workers: + assert not p._workers + for t in workers: assert t.done() @@ -669,6 +672,19 @@ async def test_closed_queue(dsn): assert len(success) == 2 +async def test_reopen(dsn): + p = pool.AsyncConnectionPool(dsn) + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + assert p._sched_runner is None + p.open() + assert p._sched_runner is not None + async with p.connection() as conn: + await conn.execute("select 1") + await p.close() + + @pytest.mark.slow @pytest.mark.timing async def test_grow(dsn, monkeypatch, retries): -- 2.47.2