]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: shuffle attempts when one host resolves to more than one IP
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 00:03:55 +0000 (01:03 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Dec 2023 00:26:10 +0000 (01:26 +0100)
This behaviour (first resolve all the hosts, then shuffle the IPs) mimics
better what the libpq does in non-async mode.

psycopg/psycopg/_dns.py
psycopg/psycopg/conninfo.py
tests/test_conninfo.py

index eb06d1cd8f938f951016884f19a62e5180ba940b..86f3468d2c36e2fdbe7f9c67c1ef557d94b1974b 100644 (file)
@@ -52,21 +52,13 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     hostaddrs: list[str] = []
     ports: list[str] = []
 
-    for attempt in conninfo._split_attempts(params):
-        try:
-            async for a2 in conninfo._split_attempts_and_resolve(attempt):
-                if a2.get("host") is not None:
-                    hosts.append(a2["host"])
-                if a2.get("hostaddr") is not None:
-                    hostaddrs.append(a2["hostaddr"])
-                if a2.get("port") is not None:
-                    ports.append(str(a2["port"]))
-        except OSError as ex:
-            last_exc = ex
-
-    if params.get("host") and not hosts:
-        # We couldn't resolve anything
-        raise e.OperationalError(str(last_exc))
+    async for attempt in conninfo.conninfo_attempts_async(params):
+        if attempt.get("host") is not None:
+            hosts.append(attempt["host"])
+        if attempt.get("hostaddr") is not None:
+            hostaddrs.append(attempt["hostaddr"])
+        if attempt.get("port") is not None:
+            ports.append(str(attempt["port"]))
 
     out = params.copy()
     shosts = ",".join(hosts)
index 5f56eb3871fe679fec1ab8ed38efb267c7f95bb4..ee01cd6aca978c7889d884d1ee5149c7dc927444 100644 (file)
@@ -295,12 +295,10 @@ def conninfo_attempts(params: ConnDict) -> Iterator[ConnDict]:
     # If an host resolves to more than one ip, the libpq will make more than
     # one attempt and wouldn't get to try the following ones, as before
     # fixing #674.
+    attempts = _split_attempts(params)
     if params.get("load_balance_hosts", "disable") == "random":
-        attempts = list(_split_attempts(params))
         shuffle(attempts)
-        yield from attempts
-    else:
-        yield from _split_attempts(params)
+    yield from attempts
 
 
 async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
@@ -317,25 +315,27 @@ async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
     Because the libpq async function doesn't honour the timeout, we need to
     reimplement the repeated attempts.
     """
-    # TODO: the function should resolve all hosts and shuffle the results
-    # to replicate the same libpq algorithm.
-    yielded = False
     last_exc = None
+    attempts = []
     for attempt in _split_attempts(params):
         try:
-            async for a2 in _split_attempts_and_resolve(attempt):
-                yielded = True
-                yield a2
+            attempts.extend(await _resolve_hostnames(attempt))
         except OSError as ex:
             last_exc = ex
 
-    if not yielded:
+    if not attempts:
         assert last_exc
         # We couldn't resolve anything
         raise e.OperationalError(str(last_exc))
 
+    if params.get("load_balance_hosts", "disable") == "random":
+        shuffle(attempts)
+
+    for attempt in attempts:
+        yield attempt
 
-def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
+
+def _split_attempts(params: ConnDict) -> list[ConnDict]:
     """
     Split connection parameters with a sequence of hosts into separate attempts.
     """
@@ -363,13 +363,13 @@ def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
 
     # A single attempt to make. Don't mangle the conninfo string.
     if nhosts <= 1:
-        yield params
-        return
+        return [params]
 
     if len(ports) == 1:
         ports *= nhosts
 
     # Now all lists are either empty or have the same length
+    rv = []
     for i in range(nhosts):
         attempt = params.copy()
         if hosts:
@@ -378,41 +378,39 @@ def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
             attempt["hostaddr"] = hostaddrs[i]
         if ports:
             attempt["port"] = ports[i]
-        yield attempt
+        rv.append(attempt)
+
+    return rv
 
 
-async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
+async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
     """
     Perform async DNS lookup of the hosts and return a new params dict.
 
+    If a ``host`` param is present but not ``hostname``, resolve the host
+    addresses asynchronously.
+
     :param params: The input parameters, for instance as returned by
         `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
         a single entry for host, hostaddr because it is designed to further
         process the input of _split_attempts().
 
-    If a ``host`` param is present but not ``hostname``, resolve the host
-    addresses asynchronously.
-
-    The function may change the input ``host``, ``hostname``, ``port`` to allow
-    connecting without further DNS lookups.
+    :return: A list of attempts to make (to include the case of a hostname
+        resolving to more than one IP).
     """
     host = _get_param(params, "host")
     if not host or host.startswith("/") or host[1:2] == ":":
         # Local path, or no host to resolve
-        yield params
-        return
+        return [params]
 
     hostaddr = _get_param(params, "hostaddr")
     if hostaddr:
         # Already resolved
-        yield params
-        return
+        return [params]
 
     if is_ip_address(host):
         # If the host is already an ip address don't try to resolve it
-        params["hostaddr"] = host
-        yield params
-        return
+        return [{**params, "hostaddr": host}]
 
     loop = asyncio.get_running_loop()
 
@@ -426,9 +424,7 @@ async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDic
     ans = await loop.getaddrinfo(
         host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
     )
-
-    for item in ans:
-        yield {**params, "hostaddr": item[4][0]}
+    return [{**params, "hostaddr": item[4][0]} for item in ans]
 
 
 def _get_param(params: ConnDict, name: str) -> str | None:
index 8254199261c7fa1baf4234420b227b9dfb8a8013..db2e38b8a892bd0d384af72d2120da35dc2d1dd3 100644 (file)
@@ -498,7 +498,7 @@ def test_conninfo_attempts_bad(setpgenv, conninfo, env):
         list(conninfo_attempts(params))
 
 
-def test_conninfo_random(dsn, conn_cls):
+def test_conninfo_random():
     hosts = [f"host{n:02d}" for n in range(50)]
     args = {"host": ",".join(hosts)}
     ahosts = [att["host"] for att in conninfo_attempts(args)]
@@ -515,6 +515,22 @@ def test_conninfo_random(dsn, conn_cls):
     assert ahosts == hosts
 
 
+@pytest.mark.anyio
+async def test_conninfo_random_async(fake_resolve):
+    args = {"host": "alot.com"}
+    hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+    assert len(hostaddrs) == 20
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "disable"
+    hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "random"
+    hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+    assert hostaddrs != sorted(hostaddrs)
+
+
 @pytest.fixture
 async def fake_resolve(monkeypatch):
     fake_hosts = {
@@ -522,6 +538,7 @@ async def fake_resolve(monkeypatch):
         "foo.com": ["1.1.1.1"],
         "qux.com": ["2.2.2.2"],
         "dup.com": ["3.3.3.3", "3.3.3.4"],
+        "alot.com": [f"4.4.4.{n}" for n in range(10, 30)],
     }
 
     def family(host):