We now perform DNS resolution in Python both in the sync and async code.
Because the two code paths are now very similar, make sure they pass the
same tests.
Close #699.
Psycopg 3.1.17 (unreleased)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
+- Fix multiple connection attempts when a host name resolve to multiple
+ IP addresses (:ticket:`699`).
- Use `typing.Self` as a more correct return value annotation of context
managers and other self-returning methods (see :ticket:`708`).
--- /dev/null
+"""
+Separate connection attempts from a connection string.
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from __future__ import annotations
+
+import socket
+import logging
+from random import shuffle
+
+from . import errors as e
+from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from ._conninfo_utils import split_attempts
+
+logger = logging.getLogger("psycopg")
+
+
+def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
+ """Split a set of connection params on the single attempts to perform.
+
+ A connection param can perform more than one attempt more than one ``host``
+ is provided.
+
+ Also perform async resolution of the hostname into hostaddr. Because a host
+ can resolve to more than one address, this can lead to yield more attempts
+ too. Raise `OperationalError` if no host could be resolved.
+
+ Because the libpq async function doesn't honour the timeout, we need to
+ reimplement the repeated attempts.
+ """
+ last_exc = None
+ attempts = []
+ for attempt in split_attempts(params):
+ try:
+ attempts.extend(_resolve_hostnames(attempt))
+ except OSError as ex:
+ logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
+ last_exc = ex
+
+ if not attempts:
+ assert last_exc
+ # We couldn't resolve anything
+ raise e.OperationalError(str(last_exc))
+
+ if get_param(params, "load_balance_hosts") == "random":
+ shuffle(attempts)
+
+ return attempts
+
+
+def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
+ """
+ Perform DNS lookup of the hosts and return a list of connection attempts.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses asynchronously.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+ a single entry for host, hostaddr because it is designed to further
+ process the input of split_attempts().
+
+ :return: A list of attempts to make (to include the case of a hostname
+ resolving to more than one IP).
+ """
+ host = get_param(params, "host")
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path, or no host to resolve
+ return [params]
+
+ hostaddr = get_param(params, "hostaddr")
+ if hostaddr:
+ # Already resolved
+ return [params]
+
+ if is_ip_address(host):
+ # If the host is already an ip address don't try to resolve it
+ return [{**params, "hostaddr": host}]
+
+ port = get_param(params, "port")
+ if not port:
+ port_def = get_param_def("port")
+ port = port_def and port_def.compiled or "5432"
+
+ ans = socket.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+ return [{**params, "hostaddr": item[4][0]} for item in ans]
--- /dev/null
+"""
+Separate connection attempts from a connection string.
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from __future__ import annotations
+
+import socket
+import asyncio
+import logging
+from random import shuffle
+
+from . import errors as e
+from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def
+from ._conninfo_utils import split_attempts
+
+logger = logging.getLogger("psycopg")
+
+
+async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
+ """Split a set of connection params on the single attempts to perform.
+
+ A connection param can perform more than one attempt more than one ``host``
+ is provided.
+
+ Also perform async resolution of the hostname into hostaddr. Because a host
+ can resolve to more than one address, this can lead to yield more attempts
+ too. Raise `OperationalError` if no host could be resolved.
+
+ Because the libpq async function doesn't honour the timeout, we need to
+ reimplement the repeated attempts.
+ """
+ last_exc = None
+ attempts = []
+ for attempt in split_attempts(params):
+ try:
+ attempts.extend(await _resolve_hostnames(attempt))
+ except OSError as ex:
+ logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
+ last_exc = ex
+
+ if not attempts:
+ assert last_exc
+ # We couldn't resolve anything
+ raise e.OperationalError(str(last_exc))
+
+ if get_param(params, "load_balance_hosts") == "random":
+ shuffle(attempts)
+
+ return attempts
+
+
+async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
+ """
+ Perform async DNS lookup of the hosts and return a list of connection attempts.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses asynchronously.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
+ a single entry for host, hostaddr because it is designed to further
+ process the input of split_attempts().
+
+ :return: A list of attempts to make (to include the case of a hostname
+ resolving to more than one IP).
+ """
+ host = get_param(params, "host")
+ if not host or host.startswith("/") or host[1:2] == ":":
+ # Local path, or no host to resolve
+ return [params]
+
+ hostaddr = get_param(params, "hostaddr")
+ if hostaddr:
+ # Already resolved
+ return [params]
+
+ if is_ip_address(host):
+ # If the host is already an ip address don't try to resolve it
+ return [{**params, "hostaddr": host}]
+
+ port = get_param(params, "port")
+ if not port:
+ port_def = get_param_def("port")
+ port = port_def and port_def.compiled or "5432"
+
+ loop = asyncio.get_running_loop()
+ ans = await loop.getaddrinfo(
+ host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
+ )
+ return [{**params, "hostaddr": item[4][0]} for item in ans]
--- /dev/null
+"""
+Internal utilities to manipulate connection strings
+"""
+
+# Copyright (C) 2024 The Psycopg Team
+
+from __future__ import annotations
+
+import os
+from typing import Any
+from functools import lru_cache
+from ipaddress import ip_address
+from dataclasses import dataclass
+from typing_extensions import TypeAlias
+
+from . import pq
+from . import errors as e
+
+ConnDict: TypeAlias = "dict[str, Any]"
+
+
+def split_attempts(params: ConnDict) -> list[ConnDict]:
+ """
+ Split connection parameters with a sequence of hosts into separate attempts.
+ """
+
+ def split_val(key: str) -> list[str]:
+ val = get_param(params, key)
+ return val.split(",") if val else []
+
+ hosts = split_val("host")
+ hostaddrs = split_val("hostaddr")
+ ports = split_val("port")
+
+ if hosts and hostaddrs and len(hosts) != len(hostaddrs):
+ raise e.OperationalError(
+ f"could not match {len(hosts)} host names"
+ f" with {len(hostaddrs)} hostaddr values"
+ )
+
+ nhosts = max(len(hosts), len(hostaddrs))
+
+ if 1 < len(ports) != nhosts:
+ raise e.OperationalError(
+ f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
+ )
+
+ # A single attempt to make. Don't mangle the conninfo string.
+ if nhosts <= 1:
+ return [params]
+
+ if len(ports) == 1:
+ ports *= nhosts
+
+ # Now all lists are either empty or have the same length
+ rv = []
+ for i in range(nhosts):
+ attempt = params.copy()
+ if hosts:
+ attempt["host"] = hosts[i]
+ if hostaddrs:
+ attempt["hostaddr"] = hostaddrs[i]
+ if ports:
+ attempt["port"] = ports[i]
+ rv.append(attempt)
+
+ return rv
+
+
+def get_param(params: ConnDict, name: str) -> str | None:
+ """
+ Return a value from a connection string.
+
+ The value may be also specified in a PG* env var.
+ """
+ if name in params:
+ return str(params[name])
+
+ # TODO: check if in service
+
+ paramdef = get_param_def(name)
+ if not paramdef:
+ return None
+
+ env = os.environ.get(paramdef.envvar)
+ if env is not None:
+ return env
+
+ return None
+
+
+@dataclass
+class ParamDef:
+ """
+ Information about defaults and env vars for connection params
+ """
+
+ keyword: str
+ envvar: str
+ compiled: str | None
+
+
+def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
+ """
+ Return the ParamDef of a connection string parameter.
+ """
+ if not _cache:
+ defs = pq.Conninfo.get_defaults()
+ for d in defs:
+ cd = ParamDef(
+ keyword=d.keyword.decode(),
+ envvar=d.envvar.decode() if d.envvar else "",
+ compiled=d.compiled.decode() if d.compiled is not None else None,
+ )
+ _cache[cd.keyword] = cd
+
+ return _cache.get(keyword)
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+ """Return True if the string represent a valid ip address."""
+ try:
+ ip_address(s)
+ except ValueError:
+ return False
+ return True
from __future__ import annotations
-import os
import re
-import socket
-import asyncio
-import logging
from typing import Any
-from random import shuffle
-from functools import lru_cache
-from ipaddress import ip_address
-from dataclasses import dataclass
-from typing_extensions import TypeAlias
from . import pq
from . import errors as e
-ConnDict: TypeAlias = "dict[str, Any]"
+from . import _conninfo_utils
+from . import _conninfo_attempts
+from . import _conninfo_attempts_async
+
+# re-exoprts
+ConnDict = _conninfo_utils.ConnDict
+conninfo_attempts = _conninfo_attempts.conninfo_attempts
+conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
# Default timeout for connection a attempt.
# Arbitrary timeout, what applied by the libpq on my computer.
# Your mileage won't vary.
_DEFAULT_CONNECT_TIMEOUT = 130
-logger = logging.getLogger("psycopg")
-
def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
"""
return s
-def conninfo_attempts(params: ConnDict) -> list[ConnDict]:
- """Split a set of connection params on the single attempts to perform.
-
- A connection param can perform more than one attempt more than one ``host``
- is provided.
-
- Because the libpq async function doesn't honour the timeout, we need to
- reimplement the repeated attempts.
- """
- # TODO: we should actually resolve the hosts ourselves.
- # If an host resolves to more than one ip, the libpq will make more than
- # one attempt and wouldn't get to try the following ones, as before
- # fixing #674.
- attempts = _split_attempts(params)
- if _get_param(params, "load_balance_hosts") == "random":
- shuffle(attempts)
- return attempts
-
-
-async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]:
- """Split a set of connection params on the single attempts to perform.
-
- A connection param can perform more than one attempt more than one ``host``
- is provided.
-
- Also perform async resolution of the hostname into hostaddr in order to
- avoid blocking. Because a host can resolve to more than one address, this
- can lead to yield more attempts too. Raise `OperationalError` if no host
- could be resolved.
-
- Because the libpq async function doesn't honour the timeout, we need to
- reimplement the repeated attempts.
- """
- last_exc = None
- attempts = []
- for attempt in _split_attempts(params):
- try:
- attempts.extend(await _resolve_hostnames(attempt))
- except OSError as ex:
- logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex))
- last_exc = ex
-
- if not attempts:
- assert last_exc
- # We couldn't resolve anything
- raise e.OperationalError(str(last_exc))
-
- if _get_param(params, "load_balance_hosts") == "random":
- shuffle(attempts)
-
- return attempts
-
-
-def _split_attempts(params: ConnDict) -> list[ConnDict]:
- """
- Split connection parameters with a sequence of hosts into separate attempts.
- """
-
- def split_val(key: str) -> list[str]:
- val = _get_param(params, key)
- return val.split(",") if val else []
-
- hosts = split_val("host")
- hostaddrs = split_val("hostaddr")
- ports = split_val("port")
-
- if hosts and hostaddrs and len(hosts) != len(hostaddrs):
- raise e.OperationalError(
- f"could not match {len(hosts)} host names"
- f" with {len(hostaddrs)} hostaddr values"
- )
-
- nhosts = max(len(hosts), len(hostaddrs))
-
- if 1 < len(ports) != nhosts:
- raise e.OperationalError(
- f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
- )
-
- # A single attempt to make. Don't mangle the conninfo string.
- if nhosts <= 1:
- return [params]
-
- if len(ports) == 1:
- ports *= nhosts
-
- # Now all lists are either empty or have the same length
- rv = []
- for i in range(nhosts):
- attempt = params.copy()
- if hosts:
- attempt["host"] = hosts[i]
- if hostaddrs:
- attempt["hostaddr"] = hostaddrs[i]
- if ports:
- attempt["port"] = ports[i]
- rv.append(attempt)
-
- return rv
-
-
-async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
- """
- Perform async DNS lookup of the hosts and return a new params dict.
-
- If a ``host`` param is present but not ``hostname``, resolve the host
- addresses asynchronously.
-
- :param params: The input parameters, for instance as returned by
- `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
- a single entry for host, hostaddr because it is designed to further
- process the input of _split_attempts().
-
- :return: A list of attempts to make (to include the case of a hostname
- resolving to more than one IP).
- """
- host = _get_param(params, "host")
- if not host or host.startswith("/") or host[1:2] == ":":
- # Local path, or no host to resolve
- return [params]
-
- hostaddr = _get_param(params, "hostaddr")
- if hostaddr:
- # Already resolved
- return [params]
-
- if is_ip_address(host):
- # If the host is already an ip address don't try to resolve it
- return [{**params, "hostaddr": host}]
-
- loop = asyncio.get_running_loop()
-
- port = _get_param(params, "port")
- if not port:
- port_def = _get_param_def("port")
- port = port_def and port_def.compiled or "5432"
-
- ans = await loop.getaddrinfo(
- host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
- )
- return [{**params, "hostaddr": item[4][0]} for item in ans]
-
-
def timeout_from_conninfo(params: ConnDict) -> int:
"""
Return the timeout in seconds from the connection parameters.
# - at least 2 seconds.
#
# See connectDBComplete in fe-connect.c
- value: str | int | None = _get_param(params, "connect_timeout")
+ value: str | int | None = _conninfo_utils.get_param(params, "connect_timeout")
if value is None:
value = _DEFAULT_CONNECT_TIMEOUT
try:
timeout = 2
return timeout
-
-
-def _get_param(params: ConnDict, name: str) -> str | None:
- """
- Return a value from a connection string.
-
- The value may be also specified in a PG* env var.
- """
- if name in params:
- return str(params[name])
-
- # TODO: check if in service
-
- paramdef = _get_param_def(name)
- if not paramdef:
- return None
-
- env = os.environ.get(paramdef.envvar)
- if env is not None:
- return env
-
- return None
-
-
-@dataclass
-class ParamDef:
- """
- Information about defaults and env vars for connection params
- """
-
- keyword: str
- envvar: str
- compiled: str | None
-
-
-def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
- """
- Return the ParamDef of a connection string parameter.
- """
- if not _cache:
- defs = pq.Conninfo.get_defaults()
- for d in defs:
- cd = ParamDef(
- keyword=d.keyword.decode(),
- envvar=d.envvar.decode() if d.envvar else "",
- compiled=d.compiled.decode() if d.compiled is not None else None,
- )
- _cache[cd.keyword] = cd
-
- return _cache.get(keyword)
-
-
-@lru_cache()
-def is_ip_address(s: str) -> bool:
- """Return True if the string represent a valid ip address."""
- try:
- ip_address(s)
- except ValueError:
- return False
- return True
[
((), {}, ""),
(("",), {}, ""),
- (("host=foo user=bar",), {}, "host=foo user=bar"),
- (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"),
+ (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"),
(
("dbname=foo port=5433",),
{"dbname": "qux", "user": "joe"},
"dbname=qux user=joe port=5433",
),
- (("host=foo",), {"user": None}, "host=foo"),
+ (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"),
],
)
-def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want):
+def test_connect_args(
+ conn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want
+):
got_conninfo: str
def fake_connect(conninfo):
def test_cancel_closed(conn):
conn.close()
conn.cancel()
+
+
+def test_resolve_hostaddr_conn(conn_cls, monkeypatch, fake_resolve):
+ got = []
+
+ def fake_connect_gen(conninfo, **kwargs):
+ got.append(conninfo)
+ 1 / 0
+
+ monkeypatch.setattr(conn_cls, "_connect_gen", fake_connect_gen)
+
+ with pytest.raises(ZeroDivisionError):
+ conn_cls.connect("host=foo.com")
+
+ assert len(got) == 1
+ want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+ assert conninfo_to_dict(got[0]) == want
[
((), {}, ""),
(("",), {}, ""),
- (("dbname=foo user=bar",), {}, "dbname=foo user=bar"),
- (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"),
+ (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"),
+ (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"),
(
("dbname=foo port=5433",),
{"dbname": "qux", "user": "joe"},
"dbname=qux user=joe port=5433",
),
- (("dbname=foo",), {"user": None}, "dbname=foo"),
+ (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"),
],
)
async def test_connect_args(
- aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want
+ aconn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want
):
got_conninfo: str
aconn.cancel()
-async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
+async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch, fake_resolve):
got = []
def fake_connect_gen(conninfo, **kwargs):
got.append(conninfo)
1 / 0
- monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
+ monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
with pytest.raises(ZeroDivisionError):
- await psycopg.AsyncConnection.connect("host=foo.com")
+ await aconn_cls.connect("host=foo.com")
assert len(got) == 1
want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
-import socket
-import asyncio
-
import pytest
-import psycopg
from psycopg import ProgrammingError
from psycopg.conninfo import make_conninfo, conninfo_to_dict
-from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async
from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT
snowman = "\u2603"
assert dsnin == dsnout
-@pytest.mark.parametrize(
- "conninfo, want, env",
- [
- ("", [""], None),
- ("service=foo", ["service=foo"], None),
- ("host='' user=bar", ["host='' user=bar"], None),
- (
- "host=127.0.0.1 user=bar",
- ["host=127.0.0.1 user=bar"],
- None,
- ),
- (
- "host=1.1.1.1,2.2.2.2 user=bar",
- ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"],
- None,
- ),
- (
- "host=1.1.1.1,2.2.2.2 port=5432",
- ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"],
- None,
- ),
- (
- "host=1.1.1.1,1.1.1.1 port=5432,",
- ["host=1.1.1.1 port=5432", "host=1.1.1.1 port=''"],
- None,
- ),
- (
- "host=foo.com port=5432",
- ["host=foo.com port=5432"],
- {"PGHOSTADDR": "1.2.3.4"},
- ),
- ],
-)
-@pytest.mark.anyio
-def test_conninfo_attempts(setpgenv, conninfo, want, env):
- setpgenv(env)
- params = conninfo_to_dict(conninfo)
- attempts = conninfo_attempts(params)
- want = list(map(conninfo_to_dict, want))
- assert want == attempts
-
-
-@pytest.mark.parametrize(
- "conninfo, want, env",
- [
- ("", [""], None),
- ("host='' user=bar", ["host='' user=bar"], None),
- (
- "host=127.0.0.1 user=bar port=''",
- ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
- None,
- ),
- (
- "host=127.0.0.1 user=bar",
- ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
- None,
- ),
- (
- "host=1.1.1.1,2.2.2.2 user=bar",
- [
- "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
- "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
- ],
- None,
- ),
- (
- "host=1.1.1.1,2.2.2.2 port=5432",
- [
- "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
- "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
- ],
- None,
- ),
- (
- "host=1.1.1.1,2.2.2.2 port=5432,",
- [
- "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
- "host=2.2.2.2 port='' hostaddr=2.2.2.2",
- ],
- None,
- ),
- (
- "port=5432",
- [
- "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
- "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
- ],
- {"PGHOST": "1.1.1.1,2.2.2.2"},
- ),
- (
- "host=foo.com port=5432",
- ["host=foo.com port=5432"],
- {"PGHOSTADDR": "1.2.3.4"},
- ),
- ],
-)
-@pytest.mark.anyio
-async def test_conninfo_attempts_async_no_resolve(
- setpgenv, conninfo, want, env, fail_resolve
-):
- setpgenv(env)
- params = conninfo_to_dict(conninfo)
- attempts = await conninfo_attempts_async(params)
- want = list(map(conninfo_to_dict, want))
- assert want == attempts
-
-
-@pytest.mark.parametrize(
- "conninfo, want, env",
- [
- (
- "host=foo.com,qux.com",
- ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
- None,
- ),
- (
- "host=foo.com,qux.com port=5433",
- [
- "host=foo.com hostaddr=1.1.1.1 port=5433",
- "host=qux.com hostaddr=2.2.2.2 port=5433",
- ],
- None,
- ),
- (
- "host=foo.com,qux.com port=5432,5433",
- [
- "host=foo.com hostaddr=1.1.1.1 port=5432",
- "host=qux.com hostaddr=2.2.2.2 port=5433",
- ],
- None,
- ),
- (
- "host=foo.com,foo.com port=5432,",
- [
- "host=foo.com hostaddr=1.1.1.1 port=5432",
- "host=foo.com hostaddr=1.1.1.1 port=''",
- ],
- None,
- ),
- (
- "host=foo.com,nosuchhost.com",
- ["host=foo.com hostaddr=1.1.1.1"],
- None,
- ),
- (
- "host=foo.com, port=5432,5433",
- ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
- None,
- ),
- (
- "host=nosuchhost.com,foo.com",
- ["host=foo.com hostaddr=1.1.1.1"],
- None,
- ),
- (
- "host=foo.com,qux.com",
- ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
- {},
- ),
- (
- "host=dup.com",
- ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
- None,
- ),
- ],
-)
-@pytest.mark.anyio
-async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve):
- params = conninfo_to_dict(conninfo)
- attempts = await conninfo_attempts_async(params)
- want = list(map(conninfo_to_dict, want))
- assert want == attempts
-
-
-@pytest.mark.parametrize(
- "conninfo, env",
- [
- ("host=bad1.com,bad2.com", None),
- ("host=foo.com port=1,2", None),
- ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
- ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
- ],
-)
-@pytest.mark.anyio
-async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve):
- setpgenv(env)
- params = conninfo_to_dict(conninfo)
- with pytest.raises(psycopg.Error):
- await conninfo_attempts_async(params)
-
-
-@pytest.mark.parametrize(
- "conninfo, env",
- [
- ("host=foo.com port=1,2", None),
- ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
- ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
- ],
-)
-@pytest.mark.anyio
-def test_conninfo_attempts_bad(setpgenv, conninfo, env):
- setpgenv(env)
- params = conninfo_to_dict(conninfo)
- with pytest.raises(psycopg.Error):
- conninfo_attempts(params)
-
-
-def test_conninfo_random():
- hosts = [f"host{n:02d}" for n in range(50)]
- args = {"host": ",".join(hosts)}
- ahosts = [att["host"] for att in conninfo_attempts(args)]
- assert ahosts == hosts
-
- args["load_balance_hosts"] = "disable"
- ahosts = [att["host"] for att in conninfo_attempts(args)]
- assert ahosts == hosts
-
- args["load_balance_hosts"] = "random"
- ahosts = [att["host"] for att in conninfo_attempts(args)]
- assert ahosts != hosts
- ahosts.sort()
- assert ahosts == hosts
-
-
-@pytest.mark.anyio
-async def test_conninfo_random_async(fake_resolve):
- args = {"host": "alot.com"}
- hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
- assert len(hostaddrs) == 20
- assert hostaddrs == sorted(hostaddrs)
-
- args["load_balance_hosts"] = "disable"
- hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
- assert hostaddrs == sorted(hostaddrs)
-
- args["load_balance_hosts"] = "random"
- hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
- assert hostaddrs != sorted(hostaddrs)
-
-
@pytest.mark.parametrize(
"conninfo, want, env",
[
--- /dev/null
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict, conninfo_attempts
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ ("", [""], None),
+ ("service=foo", ["service=foo"], None),
+ ("host='' user=bar", ["host='' user=bar"], None),
+ (
+ "host=127.0.0.1 user=bar port=''",
+ ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
+ None,
+ ),
+ (
+ "host=127.0.0.1 user=bar",
+ ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ [
+ "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+ "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+ ],
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+ ],
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432,",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port='' hostaddr=2.2.2.2",
+ ],
+ None,
+ ),
+ (
+ "port=5432",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+ ],
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ ["host=foo.com port=5432"],
+ {"PGHOSTADDR": "1.2.3.4"},
+ ),
+ ],
+)
+def test_conninfo_attempts_no_resolve(setpgenv, conninfo, want, env, fail_resolve):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ attempts = conninfo_attempts(params)
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ (
+ "host=foo.com,qux.com",
+ ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5433",
+ "host=qux.com hostaddr=2.2.2.2 port=5433",
+ ],
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5432",
+ "host=qux.com hostaddr=2.2.2.2 port=5433",
+ ],
+ None,
+ ),
+ (
+ "host=foo.com,foo.com port=5432,",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5432",
+ "host=foo.com hostaddr=1.1.1.1 port=''",
+ ],
+ None,
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ ["host=foo.com hostaddr=1.1.1.1"],
+ None,
+ ),
+ (
+ "host=foo.com, port=5432,5433",
+ ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
+ None,
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ ["host=foo.com hostaddr=1.1.1.1"],
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+ {},
+ ),
+ (
+ "host=dup.com",
+ ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
+ None,
+ ),
+ ],
+)
+def test_conninfo_attempts(conninfo, want, env, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ attempts = conninfo_attempts(params)
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
+
+
+@pytest.mark.parametrize(
+ "conninfo, env",
+ [
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+ ],
+)
+def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.Error):
+ conninfo_attempts(params)
+
+
+def test_conninfo_random_multi_host():
+ hosts = [f"host{n:02d}" for n in range(50)]
+ args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))}
+ ahosts = [att["host"] for att in conninfo_attempts(args)]
+ assert ahosts == hosts
+
+ args["load_balance_hosts"] = "disable"
+ ahosts = [att["host"] for att in conninfo_attempts(args)]
+ assert ahosts == hosts
+
+ args["load_balance_hosts"] = "random"
+ ahosts = [att["host"] for att in conninfo_attempts(args)]
+ assert ahosts != hosts
+ ahosts.sort()
+ assert ahosts == hosts
+
+
+def test_conninfo_random_multi_ips(fake_resolve):
+ args = {"host": "alot.com"}
+ hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+ assert len(hostaddrs) == 20
+ assert hostaddrs == sorted(hostaddrs)
+
+ args["load_balance_hosts"] = "disable"
+ hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+ assert hostaddrs == sorted(hostaddrs)
+
+ args["load_balance_hosts"] = "random"
+ hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)]
+ assert hostaddrs != sorted(hostaddrs)
--- /dev/null
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict, conninfo_attempts_async
+
+pytestmark = pytest.mark.anyio
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ ("", [""], None),
+ ("service=foo", ["service=foo"], None),
+ ("host='' user=bar", ["host='' user=bar"], None),
+ (
+ "host=127.0.0.1 user=bar port=''",
+ ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"],
+ None,
+ ),
+ (
+ "host=127.0.0.1 user=bar",
+ ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"],
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ [
+ "host=1.1.1.1 user=bar hostaddr=1.1.1.1",
+ "host=2.2.2.2 user=bar hostaddr=2.2.2.2",
+ ],
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+ ],
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432,",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port='' hostaddr=2.2.2.2",
+ ],
+ None,
+ ),
+ (
+ "port=5432",
+ [
+ "host=1.1.1.1 port=5432 hostaddr=1.1.1.1",
+ "host=2.2.2.2 port=5432 hostaddr=2.2.2.2",
+ ],
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ ["host=foo.com port=5432"],
+ {"PGHOSTADDR": "1.2.3.4"},
+ ),
+ ],
+)
+async def test_conninfo_attempts_no_resolve(
+ setpgenv, conninfo, want, env, fail_resolve
+):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ attempts = await conninfo_attempts_async(params)
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ (
+ "host=foo.com,qux.com",
+ ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5433",
+ "host=qux.com hostaddr=2.2.2.2 port=5433",
+ ],
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5432",
+ "host=qux.com hostaddr=2.2.2.2 port=5433",
+ ],
+ None,
+ ),
+ (
+ "host=foo.com,foo.com port=5432,",
+ [
+ "host=foo.com hostaddr=1.1.1.1 port=5432",
+ "host=foo.com hostaddr=1.1.1.1 port=''",
+ ],
+ None,
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ ["host=foo.com hostaddr=1.1.1.1"],
+ None,
+ ),
+ (
+ "host=foo.com, port=5432,5433",
+ ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"],
+ None,
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ ["host=foo.com hostaddr=1.1.1.1"],
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"],
+ {},
+ ),
+ (
+ "host=dup.com",
+ ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"],
+ None,
+ ),
+ ],
+)
+async def test_conninfo_attempts(conninfo, want, env, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ attempts = await conninfo_attempts_async(params)
+ want = list(map(conninfo_to_dict, want))
+ assert want == attempts
+
+
+@pytest.mark.parametrize(
+ "conninfo, env",
+ [
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+ ],
+)
+async def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.Error):
+ await conninfo_attempts_async(params)
+
+
+async def test_conninfo_random_multi_host():
+ hosts = [f"host{n:02d}" for n in range(50)]
+ args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))}
+ ahosts = [att["host"] for att in await conninfo_attempts_async(args)]
+ assert ahosts == hosts
+
+ args["load_balance_hosts"] = "disable"
+ ahosts = [att["host"] for att in await conninfo_attempts_async(args)]
+ assert ahosts == hosts
+
+ args["load_balance_hosts"] = "random"
+ ahosts = [att["host"] for att in await conninfo_attempts_async(args)]
+ assert ahosts != hosts
+ ahosts.sort()
+ assert ahosts == hosts
+
+
+async def test_conninfo_random_multi_ips(fake_resolve):
+ args = {"host": "alot.com"}
+ hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+ assert len(hostaddrs) == 20
+ assert hostaddrs == sorted(hostaddrs)
+
+ args["load_balance_hosts"] = "disable"
+ hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+ assert hostaddrs == sorted(hostaddrs)
+
+ args["load_balance_hosts"] = "random"
+ hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)]
+ assert hostaddrs != sorted(hostaddrs)
from psycopg.conninfo import conninfo_to_dict
-@pytest.mark.usefixtures("fake_resolve")
-async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch):
- got = []
-
- def fake_connect_gen(conninfo, **kwargs):
- got.append(conninfo)
- 1 / 0
-
- monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen)
-
- with pytest.raises(ZeroDivisionError):
- await aconn_cls.connect("host=foo.com")
-
- assert len(got) == 1
- want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
- assert conninfo_to_dict(got[0]) == want
-
-
@pytest.mark.dns
@pytest.mark.anyio
async def test_resolve_hostaddr_async_warning(recwarn):
[
((), {}, ""),
(("",), {}, ""),
- (("host=foo user=bar",), {}, "host=foo user=bar"),
- (("host=foo",), {"user": "baz"}, "host=foo user=baz"),
+ (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"),
+ (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"),
(
- ("host=foo port=5433",),
- {"host": "qux", "user": "joe"},
- "host=qux user=joe port=5433",
+ ("host=foo.com port=5433",),
+ {"host": "qux.com", "user": "joe"},
+ "host=qux.com user=joe port=5433 hostaddr=2.2.2.2",
),
- (("host=foo",), {"user": None}, "host=foo"),
+ (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"),
],
)
-def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv):
+def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve):
got_conninfo: str
def fake_connect(conninfo):