]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): generate pool sync module from async counterpart
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Oct 2023 01:40:28 +0000 (03:40 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:39 +0000 (23:45 +0200)
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

index 2185a71fb1f4734421aaddeb002a46f5ccab5d53..f96eab00e86ef9bee871cec919ea67c020884c16 100644 (file)
@@ -1,5 +1,8 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'pool_async.py'
+# DO NOT CHANGE! Change the original file instead.
 """
-psycopg synchronous connection pool
+Psycopg connection pool module.
 """
 
 # Copyright (C) 2021 The Psycopg Team
@@ -10,7 +13,7 @@ import logging
 from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
-from typing import Any, cast, Dict, Generic, Iterator, List
+from typing import Any, Iterator, cast, Dict, Generic, List
 from typing import Optional, overload, Sequence, Type, TypeVar
 from weakref import ref
 from contextlib import contextmanager
@@ -22,11 +25,12 @@ from psycopg.rows import TupleRow
 
 from .abc import CT, ConnectionCB, ConnectFailedCB
 from .base import ConnectionAttempt, BasePool
-from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Deque
 from ._acompat import Condition, Event, Lock, Queue, Worker, spawn, gather
 from ._acompat import current_thread_name
+from .sched import Scheduler
+
 
 logger = logging.getLogger("psycopg.pool")
 
@@ -85,7 +89,7 @@ class ConnectionPool(Generic[CT], BasePool):
         conninfo: str = "",
         *,
         open: bool = True,
-        connection_class: Type[CT] = cast(Type[CT], Connection[TupleRow]),
+        connection_class: Type[CT] = cast(Type[CT], Connection),
         configure: Optional[ConnectionCB[CT]] = None,
         reset: Optional[ConnectionCB[CT]] = None,
         kwargs: Optional[Dict[str, Any]] = None,
@@ -106,6 +110,8 @@ class ConnectionPool(Generic[CT], BasePool):
 
         self._reconnect_failed = reconnect_failed
 
+        # If these are asyncio objects, make sure to create them on open
+        # to attach them to the right loop.
         self._lock: Lock
         self._sched: Scheduler
         self._tasks: Queue[MaintenanceTask]
@@ -134,7 +140,7 @@ class ConnectionPool(Generic[CT], BasePool):
         )
 
         if open:
-            self.open()
+            self._open()
 
     def __del__(self) -> None:
         # If the '_closed' property is not set we probably failed in __init__.
@@ -167,10 +173,8 @@ class ConnectionPool(Generic[CT], BasePool):
 
         logger.info("waiting for pool %r initialization", self.name)
         if not self._pool_full_event.wait(timeout):
-            self.close()  # stop all the threads
-            raise PoolTimeout(
-                f"pool initialization incomplete after {timeout} sec"
-            ) from None
+            self.close()  # stop all the tasks
+            raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
 
         with self._lock:
             assert self._pool_full_event
@@ -263,12 +267,12 @@ class ConnectionPool(Generic[CT], BasePool):
             self._stats[self._REQUESTS_ERRORS] += 1
             raise TooManyRequests(
                 f"the pool {self.name!r} has already"
-                f" {len(self._waiting)} requests waiting"
+                f" {len(self._waiting)} requests waiting"
             )
         return conn
 
     def _maybe_grow_pool(self) -> None:
-        # Allow only one thread at time to grow the pool (or returning
+        # Allow only one task at time to grow the pool (or returning
         # connections might be starved).
         if self._nconns >= self._max_size or self._growing:
             return
@@ -290,7 +294,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if self._maybe_close_connection(conn):
             return
 
-        # Use a worker to perform eventual maintenance work in a separate thread
+        # Use a worker to perform eventual maintenance work in a separate task
         if self._reset:
             self.run_task(ReturnConnection(self, conn))
         else:
