From: Daniele Varrazzo Date: Wed, 4 Oct 2023 22:01:24 +0000 (+0200) Subject: refactor(pool): add spawn/gather async compat functions X-Git-Tag: pool-3.2.0~12^2~25 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=08be9904586c64cbdbc8bcb152097d588f1b6612;p=thirdparty%2Fpsycopg.git refactor(pool): add spawn/gather async compat functions --- diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py index 9f59a64f9..ffc2f1a3b 100644 --- a/psycopg_pool/psycopg_pool/_acompat.py +++ b/psycopg_pool/psycopg_pool/_acompat.py @@ -9,13 +9,17 @@ when generating the sync version. # Copyright (C) 2023 The Psycopg Team import asyncio +import logging import threading +from typing import Any, Callable, Coroutine Event = threading.Event Condition = threading.Condition Lock = threading.RLock ALock = asyncio.Lock +logger = logging.getLogger("psycopg.pool") + class AEvent(asyncio.Event): """ @@ -45,3 +49,61 @@ class ACondition(asyncio.Condition): return True except asyncio.TimeoutError: return False + + +def aspawn( + f: Callable[..., Coroutine[Any, Any, None]], + args: tuple[Any, ...] = (), + name: str | None = None, +) -> asyncio.Task[None]: + """ + Equivalent to asyncio.create_task. + """ + return asyncio.create_task(f(*args), name=name) + + +def spawn( + f: Callable[..., Any], + args: tuple[Any, ...] = (), + name: str | None = None, +) -> threading.Thread: + """ + Equivalent to creating and running a daemon thread. + """ + t = threading.Thread(target=f, args=args, name=name, daemon=True) + t.start() + return t + + +async def agather(*tasks: asyncio.Task[Any], timeout: float | None = None) -> None: + """ + Equivalent to asyncio.gather or Thread.join() + """ + wait = asyncio.gather(*tasks) + try: + if timeout is not None: + await asyncio.wait_for(asyncio.shield(wait), timeout=timeout) + else: + await wait + except asyncio.TimeoutError: + pass + else: + return + + for t in tasks: + if t.done(): + continue + logger.warning("couldn't stop task %r within %s seconds", t.get_name(), timeout) + + +def gather(*tasks: threading.Thread, timeout: float | None = None) -> None: + """ + Equivalent to asyncio.gather or Thread.join() + """ + for t in tasks: + if not t.is_alive(): + continue + t.join(timeout) + if not t.is_alive(): + continue + logger.warning("couldn't stop thread %r within %s seconds", t.name, timeout) diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 0cde8ea77..4d55f2317 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -25,7 +25,7 @@ from .base import ConnectionAttempt, BasePool from .sched import Scheduler from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Deque -from ._acompat import Condition, Event, Lock +from ._acompat import Condition, Event, Lock, spawn, gather logger = logging.getLogger("psycopg.pool") @@ -340,24 +340,12 @@ class ConnectionPool(Generic[CT], BasePool): self._start_initial_tasks() def _start_workers(self) -> None: - self._sched_runner = threading.Thread( - target=self._sched.run, name=f"{self.name}-scheduler", daemon=True - ) + self._sched_runner = spawn(self._sched.run, name=f"{self.name}-scheduler") assert not self._workers for i in range(self.num_workers): - t = threading.Thread( - target=self.worker, - args=(self._tasks,), - name=f"{self.name}-worker-{i}", - daemon=True, - ) + t = spawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}") self._workers.append(t) - # The object state is complete. Start the worker threads - self._sched_runner.start() - for t in self._workers: - t.start() - def _start_initial_tasks(self) -> None: # populate the pool with initial min_size connections in background for i in range(self._nconns): @@ -399,7 +387,7 @@ class ConnectionPool(Generic[CT], BasePool): self, waiting_clients: Sequence["WaitingClient[CT]"] = (), connections: Sequence[CT] = (), - timeout: float = 0.0, + timeout: float | None = None, ) -> None: # Stop the scheduler self._sched.enter(0, None) @@ -420,18 +408,7 @@ class ConnectionPool(Generic[CT], BasePool): # Wait for the worker threads to terminate assert self._sched_runner is not None sched_runner, self._sched_runner = self._sched_runner, None - if timeout > 0: - for t in [sched_runner] + workers: - if not t.is_alive(): - continue - t.join(timeout) - if t.is_alive(): - logger.warning( - "couldn't stop thread %s in pool %r within %s seconds", - t, - self.name, - timeout, - ) + gather(sched_runner, *workers, timeout=timeout) def __enter__(self: _Self) -> _Self: self.open() diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index fe46fec24..fa2428e09 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -11,7 +11,7 @@ from time import monotonic from types import TracebackType from typing import Any, AsyncIterator, cast, Generic from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar -from asyncio import create_task, Task +from asyncio import Task from weakref import ref from contextlib import asynccontextmanager @@ -24,7 +24,7 @@ from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB from .base import ConnectionAttempt, BasePool from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Deque -from ._acompat import ACondition, AEvent, ALock +from ._acompat import ACondition, AEvent, ALock, aspawn, agather from .sched_async import AsyncScheduler logger = logging.getLogger("psycopg.pool") @@ -351,14 +351,9 @@ class AsyncConnectionPool(Generic[ACT], BasePool): self._start_initial_tasks() def _start_workers(self) -> None: - self._sched_runner = create_task( - self._sched.run(), name=f"{self.name}-scheduler" - ) + self._sched_runner = aspawn(self._sched.run, name=f"{self.name}-scheduler") for i in range(self.num_workers): - t = create_task( - self.worker(self._tasks), - name=f"{self.name}-worker-{i}", - ) + t = aspawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}") self._workers.append(t) def _start_initial_tasks(self) -> None: @@ -402,7 +397,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): self, waiting_clients: Sequence["AsyncClient[ACT]"] = (), connections: Sequence[ACT] = (), - timeout: float = 0.0, + timeout: float | None = None, ) -> None: # Stop the scheduler await self._sched.enter(0, None) @@ -423,18 +418,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Wait for the worker tasks to terminate assert self._sched_runner is not None sched_runner, self._sched_runner = self._sched_runner, None - wait = asyncio.gather(sched_runner, *workers) - try: - if timeout > 0: - await asyncio.wait_for(asyncio.shield(wait), timeout=timeout) - else: - await wait - except asyncio.TimeoutError: - logger.warning( - "couldn't stop pool %r tasks within %s seconds", - self.name, - timeout, - ) + await agather(sched_runner, *workers, timeout=timeout) async def __aenter__(self: _Self) -> _Self: await self.open()