]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: add support for the 'load_balance_hosts' connection parameter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 20:12:45 +0000 (20:12 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 23:37:49 +0000 (00:37 +0100)
docs/news.rst
psycopg/psycopg/conninfo.py
tests/test_conninfo.py

index cc456e5b1a72e6d65484c1ddfe256e88c9ff9e21..0ecea199628436cbd1a1fe2a677b3aa8153ec545 100644 (file)
@@ -18,7 +18,8 @@ Psycopg 3.1.13 (unreleased)
   `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`).
 - Handle gracefully EINTR on signals instead of raising `InterruptedError`,
   consistently with :pep:`475` guideline (:ticket:`#667`).
-- Fix support for connection strings with multiple hosts (:ticket:`#674`).
+- Fix support for connection strings with multiple hosts/ports and for the
+  ``load_balance_hosts`` connection parameter (:ticket:`#674`).
 
 
 Current release
index 4f633fff4a09c9c5bb72758af0c5d0ae7c637344..6356a2d0f4524bcf8051925eb854672441d51584 100644 (file)
@@ -11,6 +11,7 @@ import re
 import socket
 import asyncio
 from typing import Any, Iterator, AsyncIterator
+from random import shuffle
 from pathlib import Path
 from datetime import tzinfo
 from functools import lru_cache
@@ -339,8 +340,14 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
     Because the libpq async function doesn't honour the timeout, we need to
     reimplement the repeated attempts.
     """
-    for attempt in _split_attempts(_inject_defaults(params)):
-        yield attempt
+    if params.get("load_balance_hosts", "disable") == "random":
+        attempts = list(_split_attempts(_inject_defaults(params)))
+        shuffle(attempts)
+        for attempt in attempts:
+            yield attempt
+    else:
+        for attempt in _split_attempts(_inject_defaults(params)):
+            yield attempt
 
 
 async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
index 1ebf42f79a73ba1964283cc4e48a765e9c0d219e..2ae1d142a3b281ea25a53d880082dc03a433bc1f 100644 (file)
@@ -479,6 +479,23 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env):
         list(conninfo_attempts(params))
 
 
+def test_conninfo_random(dsn, conn_cls):
+    hosts = [f"host{n:02d}" for n in range(50)]
+    args = {"host": ",".join(hosts)}
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "disable"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "random"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts != hosts
+    ahosts.sort()
+    assert ahosts == hosts
+
+
 @pytest.fixture
 async def fake_resolve(monkeypatch):
     fake_hosts = {