--- /dev/null
+"""
+Utilities to ease the differences between async and sync code.
+
+These object offer a similar interface between sync and async versions; the
+script async_to_sync.py will replace the async names with the sync names
+when generating the sync version.
+"""
+
+# Copyright (C) 2023 The Psycopg Team
+
+import asyncio
+import threading
+
+Event = threading.Event
+Condition = threading.Condition
+Lock = threading.RLock
+ALock = asyncio.Lock
+
+
+class AEvent(asyncio.Event):
+ """
+ Subclass of asyncio.Event adding a wait with timeout like threading.Event.
+
+ wait_timeout() is converted to wait() by async_to_sync.
+ """
+
+ async def wait_timeout(self, timeout: float) -> bool:
+ try:
+ await asyncio.wait_for(self.wait(), timeout)
+ return True
+ except asyncio.TimeoutError:
+ return False
+
+
+class ACondition(asyncio.Condition):
+ """
+ Subclass of asyncio.Condition adding a wait with timeout like threading.Condition.
+
+ wait_timeout() is converted to wait() by async_to_sync.
+ """
+
+ async def wait_timeout(self, timeout: float) -> bool:
+ try:
+ await asyncio.wait_for(self.wait(), timeout)
+ return True
+ except asyncio.TimeoutError:
+ return False
# Copyright (C) 2022 The Psycopg Team
import logging
-import threading
from typing import Any, cast, Dict, Optional, overload, Tuple, Type
from psycopg import Connection
from .pool import ConnectionPool, AddConnection
from .errors import PoolTimeout, TooManyRequests
from ._compat import ConnectionTimeout
+from ._acompat import Event
logger = logging.getLogger("psycopg.pool")
with self._lock:
assert not self._pool_full_event
- self._pool_full_event = threading.Event()
+ self._pool_full_event = Event()
logger.info("waiting for pool %r initialization", self.name)
self.run_task(AddConnection(self))
# Copyright (C) 2022 The Psycopg Team
-import asyncio
import logging
from typing import Any, cast, Dict, Optional, overload, Type
from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB
from .errors import PoolTimeout, TooManyRequests
from ._compat import ConnectionTimeout
+from ._acompat import AEvent
from .null_pool import _BaseNullConnectionPool
from .pool_async import AsyncConnectionPool, AddConnection
async with self._lock:
assert not self._pool_full_event
- self._pool_full_event = asyncio.Event()
+ self._pool_full_event = AEvent()
logger.info("waiting for pool %r initialization", self.name)
self.run_task(AddConnection(self))
- try:
- await asyncio.wait_for(self._pool_full_event.wait(), timeout)
- except asyncio.TimeoutError:
+ if not await self._pool_full_event.wait_timeout(timeout):
await self.close() # stop all the tasks
- raise PoolTimeout(
- f"pool initialization incomplete after {timeout} sec"
- ) from None
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
async with self._lock:
assert self._pool_full_event
from .sched import Scheduler
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from ._compat import Deque
+from ._acompat import Condition, Event, Lock
logger = logging.getLogger("psycopg.pool")
self._reconnect_failed = reconnect_failed
- self._lock = threading.RLock()
+ self._lock = Lock()
self._waiting = Deque["WaitingClient[CT]"]()
# to notify that the pool is full
- self._pool_full_event: Optional[threading.Event] = None
+ self._pool_full_event: Optional[Event] = None
self._sched = Scheduler()
self._sched_runner: Optional[threading.Thread] = None
assert not self._pool_full_event
if len(self._pool) >= self._min_size:
return
- self._pool_full_event = threading.Event()
+ self._pool_full_event = Event()
logger.info("waiting for pool %r initialization", self.name)
if not self._pool_full_event.wait(timeout):
# message and it hasn't timed out yet, otherwise the pool may give a
# connection to a client that has already timed out getconn(), which
# will be lost.
- self._cond = threading.Condition()
+ self._cond = Condition()
def wait(self, timeout: float) -> CT:
"""Wait for a connection to be set and return it.
from .base import ConnectionAttempt, BasePool
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from ._compat import Deque
+from ._acompat import ACondition, AEvent, ALock
from .sched_async import AsyncScheduler
logger = logging.getLogger("psycopg.pool")
self._reconnect_failed = reconnect_failed
# asyncio objects, created on open to attach them to the right loop.
- self._lock: asyncio.Lock
+ self._lock: ALock
self._sched: AsyncScheduler
self._tasks: "asyncio.Queue[MaintenanceTask]"
self._waiting = Deque["AsyncClient[ACT]"]()
# to notify that the pool is full
- self._pool_full_event: Optional[asyncio.Event] = None
+ self._pool_full_event: Optional[AEvent] = None
self._sched_runner: Optional[Task[None]] = None
self._workers: List[Task[None]] = []
assert not self._pool_full_event
if len(self._pool) >= self._min_size:
return
- self._pool_full_event = asyncio.Event()
+ self._pool_full_event = AEvent()
logger.info("waiting for pool %r initialization", self.name)
- try:
- await asyncio.wait_for(self._pool_full_event.wait(), timeout)
- except asyncio.TimeoutError:
+ if not await self._pool_full_event.wait_timeout(timeout):
await self.close() # stop all the tasks
- raise PoolTimeout(
- f"pool initialization incomplete after {timeout} sec"
- ) from None
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
async with self._lock:
assert self._pool_full_event
try:
self._lock
except AttributeError:
- self._lock = asyncio.Lock()
+ self._lock = ALock()
async with self._lock:
self._open()
try:
self._lock
except AttributeError:
- self._lock = asyncio.Lock()
+ self._lock = ALock()
self._closed = False
self._opened = True
# message and it hasn't timed out yet, otherwise the pool may give a
# connection to a client that has already timed out getconn(), which
# will be lost.
- self._cond = asyncio.Condition()
+ self._cond = ACondition()
async def wait(self, timeout: float) -> ACT:
"""Wait for a connection to be set and return it.
async with self._cond:
if not (self.conn or self.error):
try:
- await asyncio.wait_for(self._cond.wait(), timeout)
- except asyncio.TimeoutError:
- self.error = PoolTimeout(
- f"couldn't get a connection after {timeout:.2f} sec"
- )
+ if not await self._cond.wait_timeout(timeout):
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout:.2f} sec"
+ )
except BaseException as ex:
self.error = ex
from heapq import heappush, heappop
from typing import Any, Callable, List, Optional
-from threading import RLock as Lock, Event
-
from ._task import Task
+from ._acompat import Lock, Event
logger = logging.getLogger(__name__)
)
else:
# Block for the expected timeout or until a new task scheduled
- self._event.wait(timeout=delay)
+ self._event.wait(delay)
from heapq import heappush, heappop
from typing import Any, Callable, List, Optional
-if True: # ASYNC
- from asyncio import Event, Lock, TimeoutError, wait_for
-else:
- from threading import RLock as Lock, Event
-
-
from ._task import Task
+from ._acompat import ALock, AEvent
logger = logging.getLogger(__name__)
def __init__(self) -> None:
"""Initialize a new instance, passing the time and delay functions."""
self._queue: List[Task] = []
- self._lock = Lock()
- self._event = Event()
+ self._lock = ALock()
+ self._event = AEvent()
EMPTY_QUEUE_TIMEOUT = 600.0
)
else:
# Block for the expected timeout or until a new task scheduled
- if True: # ASYNC
- try:
- await wait_for(self._event.wait(), delay)
- except TimeoutError:
- pass
- else:
- self._event.wait(timeout=delay)
+ await self._event.wait_timeout(delay)
class RenameAsyncToSync(ast.NodeTransformer):
names_map = {
"AEvent": "Event",
+ "ALock": "Lock",
"AsyncClientCursor": "ClientCursor",
"AsyncConnection": "Connection",
"AsyncConnectionPool": "ConnectionPool",