@@ -323,6 +327,7 @@ class ConnectionPool(Generic[CT], BasePool):
         because the pool was initialized with *open* = `!True`) but you cannot
         currently re-open a closed pool.
         """
+        # Make sure the lock is created after there is an event loop
         self._ensure_lock()
 
         with self._lock:
@@ -340,6 +345,8 @@ class ConnectionPool(Generic[CT], BasePool):
         # A lock has been most likely, but not necessarily, created in `open()`.
         self._ensure_lock()
 
+        # Create these objects now to attach them to the right loop.
+        # See #219
         self._tasks = Queue()
         self._sched = Scheduler()
 
@@ -350,6 +357,11 @@ class ConnectionPool(Generic[CT], BasePool):
         self._start_initial_tasks()
 
     def _ensure_lock(self) -> None:
+        """Make sure the pool lock is created.
+
+        In async code, also make sure that the loop is running.
+        """
+
         try:
             self._lock
         except AttributeError:
@@ -369,7 +381,7 @@ class ConnectionPool(Generic[CT], BasePool):
 
         # Schedule a task to shrink the pool if connections over min_size have
         # remained unused.
-        self.schedule_task(ShrinkPool(self), self.max_idle)
+        self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
 
     def close(self, timeout: float = 5.0) -> None:
         """Close the pool and make it unavailable to new clients.
@@ -408,9 +420,9 @@ class ConnectionPool(Generic[CT], BasePool):
         # Stop the scheduler
         self._sched.enter(0, None)
 
-        # Stop the worker threads
-        workers, self._workers = self._workers[:], []
-        for i in range(len(workers)):
+        # Stop the worker tasks
+        (workers, self._workers) = (self._workers[:], [])
+        for _ in workers:
             self.run_task(StopWorker(self))
 
         # Signal to eventual clients in the queue that business is closed.
@@ -421,9 +433,9 @@ class ConnectionPool(Generic[CT], BasePool):
         for conn in connections:
             conn.close()
 
-        # Wait for the worker threads to terminate
+        # Wait for the worker tasks to terminate
         assert self._sched_runner is not None
-        sched_runner, self._sched_runner = self._sched_runner, None
+        (sched_runner, self._sched_runner) = (self._sched_runner, None)
         gather(sched_runner, *workers, timeout=timeout)
 
     def __enter__(self: _Self) -> _Self:
@@ -440,15 +452,12 @@ class ConnectionPool(Generic[CT], BasePool):
 
     def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
         """Change the size of the pool during runtime."""
-        min_size, max_size = self._check_size(min_size, max_size)
+        (min_size, max_size) = self._check_size(min_size, max_size)
 
         ngrow = max(0, min_size - self._min_size)
 
         logger.info(
-            "resizing %r to min_size=%s max_size=%s",
-            self.name,
-            min_size,
-            max_size,
+            "resizing %r to min_size=%s max_size=%s", self.name, min_size, max_size
         )
         with self._lock:
             self._min_size = min_size
@@ -505,18 +514,18 @@ class ConnectionPool(Generic[CT], BasePool):
         self._reconnect_failed(self)
 
     def run_task(self, task: MaintenanceTask) -> None:
-        """Run a maintenance task in a worker thread."""
+        """Run a maintenance task in a worker."""
         self._tasks.put_nowait(task)
 
     def schedule_task(self, task: MaintenanceTask, delay: float) -> None:
