From: Daniele Varrazzo Date: Wed, 13 Dec 2023 00:03:55 +0000 (+0100) Subject: fix: shuffle attempts when one host resolves to more than one IP X-Git-Tag: 3.1.15~1^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bd259ee3176aad975304e8abc95d2965f17e5de4;p=thirdparty%2Fpsycopg.git fix: shuffle attempts when one host resolves to more than one IP This behaviour (first resolve all the hosts, then shuffle the IPs) mimics better what the libpq does in non-async mode. --- diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index eb06d1cd8..86f3468d2 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -52,21 +52,13 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: hostaddrs: list[str] = [] ports: list[str] = [] - for attempt in conninfo._split_attempts(params): - try: - async for a2 in conninfo._split_attempts_and_resolve(attempt): - if a2.get("host") is not None: - hosts.append(a2["host"]) - if a2.get("hostaddr") is not None: - hostaddrs.append(a2["hostaddr"]) - if a2.get("port") is not None: - ports.append(str(a2["port"])) - except OSError as ex: - last_exc = ex - - if params.get("host") and not hosts: - # We couldn't resolve anything - raise e.OperationalError(str(last_exc)) + async for attempt in conninfo.conninfo_attempts_async(params): + if attempt.get("host") is not None: + hosts.append(attempt["host"]) + if attempt.get("hostaddr") is not None: + hostaddrs.append(attempt["hostaddr"]) + if attempt.get("port") is not None: + ports.append(str(attempt["port"])) out = params.copy() shosts = ",".join(hosts) diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 5f56eb387..ee01cd6ac 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -295,12 +295,10 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]: # If an host resolves to more than one ip, the libpq will make more than # one attempt and wouldn't get to try the following ones, as before # fixing #674. + attempts = _split_attempts(params) if params.get("load_balance_hosts", "disable") == "random": - attempts = list(_split_attempts(params)) shuffle(attempts) - yield from attempts - else: - yield from _split_attempts(params) + yield from attempts async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: @@ -317,25 +315,27 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]: Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ - # TODO: the function should resolve all hosts and shuffle the results - # to replicate the same libpq algorithm. - yielded = False last_exc = None + attempts = [] for attempt in _split_attempts(params): try: - async for a2 in _split_attempts_and_resolve(attempt): - yielded = True - yield a2 + attempts.extend(await _resolve_hostnames(attempt)) except OSError as ex: last_exc = ex - if not yielded: + if not attempts: assert last_exc # We couldn't resolve anything raise e.OperationalError(str(last_exc)) + if params.get("load_balance_hosts", "disable") == "random": + shuffle(attempts) + + for attempt in attempts: + yield attempt -def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: + +def _split_attempts(params: ConnDict) -> list[ConnDict]: """ Split connection parameters with a sequence of hosts into separate attempts. """ @@ -363,13 +363,13 @@ def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: # A single attempt to make. Don't mangle the conninfo string. if nhosts <= 1: - yield params - return + return [params] if len(ports) == 1: ports *= nhosts # Now all lists are either empty or have the same length + rv = [] for i in range(nhosts): attempt = params.copy() if hosts: @@ -378,41 +378,39 @@ def _split_attempts(params: ConnDict) -> Iterator[ConnDict]: attempt["hostaddr"] = hostaddrs[i] if ports: attempt["port"] = ports[i] - yield attempt + rv.append(attempt) + + return rv -async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]: +async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: """ Perform async DNS lookup of the hosts and return a new params dict. + If a ``host`` param is present but not ``hostname``, resolve the host + addresses asynchronously. + :param params: The input parameters, for instance as returned by `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most a single entry for host, hostaddr because it is designed to further process the input of _split_attempts(). - If a ``host`` param is present but not ``hostname``, resolve the host - addresses asynchronously. - - The function may change the input ``host``, ``hostname``, ``port`` to allow - connecting without further DNS lookups. + :return: A list of attempts to make (to include the case of a hostname + resolving to more than one IP). """ host = _get_param(params, "host") if not host or host.startswith("/") or host[1:2] == ":": # Local path, or no host to resolve - yield params - return + return [params] hostaddr = _get_param(params, "hostaddr") if hostaddr: # Already resolved - yield params - return + return [params] if is_ip_address(host): # If the host is already an ip address don't try to resolve it - params["hostaddr"] = host - yield params - return + return [{**params, "hostaddr": host}] loop = asyncio.get_running_loop() @@ -426,9 +424,7 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic ans = await loop.getaddrinfo( host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM ) - - for item in ans: - yield {**params, "hostaddr": item[4][0]} + return [{**params, "hostaddr": item[4][0]} for item in ans] def _get_param(params: ConnDict, name: str) -> str | None: diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 825419926..db2e38b8a 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -498,7 +498,7 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env): list(conninfo_attempts(params)) -def test_conninfo_random(dsn, conn_cls): +def test_conninfo_random(): hosts = [f"host{n:02d}" for n in range(50)] args = {"host": ",".join(hosts)} ahosts = [att["host"] for att in conninfo_attempts(args)] @@ -515,6 +515,22 @@ def test_conninfo_random(dsn, conn_cls): assert ahosts == hosts +@pytest.mark.anyio +async def test_conninfo_random_async(fake_resolve): + args = {"host": "alot.com"} + hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)] + assert len(hostaddrs) == 20 + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "disable" + hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)] + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "random" + hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)] + assert hostaddrs != sorted(hostaddrs) + + @pytest.fixture async def fake_resolve(monkeypatch): fake_hosts = { @@ -522,6 +538,7 @@ async def fake_resolve(monkeypatch): "foo.com": ["1.1.1.1"], "qux.com": ["2.2.2.2"], "dup.com": ["3.3.3.3", "3.3.3.4"], + "alot.com": [f"4.4.4.{n}" for n in range(10, 30)], } def family(host):