]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move non-blocking pool functionalities to a generic base class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 25 Feb 2021 21:38:33 +0000 (22:38 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
Note that running maintenance threads is "non-blocking" and will be
thread-based in the async pool.

psycopg3/psycopg3/pool/base.py
psycopg3/psycopg3/pool/pool.py
psycopg3/psycopg3/pool/tasks.py

index ad76dabc9999fbefea73f9a4a12f5de61c276e76..4d279cb8726a8ea976d0966deebb55089acef881 100644 (file)
@@ -5,6 +5,179 @@ psycopg3 connection pool base class and functionalities.
 # Copyright (C) 2021 The Psycopg Team
 
 import random
+import logging
+import threading
+from queue import Queue, Empty
+from typing import Any, Callable, Deque, Dict, Generic, List, Optional
+from collections import deque
+
+from ..proto import ConnectionType
+
+from . import tasks
+from .sched import Scheduler
+
+logger = logging.getLogger(__name__)
+
+WORKER_TIMEOUT = 60.0
+
+
+class BasePool(Generic[ConnectionType]):
+
+    # Used to generate pool names
+    _num_pool = 0
+
+    def __init__(
+        self,
+        conninfo: str = "",
+        *,
+        kwargs: Optional[Dict[str, Any]] = None,
+        configure: Optional[Callable[[ConnectionType], None]] = None,
+        minconn: int = 4,
+        maxconn: Optional[int] = None,
+        name: Optional[str] = None,
+        timeout: float = 30.0,
+        max_idle: float = 10 * 60.0,
+        reconnect_timeout: float = 5 * 60.0,
+        reconnect_failed: Optional[
+            Callable[["BasePool[ConnectionType]"], None]
+        ] = None,
+        num_workers: int = 3,
+    ):
+        if maxconn is None:
+            maxconn = minconn
+        if maxconn < minconn:
+            raise ValueError(
+                f"can't create {self.__class__.__name__}"
+                f" with maxconn={maxconn} < minconn={minconn}"
+            )
+        if not name:
+            num = BasePool._num_pool = BasePool._num_pool + 1
+            name = f"pool-{num - 1}"
+
+        if num_workers < 1:
+            raise ValueError("num_workers must be at least 1")
+
+        self.conninfo = conninfo
+        self.kwargs: Dict[str, Any] = kwargs or {}
+        self._configure: Callable[[ConnectionType], None]
+        self._configure = configure or (lambda conn: None)
+        self._reconnect_failed: Callable[["BasePool[ConnectionType]"], None]
+        self._reconnect_failed = reconnect_failed or (lambda pool: None)
+        self.name = name
+        self.minconn = minconn
+        self.maxconn = maxconn
+        self.timeout = timeout
+        self.reconnect_timeout = reconnect_timeout
+        self.max_idle = max_idle
+        self.num_workers = num_workers
+
+        self._nconns = minconn  # currently in the pool, out, being prepared
+        self._pool: Deque[ConnectionType] = deque()
+        self._sched = Scheduler()
+
+        # Min number of connections in the pool in a max_idle unit of time.
+        # It is reset periodically by the ShrinkPool scheduled task.
+        # It is used to shrink back the pool if maxcon > minconn and extra
+        # connections have been acquired, if we notice that in the last
+        # max_idle interval they weren't all used.
+        self._nconns_min = minconn
+
+        self._tasks: "Queue[tasks.MaintenanceTask[ConnectionType]]" = Queue()
+        self._workers: List[threading.Thread] = []
+        for i in range(num_workers):
+            t = threading.Thread(
+                target=self.worker, args=(self._tasks,), daemon=True
+            )
+            self._workers.append(t)
+
+        self._sched_runner = threading.Thread(
+            target=self._sched.run, daemon=True
+        )
+
+        # _close should be the last property to be set in the state
+        # to avoid warning on __del__ in case __init__ fails.
+        self._closed = False
+
+        # The object state is complete. Start the worker threads
+        self._sched_runner.start()
+        for t in self._workers:
+            t.start()
+
+        # populate the pool with initial minconn connections in background
+        for i in range(self._nconns):
+            self.run_task(tasks.AddConnection(self))
+
+        # Schedule a task to shrink the pool if connections over minconn have
+        # remained unused. However if the pool can't grow don't bother.
+        if maxconn > minconn:
+            self.schedule_task(tasks.ShrinkPool(self), self.max_idle)
+
+    def __repr__(self) -> str:
+        return (
+            f"<{self.__class__.__module__}.{self.__class__.__name__}"
+            f" {self.name!r} at 0x{id(self):x}>"
+        )
+
+    def __del__(self) -> None:
+        # If the '_closed' property is not set we probably failed in __init__.
+        # Don't try anything complicated as probably it won't work.
+        if getattr(self, "_closed", True):
+            return
+
+        # Things we can try to do on a best-effort basis while the world
+        # is crumbling (a-la Eternal Sunshine of the Spotless Mind)
+        # At worse we put an item in a queue that is being deleted.
+
+        # Stop the scheduler
+        self._sched.enter(0, None)
+
+        # Stop the worker threads
+        for i in range(len(self._workers)):
+            self.run_task(tasks.StopWorker(self))
+
+    @property
+    def closed(self) -> bool:
+        """`!True` if the pool is closed."""
+        return self._closed
+
+    def run_task(self, task: tasks.MaintenanceTask[ConnectionType]) -> None:
+        """Run a maintenance task in a worker thread."""
+        self._tasks.put_nowait(task)
+
+    def schedule_task(
+        self, task: tasks.MaintenanceTask[ConnectionType], delay: float
+    ) -> None:
+        """Run a maintenance task in a worker thread in the future."""
+        self._sched.enter(delay, task.tick)
+
+    @classmethod
+    def worker(cls, q: "Queue[tasks.MaintenanceTask[ConnectionType]]") -> None:
+        """Runner to execute pending maintenance task.
+
+        The function is designed to run as a separate thread.
+
+        Block on the queue *q*, run a task received. Finish running if a
+        StopWorker is received.
+        """
+        # Don't make all the workers time out at the same moment
+        timeout = WORKER_TIMEOUT * (0.9 + 0.1 * random.random())
+        while True:
+            # Use a timeout to make the wait unterruptable
+            try:
+                task = q.get(timeout=timeout)
+            except Empty:
+                continue
+
+            # Run the task. Make sure don't die in the attempt.
+            try:
+                task.run()
+            except Exception as e:
+                logger.warning(
+                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
+                )
+
+            if isinstance(task, tasks.StopWorker):
+                return
 
 
 class ConnectionAttempt:
index ce6282ed36929faca657b66bfaf1a6066b511fe4..dfdd7c8dfd6428ddcd906ec4d66eb16eee63db55 100644 (file)
@@ -5,11 +5,9 @@ psycopg3 synchronous connection pool
 # Copyright (C) 2021 The Psycopg Team
 
 import time
-import random
 import logging
 import threading
-from queue import Queue, Empty
-from typing import Any, Callable, Deque, Dict, Iterator, List, Optional
+from typing import Any, Deque, Iterator, Optional
 from contextlib import contextmanager
 from collections import deque
 
@@ -17,119 +15,21 @@ from ..pq import TransactionStatus
 from ..connection import Connection
 
 from . import tasks
-from .base import ConnectionAttempt
-from .sched import Scheduler
+from .base import ConnectionAttempt, BasePool
 from .errors import PoolClosed, PoolTimeout
 
 logger = logging.getLogger(__name__)
 
-WORKER_TIMEOUT = 60.0
-
-
-class ConnectionPool:
-
-    _num_pool = 0
-
-    def __init__(
-        self,
-        conninfo: str = "",
-        kwargs: Optional[Dict[str, Any]] = None,
-        configure: Optional[Callable[[Connection], None]] = None,
-        minconn: int = 4,
-        maxconn: Optional[int] = None,
-        name: Optional[str] = None,
-        timeout: float = 30.0,
-        max_idle: float = 10 * 60.0,
-        reconnect_timeout: float = 5 * 60.0,
-        reconnect_failed: Optional[Callable[["ConnectionPool"], None]] = None,
-        num_workers: int = 3,
-    ):
-        if maxconn is None:
-            maxconn = minconn
-        if maxconn < minconn:
-            raise ValueError(
-                f"can't create {self.__class__.__name__}"
-                f" with maxconn={maxconn} < minconn={minconn}"
-            )
-        if not name:
-            self.__class__._num_pool += 1
-            name = f"pool-{self._num_pool}"
-
-        if num_workers < 1:
-            # TODO: allow num_workers to be 0 - sync pool?
-            raise ValueError("num_workers must be at least 1")
-
-        self.conninfo = conninfo
-        self.kwargs: Dict[str, Any] = kwargs or {}
-        self._configure: Callable[[Connection], None]
-        self._configure = configure or (lambda conn: None)
-        self._reconnect_failed: Callable[["ConnectionPool"], None]
-        self._reconnect_failed = reconnect_failed or (lambda pool: None)
-        self.name = name
-        self.minconn = minconn
-        self.maxconn = maxconn
-        self.timeout = timeout
-        self.reconnect_timeout = reconnect_timeout
-        self.max_idle = max_idle
-        self.num_workers = num_workers
-
-        self._nconns = minconn  # currently in the pool, out, being prepared
-        self._pool: Deque[Connection] = deque()
-        self._waiting: Deque["WaitingClient"] = deque()
-        self._lock = threading.RLock()
-        self._sched = Scheduler()
 
-        # Min number of connections in the pool in a max_idle unit of time.
-        # It is reset periodically by the ShrinkPool scheduled task.
-        # It is used to shrink back the pool if maxcon > minconn and extra
-        # connections have been acquired, if we notice that in the last
-        # max_idle interval they weren't all used.
-        self._nconns_min = minconn
+class ConnectionPool(BasePool[Connection]):
+    def __init__(self, conninfo: str = "", **kwargs: Any):
+        self._lock = threading.RLock()
+        self._waiting: Deque["WaitingClient"] = deque()
 
         # to notify that the pool is full
         self._pool_full_event: Optional[threading.Event] = None
 
-        self._tasks: "Queue[tasks.MaintenanceTask]" = Queue()
-        self._workers: List[threading.Thread] = []
-        for i in range(num_workers):
-            t = threading.Thread(
-                target=self.worker, args=(self._tasks,), daemon=True
-            )
-            self._workers.append(t)
-
-        self._sched_runner = threading.Thread(
-            target=self._sched.run, daemon=True
-        )
-
-        # _close should be the last property to be set in the state
-        # to avoid warning on __del__ in case __init__ fails.
-        self._closed = False
-
-        # The object state is complete. Start the worker threads
-        self._sched_runner.start()
-        for t in self._workers:
-            t.start()
-
-        # Populate the pool with initial minconn connections in background
-        for i in range(self._nconns):
-            self.run_task(tasks.AddConnection(self))
-
-        # Schedule a task to shrink the pool if connections over minconn have
-        # remained unused. However if the pool cannot't grow don't bother.
-        if maxconn > minconn:
-            self.schedule_task(tasks.ShrinkPool(self), self.max_idle)
-
-    def __repr__(self) -> str:
-        return (
-            f"<{self.__class__.__module__}.{self.__class__.__name__}"
-            f" {self.name!r} at 0x{id(self):x}>"
-        )
-
-    def __del__(self) -> None:
-        # If the '_closed' property is not set we probably failed in __init__.
-        # Don't try anything complicated as probably it won't work.
-        if hasattr(self, "_closed"):
-            self.close(timeout=0)
+        super().__init__(conninfo, **kwargs)
 
     def wait_ready(self, timeout: float = 30.0) -> None:
         """
@@ -253,11 +153,6 @@ class ConnectionPool:
         # Use a worker to perform eventual maintenance work in a separate thread
         self.run_task(tasks.ReturnConnection(self, conn))
 
-    @property
-    def closed(self) -> bool:
-        """`!True` if the pool is closed."""
-        return self._closed
-
     def close(self, timeout: float = 1.0) -> None:
         """Close the pool and make it unavailable to new clients.
 
@@ -311,43 +206,6 @@ class ConnectionPool:
                         timeout,
                     )
 
