]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add pool reconnection retry
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 21 Feb 2021 03:26:00 +0000 (04:26 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool.py
psycopg3/setup.py
tests/conftest.py
tests/fix_proxy.py [new file with mode: 0644]
tests/test_pool.py

index f28c902fd6ff864c8e9dd3a81a22b7c1ed5ca50b..27de902015197831bb513a030d2e484d56558508 100644 (file)
@@ -17,6 +17,7 @@ from collections import deque
 
 from . import errors as e
 from .pq import TransactionStatus
+from ._sched import Scheduler
 from .connection import Connection
 
 WORKER_TIMEOUT = 60.0
@@ -46,6 +47,8 @@ class ConnectionPool:
         name: Optional[str] = None,
         timeout_sec: float = 30.0,
         max_idle_sec: float = 10 * 60.0,
+        reconnect_timeout: float = 5 * 60.0,
+        reconnect_failed: Optional[Callable[["ConnectionPool"], None]] = None,
         num_workers: int = 3,
     ):
         if maxconn is None:
@@ -67,18 +70,22 @@ class ConnectionPool:
         self.kwargs: Dict[str, Any] = kwargs or {}
         self._configure: Callable[[Connection], None]
         self._configure = configure or (lambda conn: None)
+        self._reconnect_failed: Callable[["ConnectionPool"], None]
+        self._reconnect_failed = reconnect_failed or (lambda pool: None)
         self.name = name
         self.minconn = minconn
         self.maxconn = maxconn
         self.timeout_sec = timeout_sec
+        self.reconnect_timeout = reconnect_timeout
         self.max_idle_sec = max_idle_sec
         self.num_workers = num_workers
 
         self._nconns = minconn  # currently in the pool, out, being prepared
         self._pool: Deque[Tuple[Connection, float]] = deque()
         self._waiting: Deque["WaitingClient"] = deque()
-        self._lock = threading.Lock()
+        self._lock = threading.RLock()
         self._closed = False
+        self.sched = Scheduler()
 
         self._wqueue: "Queue[MaintenanceTask]" = Queue()
         self._workers: List[threading.Thread] = []
@@ -88,6 +95,10 @@ class ConnectionPool:
             t.start()
             self._workers.append(t)
 
+        self._sched_runner = threading.Thread(target=self.sched.run)
+        self._sched_runner.daemon = True
+        self._sched_runner.start()
+
         # Populate the pool with initial minconn connections
         event = threading.Event()
         for i in range(self._nconns):
@@ -231,7 +242,7 @@ class ConnectionPool:
                 # Extract the first client from the queue
                 pos = self._waiting.popleft()
             else:
-                now = time.time()
+                now = time.monotonic()
 
                 # No client waiting for a connection: put it back into the pool
                 self._pool.append((conn, now))
@@ -301,6 +312,9 @@ class ConnectionPool:
         # Now that the flag _closed is set, getconn will fail immediately,
         # putconn will just close the returned connection.
 
+        # Stop the scheduler
+        self.sched.enter(0, None)
+
         # Signal to eventual clients in the queue that business is closed.
         while self._waiting:
             pos = self._waiting.popleft()
@@ -359,6 +373,12 @@ class ConnectionPool:
         """Configure a connection after creation."""
         self._configure(conn)
 
+    def reconnect_failed(self) -> None:
+        """
+        Called when reconnection failed for longer than `reconnect_timeout`.
+        """
+        self._reconnect_failed(self)
+
 
 class WaitingClient:
     """An position in a queue for a client waiting for a connection."""
@@ -422,15 +442,7 @@ class StopWorker(MaintenanceTask):
         pass
 
 
