]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): add spawn/gather async compat functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Oct 2023 22:01:24 +0000 (00:01 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg_pool/psycopg_pool/_acompat.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py

index 9f59a64f92cf97f241c8d98a8bd0fc869f5f66b4..ffc2f1a3b13d43374dea5672064511d873da1c69 100644 (file)
@@ -9,13 +9,17 @@ when generating the sync version.
 # Copyright (C) 2023 The Psycopg Team
 
 import asyncio
+import logging
 import threading
+from typing import Any, Callable, Coroutine
 
 Event = threading.Event
 Condition = threading.Condition
 Lock = threading.RLock
 ALock = asyncio.Lock
 
+logger = logging.getLogger("psycopg.pool")
+
 
 class AEvent(asyncio.Event):
     """
@@ -45,3 +49,61 @@ class ACondition(asyncio.Condition):
             return True
         except asyncio.TimeoutError:
             return False
+
+
+def aspawn(
+    f: Callable[..., Coroutine[Any, Any, None]],
+    args: tuple[Any, ...] = (),
+    name: str | None = None,
+) -> asyncio.Task[None]:
+    """
+    Equivalent to asyncio.create_task.
+    """
+    return asyncio.create_task(f(*args), name=name)
+
+
+def spawn(
+    f: Callable[..., Any],
+    args: tuple[Any, ...] = (),
+    name: str | None = None,
+) -> threading.Thread:
+    """
+    Equivalent to creating and running a daemon thread.
+    """
+    t = threading.Thread(target=f, args=args, name=name, daemon=True)
+    t.start()
+    return t
+
+
+async def agather(*tasks: asyncio.Task[Any], timeout: float | None = None) -> None:
+    """
+    Equivalent to asyncio.gather or Thread.join()
+    """
+    wait = asyncio.gather(*tasks)
+    try:
+        if timeout is not None:
+            await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
+        else:
+            await wait
+    except asyncio.TimeoutError:
+        pass
+    else:
+        return
+
+    for t in tasks:
+        if t.done():
+            continue
+        logger.warning("couldn't stop task %r within %s seconds", t.get_name(), timeout)
+
+
+def gather(*tasks: threading.Thread, timeout: float | None = None) -> None:
+    """
+    Equivalent to asyncio.gather or Thread.join()
+    """
+    for t in tasks:
+        if not t.is_alive():
+            continue
+        t.join(timeout)
+        if not t.is_alive():
+            continue
+        logger.warning("couldn't stop thread %r within %s seconds", t.name, timeout)
index 0cde8ea770b24c81b9d79f95b5500b63beb6ad3e..4d55f2317c1c97b877f2ddffff8331866c2146a8 100644 (file)
@@ -25,7 +25,7 @@ from .base import ConnectionAttempt, BasePool
 from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
-from ._acompat import Condition, Event, Lock
+from ._acompat import Condition, Event, Lock, spawn, gather
 
 logger = logging.getLogger("psycopg.pool")
 
@@ -340,24 +340,12 @@ class ConnectionPool(Generic[CT], BasePool):
         self._start_initial_tasks()
 
     def _start_workers(self) -> None:
-        self._sched_runner = threading.Thread(
-            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
-        )
+        self._sched_runner = spawn(self._sched.run, name=f"{self.name}-scheduler")
         assert not self._workers
         for i in range(self.num_workers):
-            t = threading.Thread(
-                target=self.worker,
-                args=(self._tasks,),
-                name=f"{self.name}-worker-{i}",
-                daemon=True,
-            )
+            t = spawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}")
             self._workers.append(t)
 
-        # The object state is complete. Start the worker threads
-        self._sched_runner.start()
-        for t in self._workers:
-            t.start()
-
     def _start_initial_tasks(self) -> None:
         # populate the pool with initial min_size connections in background
         for i in range(self._nconns):
@@ -399,7 +387,7 @@ class ConnectionPool(Generic[CT], BasePool):
         self,
         waiting_clients: Sequence["WaitingClient[CT]"] = (),
         connections: Sequence[CT] = (),
-        timeout: float = 0.0,
+        timeout: float | None = None,
     ) -> None:
         # Stop the scheduler
         self._sched.enter(0, None)
@@ -420,18 +408,7 @@ class ConnectionPool(Generic[CT], BasePool):
         # Wait for the worker threads to terminate
         assert self._sched_runner is not None
         sched_runner, self._sched_runner = self._sched_runner, None
-        if timeout > 0:
-            for t in [sched_runner] + workers:
-                if not t.is_alive():
-                    continue
-                t.join(timeout)
-                if t.is_alive():
-                    logger.warning(
-                        "couldn't stop thread %s in pool %r within %s seconds",
-                        t,
-                        self.name,
-                        timeout,
-                    )
+        gather(sched_runner, *workers, timeout=timeout)
 
     def __enter__(self: _Self) -> _Self:
         self.open()
index fe46fec24758da311ab06c76e437f4f163474896..fa2428e09eac4700e997e4e3eedb68bd3a06bc67 100644 (file)
@@ -11,7 +11,7 @@ from time import monotonic
 from types import TracebackType
 from typing import Any, AsyncIterator, cast, Generic
 from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar
-from asyncio import create_task, Task
+from asyncio import Task
 from weakref import ref
 from contextlib import asynccontextmanager
 
@@ -24,7 +24,7 @@ from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB
 from .base import ConnectionAttempt, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
-from ._acompat import ACondition, AEvent, ALock
+from ._acompat import ACondition, AEvent, ALock, aspawn, agather
 from .sched_async import AsyncScheduler
 
 logger = logging.getLogger("psycopg.pool")
@@ -351,14 +351,9 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         self._start_initial_tasks()
 
     def _start_workers(self) -> None:
-        self._sched_runner = create_task(
-            self._sched.run(), name=f"{self.name}-scheduler"
-        )
+        self._sched_runner = aspawn(self._sched.run, name=f"{self.name}-scheduler")
         for i in range(self.num_workers):
-            t = create_task(
-                self.worker(self._tasks),
-                name=f"{self.name}-worker-{i}",
-            )
+            t = aspawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}")
             self._workers.append(t)
 
     def _start_initial_tasks(self) -> None:
@@ -402,7 +397,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         self,
         waiting_clients: Sequence["AsyncClient[ACT]"] = (),
         connections: Sequence[ACT] = (),
-        timeout: float = 0.0,
+        timeout: float | None = None,
     ) -> None:
         # Stop the scheduler
         await self._sched.enter(0, None)
@@ -423,18 +418,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         # Wait for the worker tasks to terminate
         assert self._sched_runner is not None
         sched_runner, self._sched_runner = self._sched_runner, None
-        wait = asyncio.gather(sched_runner, *workers)
-        try:
-            if timeout > 0:
-                await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
-            else:
-                await wait
-        except asyncio.TimeoutError:
-            logger.warning(
-                "couldn't stop pool %r tasks within %s seconds",
-                self.name,
-                timeout,
-            )
+        await agather(sched_runner, *workers, timeout=timeout)
 
     async def __aenter__(self: _Self) -> _Self:
         await self.open()