-    def run_task(self, task: tasks.MaintenanceTask) -> None:
-        """Run a maintenance task in a worker thread."""
-        self._tasks.put(task)
-
-    def schedule_task(self, task: tasks.MaintenanceTask, delay: float) -> None:
-        """Run a maintenance task in a worker thread in the future."""
-        self._sched.enter(delay, task.tick)
-
-    @classmethod
-    def worker(cls, q: "Queue[tasks.MaintenanceTask]") -> None:
-        """Runner to execute pending maintenance task.
-
-        The function is designed to run as a separate thread.
-
-        Block on the queue *q*, run a task received. Finish running if a
-        StopWorker is received.
-        """
-        # Don't make all the workers time out at the same moment
-        timeout = WORKER_TIMEOUT * (0.9 + 0.1 * random.random())
-        while True:
-            # Use a timeout to make the wait unterruptable
-            try:
-                task = q.get(timeout=timeout)
-            except Empty:
-                continue
-
-            # Run the task. Make sure don't die in the attempt.
-            try:
-                task.run()
-            except Exception as e:
-                logger.warning(
-                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
-                )
-
-            if isinstance(task, tasks.StopWorker):
-                return
-
     def configure(self, conn: Connection) -> None:
         """Configure a connection after creation."""
         self._configure(conn)