-        """Run a maintenance task in a worker thread in the future."""
+        """Run a maintenance task in a worker in the future."""
         self._sched.enter(delay, task.tick)
 
     @classmethod
     def worker(cls, q: Queue[MaintenanceTask]) -> None:
         """Runner to execute pending maintenance task.
 
-        The function is designed to run as a separate thread.
+        The function is designed to run as a task.
 
         Block on the queue *q*, run a task received. Finish running if a
         StopWorker is received.
@@ -525,7 +534,7 @@ class ConnectionPool(Generic[CT], BasePool):
             task = q.get()
 
             if isinstance(task, StopWorker):
-                logger.debug("terminating working thread %s", current_thread_name())
+                logger.debug("terminating working task %s", current_thread_name())
                 return
 
             # Run the task. Make sure don't die in the attempt.
@@ -533,10 +542,7 @@ class ConnectionPool(Generic[CT], BasePool):
                 task.run()
             except Exception as ex:
                 logger.warning(
-                    "task run %s failed: %s: %s",
-                    task,
-                    ex.__class__.__name__,
-                    ex,
+                    "task run %s failed: %s: %s", task, ex.__class__.__name__, ex
                 )
 
     def _connect(self, timeout: Optional[float] = None) -> CT:
@@ -565,7 +571,7 @@ class ConnectionPool(Generic[CT], BasePool):
                 sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
                     f"connection left in status {sname} by configure function"
-                    f" {self._configure}: discarded"
+                    f" {self._configure}: discarded"
                 )
 
         # Set an expiry date, with some randomness to avoid mass reconnection
@@ -605,8 +611,7 @@ class ConnectionPool(Generic[CT], BasePool):
             else:
                 attempt.update_delay(now)
                 self.schedule_task(
-                    AddConnection(self, attempt, growing=growing),
-                    attempt.delay,
+                    AddConnection(self, attempt, growing=growing), attempt.delay
                 )
             return
 
@@ -672,7 +677,6 @@ class ConnectionPool(Generic[CT], BasePool):
             else:
                 # No client waiting for a connection: put it back into the pool
                 self._pool.append(conn)
-
                 # If we have been asked to wait for pool init, notify the
                 # waiter if the pool is full.
                 if self._pool_full_event and len(self._pool) >= self._min_size:
@@ -685,7 +689,6 @@ class ConnectionPool(Generic[CT], BasePool):
         status = conn.pgconn.transaction_status
         if status == TransactionStatus.IDLE:
             pass
-
         elif status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
             # Connection returned with an active transaction
             logger.warning("rolling back returned connection: %s", conn)
@@ -699,7 +702,6 @@ class ConnectionPool(Generic[CT], BasePool):
                     conn,
                 )
                 conn.close()
-
         elif status == TransactionStatus.ACTIVE:
             # Connection returned during an operation. Bad... just close it.
             logger.warning("closing returned connection: %s", conn)
@@ -713,7 +715,7 @@ class ConnectionPool(Generic[CT], BasePool):
                     sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
                         f"connection left in status {sname} by reset function"
-                        f" {self._reset}: discarded"
+                        f" {self._reset}: discarded"
                     )
             except Exception as ex:
                 logger.warning(f"error resetting connection: {ex}")
@@ -736,7 +738,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if to_close:
             logger.info(
                 "shrinking pool %r to %s because %s unused connections"
-                " in the last %s sec",
+                " in the last %s sec",
                 self.name,
                 self._nconns,
                 nconns_min,
@@ -830,7 +832,7 @@ class MaintenanceTask(ABC):
     def run(self) -> None:
         """Run the task.
 
-        This usually happens in a worker thread. Call the concrete _run()
+        This usually happens in a worker. Call the concrete _run()
         implementation, if the pool is still alive.
         """
         pool = self.pool()
@@ -845,7 +847,7 @@ class MaintenanceTask(ABC):
     def tick(self) -> None:
         """Run the scheduled task
 
-        This function is called by the scheduler thread. Use a worker to
+        This function is called by the scheduler task. Use a worker to
         run the task for real in order to free the scheduler immediately.
         """
         pool = self.pool()
@@ -862,7 +864,7 @@ class MaintenanceTask(ABC):
 
 
 class StopWorker(MaintenanceTask):
-    """Signal the maintenance thread to terminate."""
+    """Signal the maintenance worker to terminate."""
 
     def _run(self, pool: ConnectionPool[Any]) -> None:
         pass
@@ -906,3 +908,20 @@ class ShrinkPool(MaintenanceTask):
         # the periodic run.
         pool.schedule_task(self, pool.max_idle)
         pool._shrink_pool()
+
+
+class Schedule(MaintenanceTask):
+    """Schedule a task in the pool scheduler.
+
+    This task is a trampoline to allow to use a sync call (pool.run_task)
+    to execute an async one (pool.schedule_task). It is pretty much no-op
+    in sync code.
+    """
+
+    def __init__(self, pool: ConnectionPool[Any], task: MaintenanceTask, delay: float):
+        super().__init__(pool)
+        self.task = task
+        self.delay = delay
+
+    def _run(self, pool: ConnectionPool[Any]) -> None:
+        pool.schedule_task(self.task, self.delay)
index c283196fcc9b7ee973121310adc020aa90653d60..aa34df36bb7568f97e270b1c9b78961d0ffb8fc9 100644 (file)
@@ -1,18 +1,17 @@
 """
