]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't waste a worker thread adding a connection to the pool
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 20 Feb 2021 00:47:00 +0000 (01:47 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool.py
tests/test_pool.py

index 259ce26210ca9d5943a11cd91e32dd9fd9ac37d2..e443801671331ce5427ced7d82e30c1376339b99 100644 (file)
@@ -4,12 +4,13 @@ psycopg3 connection pool
 
 # Copyright (C) 2021 The Psycopg Team
 
+import time
 import random
 import logging
 import threading
 from abc import ABC, abstractmethod
 from queue import Queue, Empty
-from typing import Any, Callable, Deque, Dict, Iterator, List, Optional
+from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Tuple
 from contextlib import contextmanager
 from collections import deque
 
@@ -71,7 +72,7 @@ class ConnectionPool:
         self.num_workers = num_workers
 
         self._nconns = 0  # currently in the pool, out, being prepared
-        self._pool: List[Connection] = []
+        self._pool: List[Tuple[Connection, float]] = []
         self._waiting: Deque["WaitingClient"] = deque()
         self._lock = threading.Lock()
         self._closed = False
@@ -125,6 +126,7 @@ class ConnectionPool:
         failing to do so will deplete the pool. A depleted pool is a sad pool:
         you don't want a depleted pool.
         """
+        logger.debug("connection requested to %r", self.name)
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         with self._lock:
@@ -142,7 +144,7 @@ class ConnectionPool:
 
                 # If there is space for the pool to grow, let's do it
                 if self._nconns < self.maxconn:
-                    logger.debug("growing pool %s", self.name)
+                    logger.debug("growing pool %r", self.name)
                     self._nconns += 1
                     self.add_task(AddConnection(self))
 
@@ -157,6 +159,7 @@ class ConnectionPool:
         # Note that this property shouldn't be set while the connection is in
         # the pool, to avoid to create a reference loop.
         conn._pool = self
+        logger.debug("connection given by %r", self.name)
         return conn
 
     def putconn(self, conn: Connection) -> None:
@@ -176,6 +179,8 @@ class ConnectionPool:
                 f"can't return connection to pool {self.name!r}, {msg}: {conn}"
             )
 
+        logger.debug("returning connection to %r", self.name)
+
         # If the pool is closed just close the connection instead of returning
         # it to the poo. For extra refcare remove the pool reference from it.
         if self._closed:
@@ -186,7 +191,12 @@ class ConnectionPool:
         # Use a worker to perform eventual maintenance work in a separate thread
         self.add_task(ReturnConnection(self, conn))
 
-    def _return_connection(self, conn: Connection) -> None:
+    def _add_to_pool(self, conn: Connection) -> None:
+        """
+        Add a connection to the pool.
+
+        The connection can be a fresh one or one already used in the pool.
+        """
         # Remove the pool reference from the connection before returning it
         # to the state, to avoid to create a reference loop.
         # Also disable the warning for open connection in conn.__del__
@@ -208,7 +218,7 @@ class ConnectionPool:
                 pos = self._waiting.popleft()
             else:
                 # No client waiting for a connection: put it back into the pool
-                self._pool.append(conn)
+                self._pool.append((conn, time.time()))
 
         # If we found a client in queue, give it the connection and notify it
         if pos:
@@ -266,7 +276,7 @@ class ConnectionPool:
 
         # Close the connections still in the pool
         while self._pool:
-            conn = self._pool.pop(-1)
+            conn = self._pool.pop(-1)[0]
             conn.close()
 
         # Stop the worker threads
@@ -404,8 +414,7 @@ class AddConnection(MaintenanceTask):
 
     def _run(self) -> None:
         conn = self.pool._connect()
-        conn._pool = self.pool  # make it accepted by putconn
-        self.pool.putconn(conn)
+        self.pool._add_to_pool(conn)
 
 
 class ReturnConnection(MaintenanceTask):
@@ -416,4 +425,4 @@ class ReturnConnection(MaintenanceTask):
         self.conn = conn
 
     def _run(self) -> None:
-        self.pool._return_connection(self.conn)
+        self.pool._add_to_pool(self.conn)
index db50f85f9234c285c923ec94a0f8216f211bf0b0..8a1bc106e3656c35aa4da0a813fc75d12ba09738 100644 (file)
@@ -50,6 +50,18 @@ def test_connection_not_lost(dsn):
         assert conn2.pgconn.backend_pid == pid
 
 
+@pytest.mark.slow
+def test_concurrent_filling(dsn, monkeypatch):
+    delay_connection(monkeypatch, 0.1)
+    t0 = time()
+    p = pool.ConnectionPool(dsn, minconn=5, num_workers=2)
+    wait_pool_full(p)
+    times = [item[1] - t0 for item in p._pool]
+    want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
+    for got, want in zip(times, want_times):
+        assert got == pytest.approx(want, 0.1), times
+
+
 @pytest.mark.slow
 def test_queue(dsn):
     p = pool.ConnectionPool(dsn, minconn=2)
@@ -73,9 +85,11 @@ def test_queue(dsn):
     for t in ts:
         t.join()
 
-    assert len([r for r in results if 0.2 < r[1] < 0.35]) == 2
-    assert len([r for r in results if 0.4 < r[1] < 0.55]) == 2
-    assert len([r for r in results if 0.5 < r[1] < 0.75]) == 2
+    times = [item[1] for item in results]
+    want_times = [0.2, 0.2, 0.4, 0.4, 0.6, 0.6]
+    for got, want in zip(times, want_times):
+        assert got == pytest.approx(want, 0.15), times
+
     assert len(set(r[2] for r in results)) == 2
 
 
@@ -360,26 +374,26 @@ def test_grow(dsn, monkeypatch):
     for t in ts:
         t.join()
 
-    deltas = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
-    for (_, got), want in zip(results, deltas):
-        assert got == pytest.approx(want, 0.1)
+    want_times = [0.2, 0.2, 0.3, 0.3, 0.4, 0.4]
+    times = [item[1] for item in results]
+    for got, want in zip(times, want_times):
+        assert got == pytest.approx(want, 0.15), times
 
 
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds
     """
-    connect_gen_orig = psycopg3.Connection._connect_gen
+    connect_orig = psycopg3.Connection.connect
 
-    def connect_gen_delayed(*args, **kwargs):
-        psycopg3.pool.logger.debug("delaying connection")
-        sleep(sec)
-        rv = yield from connect_gen_orig(*args, **kwargs)
+    def connect_delay(*args, **kwargs):
+        t0 = time()
+        rv = connect_orig(*args, **kwargs)
+        t1 = time()
+        sleep(sec - (t1 - t0))
         return rv
 
-    monkeypatch.setattr(
-        psycopg3.Connection, "_connect_gen", connect_gen_delayed
-    )
+    monkeypatch.setattr(psycopg3.Connection, "connect", connect_delay)
 
 
 def wait_pool_full(pool):