]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): add psycopg_pool._acompat to ease async/sync differences
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Oct 2023 20:44:48 +0000 (22:44 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg_pool/psycopg_pool/_acompat.py [new file with mode: 0644]
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
psycopg_pool/psycopg_pool/sched.py
psycopg_pool/psycopg_pool/sched_async.py
tools/async_to_sync.py

diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py
new file mode 100644 (file)
index 0000000..9f59a64
--- /dev/null
@@ -0,0 +1,47 @@
+"""
+Utilities to ease the differences between async and sync code.
+
+These object offer a similar interface between sync and async versions; the
+script async_to_sync.py will replace the async names with the sync names
+when generating the sync version.
+"""
+
+# Copyright (C) 2023 The Psycopg Team
+
+import asyncio
+import threading
+
+Event = threading.Event
+Condition = threading.Condition
+Lock = threading.RLock
+ALock = asyncio.Lock
+
+
+class AEvent(asyncio.Event):
+    """
+    Subclass of asyncio.Event adding a wait with timeout like threading.Event.
+
+    wait_timeout() is converted to wait() by async_to_sync.
+    """
+
+    async def wait_timeout(self, timeout: float) -> bool:
+        try:
+            await asyncio.wait_for(self.wait(), timeout)
+            return True
+        except asyncio.TimeoutError:
+            return False
+
+
+class ACondition(asyncio.Condition):
+    """
+    Subclass of asyncio.Condition adding a wait with timeout like threading.Condition.
+
+    wait_timeout() is converted to wait() by async_to_sync.
+    """
+
+    async def wait_timeout(self, timeout: float) -> bool:
+        try:
+            await asyncio.wait_for(self.wait(), timeout)
+            return True
+        except asyncio.TimeoutError:
+            return False
index d0b33931ea3a086ad1e47908f363952bb2df6fab..6d39e82d9b465034f6c4e2c6b68f4b5831dbb0c7 100644 (file)
@@ -5,7 +5,6 @@ Psycopg null connection pools
 # Copyright (C) 2022 The Psycopg Team
 
 import logging
-import threading
 from typing import Any, cast, Dict, Optional, overload, Tuple, Type
 
 from psycopg import Connection
@@ -16,6 +15,7 @@ from .abc import CT, ConnectionCB, ConnectFailedCB
 from .pool import ConnectionPool, AddConnection
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
+from ._acompat import Event
 
 logger = logging.getLogger("psycopg.pool")
 
@@ -142,7 +142,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
 
         with self._lock:
             assert not self._pool_full_event
-            self._pool_full_event = threading.Event()
+            self._pool_full_event = Event()
 
         logger.info("waiting for pool %r initialization", self.name)
         self.run_task(AddConnection(self))
index b1553ee5bb572adf2dce62006c3bba394429af25..5cf4b09975aaa80c6890231f6cda308305b87aea 100644 (file)
@@ -4,7 +4,6 @@ psycopg asynchronous null connection pool
 
 # Copyright (C) 2022 The Psycopg Team
 
-import asyncio
 import logging
 from typing import Any, cast, Dict, Optional, overload, Type
 
@@ -15,6 +14,7 @@ from psycopg.rows import TupleRow
 from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB
 from .errors import PoolTimeout, TooManyRequests
 from ._compat import ConnectionTimeout
+from ._acompat import AEvent
 from .null_pool import _BaseNullConnectionPool
 from .pool_async import AsyncConnectionPool, AddConnection
 
@@ -111,17 +111,13 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT])
 
         async with self._lock:
             assert not self._pool_full_event
-            self._pool_full_event = asyncio.Event()
+            self._pool_full_event = AEvent()
 
         logger.info("waiting for pool %r initialization", self.name)
         self.run_task(AddConnection(self))
-        try:
-            await asyncio.wait_for(self._pool_full_event.wait(), timeout)
-        except asyncio.TimeoutError:
+        if not await self._pool_full_event.wait_timeout(timeout):
             await self.close()  # stop all the tasks
