From: Daniele Varrazzo Date: Wed, 4 Oct 2023 20:44:48 +0000 (+0200) Subject: refactor(pool): add psycopg_pool._acompat to ease async/sync differences X-Git-Tag: pool-3.2.0~12^2~26 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=1e1cad963c46f93fd35eef517d2ca348e31750d9;p=thirdparty%2Fpsycopg.git refactor(pool): add psycopg_pool._acompat to ease async/sync differences --- diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py new file mode 100644 index 000000000..9f59a64f9 --- /dev/null +++ b/psycopg_pool/psycopg_pool/_acompat.py @@ -0,0 +1,47 @@ +""" +Utilities to ease the differences between async and sync code. + +These object offer a similar interface between sync and async versions; the +script async_to_sync.py will replace the async names with the sync names +when generating the sync version. +""" + +# Copyright (C) 2023 The Psycopg Team + +import asyncio +import threading + +Event = threading.Event +Condition = threading.Condition +Lock = threading.RLock +ALock = asyncio.Lock + + +class AEvent(asyncio.Event): + """ + Subclass of asyncio.Event adding a wait with timeout like threading.Event. + + wait_timeout() is converted to wait() by async_to_sync. + """ + + async def wait_timeout(self, timeout: float) -> bool: + try: + await asyncio.wait_for(self.wait(), timeout) + return True + except asyncio.TimeoutError: + return False + + +class ACondition(asyncio.Condition): + """ + Subclass of asyncio.Condition adding a wait with timeout like threading.Condition. + + wait_timeout() is converted to wait() by async_to_sync. + """ + + async def wait_timeout(self, timeout: float) -> bool: + try: + await asyncio.wait_for(self.wait(), timeout) + return True + except asyncio.TimeoutError: + return False diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py index d0b33931e..6d39e82d9 100644 --- a/psycopg_pool/psycopg_pool/null_pool.py +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -5,7 +5,6 @@ Psycopg null connection pools # Copyright (C) 2022 The Psycopg Team import logging -import threading from typing import Any, cast, Dict, Optional, overload, Tuple, Type from psycopg import Connection @@ -16,6 +15,7 @@ from .abc import CT, ConnectionCB, ConnectFailedCB from .pool import ConnectionPool, AddConnection from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout +from ._acompat import Event logger = logging.getLogger("psycopg.pool") @@ -142,7 +142,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]): with self._lock: assert not self._pool_full_event - self._pool_full_event = threading.Event() + self._pool_full_event = Event() logger.info("waiting for pool %r initialization", self.name) self.run_task(AddConnection(self)) diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py index b1553ee5b..5cf4b0997 100644 --- a/psycopg_pool/psycopg_pool/null_pool_async.py +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -4,7 +4,6 @@ psycopg asynchronous null connection pool # Copyright (C) 2022 The Psycopg Team -import asyncio import logging from typing import Any, cast, Dict, Optional, overload, Type @@ -15,6 +14,7 @@ from psycopg.rows import TupleRow from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB from .errors import PoolTimeout, TooManyRequests from ._compat import ConnectionTimeout +from ._acompat import AEvent from .null_pool import _BaseNullConnectionPool from .pool_async import AsyncConnectionPool, AddConnection @@ -111,17 +111,13 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]) async with self._lock: assert not self._pool_full_event - self._pool_full_event = asyncio.Event() + self._pool_full_event = AEvent() logger.info("waiting for pool %r initialization", self.name) self.run_task(AddConnection(self)) - try: - await asyncio.wait_for(self._pool_full_event.wait(), timeout) - except asyncio.TimeoutError: + if not await self._pool_full_event.wait_timeout(timeout): await self.close() # stop all the tasks - raise PoolTimeout( - f"pool initialization incomplete after {timeout} sec" - ) from None + raise PoolTimeout(f"pool initialization incomplete after {timeout} sec") async with self._lock: assert self._pool_full_event diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index cbf3520fb..0cde8ea77 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -25,6 +25,7 @@ from .base import ConnectionAttempt, BasePool from .sched import Scheduler from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Deque +from ._acompat import Condition, Event, Lock logger = logging.getLogger("psycopg.pool") @@ -104,11 +105,11 @@ class ConnectionPool(Generic[CT], BasePool): self._reconnect_failed = reconnect_failed - self._lock = threading.RLock() + self._lock = Lock() self._waiting = Deque["WaitingClient[CT]"]() # to notify that the pool is full - self._pool_full_event: Optional[threading.Event] = None + self._pool_full_event: Optional[Event] = None self._sched = Scheduler() self._sched_runner: Optional[threading.Thread] = None @@ -160,7 +161,7 @@ class ConnectionPool(Generic[CT], BasePool): assert not self._pool_full_event if len(self._pool) >= self._min_size: return - self._pool_full_event = threading.Event() + self._pool_full_event = Event() logger.info("waiting for pool %r initialization", self.name) if not self._pool_full_event.wait(timeout): @@ -780,7 +781,7 @@ class WaitingClient(Generic[CT]): # message and it hasn't timed out yet, otherwise the pool may give a # connection to a client that has already timed out getconn(), which # will be lost. - self._cond = threading.Condition() + self._cond = Condition() def wait(self, timeout: float) -> CT: """Wait for a connection to be set and return it. diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 82ad0e562..fe46fec24 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -24,6 +24,7 @@ from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB from .base import ConnectionAttempt, BasePool from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Deque +from ._acompat import ACondition, AEvent, ALock from .sched_async import AsyncScheduler logger = logging.getLogger("psycopg.pool") @@ -105,14 +106,14 @@ class AsyncConnectionPool(Generic[ACT], BasePool): self._reconnect_failed = reconnect_failed # asyncio objects, created on open to attach them to the right loop. - self._lock: asyncio.Lock + self._lock: ALock self._sched: AsyncScheduler self._tasks: "asyncio.Queue[MaintenanceTask]" self._waiting = Deque["AsyncClient[ACT]"]() # to notify that the pool is full - self._pool_full_event: Optional[asyncio.Event] = None + self._pool_full_event: Optional[AEvent] = None self._sched_runner: Optional[Task[None]] = None self._workers: List[Task[None]] = [] @@ -154,16 +155,12 @@ class AsyncConnectionPool(Generic[ACT], BasePool): assert not self._pool_full_event if len(self._pool) >= self._min_size: return - self._pool_full_event = asyncio.Event() + self._pool_full_event = AEvent() logger.info("waiting for pool %r initialization", self.name) - try: - await asyncio.wait_for(self._pool_full_event.wait(), timeout) - except asyncio.TimeoutError: + if not await self._pool_full_event.wait_timeout(timeout): await self.close() # stop all the tasks - raise PoolTimeout( - f"pool initialization incomplete after {timeout} sec" - ) from None + raise PoolTimeout(f"pool initialization incomplete after {timeout} sec") async with self._lock: assert self._pool_full_event @@ -320,7 +317,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): try: self._lock except AttributeError: - self._lock = asyncio.Lock() + self._lock = ALock() async with self._lock: self._open() @@ -345,7 +342,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): try: self._lock except AttributeError: - self._lock = asyncio.Lock() + self._lock = ALock() self._closed = False self._opened = True @@ -782,7 +779,7 @@ class AsyncClient(Generic[ACT]): # message and it hasn't timed out yet, otherwise the pool may give a # connection to a client that has already timed out getconn(), which # will be lost. - self._cond = asyncio.Condition() + self._cond = ACondition() async def wait(self, timeout: float) -> ACT: """Wait for a connection to be set and return it. @@ -792,11 +789,10 @@ class AsyncClient(Generic[ACT]): async with self._cond: if not (self.conn or self.error): try: - await asyncio.wait_for(self._cond.wait(), timeout) - except asyncio.TimeoutError: - self.error = PoolTimeout( - f"couldn't get a connection after {timeout:.2f} sec" - ) + if not await self._cond.wait_timeout(timeout): + self.error = PoolTimeout( + f"couldn't get a connection after {timeout:.2f} sec" + ) except BaseException as ex: self.error = ex diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py index 2c6f3c0e8..4a2848615 100644 --- a/psycopg_pool/psycopg_pool/sched.py +++ b/psycopg_pool/psycopg_pool/sched.py @@ -20,9 +20,8 @@ from time import monotonic from heapq import heappush, heappop from typing import Any, Callable, List, Optional -from threading import RLock as Lock, Event - from ._task import Task +from ._acompat import Lock, Event logger = logging.getLogger(__name__) @@ -90,4 +89,4 @@ class Scheduler: ) else: # Block for the expected timeout or until a new task scheduled - self._event.wait(timeout=delay) + self._event.wait(delay) diff --git a/psycopg_pool/psycopg_pool/sched_async.py b/psycopg_pool/psycopg_pool/sched_async.py index fe9e443ff..db273dd1f 100644 --- a/psycopg_pool/psycopg_pool/sched_async.py +++ b/psycopg_pool/psycopg_pool/sched_async.py @@ -17,13 +17,8 @@ from time import monotonic from heapq import heappush, heappop from typing import Any, Callable, List, Optional -if True: # ASYNC - from asyncio import Event, Lock, TimeoutError, wait_for -else: - from threading import RLock as Lock, Event - - from ._task import Task +from ._acompat import ALock, AEvent logger = logging.getLogger(__name__) @@ -32,8 +27,8 @@ class AsyncScheduler: def __init__(self) -> None: """Initialize a new instance, passing the time and delay functions.""" self._queue: List[Task] = [] - self._lock = Lock() - self._event = Event() + self._lock = ALock() + self._event = AEvent() EMPTY_QUEUE_TIMEOUT = 600.0 @@ -91,10 +86,4 @@ class AsyncScheduler: ) else: # Block for the expected timeout or until a new task scheduled - if True: # ASYNC - try: - await wait_for(self._event.wait(), delay) - except TimeoutError: - pass - else: - self._event.wait(timeout=delay) + await self._event.wait_timeout(delay) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 782a7fa23..dd854f014 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -153,6 +153,7 @@ class AsyncToSync(ast.NodeTransformer): class RenameAsyncToSync(ast.NodeTransformer): names_map = { "AEvent": "Event", + "ALock": "Lock", "AsyncClientCursor": "ClientCursor", "AsyncConnection": "Connection", "AsyncConnectionPool": "ConnectionPool",