From: Daniele Varrazzo Date: Wed, 4 Oct 2023 22:21:58 +0000 (+0200) Subject: refactor(pool): add Queue/AQueue compatibility objects X-Git-Tag: pool-3.2.0~12^2~24 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=245f3d363380537ddcd1ea4f7e44f3a31c3e8925;p=thirdparty%2Fpsycopg.git refactor(pool): add Queue/AQueue compatibility objects --- diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py index ffc2f1a3b..620fad687 100644 --- a/psycopg_pool/psycopg_pool/_acompat.py +++ b/psycopg_pool/psycopg_pool/_acompat.py @@ -8,17 +8,32 @@ when generating the sync version. # Copyright (C) 2023 The Psycopg Team +import queue import asyncio import logging import threading -from typing import Any, Callable, Coroutine +from typing import Any, Callable, Coroutine, TypeVar +logger = logging.getLogger("psycopg.pool") +T = TypeVar("T") + +# Re-exports Event = threading.Event Condition = threading.Condition Lock = threading.RLock ALock = asyncio.Lock -logger = logging.getLogger("psycopg.pool") + +class Queue(queue.Queue[T]): + """ + A Queue subclass with an interruptible get() method. + """ + + def get(self, block: bool = True, timeout: float | None = None) -> T: + # Always specify a timeout to make the wait interruptible. + if timeout is None: + timeout = 24.0 * 60.0 * 60.0 + return super().get(block, timeout) class AEvent(asyncio.Event): @@ -51,6 +66,10 @@ class ACondition(asyncio.Condition): return False +class AQueue(asyncio.Queue[T]): + pass + + def aspawn( f: Callable[..., Coroutine[Any, Any, None]], args: tuple[Any, ...] = (), diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 4d55f2317..ccea754ee 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -8,7 +8,6 @@ import logging import threading from abc import ABC, abstractmethod from time import monotonic -from queue import Queue, Empty from types import TracebackType from typing import Any, cast, Dict, Generic, Iterator, List from typing import Optional, overload, Sequence, Type, TypeVar @@ -25,7 +24,7 @@ from .base import ConnectionAttempt, BasePool from .sched import Scheduler from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Deque -from ._acompat import Condition, Event, Lock, spawn, gather +from ._acompat import Condition, Event, Lock, Queue, spawn, gather logger = logging.getLogger("psycopg.pool") @@ -496,8 +495,6 @@ class ConnectionPool(Generic[CT], BasePool): """Run a maintenance task in a worker thread in the future.""" self._sched.enter(delay, task.tick) - _WORKER_TIMEOUT = 60.0 - @classmethod def worker(cls, q: "Queue[MaintenanceTask]") -> None: """Runner to execute pending maintenance task. @@ -507,14 +504,8 @@ class ConnectionPool(Generic[CT], BasePool): Block on the queue *q*, run a task received. Finish running if a StopWorker is received. """ - # Don't make all the workers time out at the same moment - timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1) while True: - # Use a timeout to make the wait interruptible - try: - task = q.get(timeout=timeout) - except Empty: - continue + task = q.get() if isinstance(task, StopWorker): logger.debug( diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index fa2428e09..bf3a6f5f3 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -24,7 +24,7 @@ from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB from .base import ConnectionAttempt, BasePool from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Deque -from ._acompat import ACondition, AEvent, ALock, aspawn, agather +from ._acompat import ACondition, AEvent, ALock, AQueue, aspawn, agather from .sched_async import AsyncScheduler logger = logging.getLogger("psycopg.pool") @@ -108,7 +108,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # asyncio objects, created on open to attach them to the right loop. self._lock: ALock self._sched: AsyncScheduler - self._tasks: "asyncio.Queue[MaintenanceTask]" + self._tasks: AQueue["MaintenanceTask"] self._waiting = Deque["AsyncClient[ACT]"]() @@ -336,7 +336,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Create these objects now to attach them to the right loop. # See #219 - self._tasks = asyncio.Queue() + self._tasks = AQueue() self._sched = AsyncScheduler() # This has been most likely, but not necessarily, created in `open()`. try: @@ -510,7 +510,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): await self._sched.enter(delay, task.tick) @classmethod - async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None: + async def worker(cls, q: AQueue["MaintenanceTask"]) -> None: """Runner to execute pending maintenance task. The function is designed to run as a task.