]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add ConnectionPool class sketch
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 13 Feb 2021 03:06:54 +0000 (04:06 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
It has a pool of connections alright, workers, clients waiting list.

Only minconn implemented, no maintenance thread etc. It's a start.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/pool.py [new file with mode: 0644]
tests/test_pool.py [new file with mode: 0644]

index 02baea3767e041717916e6a4b5335ca7a07d63c7..8fef6f82b685743afc7e6de8068776204931ec38 100644 (file)
@@ -46,6 +46,7 @@ execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
 if TYPE_CHECKING:
     from .pq.proto import PGconn, PGresult
+    from .pool import ConnectionPool
 
 if pq.__impl__ == "c":
     from psycopg3_c import _psycopg3
@@ -124,6 +125,8 @@ class BaseConnection(AdaptContext):
         pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
         pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
 
+        self._pool: Optional["ConnectionPool"] = None
+
     def __del__(self) -> None:
         # If fails on connection we might not have this attribute yet
         if not hasattr(self, "pgconn"):
@@ -462,7 +465,9 @@ class Connection(BaseConnection):
         else:
             self.commit()
 
-        self.close()
+        # Close the connection only if it doesn't belong to a pool.
+        if not self._pool:
+            self.close()
 
     def close(self) -> None:
         """Close the database connection."""
diff --git a/psycopg3/psycopg3/pool.py b/psycopg3/psycopg3/pool.py
new file mode 100644 (file)
index 0000000..3a7675a
--- /dev/null
@@ -0,0 +1,248 @@
+"""
+psycopg3 connection pool
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import random
+import logging
+import threading
+from queue import Queue, Empty
+from typing import Any, Callable, Deque, Dict, Iterator, List, Optional
+from contextlib import contextmanager
+from collections import deque
+
+from . import errors as e
+from .connection import Connection
+
+WORKER_TIMEOUT = 60.0
+
+logger = logging.getLogger(__name__)
+
+
+class PoolTimeout(e.OperationalError):
+    pass
+
+
+class ConnectionPool:
+
+    _num_pool = 0
+
+    def __init__(
+        self,
+        conninfo: str = "",
+        kwargs: Optional[Dict[str, Any]] = None,
+        configure: Optional[Callable[[Connection], None]] = None,
+        minconn: int = 4,
+        maxconn: Optional[int] = None,
+        name: Optional[str] = None,
+        timeout_sec: float = 30.0,
+        nworkers: int = 1,
+    ):
+        if maxconn is None:
+            maxconn = minconn
+        if maxconn < minconn:
+            raise ValueError(
+                f"can't create {self.__class__.__name__}"
+                f" with maxconn={maxconn} < minconn={minconn}"
+            )
+        if not name:
+            self.__class__._num_pool += 1
+            name = f"pool-{self._num_pool}"
+
+        self.conninfo = conninfo
+        self.kwargs: Dict[str, Any] = kwargs or {}
+        self._configure: Callable[[Connection], None]
+        self._configure = configure or (lambda conn: None)
+        self.name = name
+        self.minconn = minconn
+        self.maxconn = maxconn
+        self.timeout_sec = timeout_sec
+        self.nworkers = nworkers
+
+        self._nconns = 0  # currently in the pool, out, being prepared
+        self._pool: List[Connection] = []
+        self._waiting: Deque["WaitingClient"] = deque()
+        self._lock = threading.Lock()
+
+        self._wqueue: "Queue[MaintenanceTask]" = Queue()
+        self._workers: List[threading.Thread] = []
+        for i in range(nworkers):
+            t = threading.Thread(target=self.worker, args=(self._wqueue,))
+            t.daemon = True
+            t.start()
+            self._workers.append(t)
+
+        # Run a task to create the connections immediately
+        self.add_task(TopUpConnections(self))
+
+    @contextmanager
+    def connection(
+        self, timeout_sec: Optional[float] = None
+    ) -> Iterator[Connection]:
+        conn = self.getconn(timeout_sec=timeout_sec)
+        try:
+            with conn:
+                yield conn
+        finally:
+            self.putconn(conn)
+
+    def getconn(self, timeout_sec: Optional[float] = None) -> Connection:
+        # Critical section: decide here if there's a connection ready
+        # or if the client needs to wait.
+        with self._lock:
+            pos: Optional[WaitingClient] = None
+            if self._pool:
+                # Take a connection ready out of the pool
+                conn = self._pool.pop(-1)
+            else:
+                # No connection available: put the client in the waiting queue
+                pos = WaitingClient()
+                self._waiting.append(pos)
+
+        # If we are in the waiting queue, wait to be assigned a connection
+        # (outside the critical section, so only the waiting client is locked)
+        if pos:
+            if timeout_sec is None:
+                timeout_sec = self.timeout_sec
+            conn = pos.wait(timeout=timeout_sec)
+
+        # Tell the connection it belongs to a pool to avoid closing on __exit__
+        # Note that this property shouldn't be set while the connection is in
+        # the pool, to avoid to create a reference loop.
+        conn._pool = self
+        return conn
+
+    def putconn(self, conn: Connection) -> None:
+        # TODO: this should happen in a maintenance thread
+        # TODO: add check for broken connections
+
+        if conn._pool is not self:
+            if conn._pool:
+                raise ValueError(f"the connection belongs to {conn._pool}")
+            else:
+                raise ValueError("the connection doesn't belong to a pool")
+
+        # Remove the pool reference from the connection before returning it
+        # to the state, to avoid to create a reference loop.
+        conn._pool = None
+
+        # Critical section: if there is a client waiting give it the connection
+        # otherwise put it back into the pool.
+        with self._lock:
+            if self._waiting:
+                # Give the connection to the client and notify it
+                pos = self._waiting.popleft()
+                pos.set(conn)
+            else:
+                # No client waiting for a connection: put it back into the queue
+                self._pool.append(conn)
+
+    def add_task(self, task: "MaintenanceTask") -> None:
+        """Add a task to the queue of tasts to perform."""
+        self._wqueue.put(task)
+
+    @classmethod
+    def worker(cls, q: "Queue[MaintenanceTask]") -> None:
+        """Runner to execute pending maintenance task.
+
+        The function is designed to run as a separate thread.
+
+        Block on the queue *q*, run a task received. Finish running if a
+        StopWorker is received.
+        """
+        # Don't make all the workers time out at the same moment
+        timeout = WORKER_TIMEOUT * (0.9 + 0.1 * random.random())
+        while True:
+            # Use a timeout to make the wait unterruptable
+            try:
+                task = q.get(timeout=timeout)
+            except Empty:
+                continue
+
+            # Run the task. Make sure don't die in the attempt.
+            try:
+                task()
+            except Exception as e:
+                logger.warning(
+                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
+                )
+            if isinstance(task, StopWorker):
+                return
+
+    def _connect(self) -> Connection:
+        """Return a connection configured for the pool."""
+        conn = Connection.connect(self.conninfo, **self.kwargs)
+        self.configure(conn)
+        conn._pool = self
+        return conn
+
+    def configure(self, conn: Connection) -> None:
+        """Configure a connection after creation."""
+        self._configure(conn)
+
+
+class WaitingClient:
+    """An position in a queue for a client waiting for a connection."""
+
+    __slots__ = ("event", "conn")
+
+    def __init__(self) -> None:
+        self.event = threading.Event()
+        self.conn: Connection
+
+    def wait(self, timeout: float) -> Connection:
+        """Wait for the event to be set and return the connection."""
+        if not self.event.wait(timeout):
+            raise PoolTimeout(f"couldn't get a connection after {timeout} sec")
+        return self.conn
+
+    def set(self, conn: Connection) -> None:
+        self.conn = conn
+        self.event.set()
+
+
+class MaintenanceTask:
+    def __init__(self, pool: ConnectionPool):
+        self.pool = pool
+        logger.debug("task created: %s", self)
+
+    def __call__(self) -> None:
+        logger.debug("task running: %s", self)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self.pool.name})"
+
+
+class StopWorker(MaintenanceTask):
+    pass
+
+
+class TopUpConnections(MaintenanceTask):
+    def __call__(self) -> None:
+        super().__call__()
+
+        with self.pool._lock:
+            # Check if there are new connections to create. If there are
+            # update the number of connections managed immediately and in
+            # the same critical section to avoid finding more than owed
+            nconns = self.pool._nconns
+            if nconns < self.pool.minconn:
+                newconns = self.pool.minconn - nconns
+                self.pool._nconns += newconns
+            else:
+                return
+
+        # enqueue connection creations command so that might be picked in
+        # parallel if possible
+        for i in range(newconns):
+            self.pool.add_task(AddConnection(self.pool))
+
+
+class AddConnection(MaintenanceTask):
+    def __call__(self) -> None:
+        super().__call__()
+
+        conn = self.pool._connect()
+        conn._pool = self.pool  # make it acceptable
+        self.pool.putconn(conn)
diff --git a/tests/test_pool.py b/tests/test_pool.py
new file mode 100644 (file)
index 0000000..b716597
--- /dev/null
@@ -0,0 +1,15 @@
+from psycopg3 import pool
+
+
+def test_pool(dsn):
+    p = pool.ConnectionPool(dsn, minconn=2, timeout_sec=1.0)
+    with p.connection() as conn:
+        with conn.execute("select pg_backend_pid()") as cur:
+            (pid1,) = cur.fetchone()
+
+        with p.connection() as conn2:
+            with conn2.execute("select pg_backend_pid()") as cur:
+                (pid2,) = cur.fetchone()
+
+    with p.connection() as conn:
+        assert conn.pgconn.backend_pid in (pid1, pid2)