-class AddConnection(MaintenanceTask):
-    """Add a new connection into to the pool."""
-
-    def _run(self, pool: ConnectionPool) -> None:
-        conn = pool._connect()
-        pool._add_to_pool(conn)
-
-
-class AddInitialConnection(AddConnection):
+class AddInitialConnection(MaintenanceTask):
     """Add a new connection into to the pool.
 
     If the desired number of connections is reached notify the event.
@@ -441,11 +453,76 @@ class AddInitialConnection(AddConnection):
         self.event = event
 
     def _run(self, pool: ConnectionPool) -> None:
-        super()._run(pool)
+        conn = pool._connect()
+        pool._add_to_pool(conn)
         if len(pool._pool) >= pool._nconns:
             self.event.set()
 
 
+class AddConnection(MaintenanceTask):
+    INITIAL_DELAY = 1.0
+    DELAY_JITTER = 0.1
+    DELAY_BACKOFF = 2.0
+
+    def __init__(self, pool: ConnectionPool):
+        super().__init__(pool)
+        self.delay = 0.0
+        self.give_up_at = 0.0
+
+    def _run(self, pool: ConnectionPool) -> None:
+        try:
+            conn = pool._connect()
+        except Exception as e:
+            logger.warning(f"error reconnecting in {pool.name!r}: {e}")
+            self._handle_error(pool)
+        else:
+            pool._add_to_pool(conn)
+
+    def _handle_error(self, pool: ConnectionPool) -> None:
+        """Called after a connection failure.
+
+        Calculate the new time for a new reconnection attempt and schedule a
+        retry in the future. If too many attempts were performed, give up, by
+        decreasing the pool connection number and calling
+        `pool.reconnect_failed()`.
+        """
+        now = time.monotonic()
+        if self.give_up_at and now >= self.give_up_at:
+            logger.warning(
+                "reconnection attempt in pool %r failed after %s sec",
+                pool.name,
+                pool.reconnect_timeout,
+            )
+            with pool._lock:
+                pool._nconns -= 1
+            pool.reconnect_failed()
+            return
+
+        # Calculate how long to wait for a new connection attempt
+        if self.delay == 0.0:
+            self.give_up_at = now + pool.reconnect_timeout
+            # +/- 10% of the initial delay
+            jitter = self.INITIAL_DELAY * (
+                (2.0 * self.DELAY_JITTER * random.random()) - self.DELAY_JITTER
+            )
+            self.delay = self.INITIAL_DELAY + jitter
+        else:
+            self.delay *= self.DELAY_BACKOFF
+
+        # Schedule a run of self.retry() some time in the future
+        if now + self.delay < self.give_up_at:
+            pool.sched.enter(self.delay, self.retry)
+        else:
+            pool.sched.enterabs(self.give_up_at, self.retry)
+
+    def retry(self) -> None:
+        pool = self.pool()
+        if not pool:
+            return
+
+        pool.add_task(self)
+
+
 class ReturnConnection(MaintenanceTask):
     """Clean up and return a connection to the pool."""
 
index 3d76427352c24894098729fb5e0490db0af28e65..91a39c2b45720696b6e40c908520a1f6621e41d2 100644 (file)
@@ -32,6 +32,7 @@ extras_require = {
         f"psycopg3-binary == {version}",
     ],
     "test": [
+        "pproxy >= 2.7, < 2.8",
         "pytest >= 6, < 6.1",
         "pytest-asyncio >= 0.14.0, < 0.15",
         "pytest-randomly >= 3.5, < 3.6",
index 98d222858f9d388e8aa770413e562f275a0319c0..b30ede9345f3b88e70ff3de199d9cbc8225aa0bb 100644 (file)
@@ -5,6 +5,7 @@ import pytest
 pytest_plugins = (
     "tests.fix_db",
     "tests.fix_pq",
+    "tests.fix_proxy",
     "tests.fix_faker",
 )
 
diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py
new file mode 100644 (file)
index 0000000..fc92651
--- /dev/null
@@ -0,0 +1,84 @@
+import time
+import socket
+import subprocess as sp
+from shutil import which
+
+import pytest
+
+from psycopg3 import conninfo
+
+
+@pytest.fixture
+def proxy(dsn):
+    """Return a proxy to the --test-dsn database"""
+    p = Proxy(dsn)
+    yield p
+    p.stop()
+
+
+class Proxy:
+    """
+    Proxy a Postgres service for testing purpose.
+
+    Allow to lose connectivity and restart it using stop/start.
+    """
+
+    def __init__(self, server_dsn):
+        cdict = conninfo.conninfo_to_dict(server_dsn)
+
+        # Get server params
+        self.server_port = cdict.get("port", "5432")
+        if "host" not in cdict or cdict["host"].startswith("/"):
+            self.server_host = "localhost"
+        else:
+            self.server_host = cdict["host"]
+
+        # Get client params
+        self.client_host = "localhost"
+        self.client_port = self._get_random_port()
+
+        # Make a connection string to the proxy
+        cdict["host"] = self.client_host
+        cdict["port"] = self.client_port
+        self.client_dsn = conninfo.make_conninfo(**cdict)
+
+        # The running proxy process
+        self.proc = None
+
+    def start(self):
+        if self.proc:
+            raise ValueError("proxy already running")
+
+        pproxy = which("pproxy")
+        if not pproxy:
+            raise ValueError("pproxy program not found")
+        cmdline = [pproxy, "--reuse"]
+        cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
+        cmdline.extend(
+            ["-r", f"tunnel://{self.server_host}:{self.server_port}"]
+        )
+
+        self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL)
+        self._wait_listen()
+
+    def stop(self):
+        if not self.proc:
+            return
+
+        self.proc.terminate()
+        self.proc.wait()
+        self.proc = None
+
+    @classmethod
+    def _get_random_port(cls):
+        with socket.socket() as s:
+            s.bind(("", 0))
+            return s.getsockname()[1]
+
+    def _wait_listen(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+            for i in range(20):
+                if 0 == sock.connect_ex((self.client_host, self.client_port)):
+                    return
+                time.sleep(0.1)
+            raise ValueError("the proxy didn't start")
index 0463f1eef1e31efd875676fac2d32476e6e3eed7..da96d11b34d826937b60daf6d0fed07f8861e2b1 100644 (file)
@@ -1,6 +1,6 @@
 import logging
 import weakref
-from time import time, sleep
+from time import monotonic, sleep, time
 from threading import Thread
 
 import pytest
@@ -64,7 +64,7 @@ def test_connection_not_lost(dsn):
 @pytest.mark.slow
 def test_concurrent_filling(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
-    t0 = time()
+    t0 = monotonic()
     p = pool.ConnectionPool(dsn, minconn=5, num_workers=2)
     times = [item[1] - t0 for item in p._pool]
     want_times = [0.1, 0.1, 0.2, 0.2, 0.3]
@@ -436,6 +436,80 @@ def test_shrink(dsn, monkeypatch):
     assert t == pytest.approx(0.2, 0.1)
 
 
+@pytest.mark.slow
+def test_reconnect(proxy, caplog, monkeypatch):
+    caplog.set_level(logging.WARNING, logger="psycopg3.pool")
+
+    assert pool.AddConnection.INITIAL_DELAY == 1.0
+    assert pool.AddConnection.DELAY_JITTER == 0.1
+    monkeypatch.setattr(pool.AddConnection, "INITIAL_DELAY", 0.1)
+    monkeypatch.setattr(pool.AddConnection, "DELAY_JITTER", 0.0)
+
+    proxy.start()
+    p = pool.ConnectionPool(proxy.client_dsn, minconn=1, timeout_sec=2)
+    proxy.stop()
+
+    with pytest.raises(psycopg3.OperationalError):
+        with p.connection() as conn:
+            conn.execute("select 1")
+
+    sleep(1.0)
+    proxy.start()
+    wait_pool_full(p)
+
+    with p.connection() as conn:
+        conn.execute("select 1")
+
+    assert "BAD" in caplog.messages[0]
+    times = [rec.created for rec in caplog.records]
+    assert times[1] - times[0] < 0.05
+    deltas = [times[i + 1] - times[i] for i in range(1, len(times) - 1)]
+    assert len(deltas) == 3
+    want = 0.1
+    for delta in deltas:
+        assert delta == pytest.approx(want, 0.05), deltas
+        want *= 2
+
+
+@pytest.mark.slow
+def test_reconnect_failure(proxy):
+    proxy.start()
+
+    t1 = None
+
+    def failed(pool):
+        assert pool.name == "this-one"
+        nonlocal t1
+        t1 = time()
+
+    p = pool.ConnectionPool(
+        proxy.client_dsn,
+        name="this-one",
+        minconn=1,
+        timeout_sec=2,
+        reconnect_timeout=1.0,
+        reconnect_failed=failed,
+    )
+    proxy.stop()
+
+    with pytest.raises(psycopg3.OperationalError):
+        with p.connection() as conn:
+            conn.execute("select 1")
+
+    t0 = time()
+    sleep(1.5)
+    assert t1
+    assert t1 - t0 == pytest.approx(1.0, 0.1)
+    assert p._nconns == 0
+
+    proxy.start()
+    t0 = time()
+    with p.connection() as conn:
+        conn.execute("select 1")
+    t1 = time()
+    assert t1 - t0 < 0.2
+
+
 def delay_connection(monkeypatch, sec):
     """
     Return a _connect_gen function delayed by the amount of seconds