]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move task implementations to the pool object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 25 Feb 2021 19:11:56 +0000 (20:11 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/pool.py
tests/pool/test_pool.py

index 6d43a009cd6bbeaa6e2ff82b2e5c03a090e46561..c544c0662875ab7ec897fb854c60a11b684f9b94 100644 (file)
@@ -95,11 +95,11 @@ class ConnectionPool:
         # max_idle interval they weren't all used.
         self._nconns_min = minconn
 
-        self._wqueue: "Queue[MaintenanceTask]" = Queue()
+        self._tasks: "Queue[MaintenanceTask]" = Queue()
         self._workers: List[threading.Thread] = []
         for i in range(num_workers):
             t = threading.Thread(
-                target=self.worker, args=(self._wqueue,), daemon=True
+                target=self.worker, args=(self._tasks,), daemon=True
             )
             self._workers.append(t)
 
@@ -251,71 +251,6 @@ class ConnectionPool:
         # Use a worker to perform eventual maintenance work in a separate thread
         self.run_task(ReturnConnection(self, conn))
 
-    def _add_to_pool(self, conn: Connection) -> None:
-        """
-        Add a connection to the pool.
-
-        The connection can be a fresh one or one already used in the pool.
-        """
-        # Remove the pool reference from the connection before returning it
-        # to the state, to avoid to create a reference loop.
-        # Also disable the warning for open connection in conn.__del__
-        conn._pool = None
-
-        self._reset_transaction_status(conn)
-        if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
-            # Connection no more in working state: create a new one.
-            logger.warning("discarding closed connection: %s", conn)
-            self.run_task(AddConnection(self))
-            return
-
-        pos: Optional[WaitingClient] = None
-        to_close: Optional[Connection] = None
-
-        # Critical section: if there is a client waiting give it the connection
-        # otherwise put it back into the pool.
-        with self._lock:
-            while self._waiting:
-                # If there is a client waiting (which is still waiting and
-                # hasn't timed out), give it the connection and notify it.
-                pos = self._waiting.popleft()
-                if pos.set(conn):
-                    break
-
-            else:
-                # No client waiting for a connection: put it back into the pool
-                self._pool.append(conn)
-
-        if to_close:
-            to_close.close()
-
-    def _reset_transaction_status(self, conn: Connection) -> None:
-        """
-        Bring a connection to IDLE state or close it.
-        """
-        status = conn.pgconn.transaction_status
-        if status == TransactionStatus.IDLE:
-            return
-
-        if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
-            # Connection returned with an active transaction
-            logger.warning("rolling back returned connection: %s", conn)
-            try:
-                conn.rollback()
-            except Exception as e:
-                logger.warning(
-                    "rollback failed: %s: %s. Discarding connection %s",
-                    e.__class__.__name__,
-                    e,
-                    conn,
-                )
-                conn.close()
-
-        elif status == TransactionStatus.ACTIVE:
-            # Connection returned during an operation. Bad... just close it.
-            logger.warning("closing returned connection: %s", conn)
-            conn.close()
-
     @property
     def closed(self) -> bool:
         """`!True` if the pool is closed."""
@@ -376,16 +311,11 @@ class ConnectionPool:
 
     def run_task(self, task: "MaintenanceTask") -> None:
         """Run a maintenance task in a worker thread."""
-        self._wqueue.put(task)
+        self._tasks.put(task)
 
-    def schedule_task(
-        self, task: "MaintenanceTask", delay: float, absolute: bool = False
-    ) -> None:
+    def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
         """Run a maintenance task in a worker thread in the future."""
-        if absolute:
-            self._sched.enterabs(delay, task.tick)
-        else:
-            self._sched.enter(delay, task.tick)
+        self._sched.enter(delay, task.tick)
 
     @classmethod
     def worker(cls, q: "Queue[MaintenanceTask]") -> None:
@@ -416,6 +346,16 @@ class ConnectionPool:
             if isinstance(task, StopWorker):
                 return
 
+    def configure(self, conn: Connection) -> None:
+        """Configure a connection after creation."""
+        self._configure(conn)
+
+    def reconnect_failed(self) -> None:
+        """
+        Called when reconnection failed for longer than `reconnect_timeout`.
+        """
+        self._reconnect_failed(self)
+
     def _connect(self) -> Connection:
         """Return a new connection configured for the pool."""
         conn = Connection.connect(self.conninfo, **self.kwargs)
@@ -423,15 +363,149 @@ class ConnectionPool:
         conn._pool = self
         return conn
 
-    def configure(self, conn: Connection) -> None:
-        """Configure a connection after creation."""
-        self._configure(conn)
+    def _add_initial_connection(self, event: threading.Event) -> None:
+        """Create a new connection at the beginning of the pool life.
 
-    def reconnect_failed(self) -> None:
+        Trigger *event* if all the connections necessary have been added.
         """
-        Called when reconnection failed for longer than `reconnect_timeout`.
+        conn = self._connect()
+        conn._pool = None  # avoid a reference loop
+
+        with self._lock:
+            assert (
+                not self._waiting
+            ), "clients waiting in a pool being initialised"
+            self._pool.append(conn)
+            trigger_event = len(self._pool) >= self._nconns
+
+        if trigger_event:
+            event.set()
+
+    def _add_connection(self, attempt: Optional["ConnectionAttempt"]) -> None:
+        """Try to connect and add the connection to the pool.
+
+        If failed, reschedule a new attempt in the future for a few times, then
+        give up, decrease the pool connections number and call
+        `self.reconnect_failed()`.
+
         """
-        self._reconnect_failed(self)
+        now = time.monotonic()
+        if not attempt:
+            attempt = ConnectionAttempt(
+                reconnect_timeout=self.reconnect_timeout
+            )
+
+        try:
+            conn = self._connect()
+        except Exception as e:
+            logger.warning(f"error connecting in {self.name!r}: {e}")
+            if attempt.time_to_give_up(now):
+                logger.warning(
+                    "reconnection attempt in pool %r failed after %s sec",
+                    self.name,
+                    self.reconnect_timeout,
+                )
+                with self._lock:
+                    self._nconns -= 1
+                self.reconnect_failed()
+            else:
+                attempt.update_delay(now)
+                self.schedule_task(AddConnection(self, attempt), attempt.delay)
+        else:
+            self._add_to_pool(conn)
+
+    def _return_connection(self, conn: Connection) -> None:
+        """
+        Return a connection to the pool after usage.
+        """
+        self._reset_connection(conn)
+        if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
+            # Connection no more in working state: create a new one.
+            logger.warning("discarding closed connection: %s", conn)
+            self.run_task(AddConnection(self))
+        else:
+            self._add_to_pool(conn)
+
+    def _add_to_pool(self, conn: Connection) -> None:
+        """
+        Add a connection to the pool.
+
+        The connection can be a fresh one or one already used in the pool.
+
+        If a client is already waiting for a connection pass it on, otherwise
+        put it back into the pool
+        """
+        # Remove the pool reference from the connection before returning it
+        # to the state, to avoid to create a reference loop.
+        # Also disable the warning for open connection in conn.__del__
+        conn._pool = None
+
+        pos: Optional[WaitingClient] = None
+
+        # Critical section: if there is a client waiting give it the connection
+        # otherwise put it back into the pool.
+        with self._lock:
+            while self._waiting:
+                # If there is a client waiting (which is still waiting and
+                # hasn't timed out), give it the connection and notify it.
+                pos = self._waiting.popleft()
+                if pos.set(conn):
+                    break
+
+            else:
+                # No client waiting for a connection: put it back into the pool
+                self._pool.append(conn)
+
+    def _reset_connection(self, conn: Connection) -> None:
+        """
+        Bring a connection to IDLE state or close it.
+        """
+        status = conn.pgconn.transaction_status
+        if status == TransactionStatus.IDLE:
+            return
+
+        if status in (TransactionStatus.INTRANS, TransactionStatus.INERROR):
+            # Connection returned with an active transaction
+            logger.warning("rolling back returned connection: %s", conn)
+            try:
+                conn.rollback()
+            except Exception as e:
+                logger.warning(
+                    "rollback failed: %s: %s. Discarding connection %s",
+                    e.__class__.__name__,
+                    e,
+                    conn,
+                )
+                conn.close()
+
+        elif status == TransactionStatus.ACTIVE:
+            # Connection returned during an operation. Bad... just close it.
+            logger.warning("closing returned connection: %s", conn)
+            conn.close()
+
+    def _shrink_if_possible(self) -> None:
+        to_close: Optional[Connection] = None
+
+        with self._lock:
+            # Reset the min number of connections used
+            nconns_min = self._nconns_min
+            self._nconns_min = len(self._pool)
+
+            # If the pool can shrink and connections were unused, drop one
+            if self._nconns > self.minconn and nconns_min > 0:
+                to_close = self._pool.popleft()
+                self._nconns -= 1
+
+        if to_close:
+            logger.info(
+                "shrinking pool %r to %s because %s unused connections"
+                " in the last %s sec",
+                self.name,
+                self._nconns,
+                nconns_min,
+                self.max_idle,
+            )
+            to_close.close()
 
 
 class WaitingClient:
@@ -497,6 +571,38 @@ class WaitingClient:
             return True
 
 
+class ConnectionAttempt:
+    """Keep the state of a connection attempt."""
+
+    INITIAL_DELAY = 1.0
+    DELAY_JITTER = 0.1
+    DELAY_BACKOFF = 2.0
+
+    def __init__(self, *, reconnect_timeout: float):
+        self.reconnect_timeout = reconnect_timeout
+        self.delay = 0.0
+        self.give_up_at = 0.0
+
+    def update_delay(self, now: float) -> None:
+        """Calculate how long to wait for a new connection attempt"""
+        if self.delay == 0.0:
+            self.give_up_at = now + self.reconnect_timeout
+            # +/- 10% of the initial delay
+            jitter = self.INITIAL_DELAY * (
+                (2.0 * self.DELAY_JITTER * random.random()) - self.DELAY_JITTER
+            )
+            self.delay = self.INITIAL_DELAY + jitter
+        else:
+            self.delay *= self.DELAY_BACKOFF
+
+        if self.delay + now > self.give_up_at:
+            self.delay = max(0.0, self.give_up_at - now)
+
+    def time_to_give_up(self, now: float) -> bool:
+        """Return True if we are tired of trying to connect. Meh."""
+        return self.give_up_at > 0.0 and now >= self.give_up_at
+
+
 class MaintenanceTask(ABC):
     """A task to run asynchronously to maintain the pool state."""
 
@@ -559,67 +665,18 @@ class AddInitialConnection(MaintenanceTask):
         self.event = event
 
     def _run(self, pool: ConnectionPool) -> None:
-        conn = pool._connect()
-        pool._add_to_pool(conn)
-        if len(pool._pool) >= pool._nconns:
-            self.event.set()
+        pool._add_initial_connection(self.event)
 
 
 class AddConnection(MaintenanceTask):
-    INITIAL_DELAY = 1.0
-    DELAY_JITTER = 0.1
-    DELAY_BACKOFF = 2.0
-
-    def __init__(self, pool: ConnectionPool):
+    def __init__(
+        self, pool: ConnectionPool, attempt: Optional[ConnectionAttempt] = None
+    ):
         super().__init__(pool)
-        self.delay = 0.0
-        self.give_up_at = 0.0
+        self.attempt = attempt
 
     def _run(self, pool: ConnectionPool) -> None:
-        try:
-            conn = pool._connect()
-        except Exception as e:
-            logger.warning(f"error reconnecting in {pool.name!r}: {e}")
-            self._handle_error(pool)
-        else:
-            pool._add_to_pool(conn)
-
-    def _handle_error(self, pool: ConnectionPool) -> None:
-        """Called after a connection failure.
-
-        Calculate the new time for a new reconnection attempt and schedule a
-        retry in the future. If too many attempts were performed, give up, by
-        decreasing the pool connection number and calling
-        `pool.reconnect_failed()`.
-        """
-        now = time.monotonic()
-        if self.give_up_at and now >= self.give_up_at:
-            logger.warning(
-                "reconnection attempt in pool %r failed after %s sec",
-                pool.name,
-                pool.reconnect_timeout,
-            )
-            with pool._lock:
-                pool._nconns -= 1
-            pool.reconnect_failed()
-            return
-
-        # Calculate how long to wait for a new connection attempt
-        if self.delay == 0.0:
-            self.give_up_at = now + pool.reconnect_timeout
-            # +/- 10% of the initial delay
-            jitter = self.INITIAL_DELAY * (
-                (2.0 * self.DELAY_JITTER * random.random()) - self.DELAY_JITTER
-            )
-            self.delay = self.INITIAL_DELAY + jitter
-        else:
-            self.delay *= self.DELAY_BACKOFF
-
-        # Schedule a run of self.tick() some time in the future
-        if now + self.delay < self.give_up_at:
-            pool.schedule_task(self, self.delay)
-        else:
-            pool.schedule_task(self, self.give_up_at, absolute=True)
+        pool._add_connection(self.attempt)
 
 
 class ReturnConnection(MaintenanceTask):
@@ -630,7 +687,7 @@ class ReturnConnection(MaintenanceTask):
         self.conn = conn
 
     def _run(self, pool: ConnectionPool) -> None:
-        pool._add_to_pool(self.conn)
+        pool._return_connection(self.conn)
 
 
 class ShrinkPool(MaintenanceTask):
@@ -645,25 +702,4 @@ class ShrinkPool(MaintenanceTask):
         # the periodic run.
         pool.schedule_task(self, pool.max_idle)
 
-        to_close: Optional[Connection] = None
-
-        with pool._lock:
-            # Reset the min number of connections used
-            nconns_min = pool._nconns_min
-            pool._nconns_min = len(pool._pool)
-
-            # If the pool can shrink and connections were unused, drop one
-            if pool._nconns > pool.minconn and nconns_min > 0:
-                to_close = pool._pool.popleft()
-                pool._nconns -= 1
-
-        if to_close:
-            logger.info(
-                "shrinking pool %r to %s because %s unused connections"
-                " in the last %s sec",
-                pool.name,
-                pool._nconns,
-                nconns_min,
-                pool.max_idle,
-            )
-            to_close.close()
+        pool._shrink_if_possible()
index 2dc185468c367f6f65ab951fbe6be2781472a75f..a80c8fb1d95249af7c8cd79564bcdb64cfa39238 100644 (file)
@@ -68,13 +68,15 @@ def test_concurrent_filling(dsn, monkeypatch):
     t0 = time()
     times = []
 
-    add_to_pool_orig = pool.ConnectionPool._add_to_pool
+    add_orig = pool.ConnectionPool._add_initial_connection
 
-    def _add_to_pool_time(self, conn):
+    def add_time(self, event):
+        add_orig(self, event)
         times.append(time() - t0)
-        add_to_pool_orig(self, conn)
 
-    monkeypatch.setattr(pool.ConnectionPool, "_add_to_pool", _add_to_pool_time)
+    monkeypatch.setattr(
+        pool.ConnectionPool, "_add_initial_connection", add_time
+    )
 
     pool.ConnectionPool(dsn, minconn=5, num_workers=2)
     want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
@@ -519,10 +521,10 @@ def test_shrink(dsn, monkeypatch):
 def test_reconnect(proxy, caplog, monkeypatch):
     caplog.set_level(logging.WARNING, logger="psycopg3.pool")
 
-    assert pool.pool.AddConnection.INITIAL_DELAY == 1.0
-    assert pool.pool.AddConnection.DELAY_JITTER == 0.1
-    monkeypatch.setattr(pool.pool.AddConnection, "INITIAL_DELAY", 0.1)
-    monkeypatch.setattr(pool.pool.AddConnection, "DELAY_JITTER", 0.0)
+    assert pool.pool.ConnectionAttempt.INITIAL_DELAY == 1.0
+    assert pool.pool.ConnectionAttempt.DELAY_JITTER == 0.1
+    monkeypatch.setattr(pool.pool.ConnectionAttempt, "INITIAL_DELAY", 0.1)
+    monkeypatch.setattr(pool.pool.ConnectionAttempt, "DELAY_JITTER", 0.0)
 
     proxy.start()
     p = pool.ConnectionPool(proxy.client_dsn, minconn=1, setup_timeout=2.0)