# Copyright (C) 2023 The Psycopg Team
import asyncio
+import logging
import threading
+from typing import Any, Callable, Coroutine
Event = threading.Event
Condition = threading.Condition
Lock = threading.RLock
ALock = asyncio.Lock
+logger = logging.getLogger("psycopg.pool")
+
class AEvent(asyncio.Event):
"""
return True
except asyncio.TimeoutError:
return False
+
+
+def aspawn(
+ f: Callable[..., Coroutine[Any, Any, None]],
+ args: tuple[Any, ...] = (),
+ name: str | None = None,
+) -> asyncio.Task[None]:
+ """
+ Equivalent to asyncio.create_task.
+ """
+ return asyncio.create_task(f(*args), name=name)
+
+
+def spawn(
+ f: Callable[..., Any],
+ args: tuple[Any, ...] = (),
+ name: str | None = None,
+) -> threading.Thread:
+ """
+ Equivalent to creating and running a daemon thread.
+ """
+ t = threading.Thread(target=f, args=args, name=name, daemon=True)
+ t.start()
+ return t
+
+
+async def agather(*tasks: asyncio.Task[Any], timeout: float | None = None) -> None:
+ """
+ Equivalent to asyncio.gather or Thread.join()
+ """
+ wait = asyncio.gather(*tasks)
+ try:
+ if timeout is not None:
+ await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
+ else:
+ await wait
+ except asyncio.TimeoutError:
+ pass
+ else:
+ return
+
+ for t in tasks:
+ if t.done():
+ continue
+ logger.warning("couldn't stop task %r within %s seconds", t.get_name(), timeout)
+
+
+def gather(*tasks: threading.Thread, timeout: float | None = None) -> None:
+ """
+ Equivalent to asyncio.gather or Thread.join()
+ """
+ for t in tasks:
+ if not t.is_alive():
+ continue
+ t.join(timeout)
+ if not t.is_alive():
+ continue
+ logger.warning("couldn't stop thread %r within %s seconds", t.name, timeout)
from .sched import Scheduler
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from ._compat import Deque
-from ._acompat import Condition, Event, Lock
+from ._acompat import Condition, Event, Lock, spawn, gather
logger = logging.getLogger("psycopg.pool")
self._start_initial_tasks()
def _start_workers(self) -> None:
- self._sched_runner = threading.Thread(
- target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
- )
+ self._sched_runner = spawn(self._sched.run, name=f"{self.name}-scheduler")
assert not self._workers
for i in range(self.num_workers):
- t = threading.Thread(
- target=self.worker,
- args=(self._tasks,),
- name=f"{self.name}-worker-{i}",
- daemon=True,
- )
+ t = spawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}")
self._workers.append(t)
- # The object state is complete. Start the worker threads
- self._sched_runner.start()
- for t in self._workers:
- t.start()
-
def _start_initial_tasks(self) -> None:
# populate the pool with initial min_size connections in background
for i in range(self._nconns):
self,
waiting_clients: Sequence["WaitingClient[CT]"] = (),
connections: Sequence[CT] = (),
- timeout: float = 0.0,
+ timeout: float | None = None,
) -> None:
# Stop the scheduler
self._sched.enter(0, None)
# Wait for the worker threads to terminate
assert self._sched_runner is not None
sched_runner, self._sched_runner = self._sched_runner, None
- if timeout > 0:
- for t in [sched_runner] + workers:
- if not t.is_alive():
- continue
- t.join(timeout)
- if t.is_alive():
- logger.warning(
- "couldn't stop thread %s in pool %r within %s seconds",
- t,
- self.name,
- timeout,
- )
+ gather(sched_runner, *workers, timeout=timeout)
def __enter__(self: _Self) -> _Self:
self.open()
from types import TracebackType
from typing import Any, AsyncIterator, cast, Generic
from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar
-from asyncio import create_task, Task
+from asyncio import Task
from weakref import ref
from contextlib import asynccontextmanager
from .base import ConnectionAttempt, BasePool
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from ._compat import Deque
-from ._acompat import ACondition, AEvent, ALock
+from ._acompat import ACondition, AEvent, ALock, aspawn, agather
from .sched_async import AsyncScheduler
logger = logging.getLogger("psycopg.pool")
self._start_initial_tasks()
def _start_workers(self) -> None:
- self._sched_runner = create_task(
- self._sched.run(), name=f"{self.name}-scheduler"
- )
+ self._sched_runner = aspawn(self._sched.run, name=f"{self.name}-scheduler")
for i in range(self.num_workers):
- t = create_task(
- self.worker(self._tasks),
- name=f"{self.name}-worker-{i}",
- )
+ t = aspawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}")
self._workers.append(t)
def _start_initial_tasks(self) -> None:
self,
waiting_clients: Sequence["AsyncClient[ACT]"] = (),
connections: Sequence[ACT] = (),
- timeout: float = 0.0,
+ timeout: float | None = None,
) -> None:
# Stop the scheduler
await self._sched.enter(0, None)
# Wait for the worker tasks to terminate
assert self._sched_runner is not None
sched_runner, self._sched_runner = self._sched_runner, None
- wait = asyncio.gather(sched_runner, *workers)
- try:
- if timeout > 0:
- await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
- else:
- await wait
- except asyncio.TimeoutError:
- logger.warning(
- "couldn't stop pool %r tasks within %s seconds",
- self.name,
- timeout,
- )
+ await agather(sched_runner, *workers, timeout=timeout)
async def __aenter__(self: _Self) -> _Self:
await self.open()