@@ -474,7 +332,7 @@ class ConnectionPool:
             logger.warning("closing returned connection: %s", conn)
             conn.close()
 
-    def _shrink_if_possible(self) -> None:
+    def _shrink_pool(self) -> None:
         to_close: Optional[Connection] = None
 
         with self._lock:
index 5b0690183dea047b20701231c6ef0c8cf140049a..e9fac5ddc2f5f0c089037ae25d3ad9e466452641 100644 (file)
@@ -6,21 +6,22 @@ Maintenance tasks for the connection pools.
 
 import logging
 from abc import ABC, abstractmethod
-from typing import Optional, TYPE_CHECKING
+from typing import Any, Generic, Optional, TYPE_CHECKING
 from weakref import ref
 
+from ..proto import ConnectionType
+
 if TYPE_CHECKING:
-    from .base import ConnectionAttempt
-    from .pool import ConnectionPool
+    from .base import BasePool, ConnectionAttempt
     from ..connection import Connection
 
 logger = logging.getLogger(__name__)
 
 
-class MaintenanceTask(ABC):
+class MaintenanceTask(ABC, Generic[ConnectionType]):
     """A task to run asynchronously to maintain the pool state."""
 
-    def __init__(self, pool: "ConnectionPool"):
+    def __init__(self, pool: "BasePool[Any]"):
         self.pool = ref(pool)
         logger.debug("task created: %s", self)
 
