From 292762b33a90fade9b121c98ad03bd72b1d1b9a3 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 13 Nov 2023 20:12:45 +0000 Subject: [PATCH] fix: add support for the 'load_balance_hosts' connection parameter --- docs/news.rst | 3 ++- psycopg/psycopg/conninfo.py | 11 +++++++++-- tests/test_conninfo.py | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/news.rst b/docs/news.rst index ded054906..7d63c0030 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -38,7 +38,8 @@ Psycopg 3.1.13 (unreleased) `~zoneinfo.ZoneInfo` (ambiguous offset, see :ticket:`#652`). - Handle gracefully EINTR on signals instead of raising `InterruptedError`, consistently with :pep:`475` guideline (:ticket:`#667`). -- Fix support for connection strings with multiple hosts (:ticket:`#674`). +- Fix support for connection strings with multiple hosts/ports and for the + ``load_balance_hosts`` connection parameter (:ticket:`#674`). Current release diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 4f633fff4..6356a2d0f 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -11,6 +11,7 @@ import re import socket import asyncio from typing import Any, Iterator, AsyncIterator +from random import shuffle from pathlib import Path from datetime import tzinfo from functools import lru_cache @@ -339,8 +340,14 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ - for attempt in _split_attempts(_inject_defaults(params)): - yield attempt + if params.get("load_balance_hosts", "disable") == "random": + attempts = list(_split_attempts(_inject_defaults(params))) + shuffle(attempts) + for attempt in attempts: + yield attempt + else: + for attempt in _split_attempts(_inject_defaults(params)): + yield attempt async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 832c88099..5c8b99a4f 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -479,6 +479,23 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env): list(conninfo_attempts(params)) +def test_conninfo_random(dsn, conn_cls): + hosts = [f"host{n:02d}" for n in range(50)] + args = {"host": ",".join(hosts)} + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "disable" + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "random" + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts != hosts + ahosts.sort() + assert ahosts == hosts + + @pytest.fixture async def fake_resolve(monkeypatch): fake_hosts = { -- 2.47.3