]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't lose pool connections giving them to a clients already timed out
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 22 Feb 2021 01:05:14 +0000 (02:05 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool.py
tests/test_pool.py

index df333e650c4f6590d2476250b494f101e8e77fa3..8eca7a203e54eacbd94ef4a0c07ad935ff3a2ae3 100644 (file)
@@ -142,8 +142,8 @@ class ConnectionPool:
         """Context manager to obtain a connection from the pool.
 
         Returned the connection immediately if available, otherwise wait up to
-        *timeout* or `self.timeout` and throw `PoolTimeout` if a
-        connection is available in time.
+        *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is
+        not available in time.
 
         Upon context exit, return the connection to the pool. Apply the normal
         connection context behaviour (commit/rollback the transaction in case
@@ -256,9 +256,13 @@ class ConnectionPool:
         # Critical section: if there is a client waiting give it the connection
         # otherwise put it back into the pool.
         with self._lock:
-            if self._waiting:
-                # Extract the first client from the queue
+            while self._waiting:
+                # If there is a client waiting (which is still waiting and
+                # hasn't timed out), give it the connection and notify it.
                 pos = self._waiting.popleft()
+                if pos.set(conn):
+                    break
+
             else:
                 now = time.monotonic()
 
@@ -278,10 +282,6 @@ class ConnectionPool:
                     )
                     self._nconns -= 1
 
-        # If we found a client in queue, give it the connection and notify it
-        if pos:
-            pos.set(conn)
-
         if to_close:
             to_close.close()
 
@@ -420,30 +420,64 @@ class ConnectionPool:
 class WaitingClient:
     """An position in a queue for a client waiting for a connection."""
 
-    __slots__ = ("event", "conn", "error")
+    __slots__ = ("conn", "error", "_cond")
 
     def __init__(self) -> None:
-        self.event = threading.Event()
-        self.conn: Connection
+        self.conn: Optional[Connection] = None
         self.error: Optional[Exception] = None
 
+        # The WaitingClient behaves in a way similar to an Event, but we need
+        # to notify reliably the flagger that the waiter has "accepted" the
+        # message and it hasn't timed out yet, otherwise the pool may give a
+        # connection to a client that has already timed out getconn(), which
+        # will be lost.
+        self._cond = threading.Condition(threading.Lock())
+
     def wait(self, timeout: float) -> Connection:
-        """Wait for the event to be set and return the connection."""
-        if not self.event.wait(timeout):
-            raise PoolTimeout(f"couldn't get a connection after {timeout} sec")
-        if self.error:
+        """Wait for a connection to be set and return it.
+
+        Raise an exception if the wait times out or if fail() is called.
+        """
+        with self._cond:
+            if not (self.conn or self.error):
+                if not self._cond.wait(timeout):
+                    self.error = PoolTimeout(
+                        f"couldn't get a connection after {timeout} sec"
+                    )
+
+        if self.conn:
+            return self.conn
+        else:
+            assert self.error
             raise self.error
-        return self.conn
 
-    def set(self, conn: Connection) -> None:
-        """Signal the client waiting that a connection is ready."""
-        self.conn = conn
-        self.event.set()
+    def set(self, conn: Connection) -> bool:
+        """Signal the client waiting that a connection is ready.
+
+        Return True if the client has "accepted" the connection, False
+        otherwise (typically because wait() has timed out.
+        """
+        with self._cond:
+            if self.conn or self.error:
+                return False
+
+            self.conn = conn
+            self._cond.notify_all()
+            return True
+
+    def fail(self, error: Exception) -> bool:
+        """Signal the client that, alas, they won't have a connection today.
+
+        Return True if the client has "accepted" the error, False otherwise
+        (typically because wait() has timed out.
+        """
+        with self._cond:
+            if self.conn or self.error:
+                return False
 
-    def fail(self, error: Exception) -> None:
-        """Signal the client that, alas, they won't have a connection today."""
-        self.error = error
-        self.event.set()
+            self.error = error
+            self._cond.notify_all()
+            return True
 
 
 class MaintenanceTask(ABC):
index b07bc9c5a94353d17fcfbd51eb2076d05b820eac..f80f0a0c82a99b7f35eb7e674d593694bedeea1b 100644 (file)
@@ -168,6 +168,35 @@ def test_queue_timeout(dsn):
         assert 0.1 < e[1] < 0.15
 
 
+@pytest.mark.slow
+def test_dead_client(dsn):
+    p = pool.ConnectionPool(dsn, minconn=2)
+
+    results = []
+
+    def worker(i, timeout):
+        try:
+            with p.connection(timeout=timeout) as conn:
+                conn.execute("select pg_sleep(0.3)")
+                results.append(i)
+        except pool.PoolTimeout:
+            if timeout > 0.2:
+                raise
+
+    ts = []
+    for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]):
+        t = Thread(target=worker, args=(i, timeout))
+        t.start()
+        ts.append(t)
+
+    for t in ts:
+        t.join()
+
+    sleep(0.2)
+    assert set(results) == set([0, 1, 3, 4])
+    assert len(p._pool) == 2  # no connection was lost
+
+
 @pytest.mark.slow
 def test_queue_timeout_override(dsn):
     p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1)