]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add connection pool close()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 14 Feb 2021 01:22:20 +0000 (02:22 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
When the pool is closed, raise an exception in the thread of the clients
already waiting and refuse new requests. Let the current request finish
anyway.

psycopg3/psycopg3/pool.py
tests/test_pool.py

index 769b008ddba6342c80352b6e9f473e460c87fd9f..5755145ddb11792818996797261073b554a97b5f 100644 (file)
@@ -25,6 +25,10 @@ class PoolTimeout(e.OperationalError):
     pass
 
 
+class PoolClosed(e.OperationalError):
+    pass
+
+
 class ConnectionPool:
 
     _num_pool = 0
@@ -65,6 +69,7 @@ class ConnectionPool:
         self._pool: List[Connection] = []
         self._waiting: Deque["WaitingClient"] = deque()
         self._lock = threading.Lock()
+        self._closed = False
 
         self._wqueue: "Queue[MaintenanceTask]" = Queue()
         self._workers: List[threading.Thread] = []
@@ -98,6 +103,9 @@ class ConnectionPool:
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         with self._lock:
+            if self._closed:
+                raise PoolClosed(f"the pool {self.name!r} is closed")
+
             pos: Optional[WaitingClient] = None
             if self._pool:
                 # Take a connection ready out of the pool
@@ -121,6 +129,7 @@ class ConnectionPool:
         return conn
 
     def putconn(self, conn: Connection) -> None:
+        # Quick check to discard the wrong connection
         pool = getattr(conn, "_pool", None)
         if pool is not self:
             if pool:
@@ -131,6 +140,13 @@ class ConnectionPool:
                 f"can't return connection to pool {self.name!r}, {msg}: {conn}"
             )
 
+        # If the pool is closed just close the connection instead of returning
+        # it to the poo. For extra refcare remove the pool reference from it.
+        if self._closed:
+            conn._pool = None
+            conn.close()
+            return
+
         # Use a worker to perform eventual maintenance work in a separate thread
         self.add_task(ReturnConnection(self, conn))
 
@@ -189,6 +205,32 @@ class ConnectionPool:
             logger.warning("closing returned connection: %s", conn)
             conn.close()
 
+    @property
+    def closed(self) -> bool:
+        return self._closed
+
+    def close(self) -> None:
+        """Close the pool connections and make it unavailable to new clients."""
+        with self._lock:
+            self._closed = True
+
+        # Now that the flag _closed is set, getconn will fail immediately,
+        # putconn will just close the returned connection.
+
+        # Signal to eventual clients in the queue that business is closed.
+        while self._waiting:
+            pos = self._waiting.popleft()
+            pos.fail(PoolClosed(f"the pool {self.name!r} is closed"))
+
+        # Close the connections still in the pool
+        while self._pool:
+            conn = self._pool.pop(-1)
+            conn.close()
+
+        # Stop the worker threads
+        for i in range(len(self._workers)):
+            self.add_task(StopWorker(self))
+
     def add_task(self, task: "MaintenanceTask") -> None:
         """Add a task to the queue of tasts to perform."""
         self._wqueue.put(task)
@@ -240,22 +282,31 @@ class ConnectionPool:
 class WaitingClient:
     """An position in a queue for a client waiting for a connection."""
 
-    __slots__ = ("event", "conn")
+    __slots__ = ("event", "conn", "error")
 
     def __init__(self) -> None:
         self.event = threading.Event()
         self.conn: Connection
+        self.error: Optional[Exception] = None
 
     def wait(self, timeout: float) -> Connection:
         """Wait for the event to be set and return the connection."""
         if not self.event.wait(timeout):
             raise PoolTimeout(f"couldn't get a connection after {timeout} sec")
+        if self.error:
+            raise self.error
         return self.conn
 
     def set(self, conn: Connection) -> None:
+        """Signal the client waiting that a connection is ready."""
         self.conn = conn
         self.event.set()
 
+    def fail(self, error: Exception) -> None:
+        """Signal the client waiting that, alas, they won't have a connection."""
+        self.error = error
+        self.event.set()
+
 
 class MaintenanceTask:
     def __init__(self, pool: ConnectionPool):
index c4ec9e388487c537fee3c18d377098d1caad0636..a7feed78ec6e0a3bfd3e28e88afc8e437209c42e 100644 (file)
@@ -281,3 +281,58 @@ def test_del_no_warning(dsn, recwarn):
     del p
     assert not ref()
     assert not recwarn
+
+
+def test_closed_getconn(dsn):
+    p = pool.ConnectionPool(dsn, minconn=1)
+    assert not p.closed
+    with p.connection():
+        pass
+
+    p.close()
+    assert p.closed
+
+    with pytest.raises(pool.PoolClosed):
+        with p.connection():
+            pass
+
+
+def test_closed_putconn(dsn):
+    p = pool.ConnectionPool(dsn, minconn=1)
+
+    with p.connection() as conn:
+        pass
+    assert not conn.closed
+
+    with p.connection() as conn:
+        p.close()
+    assert conn.closed
+
+
+@pytest.mark.slow
+def test_closed_queue(dsn):
+    p = pool.ConnectionPool(dsn, minconn=1)
+    success = []
+
+    def w1():
+        with p.connection() as conn:
+            assert (
+                conn.execute("select 1 from pg_sleep(0.2)").fetchone()[0] == 1
+            )
+        success.append("w1")
+
+    def w2():
+        with pytest.raises(pool.PoolClosed):
+            with p.connection():
+                pass
+        success.append("w2")
+
+    t1 = Thread(target=w1)
+    t2 = Thread(target=w2)
+    t1.start()
+    sleep(0.1)
+    t2.start()
+    p.close()
+    t1.join()
+    t2.join()
+    assert len(success) == 2