]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: perform multiple attemps if a host name resolve to multiple hosts
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 16:24:48 +0000 (17:24 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Jan 2024 19:22:02 +0000 (20:22 +0100)
We now perform DNS resolution in Python both in the sync and async code.
Because the two code paths are now very similar, make sure they pass the
same tests.

Close #699.

12 files changed:
docs/news.rst
psycopg/psycopg/_conninfo_attempts.py [new file with mode: 0644]
psycopg/psycopg/_conninfo_attempts_async.py [new file with mode: 0644]
psycopg/psycopg/_conninfo_utils.py [new file with mode: 0644]
psycopg/psycopg/conninfo.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py
tests/test_conninfo_attempts.py [new file with mode: 0644]
tests/test_conninfo_attempts_async.py [new file with mode: 0644]
tests/test_dns.py
tests/test_psycopg_dbapi20.py

index 5fd7c51287aaaf8b507ae834d18e8db57f3a6e6d..31e3dcd674c97545817524cfeb61870d194fc462 100644 (file)
@@ -13,6 +13,8 @@ Future releases
 Psycopg 3.1.17 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Fix multiple connection attempts when a host name resolve to multiple
+  IP addresses (:ticket:`699`).
 - Use `typing.Self` as a more correct return value annotation of context
   managers and other self-returning methods (see :ticket:`708`).
 
diff --git a/psycopg/psycopg/_conninfo_attempts.py b/psycopg/psycopg/_conninfo_attempts.py
new file mode 100644 (file)
index 0000000..5262ab7
--- /dev/null
@@ -0,0 +1,90 @@
+"""
+Separate connection attempts from a connection string.
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from __future__ import annotations
+
+import socket
+import logging
+from random import shuffle
+
+from . import errors as e
+from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from ._conninfo_utils import split_attempts
+
+logger = logging.getLogger("psycopg")
+
+
+def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
+    """Split a set of connection params on the single attempts to perform.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Also perform async resolution of the hostname into hostaddr. Because a host
+    can resolve to more than one address, this can lead to yield more attempts
+    too. Raise `OperationalError` if no host could be resolved.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    last_exc = None
+    attempts = []
+    for attempt in split_attempts(params):
+        try:
+            attempts.extend(_resolve_hostnames(attempt))
+        except OSError as ex:
+            logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
+            last_exc = ex
+
+    if not attempts:
+        assert last_exc
+        # We couldn't resolve anything
+        raise e.OperationalError(str(last_exc))
+
+    if get_param(params, "load_balance_hosts") == "random":
+        shuffle(attempts)
+
+    return attempts
+
+
+def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
+    """
+    Perform DNS lookup of the hosts and return a list of connection attempts.
+
+    If a ``host`` param is present but not ``hostname``, resolve the host
+    addresses asynchronously.
+
+    :param params: The input parameters, for instance as returned by
+        `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+        a single entry for host, hostaddr because it is designed to further
+        process the input of split_attempts().
+
+    :return: A list of attempts to make (to include the case of a hostname
+        resolving to more than one IP).
+    """
+    host = get_param(params, "host")
+    if not host or host.startswith("/") or host[1:2] == ":":
+        # Local path, or no host to resolve
+        return [params]
+
+    hostaddr = get_param(params, "hostaddr")
+    if hostaddr:
+        # Already resolved
+        return [params]
+
+    if is_ip_address(host):
+        # If the host is already an ip address don't try to resolve it
+        return [{**params, "hostaddr": host}]
+
+    port = get_param(params, "port")
+    if not port:
+        port_def = get_param_def("port")
+        port = port_def and port_def.compiled or "5432"
+
+    ans = socket.getaddrinfo(
+        host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+    )
+    return [{**params, "hostaddr": item[4][0]} for item in ans]
diff --git a/psycopg/psycopg/_conninfo_attempts_async.py b/psycopg/psycopg/_conninfo_attempts_async.py
new file mode 100644 (file)
index 0000000..f0f4572
--- /dev/null
@@ -0,0 +1,92 @@
+"""
+Separate connection attempts from a connection string.
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from __future__ import annotations
+
+import socket
+import asyncio
+import logging
+from random import shuffle
+
+from . import errors as e
+from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from ._conninfo_utils import split_attempts
+
+logger = logging.getLogger("psycopg")
+
+
+async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
+    """Split a set of connection params on the single attempts to perform.
+
+    A connection param can perform more than one attempt more than one ``host``
+    is provided.
+
+    Also perform async resolution of the hostname into hostaddr. Because a host
+    can resolve to more than one address, this can lead to yield more attempts
+    too. Raise `OperationalError` if no host could be resolved.
+
+    Because the libpq async function doesn't honour the timeout, we need to
+    reimplement the repeated attempts.
+    """
+    last_exc = None
+    attempts = []
+    for attempt in split_attempts(params):
+        try:
+            attempts.extend(await _resolve_hostnames(attempt))
+        except OSError as ex:
+            logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
+            last_exc = ex
+
+    if not attempts:
+        assert last_exc
+        # We couldn't resolve anything
+        raise e.OperationalError(str(last_exc))
+
+    if get_param(params, "load_balance_hosts") == "random":
+        shuffle(attempts)
+
+    return attempts
+
+
+async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
+    """
+    Perform async DNS lookup of the hosts and return a list of connection attempts.
+
+    If a ``host`` param is present but not ``hostname``, resolve the host
+    addresses asynchronously.
+
+    :param params: The input parameters, for instance as returned by
+        `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+        a single entry for host, hostaddr because it is designed to further
+        process the input of split_attempts().
+
+    :return: A list of attempts to make (to include the case of a hostname
+        resolving to more than one IP).
+    """
+    host = get_param(params, "host")
+    if not host or host.startswith("/") or host[1:2] == ":":
+        # Local path, or no host to resolve
+        return [params]
+
+    hostaddr = get_param(params, "hostaddr")
+    if hostaddr:
+        # Already resolved
+        return [params]
+
+    if is_ip_address(host):
+        # If the host is already an ip address don't try to resolve it
+        return [{**params, "hostaddr": host}]
+
+    port = get_param(params, "port")
+    if not port:
+        port_def = get_param_def("port")
+        port = port_def and port_def.compiled or "5432"
+
+    loop = asyncio.get_running_loop()
+    ans = await loop.getaddrinfo(
+        host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+    )
+    return [{**params, "hostaddr": item[4][0]} for item in ans]
diff --git a/psycopg/psycopg/_conninfo_utils.py b/psycopg/psycopg/_conninfo_utils.py
new file mode 100644 (file)
index 0000000..8940c93
--- /dev/null
@@ -0,0 +1,127 @@
+"""
+Internal utilities to manipulate connection strings
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from __future__ import annotations
+
+import os
+from typing import Any
+from functools import lru_cache
+from ipaddress import ip_address
+from dataclasses import dataclass
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+
+ConnDict: TypeAlias = "dict[str, Any]"
+
+
+def split_attempts(params: ConnDict) -> list[ConnDict]:
+    """
+    Split connection parameters with a sequence of hosts into separate attempts.
+    """
+
+    def split_val(key: str) -> list[str]:
+        val = get_param(params, key)
+        return val.split(",") if val else []
+
+    hosts = split_val("host")
+    hostaddrs = split_val("hostaddr")
+    ports = split_val("port")
+
+    if hosts and hostaddrs and len(hosts) != len(hostaddrs):
+        raise e.OperationalError(
+            f"could not match {len(hosts)} host names"
+            f" with {len(hostaddrs)} hostaddr values"
+        )
+
+    nhosts = max(len(hosts), len(hostaddrs))
+
+    if 1 < len(ports) != nhosts:
+        raise e.OperationalError(
+            f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
+        )
+
+    # A single attempt to make. Don't mangle the conninfo string.
+    if nhosts <= 1:
+        return [params]
+
+    if len(ports) == 1:
+        ports *= nhosts
+
+    # Now all lists are either empty or have the same length
+    rv = []
+    for i in range(nhosts):
+        attempt = params.copy()
+        if hosts:
+            attempt["host"] = hosts[i]
+        if hostaddrs:
+            attempt["hostaddr"] = hostaddrs[i]
+        if ports:
+            attempt["port"] = ports[i]
+        rv.append(attempt)
+
+    return rv
+
+
+def get_param(params: ConnDict, name: str) -> str | None:
+    """
+    Return a value from a connection string.
+
+    The value may be also specified in a PG* env var.
+    """
+    if name in params:
+        return str(params[name])
+
+    # TODO: check if in service
+
+    paramdef = get_param_def(name)
+    if not paramdef:
+        return None
+
+    env = os.environ.get(paramdef.envvar)
+    if env is not None:
+        return env
+
+    return None
+
+
+@dataclass
+class ParamDef:
+    """
+    Information about defaults and env vars for connection params
+    """
+
+    keyword: str
+    envvar: str
+    compiled: str | None
+
+
+def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
+    """
+    Return the ParamDef of a connection string parameter.
+    """
+    if not _cache:
+        defs = pq.Conninfo.get_defaults()
+        for d in defs:
+            cd = ParamDef(
+                keyword=d.keyword.decode(),
+                envvar=d.envvar.decode() if d.envvar else "",
+                compiled=d.compiled.decode() if d.compiled is not None else None,
+            )
+            _cache[cd.keyword] = cd
+
+    return _cache.get(keyword)
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+    """Return True if the string represent a valid ip address."""
+    try:
+        ip_address(s)
+    except ValueError:
+        return False
+    return True
index 9044550cb555fc8e851fd0723a3071dcfd071842..82da5882259057c9b668ff198d6f42a6aa4114b0 100644 (file)
@@ -6,30 +6,26 @@ Functions to manipulate conninfo strings
 
 from __future__ import annotations
 
-import os
 import re
-import socket
-import asyncio
-import logging
 from typing import Any
-from random import shuffle
-from functools import lru_cache
-from ipaddress import ip_address
-from dataclasses import dataclass
-from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 
-ConnDict: TypeAlias = "dict[str, Any]"
+from . import _conninfo_utils
+from . import _conninfo_attempts
+from . import _conninfo_attempts_async
+
+# re-exoprts
+ConnDict = _conninfo_utils.ConnDict
+conninfo_attempts = _conninfo_attempts.conninfo_attempts
+conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
 
 # Default timeout for connection a attempt.
 # Arbitrary timeout, what applied by the libpq on my computer.
 # Your mileage won't vary.
 _DEFAULT_CONNECT_TIMEOUT = 130
 
-logger = logging.getLogger("psycopg")
-
 
 def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
     """
@@ -127,149 +123,6 @@ def _param_escape(s: str) -> str:
     return s
 
 
-def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
-    """Split a set of connection params on the single attempts to perform.
-
-    A connection param can perform more than one attempt more than one ``host``
-    is provided.
-
-    Because the libpq async function doesn't honour the timeout, we need to
-    reimplement the repeated attempts.
-    """
-    # TODO: we should actually resolve the hosts ourselves.
-    # If an host resolves to more than one ip, the libpq will make more than
-    # one attempt and wouldn't get to try the following ones, as before
-    # fixing #674.
-    attempts = _split_attempts(params)
-    if _get_param(params, "load_balance_hosts") == "random":
-        shuffle(attempts)
-    return attempts
-
-
-async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
-    """Split a set of connection params on the single attempts to perform.
-
-    A connection param can perform more than one attempt more than one ``host``
-    is provided.
-
-    Also perform async resolution of the hostname into hostaddr in order to
-    avoid blocking. Because a host can resolve to more than one address, this
-    can lead to yield more attempts too. Raise `OperationalError` if no host
-    could be resolved.
-
-    Because the libpq async function doesn't honour the timeout, we need to
-    reimplement the repeated attempts.
-    """
-    last_exc = None
-    attempts = []
-    for attempt in _split_attempts(params):
-        try:
-            attempts.extend(await _resolve_hostnames(attempt))
-        except OSError as ex:
-            logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
-            last_exc = ex
-
-    if not attempts:
-        assert last_exc
-        # We couldn't resolve anything
-        raise e.OperationalError(str(last_exc))
-
-    if _get_param(params, "load_balance_hosts") == "random":
-        shuffle(attempts)
-
-    return attempts
-
-
-def _split_attempts(params: ConnDict) -> list[ConnDict]:
-    """
-    Split connection parameters with a sequence of hosts into separate attempts.
-    """
-
-    def split_val(key: str) -> list[str]:
-        val = _get_param(params, key)
-        return val.split(",") if val else []
-
-    hosts = split_val("host")
-    hostaddrs = split_val("hostaddr")
-    ports = split_val("port")
-
-    if hosts and hostaddrs and len(hosts) != len(hostaddrs):
-        raise e.OperationalError(
-            f"could not match {len(hosts)} host names"
-            f" with {len(hostaddrs)} hostaddr values"
-        )
-
-    nhosts = max(len(hosts), len(hostaddrs))
-
-    if 1 < len(ports) != nhosts:
-        raise e.OperationalError(
-            f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
-        )
-
-    # A single attempt to make. Don't mangle the conninfo string.
-    if nhosts <= 1:
-        return [params]
-
-    if len(ports) == 1:
-        ports *= nhosts
-
-    # Now all lists are either empty or have the same length
-    rv = []
-    for i in range(nhosts):
-        attempt = params.copy()
-        if hosts:
-            attempt["host"] = hosts[i]
-        if hostaddrs:
-            attempt["hostaddr"] = hostaddrs[i]
-        if ports:
-            attempt["port"] = ports[i]
-        rv.append(attempt)
-
-    return rv
-
-
-async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
-    """
-    Perform async DNS lookup of the hosts and return a new params dict.
-
-    If a ``host`` param is present but not ``hostname``, resolve the host
-    addresses asynchronously.
-
-    :param params: The input parameters, for instance as returned by
-        `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
-        a single entry for host, hostaddr because it is designed to further
-        process the input of _split_attempts().
-
-    :return: A list of attempts to make (to include the case of a hostname
-        resolving to more than one IP).
-    """
-    host = _get_param(params, "host")
-    if not host or host.startswith("/") or host[1:2] == ":":
-        # Local path, or no host to resolve
-        return [params]
-
-    hostaddr = _get_param(params, "hostaddr")
-    if hostaddr:
-        # Already resolved
-        return [params]
-
-    if is_ip_address(host):
-        # If the host is already an ip address don't try to resolve it
-        return [{**params, "hostaddr": host}]
-
-    loop = asyncio.get_running_loop()
-
-    port = _get_param(params, "port")
-    if not port:
-        port_def = _get_param_def("port")
-        port = port_def and port_def.compiled or "5432"
-
-    ans = await loop.getaddrinfo(
-        host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
-    )
-    return [{**params, "hostaddr": item[4][0]} for item in ans]
-
-
 def timeout_from_conninfo(params: ConnDict) -> int:
     """
     Return the timeout in seconds from the connection parameters.
@@ -281,7 +134,7 @@ def timeout_from_conninfo(params: ConnDict) -> int:
     # - at least 2 seconds.
     #
     # See connectDBComplete in fe-connect.c
-    value: str | int | None = _get_param(params, "connect_timeout")
+    value: str | int | None = _conninfo_utils.get_param(params, "connect_timeout")
     if value is None:
         value = _DEFAULT_CONNECT_TIMEOUT
     try:
@@ -299,63 +152,3 @@ def timeout_from_conninfo(params: ConnDict) -> int:
         timeout = 2
 
     return timeout
-
-
-def _get_param(params: ConnDict, name: str) -> str | None:
-    """
-    Return a value from a connection string.
-
-    The value may be also specified in a PG* env var.
-    """
-    if name in params:
-        return str(params[name])
-
-    # TODO: check if in service
-
-    paramdef = _get_param_def(name)
-    if not paramdef:
-        return None
-
-    env = os.environ.get(paramdef.envvar)
-    if env is not None:
-        return env
-
-    return None
-
-
-@dataclass
-class ParamDef:
-    """
-    Information about defaults and env vars for connection params
-    """
-
-    keyword: str
-    envvar: str
-    compiled: str | None
-
-
-def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
-    """
-    Return the ParamDef of a connection string parameter.
-    """
-    if not _cache:
-        defs = pq.Conninfo.get_defaults()
-        for d in defs:
-            cd = ParamDef(
-                keyword=d.keyword.decode(),
-                envvar=d.envvar.decode() if d.envvar else "",
-                compiled=d.compiled.decode() if d.compiled is not None else None,
-            )
-            _cache[cd.keyword] = cd
-
-    return _cache.get(keyword)
-
-
-@lru_cache()
-def is_ip_address(s: str) -> bool:
-    """Return True if the string represent a valid ip address."""
-    try:
-        ip_address(s)
-    except ValueError:
-        return False
-    return True
index 0ee8aea3136e7c737616092d6c5ccc3cbd3e8329..072d467cd071cefbff1e12cd588029f42cee129c 100644 (file)
@@ -393,17 +393,19 @@ def test_autocommit_unknown(conn):
     [
         ((), {}, ""),
         (("",), {}, ""),
-        (("host=foo user=bar",), {}, "host=foo user=bar"),
-        (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+        (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"),
+        (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"),
         (
             ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
             "dbname=qux user=joe port=5433",
         ),
-        (("host=foo",), {"user": None}, "host=foo"),
+        (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"),
     ],
 )
-def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want):
+def test_connect_args(
+    conn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want
+):
     got_conninfo: str
 
     def fake_connect(conninfo):
@@ -861,3 +863,20 @@ def test_connect_context_copy(conn_cls, dsn, conn):
 def test_cancel_closed(conn):
     conn.close()
     conn.cancel()
+
+
+def test_resolve_hostaddr_conn(conn_cls, monkeypatch, fake_resolve):
+    got = []
+
+    def fake_connect_gen(conninfo, **kwargs):
+        got.append(conninfo)
+        1 / 0
+
+    monkeypatch.setattr(conn_cls, "_connect_gen", fake_connect_gen)
+
+    with pytest.raises(ZeroDivisionError):
+        conn_cls.connect("host=foo.com")
+
+    assert len(got) == 1
+    want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+    assert conninfo_to_dict(got[0]) == want
index ced5db583d3e0c07557961858aa6c6d72dcc8126..1cf0403497e5a52d8bcd6a592618ea8f04a87f09 100644 (file)
@@ -398,18 +398,18 @@ async def test_autocommit_unknown(aconn):
     [
         ((), {}, ""),
         (("",), {}, ""),
-        (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
-        (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
+        (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"),
+        (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"),
         (
             ("dbname=foo port=5433",),
             {"dbname": "qux", "user": "joe"},
             "dbname=qux user=joe port=5433",
         ),
-        (("dbname=foo",), {"user": None}, "dbname=foo"),
+        (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"),
     ],
 )
 async def test_connect_args(
-    aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
+    aconn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want
 ):
     got_conninfo: str
 
@@ -803,17 +803,17 @@ async def test_cancel_closed(aconn):
     aconn.cancel()
 
 
-async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
+async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch, fake_resolve):
     got = []
 
     def fake_connect_gen(conninfo, **kwargs):
         got.append(conninfo)
         1 / 0
 
-    monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+    monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
 
     with pytest.raises(ZeroDivisionError):
-        await psycopg.AsyncConnection.connect("host=foo.com")
+        await aconn_cls.connect("host=foo.com")
 
     assert len(got) == 1
     want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
index 2e8c44822fc896621ab75788c4017933c7b0932c..badd5d92c8063a58d0f5e19af9b64bfcd4c06677 100644 (file)
@@ -1,12 +1,7 @@
-import socket
-import asyncio
-
 import pytest
 
-import psycopg
 from psycopg import ProgrammingError
 from psycopg.conninfo import make_conninfo, conninfo_to_dict
-from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async
 from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT
 
 snowman = "\u2603"
@@ -92,246 +87,6 @@ def test_no_munging():
     assert dsnin == dsnout
 
 
-@pytest.mark.parametrize(
-    "conninfo, want, env",
-    [
-        ("", [""], None),
-        ("service=foo", ["service=foo"], None),
-        ("host='' user=bar", ["host='' user=bar"], None),
-        (
-            "host=127.0.0.1 user=bar",
-            ["host=127.0.0.1 user=bar"],
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 user=bar",
-            ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"],
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 port=5432",
-            ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"],
-            None,
-        ),
-        (
-            "host=1.1.1.1,1.1.1.1 port=5432,",
-            ["host=1.1.1.1 port=5432", "host=1.1.1.1 port=''"],
-            None,
-        ),
-        (
-            "host=foo.com port=5432",
-            ["host=foo.com port=5432"],
-            {"PGHOSTADDR": "1.2.3.4"},
-        ),
-    ],
-)
-@pytest.mark.anyio
-def test_conninfo_attempts(setpgenv, conninfo, want, env):
-    setpgenv(env)
-    params = conninfo_to_dict(conninfo)
-    attempts = conninfo_attempts(params)
-    want = list(map(conninfo_to_dict, want))
-    assert want == attempts
-
-
-@pytest.mark.parametrize(
-    "conninfo, want, env",
-    [
-        ("", [""], None),
-        ("host='' user=bar", ["host='' user=bar"], None),
-        (
-            "host=127.0.0.1 user=bar port=''",
-            ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
-            None,
-        ),
-        (
-            "host=127.0.0.1 user=bar",
-            ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 user=bar",
-            [
-                "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
-                "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
-            ],
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 port=5432",
-            [
-                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
-                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
-            ],
-            None,
-        ),
-        (
-            "host=1.1.1.1,2.2.2.2 port=5432,",
-            [
-                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
-                "host=2.2.2.2 port='' hostaddr=2.2.2.2",
-            ],
-            None,
-        ),
-        (
-            "port=5432",
-            [
-                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
-                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
-            ],
-            {"PGHOST": "1.1.1.1,2.2.2.2"},
-        ),
-        (
-            "host=foo.com port=5432",
-            ["host=foo.com port=5432"],
-            {"PGHOSTADDR": "1.2.3.4"},
-        ),
-    ],
-)
-@pytest.mark.anyio
-async def test_conninfo_attempts_async_no_resolve(
-    setpgenv, conninfo, want, env, fail_resolve
-):
-    setpgenv(env)
-    params = conninfo_to_dict(conninfo)
-    attempts = await conninfo_attempts_async(params)
-    want = list(map(conninfo_to_dict, want))
-    assert want == attempts
-
-
-@pytest.mark.parametrize(
-    "conninfo, want, env",
-    [
-        (
-            "host=foo.com,qux.com",
-            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
-            None,
-        ),
-        (
-            "host=foo.com,qux.com port=5433",
-            [
-                "host=foo.com hostaddr=1.1.1.1 port=5433",
-                "host=qux.com hostaddr=2.2.2.2 port=5433",
-            ],
-            None,
-        ),
-        (
-            "host=foo.com,qux.com port=5432,5433",
-            [
-                "host=foo.com hostaddr=1.1.1.1 port=5432",
-                "host=qux.com hostaddr=2.2.2.2 port=5433",
-            ],
-            None,
-        ),
-        (
-            "host=foo.com,foo.com port=5432,",
-            [
-                "host=foo.com hostaddr=1.1.1.1 port=5432",
-                "host=foo.com hostaddr=1.1.1.1 port=''",
-            ],
-            None,
-        ),
-        (
-            "host=foo.com,nosuchhost.com",
-            ["host=foo.com hostaddr=1.1.1.1"],
-            None,
-        ),
-        (
-            "host=foo.com, port=5432,5433",
-            ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
-            None,
-        ),
-        (
-            "host=nosuchhost.com,foo.com",
-            ["host=foo.com hostaddr=1.1.1.1"],
-            None,
-        ),
-        (
-            "host=foo.com,qux.com",
-            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
-            {},
-        ),
-        (
-            "host=dup.com",
-            ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
-            None,
-        ),
-    ],
-)
-@pytest.mark.anyio
-async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
-    params = conninfo_to_dict(conninfo)
-    attempts = await conninfo_attempts_async(params)
-    want = list(map(conninfo_to_dict, want))
-    assert want == attempts
-
-
-@pytest.mark.parametrize(
-    "conninfo, env",
-    [
-        ("host=bad1.com,bad2.com", None),
-        ("host=foo.com port=1,2", None),
-        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
-        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
-    ],
-)
-@pytest.mark.anyio
-async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve):
-    setpgenv(env)
-    params = conninfo_to_dict(conninfo)
-    with pytest.raises(psycopg.Error):
-        await conninfo_attempts_async(params)
-
-
-@pytest.mark.parametrize(
-    "conninfo, env",
-    [
-        ("host=foo.com port=1,2", None),
-        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
-        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
-    ],
-)
-@pytest.mark.anyio
-def test_conninfo_attempts_bad(setpgenv, conninfo, env):
-    setpgenv(env)
-    params = conninfo_to_dict(conninfo)
-    with pytest.raises(psycopg.Error):
-        conninfo_attempts(params)
-
-
-def test_conninfo_random():
-    hosts = [f"host{n:02d}" for n in range(50)]
-    args = {"host": ",".join(hosts)}
-    ahosts = [att["host"] for att in conninfo_attempts(args)]
-    assert ahosts == hosts
-
-    args["load_balance_hosts"] = "disable"
-    ahosts = [att["host"] for att in conninfo_attempts(args)]
-    assert ahosts == hosts
-
-    args["load_balance_hosts"] = "random"
-    ahosts = [att["host"] for att in conninfo_attempts(args)]
-    assert ahosts != hosts
-    ahosts.sort()
-    assert ahosts == hosts
-
-
-@pytest.mark.anyio
-async def test_conninfo_random_async(fake_resolve):
-    args = {"host": "alot.com"}
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
-    assert len(hostaddrs) == 20
-    assert hostaddrs == sorted(hostaddrs)
-
-    args["load_balance_hosts"] = "disable"
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
-    assert hostaddrs == sorted(hostaddrs)
-
-    args["load_balance_hosts"] = "random"
-    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
-    assert hostaddrs != sorted(hostaddrs)
-
-
 @pytest.mark.parametrize(
     "conninfo, want, env",
     [
diff --git a/tests/test_conninfo_attempts.py b/tests/test_conninfo_attempts.py
new file mode 100644 (file)
index 0000000..f7bd141
--- /dev/null
@@ -0,0 +1,181 @@
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict, conninfo_attempts
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        ("", [""], None),
+        ("service=foo", ["service=foo"], None),
+        ("host='' user=bar", ["host='' user=bar"], None),
+        (
+            "host=127.0.0.1 user=bar port=''",
+            ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
+            None,
+        ),
+        (
+            "host=127.0.0.1 user=bar",
+            ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            [
+                "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+                "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432,",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port='' hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "port=5432",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
+            {"PGHOST": "1.1.1.1,2.2.2.2"},
+        ),
+        (
+            "host=foo.com port=5432",
+            ["host=foo.com port=5432"],
+            {"PGHOSTADDR": "1.2.3.4"},
+        ),
+    ],
+)
+def test_conninfo_attempts_no_resolve(setpgenv, conninfo, want, env, fail_resolve):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    attempts = conninfo_attempts(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        (
+            "host=foo.com,qux.com",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+            None,
+        ),
+        (
+            "host=foo.com,qux.com port=5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5433",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,qux.com port=5432,5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,foo.com port=5432,",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=foo.com hostaddr=1.1.1.1 port=''",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,nosuchhost.com",
+            ["host=foo.com hostaddr=1.1.1.1"],
+            None,
+        ),
+        (
+            "host=foo.com, port=5432,5433",
+            ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
+            None,
+        ),
+        (
+            "host=nosuchhost.com,foo.com",
+            ["host=foo.com hostaddr=1.1.1.1"],
+            None,
+        ),
+        (
+            "host=foo.com,qux.com",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+            {},
+        ),
+        (
+            "host=dup.com",
+            ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
+            None,
+        ),
+    ],
+)
+def test_conninfo_attempts(conninfo, want, env, fake_resolve):
+    params = conninfo_to_dict(conninfo)
+    attempts = conninfo_attempts(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
+
+
+@pytest.mark.parametrize(
+    "conninfo, env",
+    [
+        ("host=bad1.com,bad2.com", None),
+        ("host=foo.com port=1,2", None),
+        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+    ],
+)
+def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.Error):
+        conninfo_attempts(params)
+
+
+def test_conninfo_random_multi_host():
+    hosts = [f"host{n:02d}" for n in range(50)]
+    args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))}
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "disable"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "random"
+    ahosts = [att["host"] for att in conninfo_attempts(args)]
+    assert ahosts != hosts
+    ahosts.sort()
+    assert ahosts == hosts
+
+
+def test_conninfo_random_multi_ips(fake_resolve):
+    args = {"host": "alot.com"}
+    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    assert len(hostaddrs) == 20
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "disable"
+    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "random"
+    hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+    assert hostaddrs != sorted(hostaddrs)
diff --git a/tests/test_conninfo_attempts_async.py b/tests/test_conninfo_attempts_async.py
new file mode 100644 (file)
index 0000000..bf6da88
--- /dev/null
@@ -0,0 +1,185 @@
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict, conninfo_attempts_async
+
+pytestmark = pytest.mark.anyio
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        ("", [""], None),
+        ("service=foo", ["service=foo"], None),
+        ("host='' user=bar", ["host='' user=bar"], None),
+        (
+            "host=127.0.0.1 user=bar port=''",
+            ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
+            None,
+        ),
+        (
+            "host=127.0.0.1 user=bar",
+            ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            [
+                "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+                "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432,",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port='' hostaddr=2.2.2.2",
+            ],
+            None,
+        ),
+        (
+            "port=5432",
+            [
+                "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+                "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+            ],
+            {"PGHOST": "1.1.1.1,2.2.2.2"},
+        ),
+        (
+            "host=foo.com port=5432",
+            ["host=foo.com port=5432"],
+            {"PGHOSTADDR": "1.2.3.4"},
+        ),
+    ],
+)
+async def test_conninfo_attempts_no_resolve(
+    setpgenv, conninfo, want, env, fail_resolve
+):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    attempts = await conninfo_attempts_async(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
+
+
+@pytest.mark.parametrize(
+    "conninfo, want, env",
+    [
+        (
+            "host=foo.com,qux.com",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+            None,
+        ),
+        (
+            "host=foo.com,qux.com port=5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5433",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,qux.com port=5432,5433",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=qux.com hostaddr=2.2.2.2 port=5433",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,foo.com port=5432,",
+            [
+                "host=foo.com hostaddr=1.1.1.1 port=5432",
+                "host=foo.com hostaddr=1.1.1.1 port=''",
+            ],
+            None,
+        ),
+        (
+            "host=foo.com,nosuchhost.com",
+            ["host=foo.com hostaddr=1.1.1.1"],
+            None,
+        ),
+        (
+            "host=foo.com, port=5432,5433",
+            ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
+            None,
+        ),
+        (
+            "host=nosuchhost.com,foo.com",
+            ["host=foo.com hostaddr=1.1.1.1"],
+            None,
+        ),
+        (
+            "host=foo.com,qux.com",
+            ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+            {},
+        ),
+        (
+            "host=dup.com",
+            ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
+            None,
+        ),
+    ],
+)
+async def test_conninfo_attempts(conninfo, want, env, fake_resolve):
+    params = conninfo_to_dict(conninfo)
+    attempts = await conninfo_attempts_async(params)
+    want = list(map(conninfo_to_dict, want))
+    assert want == attempts
+
+
+@pytest.mark.parametrize(
+    "conninfo, env",
+    [
+        ("host=bad1.com,bad2.com", None),
+        ("host=foo.com port=1,2", None),
+        ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+        ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+    ],
+)
+async def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve):
+    setpgenv(env)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.Error):
+        await conninfo_attempts_async(params)
+
+
+async def test_conninfo_random_multi_host():
+    hosts = [f"host{n:02d}" for n in range(50)]
+    args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))}
+    ahosts = [att["host"] for att in await conninfo_attempts_async(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "disable"
+    ahosts = [att["host"] for att in await conninfo_attempts_async(args)]
+    assert ahosts == hosts
+
+    args["load_balance_hosts"] = "random"
+    ahosts = [att["host"] for att in await conninfo_attempts_async(args)]
+    assert ahosts != hosts
+    ahosts.sort()
+    assert ahosts == hosts
+
+
+async def test_conninfo_random_multi_ips(fake_resolve):
+    args = {"host": "alot.com"}
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    assert len(hostaddrs) == 20
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "disable"
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    assert hostaddrs == sorted(hostaddrs)
+
+    args["load_balance_hosts"] = "random"
+    hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+    assert hostaddrs != sorted(hostaddrs)
index 7ea89ad681dd2a751ac600a796d288fd03cd6199..ded4f8408b14fd40097ad2460e0dfaea20b6851a 100644 (file)
@@ -4,24 +4,6 @@ import psycopg
 from psycopg.conninfo import conninfo_to_dict
 
 
-@pytest.mark.usefixtures("fake_resolve")
-async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
-    got = []
-
-    def fake_connect_gen(conninfo, **kwargs):
-        got.append(conninfo)
-        1 / 0
-
-    monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
-
-    with pytest.raises(ZeroDivisionError):
-        await aconn_cls.connect("host=foo.com")
-
-    assert len(got) == 1
-    want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-    assert conninfo_to_dict(got[0]) == want
-
-
 @pytest.mark.dns
 @pytest.mark.anyio
 async def test_resolve_hostaddr_async_warning(recwarn):
index 3c4ae3ac5fccdf3a3bd19a868782b4d737a76ca2..ffd18a9789a89a68adc79661220d0d8d6a1c957f 100644 (file)
@@ -122,17 +122,17 @@ def test_time_from_ticks(ticks, want):
     [
         ((), {}, ""),
         (("",), {}, ""),
-        (("host=foo user=bar",), {}, "host=foo user=bar"),
-        (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+        (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"),
+        (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"),
         (
-            ("host=foo port=5433",),
-            {"host": "qux", "user": "joe"},
-            "host=qux user=joe port=5433",
+            ("host=foo.com port=5433",),
+            {"host": "qux.com", "user": "joe"},
+            "host=qux.com user=joe port=5433 hostaddr=2.2.2.2",
         ),
-        (("host=foo",), {"user": None}, "host=foo"),
+        (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"),
     ],
 )
-def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv):
+def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve):
     got_conninfo: str
 
     def fake_connect(conninfo):