@@ -57,51 +58,66 @@ class MaintenanceTask(ABC):
         pool.run_task(self)
 
     @abstractmethod
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "BasePool[Any]") -> None:
         ...
 
 
-class StopWorker(MaintenanceTask):
+class StopWorker(MaintenanceTask[ConnectionType]):
     """Signal the maintenance thread to terminate."""
 
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "BasePool[Any]") -> None:
         pass
 
 
-class AddConnection(MaintenanceTask):
+class AddConnection(MaintenanceTask[ConnectionType]):
     def __init__(
         self,
-        pool: "ConnectionPool",
+        pool: "BasePool[Any]",
         attempt: Optional["ConnectionAttempt"] = None,
     ):
         super().__init__(pool)
         self.attempt = attempt
 
-    def _run(self, pool: "ConnectionPool") -> None:
-        pool._add_connection(self.attempt)
+    def _run(self, pool: "BasePool[Any]") -> None:
+        from . import ConnectionPool
+
+        if isinstance(pool, ConnectionPool):
+            pool._add_connection(self.attempt)
+        else:
+            assert False
 
 
-class ReturnConnection(MaintenanceTask):
+class ReturnConnection(MaintenanceTask[ConnectionType]):
     """Clean up and return a connection to the pool."""
 
-    def __init__(self, pool: "ConnectionPool", conn: "Connection"):
+    def __init__(self, pool: "BasePool[Any]", conn: "Connection"):
         super().__init__(pool)
         self.conn = conn
 
-    def _run(self, pool: "ConnectionPool") -> None:
-        pool._return_connection(self.conn)
+    def _run(self, pool: "BasePool[Any]") -> None:
+        from . import ConnectionPool
 
+        if isinstance(pool, ConnectionPool):
+            pool._return_connection(self.conn)
+        else:
+            assert False
 
-class ShrinkPool(MaintenanceTask):
+
+class ShrinkPool(MaintenanceTask[ConnectionType]):
     """If the pool can shrink, remove one connection.
 
     Re-schedule periodically and also reset the minimum number of connections
     in the pool.
     """
 
-    def _run(self, pool: "ConnectionPool") -> None:
+    def _run(self, pool: "BasePool[Any]") -> None:
         # Reschedule the task now so that in case of any error we don't lose
         # the periodic run.
         pool.schedule_task(self, pool.max_idle)
 
-        pool._shrink_if_possible()
+        from . import ConnectionPool
+
+        if isinstance(pool, ConnectionPool):
+            pool._shrink_pool()
+        else:
+            assert False