"""Context manager to obtain a connection from the pool.
Returned the connection immediately if available, otherwise wait up to
- *timeout* or `self.timeout` and throw `PoolTimeout` if a
- connection is available in time.
+ *timeout* or `self.timeout` and throw `PoolTimeout` if a connection is
+ not available in time.
Upon context exit, return the connection to the pool. Apply the normal
connection context behaviour (commit/rollback the transaction in case
# Critical section: if there is a client waiting give it the connection
# otherwise put it back into the pool.
with self._lock:
- if self._waiting:
- # Extract the first client from the queue
+ while self._waiting:
+ # If there is a client waiting (which is still waiting and
+ # hasn't timed out), give it the connection and notify it.
pos = self._waiting.popleft()
+ if pos.set(conn):
+ break
+
else:
now = time.monotonic()
)
self._nconns -= 1
- # If we found a client in queue, give it the connection and notify it
- if pos:
- pos.set(conn)
-
if to_close:
to_close.close()
class WaitingClient:
"""An position in a queue for a client waiting for a connection."""
- __slots__ = ("event", "conn", "error")
+ __slots__ = ("conn", "error", "_cond")
def __init__(self) -> None:
- self.event = threading.Event()
- self.conn: Connection
+ self.conn: Optional[Connection] = None
self.error: Optional[Exception] = None
+ # The WaitingClient behaves in a way similar to an Event, but we need
+ # to notify reliably the flagger that the waiter has "accepted" the
+ # message and it hasn't timed out yet, otherwise the pool may give a
+ # connection to a client that has already timed out getconn(), which
+ # will be lost.
+ self._cond = threading.Condition(threading.Lock())
+
def wait(self, timeout: float) -> Connection:
- """Wait for the event to be set and return the connection."""
- if not self.event.wait(timeout):
- raise PoolTimeout(f"couldn't get a connection after {timeout} sec")
- if self.error:
+ """Wait for a connection to be set and return it.
+
+ Raise an exception if the wait times out or if fail() is called.
+ """
+ with self._cond:
+ if not (self.conn or self.error):
+ if not self._cond.wait(timeout):
+ self.error = PoolTimeout(
+ f"couldn't get a connection after {timeout} sec"
+ )
+
+ if self.conn:
+ return self.conn
+ else:
+ assert self.error
raise self.error
- return self.conn
- def set(self, conn: Connection) -> None:
- """Signal the client waiting that a connection is ready."""
- self.conn = conn
- self.event.set()
+ def set(self, conn: Connection) -> bool:
+ """Signal the client waiting that a connection is ready.
+
+ Return True if the client has "accepted" the connection, False
+ otherwise (typically because wait() has timed out.
+ """
+ with self._cond:
+ if self.conn or self.error:
+ return False
+
+ self.conn = conn
+ self._cond.notify_all()
+ return True
+
+ def fail(self, error: Exception) -> bool:
+ """Signal the client that, alas, they won't have a connection today.
+
+ Return True if the client has "accepted" the error, False otherwise
+ (typically because wait() has timed out.
+ """
+ with self._cond:
+ if self.conn or self.error:
+ return False
- def fail(self, error: Exception) -> None:
- """Signal the client that, alas, they won't have a connection today."""
- self.error = error
- self.event.set()
+ self.error = error
+ self._cond.notify_all()
+ return True
class MaintenanceTask(ABC):
assert 0.1 < e[1] < 0.15
+@pytest.mark.slow
+def test_dead_client(dsn):
+ p = pool.ConnectionPool(dsn, minconn=2)
+
+ results = []
+
+ def worker(i, timeout):
+ try:
+ with p.connection(timeout=timeout) as conn:
+ conn.execute("select pg_sleep(0.3)")
+ results.append(i)
+ except pool.PoolTimeout:
+ if timeout > 0.2:
+ raise
+
+ ts = []
+ for i, timeout in enumerate([0.4, 0.4, 0.1, 0.4, 0.4]):
+ t = Thread(target=worker, args=(i, timeout))
+ t.start()
+ ts.append(t)
+
+ for t in ts:
+ t.join()
+
+ sleep(0.2)
+ assert set(results) == set([0, 1, 3, 4])
+ assert len(p._pool) == 2 # no connection was lost
+
+
@pytest.mark.slow
def test_queue_timeout_override(dsn):
p = pool.ConnectionPool(dsn, minconn=2, timeout=0.1)