-psycopg asynchronous connection pool
+Psycopg connection pool module.
 """
 
 # Copyright (C) 2021 The Psycopg Team
 
 from __future__ import annotations
 
-import asyncio
 import logging
 from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
-from typing import Any, AsyncIterator, cast, Generic
-from typing import Dict, List, Optional, overload, Sequence, Type, TypeVar
+from typing import Any, AsyncIterator, cast, Dict, Generic, List
+from typing import Optional, overload, Sequence, Type, TypeVar
 from weakref import ref
 from contextlib import asynccontextmanager
 
@@ -29,6 +28,9 @@ from ._acompat import ACondition, AEvent, ALock, AQueue, AWorker, aspawn, agathe
 from ._acompat import current_task_name
 from .sched_async import AsyncScheduler
 
+if True:  # ASYNC
+    import asyncio
+
 logger = logging.getLogger("psycopg.pool")
 
 
@@ -107,12 +109,13 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         self._reconnect_failed = reconnect_failed
 
-        # asyncio objects, created on open to attach them to the right loop.
+        # If these are asyncio objects, make sure to create them on open
+        # to attach them to the right loop.
         self._lock: ALock
         self._sched: AsyncScheduler
         self._tasks: AQueue[MaintenanceTask]
 
-        self._waiting = Deque[AsyncClient[ACT]]()
+        self._waiting = Deque[WaitingClient[ACT]]()
 
         # to notify that the pool is full
         self._pool_full_event: Optional[AEvent] = None
@@ -138,6 +141,16 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if open:
             self._open()
 
+    if False:  # ASYNC
+
+        def __del__(self) -> None:
+            # If the '_closed' property is not set we probably failed in __init__.
+            # Don't try anything complicated as probably it won't work.
+            if getattr(self, "_closed", True):
+                return
+
+            self._stop_workers()
+
     async def wait(self, timeout: float = 30.0) -> None:
         """
         Wait for the pool to be full (with `min_size` connections) after creation.
@@ -215,7 +228,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             if not conn:
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
-                pos: AsyncClient[ACT] = AsyncClient()
+                pos: WaitingClient[ACT] = WaitingClient()
                 self._waiting.append(pos)
                 self._stats[self._REQUESTS_QUEUED] += 1
 
@@ -255,7 +268,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             self._stats[self._REQUESTS_ERRORS] += 1
             raise TooManyRequests(
                 f"the pool {self.name!r} has already"
-                f" {len(self._waiting)} requests waiting"
+                f" {len(self._waiting)} requests waiting"
             )
         return conn
 
@@ -349,9 +362,9 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         In async code, also make sure that the loop is running.
         """
-
-        # Throw a RuntimeError if the pool is open outside a running loop.
-        asyncio.get_running_loop()
+        if True:  # ASYNC
+            # Throw a RuntimeError if the pool is open outside a running loop.
+            asyncio.get_running_loop()
 
         try:
             self._lock
@@ -360,6 +373,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
     def _start_workers(self) -> None:
         self._sched_runner = aspawn(self._sched.run, name=f"{self.name}-scheduler")
+        assert not self._workers
         for i in range(self.num_workers):
             t = aspawn(self.worker, args=(self._tasks,), name=f"{self.name}-worker-{i}")
             self._workers.append(t)
@@ -403,7 +417,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
     async def _stop_workers(
         self,
-        waiting_clients: Sequence[AsyncClient[ACT]] = (),
+        waiting_clients: Sequence[WaitingClient[ACT]] = (),
         connections: Sequence[ACT] = (),
         timeout: float | None = None,
     ) -> None:
@@ -412,7 +426,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         # Stop the worker tasks
         workers, self._workers = self._workers[:], []
-        for w in workers:
+        for _ in workers:
             self.run_task(StopWorker(self))
 
         # Signal to eventual clients in the queue that business is closed.
@@ -447,10 +461,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         ngrow = max(0, min_size - self._min_size)
 
         logger.info(
-            "resizing %r to min_size=%s max_size=%s",
-            self.name,
-            min_size,
-            max_size,
+            "resizing %r to min_size=%s max_size=%s", self.name, min_size, max_size
         )
         async with self._lock:
             self._min_size = min_size
@@ -504,8 +515,11 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if not self._reconnect_failed:
             return
 
-        if asyncio.iscoroutinefunction(self._reconnect_failed):
-            await self._reconnect_failed(self)
+        if True:  # ASYNC
+            if asyncio.iscoroutinefunction(self._reconnect_failed):
+                await self._reconnect_failed(self)
+            else:
+                self._reconnect_failed(self)
         else:
             self._reconnect_failed(self)
 
@@ -538,10 +552,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                 await task.run()
             except Exception as ex:
                 logger.warning(
-                    "task run %s failed: %s: %s",
-                    task,
-                    ex.__class__.__name__,
-                    ex,
+                    "task run %s failed: %s: %s", task, ex.__class__.__name__, ex
                 )
 
     async def _connect(self, timeout: Optional[float] = None) -> ACT:
@@ -572,7 +583,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                 sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
                     f"connection left in status {sname} by configure function"
-                    f" {self._configure}: discarded"
+                    f" {self._configure}: discarded"
                 )
 
         # Set an expiry date, with some randomness to avoid mass reconnection
@@ -612,8 +623,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             else:
                 attempt.update_delay(now)
                 await self.schedule_task(
-                    AddConnection(self, attempt, growing=growing),
-                    attempt.delay,
+                    AddConnection(self, attempt, growing=growing), attempt.delay
                 )
             return
 
@@ -720,7 +730,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
                     sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
                         f"connection left in status {sname} by reset function"
-                        f" {self._reset}: discarded"
+                        f" {self._reset}: discarded"
                     )
             except Exception as ex:
                 logger.warning(f"error resetting connection: {ex}")
@@ -743,7 +753,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if to_close:
             logger.info(
                 "shrinking pool %r to %s because %s unused connections"
-                " in the last %s sec",
+                " in the last %s sec",
                 self.name,
                 self._nconns,
                 nconns_min,
@@ -757,7 +767,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         return rv
 
 
-class AsyncClient(Generic[ACT]):
+class WaitingClient(Generic[ACT]):
     """A position in a queue for a client waiting for a connection."""
 
     __slots__ = ("conn", "error", "_cond")
@@ -766,7 +776,7 @@ class AsyncClient(Generic[ACT]):
         self.conn: Optional[ACT] = None
         self.error: Optional[BaseException] = None
 
-        # The AsyncClient behaves in a way similar to an Event, but we need
+        # The WaitingClient behaves in a way similar to an Event, but we need
         # to notify reliably the flagger that the waiter has "accepted" the
         # message and it hasn't timed out yet, otherwise the pool may give a
         # connection to a client that has already timed out getconn(), which
@@ -919,7 +929,8 @@ class Schedule(MaintenanceTask):
     """Schedule a task in the pool scheduler.
 
     This task is a trampoline to allow to use a sync call (pool.run_task)
