]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): add Queue/AQueue compatibility objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Oct 2023 22:21:58 +0000 (00:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg_pool/psycopg_pool/_acompat.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py

index ffc2f1a3b13d43374dea5672064511d873da1c69..620fad687bc3d787cd20fb71ef7843e98257c172 100644 (file)
@@ -8,17 +8,32 @@ when generating the sync version.
 
 # Copyright (C) 2023 The Psycopg Team
 
+import queue
 import asyncio
 import logging
 import threading
-from typing import Any, Callable, Coroutine
+from typing import Any, Callable, Coroutine, TypeVar
 
+logger = logging.getLogger("psycopg.pool")
+T = TypeVar("T")
+
+# Re-exports
 Event = threading.Event
 Condition = threading.Condition
 Lock = threading.RLock
 ALock = asyncio.Lock
 
-logger = logging.getLogger("psycopg.pool")
+
+class Queue(queue.Queue[T]):
+    """
+    A Queue subclass with an interruptible get() method.
+    """
+
+    def get(self, block: bool = True, timeout: float | None = None) -> T:
+        # Always specify a timeout to make the wait interruptible.
+        if timeout is None:
+            timeout = 24.0 * 60.0 * 60.0
+        return super().get(block, timeout)
 
 
 class AEvent(asyncio.Event):
@@ -51,6 +66,10 @@ class ACondition(asyncio.Condition):
             return False
 
 
+class AQueue(asyncio.Queue[T]):
+    pass
+
+
 def aspawn(
     f: Callable[..., Coroutine[Any, Any, None]],
     args: tuple[Any, ...] = (),
index 4d55f2317c1c97b877f2ddffff8331866c2146a8..ccea754eec572ee0e1c925d661d006b8363b9af1 100644 (file)
@@ -8,7 +8,6 @@ import logging
 import threading
 from abc import ABC, abstractmethod
 from time import monotonic
-from queue import Queue, Empty
 from types import TracebackType
 from typing import Any, cast, Dict, Generic, Iterator, List
 from typing import Optional, overload, Sequence, Type, TypeVar
@@ -25,7 +24,7 @@ from .base import ConnectionAttempt, BasePool
 from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
-from ._acompat import Condition, Event, Lock, spawn, gather
+from ._acompat import Condition, Event, Lock, Queue, spawn, gather
 
 logger = logging.getLogger("psycopg.pool")
 
@@ -496,8 +495,6 @@ class ConnectionPool(Generic[CT], BasePool):
         """Run a maintenance task in a worker thread in the future."""
         self._sched.enter(delay, task.tick)
 
-    _WORKER_TIMEOUT = 60.0
-
     @classmethod
     def worker(cls, q: "Queue[MaintenanceTask]") -> None:
         """Runner to execute pending maintenance task.
@@ -507,14 +504,8 @@ class ConnectionPool(Generic[CT], BasePool):
         Block on the queue *q*, run a task received. Finish running if a
         StopWorker is received.
         """
-        # Don't make all the workers time out at the same moment
-        timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1)
         while True:
-            # Use a timeout to make the wait interruptible
-            try:
-                task = q.get(timeout=timeout)
-            except Empty:
-                continue
+            task = q.get()
 
             if isinstance(task, StopWorker):
                 logger.debug(
index fa2428e09eac4700e997e4e3eedb68bd3a06bc67..bf3a6f5f3db394c6ff983a2ec397bf60965c84e8 100644 (file)
@@ -24,7 +24,7 @@ from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB
 from .base import ConnectionAttempt, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
-from ._acompat import ACondition, AEvent, ALock, aspawn, agather
+from ._acompat import ACondition, AEvent, ALock, AQueue, aspawn, agather
 from .sched_async import AsyncScheduler
 
 logger = logging.getLogger("psycopg.pool")
@@ -108,7 +108,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         # asyncio objects, created on open to attach them to the right loop.
         self._lock: ALock
         self._sched: AsyncScheduler
-        self._tasks: "asyncio.Queue[MaintenanceTask]"
+        self._tasks: AQueue["MaintenanceTask"]
 
         self._waiting = Deque["AsyncClient[ACT]"]()
 
@@ -336,7 +336,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         # Create these objects now to attach them to the right loop.
         # See #219
-        self._tasks = asyncio.Queue()
+        self._tasks = AQueue()
         self._sched = AsyncScheduler()
         # This has been most likely, but not necessarily, created in `open()`.
         try:
@@ -510,7 +510,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         await self._sched.enter(delay, task.tick)
 
     @classmethod
-    async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None:
+    async def worker(cls, q: AQueue["MaintenanceTask"]) -> None:
         """Runner to execute pending maintenance task.
 
         The function is designed to run as a task.