-            raise PoolTimeout(
-                f"pool initialization incomplete after {timeout} sec"
-            ) from None
+            raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
 
         async with self._lock:
             assert self._pool_full_event
index cbf3520fb2374aa5f4f68500f4f6514b8bdc7447..0cde8ea770b24c81b9d79f95b5500b63beb6ad3e 100644 (file)
@@ -25,6 +25,7 @@ from .base import ConnectionAttempt, BasePool
 from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
+from ._acompat import Condition, Event, Lock
 
 logger = logging.getLogger("psycopg.pool")
 
@@ -104,11 +105,11 @@ class ConnectionPool(Generic[CT], BasePool):
 
         self._reconnect_failed = reconnect_failed
 
-        self._lock = threading.RLock()
+        self._lock = Lock()
         self._waiting = Deque["WaitingClient[CT]"]()
 
         # to notify that the pool is full
-        self._pool_full_event: Optional[threading.Event] = None
+        self._pool_full_event: Optional[Event] = None
 
         self._sched = Scheduler()
         self._sched_runner: Optional[threading.Thread] = None
@@ -160,7 +161,7 @@ class ConnectionPool(Generic[CT], BasePool):
             assert not self._pool_full_event
             if len(self._pool) >= self._min_size:
                 return
-            self._pool_full_event = threading.Event()
+            self._pool_full_event = Event()
 
         logger.info("waiting for pool %r initialization", self.name)
         if not self._pool_full_event.wait(timeout):
@@ -780,7 +781,7 @@ class WaitingClient(Generic[CT]):
         # message and it hasn't timed out yet, otherwise the pool may give a
         # connection to a client that has already timed out getconn(), which
         # will be lost.
-        self._cond = threading.Condition()
+        self._cond = Condition()
 
     def wait(self, timeout: float) -> CT:
         """Wait for a connection to be set and return it.
index 82ad0e5625a794c05481bbeb9413b749fb74ef46..fe46fec24758da311ab06c76e437f4f163474896 100644 (file)
@@ -24,6 +24,7 @@ from .abc import ACT, AsyncConnectionCB, AsyncConnectFailedCB
 from .base import ConnectionAttempt, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
+from ._acompat import ACondition, AEvent, ALock
 from .sched_async import AsyncScheduler
 
 logger = logging.getLogger("psycopg.pool")
@@ -105,14 +106,14 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         self._reconnect_failed = reconnect_failed
 
         # asyncio objects, created on open to attach them to the right loop.
-        self._lock: asyncio.Lock
+        self._lock: ALock
         self._sched: AsyncScheduler
         self._tasks: "asyncio.Queue[MaintenanceTask]"
 
         self._waiting = Deque["AsyncClient[ACT]"]()
 
         # to notify that the pool is full
-        self._pool_full_event: Optional[asyncio.Event] = None
+        self._pool_full_event: Optional[AEvent] = None
 
         self._sched_runner: Optional[Task[None]] = None
         self._workers: List[Task[None]] = []
@@ -154,16 +155,12 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             assert not self._pool_full_event
             if len(self._pool) >= self._min_size:
                 return
-            self._pool_full_event = asyncio.Event()
+            self._pool_full_event = AEvent()
 
         logger.info("waiting for pool %r initialization", self.name)
-        try:
-            await asyncio.wait_for(self._pool_full_event.wait(), timeout)
-        except asyncio.TimeoutError:
+        if not await self._pool_full_event.wait_timeout(timeout):
             await self.close()  # stop all the tasks
-            raise PoolTimeout(
-                f"pool initialization incomplete after {timeout} sec"
-            ) from None
+            raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
 
         async with self._lock:
             assert self._pool_full_event
@@ -320,7 +317,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         try:
             self._lock
         except AttributeError:
-            self._lock = asyncio.Lock()
+            self._lock = ALock()
 
         async with self._lock:
             self._open()
@@ -345,7 +342,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         try:
             self._lock
         except AttributeError:
-            self._lock = asyncio.Lock()
+            self._lock = ALock()
 
         self._closed = False
         self._opened = True
@@ -782,7 +779,7 @@ class AsyncClient(Generic[ACT]):
         # message and it hasn't timed out yet, otherwise the pool may give a
         # connection to a client that has already timed out getconn(), which
         # will be lost.
-        self._cond = asyncio.Condition()
+        self._cond = ACondition()
 
     async def wait(self, timeout: float) -> ACT:
         """Wait for a connection to be set and return it.
