From: Daniele Varrazzo Date: Mon, 22 Feb 2021 01:05:14 +0000 (+0100) Subject: Don't lose pool connections giving them to a clients already timed out X-Git-Tag: 3.0.dev0~87^2~52 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a4fac57efc1d1a5768760fb9609682fd3cd85ee;p=thirdparty%2Fpsycopg.git Don't lose pool connections giving them to a clients already timed out --- diff --git a/psycopg3/psycopg3/pool.py b/psycopg3/psycopg3/pool.py index df333e650..8eca7a203 100644 --- a/psycopg3/psycopg3/pool.py +++ b/psycopg3/psycopg3/pool.py @@ -142,8 +142,8 @@ class ConnectionPool: """Context manager to obtain a connection from the pool. Returned the connection immediately if available, otherwise wait up to - *timeout* or `self.timeout` and throw `PoolTimeout` if a - connection is available in time. + *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is + not available in time. Upon context exit, return the connection to the pool. Apply the normal connection context behaviour (commit/rollback the transaction in case @@ -256,9 +256,13 @@ class ConnectionPool: # Critical section: if there is a client waiting give it the connection # otherwise put it back into the pool. with self._lock: - if self._waiting: - # Extract the first client from the queue + while self._waiting: + # If there is a client waiting (which is still waiting and + # hasn't timed out), give it the connection and notify it. pos = self._waiting.popleft() + if pos.set(conn): + break + else: now = time.monotonic() @@ -278,10 +282,6 @@ class ConnectionPool: ) self._nconns -= 1 - # If we found a client in queue, give it the connection and notify it - if pos: - pos.set(conn) - if to_close: to_close.close() @@ -420,30 +420,64 @@ class ConnectionPool: class WaitingClient: """An position in a queue for a client waiting for a connection.""" - __slots__ = ("event", "conn", "error") + __slots__ = ("conn", "error", "_cond") def __init__(self) -> None: - self.event = threading.Event() - self.conn: Connection + self.conn: Optional[Connection] = None self.error: Optional[Exception] = None + # The WaitingClient behaves in a way similar to an Event, but we need + # to notify reliably the flagger that the waiter has "accepted" the + # message and it hasn't timed out yet, otherwise the pool may give a + # connection to a client that has already timed out getconn(), which + # will be lost. + self._cond = threading.Condition(threading.Lock()) + def wait(self, timeout: float) -> Connection: - """Wait for the event to be set and return the connection.""" - if not self.event.wait(timeout): - raise PoolTimeout(f"couldn't get a connection after {timeout} sec") - if self.error: + """Wait for a connection to be set and return it. + + Raise an exception if the wait times out or if fail() is called. + """ + with self._cond: + if not (self.conn or self.error): + if not self._cond.wait(timeout): + self.error = PoolTimeout( + f"couldn't get a connection after {timeout} sec" + ) + + if self.conn: + return self.conn + else: + assert self.error raise self.error - return self.conn - def set(self, conn: Connection) -> None: - """Signal the client waiting that a connection is ready.""" - self.conn = conn - self.event.set() + def set(self, conn: Connection) -> bool: + """Signal the client waiting that a connection is ready. + + Return True if the client has "accepted" the connection, False + otherwise (typically because wait() has timed out. + """ + with self._cond: + if self.conn or self.error: + return False + + self.conn = conn + self._cond.notify_all() + return True + + def fail(self, error: Exception) -> bool: + """Signal the client that, alas, they won't have a connection today. + + Return True if the client has "accepted" the error, False otherwise + (typically because wait() has timed out. + """ + with self._cond: + if self.conn or self.error: + return False - def fail(self, error: Exception) -> None: - """Signal the client that, alas, they won't have a connection today.""" - self.error = error - self.event.set() + self.error = error + self._cond.notify_all() + return True class MaintenanceTask(ABC): diff --git a/tests/test_pool.py b/tests/test_pool.py index b07bc9c5a..f80f0a0c8 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -168,6 +168,35 @@ def test_queue_timeout(dsn): assert 0.1 < e[1] < 0.15 +@pytest.mark.slow +def test_dead_client(dsn): + p = pool.ConnectionPool(dsn, minconn=2) + + results = [] + + def worker(i, timeout): + try: + with p.connection(timeout=timeout) as conn: + conn.execute("select pg_sleep(0.3)") + results.append(i) + except pool.PoolTimeout: + if timeout > 0.2: + raise + + ts = [] + for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]): + t = Thread(target=worker, args=(i, timeout)) + t.start() + ts.append(t) + + for t in ts: + t.join() + + sleep(0.2) + assert set(results) == set([0, 1, 3, 4]) + assert len(p._pool) == 2 # no connection was lost + + @pytest.mark.slow def test_queue_timeout_override(dsn): p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1)