]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: add support for the 'load_balance_hosts' connection parameter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 20:12:45 +0000 (20:12 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Nov 2023 20:20:01 +0000 (21:20 +0100)
docs/news.rst
psycopg/psycopg/conninfo.py
tests/test_conninfo.py

index ded0549065206dc165e9528e264849c9da9d53cf..7d63c0030bcb74c2f9dc9f6b4ede3c5c4f67f585 100644 (file)
@@ -38,7 +38,8 @@ Psycopg 3.1.13 (unreleased)
   `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`).
 - Handle gracefully EINTR on signals instead of raising `InterruptedError`,
   consistently with :pep:`475` guideline (:ticket:`#667`).
-- Fix support for connection strings with multiple hosts (:ticket:`#674`).
+- Fix support for connection strings with multiple hosts/ports and for the
+  ``load_balance_hosts`` connection parameter (:ticket:`#674`).
 
 
 Current release
index 4f633fff4a09c9c5bb72758af0c5d0ae7c637344..6356a2d0f4524bcf8051925eb854672441d51584 100644 (file)
@@ -11,6 +11,7 @@ import re
 import socket
 import asyncio
 from typing import Any, Iterator, AsyncIterator
+from random import shuffle
 from pathlib import Path
 from datetime import tzinfo
 from functools import lru_cache
@@ -339,8 +340,14 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
     Because the libpq async function doesn't honour the timeout, we need to
     reimplement the repeated attempts.
     """
-    for attempt in _split_attempts(_inject_defaults(params)):
-        yield attempt
+    if params.get("load_balance_hosts", "disable") == "random":
+        attempts = list(_split_attempts(_inject_defaults(params)))
+        shuffle(attempts)
+        for attempt in attempts:
+            yield attempt
+    else:
+        for attempt in _split_attempts(_inject_defaults(params)):
+            yield attempt
 
 
 async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
index 832c8809908d7d632e8b3040b259092be8a90d2e..5c8b99a4f89e576885b20e876ce30f73b6e36d1b 100644 (file)
@@ -479,6 +479,23 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env):
         list(conninfo_attempts(params))
 
 
+def test_conninfo_random(dsn, conn_cls):
+    hosts = [f"host{n:02d}" for n in range(50)]
+    args = {"host": ",".join(hosts)}
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "disable"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "random"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts != hosts
+    ahosts.sort()
+    assert ahosts == hosts
+
+
 @pytest.fixture
 async def fake_resolve(monkeypatch):
     fake_hosts = {