]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make the pool maintenance tasks base class abstract
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 14 Feb 2021 02:33:19 +0000 (03:33 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool.py

index 03fc86b626f5835bbd1532bef210219ce19d4e10..ce225f96ad1b7e1978b9191aedf539fb21ad612c 100644 (file)
@@ -7,6 +7,7 @@ psycopg3 connection pool
 import random
 import logging
 import threading
+from abc import ABC, abstractmethod
 from queue import Queue, Empty
 from typing import Any, Callable, Deque, Dict, Iterator, List, Optional
 from contextlib import contextmanager
@@ -286,7 +287,7 @@ class ConnectionPool:
 
             # Run the task. Make sure don't die in the attempt.
             try:
-                task()
+                task.run()
             except Exception as e:
                 logger.warning(
                     "task run %s failed: %s: %s", task, e.__class__.__name__, e
@@ -339,30 +340,38 @@ class WaitingClient:
         self.event.set()
 
 
-class MaintenanceTask:
+class MaintenanceTask(ABC):
+    """A task run asynchronously to maintain the pool state."""
+
     def __init__(self, pool: ConnectionPool):
         self.pool = pool
         logger.debug("task created: %s", self)
 
-    def __call__(self) -> None:
-        logger.debug("task running: %s", self)
-
     def __repr__(self) -> str:
         return (
             f"<{self.__class__.__name__} {self.pool.name!r} at 0x{id(self):x}>"
         )
 
+    def run(self) -> None:
+        logger.debug("task running: %s", self)
+        self._run()
+
+    @abstractmethod
+    def _run(self) -> None:
+        ...
+
 
 class StopWorker(MaintenanceTask):
     """Signal the maintenance thread to terminate."""
 
+    def _run(self) -> None:
+        pass
+
 
 class TopUpConnections(MaintenanceTask):
     """Increase the number of connections in the pool to the desired number."""
 
-    def __call__(self) -> None:
-        super().__call__()
-
+    def _run(self) -> None:
         with self.pool._lock:
             # Check if there are new connections to create. If there are
             # update the number of connections managed immediately and in
@@ -383,9 +392,7 @@ class TopUpConnections(MaintenanceTask):
 class AddConnection(MaintenanceTask):
     """Add a new connection into to the pool."""
 
-    def __call__(self) -> None:
-        super().__call__()
-
+    def _run(self) -> None:
         conn = self.pool._connect()
         conn._pool = self.pool  # make it accepted by putconn
         self.pool.putconn(conn)
@@ -398,6 +405,5 @@ class ReturnConnection(MaintenanceTask):
         super().__init__(pool)
         self.conn = conn
 
-    def __call__(self) -> None:
-        super().__call__()
+    def _run(self) -> None:
         self.pool._return_connection(self.conn)