hostaddrs: list[str] = []
ports: list[str] = []
- for attempt in conninfo._split_attempts(params):
- try:
- async for a2 in conninfo._split_attempts_and_resolve(attempt):
- if a2.get("host") is not None:
- hosts.append(a2["host"])
- if a2.get("hostaddr") is not None:
- hostaddrs.append(a2["hostaddr"])
- if a2.get("port") is not None:
- ports.append(str(a2["port"]))
- except OSError as ex:
- last_exc = ex
-
- if params.get("host") and not hosts:
- # We couldn't resolve anything
- raise e.OperationalError(str(last_exc))
+ async for attempt in conninfo.conninfo_attempts_async(params):
+ if attempt.get("host") is not None:
+ hosts.append(attempt["host"])
+ if attempt.get("hostaddr") is not None:
+ hostaddrs.append(attempt["hostaddr"])
+ if attempt.get("port") is not None:
+ ports.append(str(attempt["port"]))
out = params.copy()
shosts = ",".join(hosts)
# If an host resolves to more than one ip, the libpq will make more than
# one attempt and wouldn't get to try the following ones, as before
# fixing #674.
+ attempts = _split_attempts(params)
if params.get("load_balance_hosts", "disable") == "random":
- attempts = list(_split_attempts(params))
shuffle(attempts)
- yield from attempts
- else:
- yield from _split_attempts(params)
+ yield from attempts
async def conninfo_attempts_async(params: ConnDict) -> AsyncIterator[ConnDict]:
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
- # TODO: the function should resolve all hosts and shuffle the results
- # to replicate the same libpq algorithm.
- yielded = False
last_exc = None
+ attempts = []
for attempt in _split_attempts(params):
try:
- async for a2 in _split_attempts_and_resolve(attempt):
- yielded = True
- yield a2
+ attempts.extend(await _resolve_hostnames(attempt))
except OSError as ex:
last_exc = ex
- if not yielded:
+ if not attempts:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
+ if params.get("load_balance_hosts", "disable") == "random":
+ shuffle(attempts)
+
+ for attempt in attempts:
+ yield attempt
-def _split_attempts(params: ConnDict) -> Iterator[ConnDict]:
+
+def _split_attempts(params: ConnDict) -> list[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
"""
# A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
- yield params
- return
+ return [params]
if len(ports) == 1:
ports *= nhosts
# Now all lists are either empty or have the same length
+ rv = []
for i in range(nhosts):
attempt = params.copy()
if hosts:
attempt["hostaddr"] = hostaddrs[i]
if ports:
attempt["port"] = ports[i]
- yield attempt
+ rv.append(attempt)
+
+ return rv
-async def _split_attempts_and_resolve(params: ConnDict) -> AsyncIterator[ConnDict]:
+async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
"""
Perform async DNS lookup of the hosts and return a new params dict.
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses asynchronously.
+
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr because it is designed to further
process the input of _split_attempts().
- If a ``host`` param is present but not ``hostname``, resolve the host
- addresses asynchronously.
-
- The function may change the input ``host``, ``hostname``, ``port`` to allow
- connecting without further DNS lookups.
+ :return: A list of attempts to make (to include the case of a hostname
+ resolving to more than one IP).
"""
host = _get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
- yield params
- return
+ return [params]
hostaddr = _get_param(params, "hostaddr")
if hostaddr:
# Already resolved
- yield params
- return
+ return [params]
if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
- params["hostaddr"] = host
- yield params
- return
+ return [{**params, "hostaddr": host}]
loop = asyncio.get_running_loop()
ans = await loop.getaddrinfo(
host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
-
- for item in ans:
- yield {**params, "hostaddr": item[4][0]}
+ return [{**params, "hostaddr": item[4][0]} for item in ans]
def _get_param(params: ConnDict, name: str) -> str | None:
list(conninfo_attempts(params))
-def test_conninfo_random(dsn, conn_cls):
+def test_conninfo_random():
hosts = [f"host{n:02d}" for n in range(50)]
args = {"host": ",".join(hosts)}
ahosts = [att["host"] for att in conninfo_attempts(args)]
assert ahosts == hosts
+@pytest.mark.anyio
+async def test_conninfo_random_async(fake_resolve):
+ args = {"host": "alot.com"}
+ hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+ assert len(hostaddrs) == 20
+ assert hostaddrs == sorted(hostaddrs)
+
+ args["load_balance_hosts"] = "disable"
+ hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+ assert hostaddrs == sorted(hostaddrs)
+
+ args["load_balance_hosts"] = "random"
+ hostaddrs = [att["hostaddr"] async for att in conninfo_attempts_async(args)]
+ assert hostaddrs != sorted(hostaddrs)
+
+
@pytest.fixture
async def fake_resolve(monkeypatch):
fake_hosts = {
"foo.com": ["1.1.1.1"],
"qux.com": ["2.2.2.2"],
"dup.com": ["3.3.3.3", "3.3.3.4"],
+ "alot.com": [f"4.4.4.{n}" for n in range(10, 30)],
}
def family(host):