@@ -792,11 +789,10 @@ class AsyncClient(Generic[ACT]):
         async with self._cond:
             if not (self.conn or self.error):
                 try:
-                    await asyncio.wait_for(self._cond.wait(), timeout)
-                except asyncio.TimeoutError:
-                    self.error = PoolTimeout(
-                        f"couldn't get a connection after {timeout:.2f} sec"
-                    )
+                    if not await self._cond.wait_timeout(timeout):
+                        self.error = PoolTimeout(
+                            f"couldn't get a connection after {timeout:.2f} sec"
+                        )
                 except BaseException as ex:
                     self.error = ex
 
index 2c6f3c0e85314aa6aa74fbb9a7f213b1bb2c06ba..4a284861551f2afc56e6cf501d94796b6a01eeeb 100644 (file)
@@ -20,9 +20,8 @@ from time import monotonic
 from heapq import heappush, heappop
 from typing import Any, Callable, List, Optional
 
-from threading import RLock as Lock, Event
-
 from ._task import Task
+from ._acompat import Lock, Event
 
 logger = logging.getLogger(__name__)
 
@@ -90,4 +89,4 @@ class Scheduler:
                     )
             else:
                 # Block for the expected timeout or until a new task scheduled
-                self._event.wait(timeout=delay)
+                self._event.wait(delay)
index fe9e443ffcadab0b3a722bb038ae04b3ecc11d60..db273dd1f7e333cedbb7dc125be02b4abc7ed70a 100644 (file)
@@ -17,13 +17,8 @@ from time import monotonic
 from heapq import heappush, heappop
 from typing import Any, Callable, List, Optional
 
-if True:  # ASYNC
-    from asyncio import Event, Lock, TimeoutError, wait_for
-else:
-    from threading import RLock as Lock, Event
-
-
 from ._task import Task
+from ._acompat import ALock, AEvent
 
 logger = logging.getLogger(__name__)
 
@@ -32,8 +27,8 @@ class AsyncScheduler:
     def __init__(self) -> None:
         """Initialize a new instance, passing the time and delay functions."""
         self._queue: List[Task] = []
-        self._lock = Lock()
-        self._event = Event()
+        self._lock = ALock()
+        self._event = AEvent()
 
     EMPTY_QUEUE_TIMEOUT = 600.0
 
@@ -91,10 +86,4 @@ class AsyncScheduler:
                     )
             else:
                 # Block for the expected timeout or until a new task scheduled
-                if True:  # ASYNC
-                    try:
-                        await wait_for(self._event.wait(), delay)
-                    except TimeoutError:
-                        pass
-                else:
-                    self._event.wait(timeout=delay)
+                await self._event.wait_timeout(delay)
index 782a7fa23e5096ebb8bab6373aeb20702f448250..dd854f0148d7e23aa5ed7829052e536522f1f8a8 100755 (executable)
@@ -153,6 +153,7 @@ class AsyncToSync(ast.NodeTransformer):
 class RenameAsyncToSync(ast.NodeTransformer):
     names_map = {
         "AEvent": "Event",
+        "ALock": "Lock",
         "AsyncClientCursor": "ClientCursor",
         "AsyncConnection": "Connection",
         "AsyncConnectionPool": "ConnectionPool",