-    to execute an async one (pool.schedule_task).
+    to execute an async one (pool.schedule_task). It is pretty much no-op
+    in sync code.
     """
 
     def __init__(
index dd854f0148d7e23aa5ed7829052e536522f1f8a8..f0bb56bab1db7f819abb5fd24fbf0831c4212fbe 100755 (executable)
@@ -152,10 +152,16 @@ class AsyncToSync(ast.NodeTransformer):
 
 class RenameAsyncToSync(ast.NodeTransformer):
     names_map = {
+        "ACT": "CT",
+        "ACondition": "Condition",
         "AEvent": "Event",
         "ALock": "Lock",
+        "AQueue": "Queue",
+        "AWorker": "Worker",
         "AsyncClientCursor": "ClientCursor",
+        "AsyncConnectFailedCB": "ConnectFailedCB",
         "AsyncConnection": "Connection",
+        "AsyncConnectionCB": "ConnectionCB",
         "AsyncConnectionPool": "ConnectionPool",
         "AsyncCopy": "Copy",
         "AsyncCopyWriter": "CopyWriter",
@@ -181,17 +187,21 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "acommands": "commands",
         "aconn": "conn",
         "aconn_cls": "conn_cls",
+        "agather": "gather",
         "alist": "list",
         "anext": "next",
         "apipeline": "pipeline",
         "asleep": "sleep",
+        "aspawn": "spawn",
         "asynccontextmanager": "contextmanager",
         "connection_async": "connection",
+        "current_task_name": "current_thread_name",
         "cursor_async": "cursor",
         "ensure_table_async": "ensure_table",
         "find_insert_problem_async": "find_insert_problem",
         "psycopg_pool.pool_async": "psycopg_pool.pool",
         "psycopg_pool.sched_async": "psycopg_pool.sched",
+        "sched_async": "sched",
         "test_pool_common_async": "test_pool_common",
         "wait_async": "wait",
         "wait_conn_async": "wait_conn",
index 2017d21cafcf7b60fe1c4e450c555a72ac2963b3..e7976d3d2ec29f33f6170a201322a3e1e44af1d3 100755 (executable)
@@ -20,6 +20,7 @@ outputs=""
 for async in \
     psycopg/psycopg/connection_async.py \
     psycopg/psycopg/cursor_async.py \
+    psycopg_pool/psycopg_pool/pool_async.py \
     psycopg_pool/psycopg_pool/sched_async.py \
     tests/pool/test_pool_async.py \
     tests/pool/test_pool_common_async.py \