From 97e9f6d0517d2ded7744c20fdf8858ee409c92ed Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sat, 6 Jan 2024 17:24:48 +0100 Subject: [PATCH] fix: perform multiple attemps if a host name resolve to multiple hosts We now perform DNS resolution in Python both in the sync and async code. Because the two code paths are now very similar, make sure they pass the same tests. Close #699. --- docs/news.rst | 2 + psycopg/psycopg/_conninfo_attempts.py | 90 ++++++++ psycopg/psycopg/_conninfo_attempts_async.py | 92 ++++++++ psycopg/psycopg/_conninfo_utils.py | 127 ++++++++++ psycopg/psycopg/conninfo.py | 225 +----------------- tests/test_connection.py | 26 ++- tests/test_connection_async.py | 24 +- tests/test_conninfo.py | 242 -------------------- tests/test_conninfo_attempts.py | 181 +++++++++++++++ tests/test_conninfo_attempts_async.py | 185 +++++++++++++++ tests/test_dns.py | 17 -- tests/test_psycopg_dbapi20.py | 14 +- 12 files changed, 735 insertions(+), 490 deletions(-) create mode 100644 psycopg/psycopg/_conninfo_attempts.py create mode 100644 psycopg/psycopg/_conninfo_attempts_async.py create mode 100644 psycopg/psycopg/_conninfo_utils.py create mode 100644 tests/test_conninfo_attempts.py create mode 100644 tests/test_conninfo_attempts_async.py diff --git a/docs/news.rst b/docs/news.rst index c7540dfca..db3234f4c 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -33,6 +33,8 @@ Psycopg 3.2 (unreleased) Psycopg 3.1.17 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +- Fix multiple connection attempts when a host name resolve to multiple + IP addresses (:ticket:`699`). - Use `typing.Self` as a more correct return value annotation of context managers and other self-returning methods (see :ticket:`708`). diff --git a/psycopg/psycopg/_conninfo_attempts.py b/psycopg/psycopg/_conninfo_attempts.py new file mode 100644 index 000000000..5262ab785 --- /dev/null +++ b/psycopg/psycopg/_conninfo_attempts.py @@ -0,0 +1,90 @@ +""" +Separate connection attempts from a connection string. +""" + +# Copyright (C) 2024 The Psycopg Team + +from __future__ import annotations + +import socket +import logging +from random import shuffle + +from . import errors as e +from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from ._conninfo_utils import split_attempts + +logger = logging.getLogger("psycopg") + + +def conninfo_attempts(params: ConnDict) -> list[ConnDict]: + """Split a set of connection params on the single attempts to perform. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Also perform async resolution of the hostname into hostaddr. Because a host + can resolve to more than one address, this can lead to yield more attempts + too. Raise `OperationalError` if no host could be resolved. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + last_exc = None + attempts = [] + for attempt in split_attempts(params): + try: + attempts.extend(_resolve_hostnames(attempt)) + except OSError as ex: + logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) + last_exc = ex + + if not attempts: + assert last_exc + # We couldn't resolve anything + raise e.OperationalError(str(last_exc)) + + if get_param(params, "load_balance_hosts") == "random": + shuffle(attempts) + + return attempts + + +def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: + """ + Perform DNS lookup of the hosts and return a list of connection attempts. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses asynchronously. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most + a single entry for host, hostaddr because it is designed to further + process the input of split_attempts(). + + :return: A list of attempts to make (to include the case of a hostname + resolving to more than one IP). + """ + host = get_param(params, "host") + if not host or host.startswith("/") or host[1:2] == ":": + # Local path, or no host to resolve + return [params] + + hostaddr = get_param(params, "hostaddr") + if hostaddr: + # Already resolved + return [params] + + if is_ip_address(host): + # If the host is already an ip address don't try to resolve it + return [{**params, "hostaddr": host}] + + port = get_param(params, "port") + if not port: + port_def = get_param_def("port") + port = port_def and port_def.compiled or "5432" + + ans = socket.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + return [{**params, "hostaddr": item[4][0]} for item in ans] diff --git a/psycopg/psycopg/_conninfo_attempts_async.py b/psycopg/psycopg/_conninfo_attempts_async.py new file mode 100644 index 000000000..037f213ff --- /dev/null +++ b/psycopg/psycopg/_conninfo_attempts_async.py @@ -0,0 +1,92 @@ +""" +Separate connection attempts from a connection string. +""" + +# Copyright (C) 2024 The Psycopg Team + +from __future__ import annotations + +import socket +import asyncio +import logging +from random import shuffle + +from . import errors as e +from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from ._conninfo_utils import split_attempts + +logger = logging.getLogger("psycopg") + + +async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: + """Split a set of connection params on the single attempts to perform. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Also perform async resolution of the hostname into hostaddr. Because a host + can resolve to more than one address, this can lead to yield more attempts + too. Raise `OperationalError` if no host could be resolved. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + last_exc = None + attempts = [] + for attempt in split_attempts(params): + try: + attempts.extend(await _resolve_hostnames_async(attempt)) + except OSError as ex: + logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) + last_exc = ex + + if not attempts: + assert last_exc + # We couldn't resolve anything + raise e.OperationalError(str(last_exc)) + + if get_param(params, "load_balance_hosts") == "random": + shuffle(attempts) + + return attempts + + +async def _resolve_hostnames_async(params: ConnDict) -> list[ConnDict]: + """ + Perform async DNS lookup of the hosts and return a list of connection attempts. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses asynchronously. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most + a single entry for host, hostaddr because it is designed to further + process the input of split_attempts(). + + :return: A list of attempts to make (to include the case of a hostname + resolving to more than one IP). + """ + host = get_param(params, "host") + if not host or host.startswith("/") or host[1:2] == ":": + # Local path, or no host to resolve + return [params] + + hostaddr = get_param(params, "hostaddr") + if hostaddr: + # Already resolved + return [params] + + if is_ip_address(host): + # If the host is already an ip address don't try to resolve it + return [{**params, "hostaddr": host}] + + port = get_param(params, "port") + if not port: + port_def = get_param_def("port") + port = port_def and port_def.compiled or "5432" + + loop = asyncio.get_running_loop() + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + return [{**params, "hostaddr": item[4][0]} for item in ans] diff --git a/psycopg/psycopg/_conninfo_utils.py b/psycopg/psycopg/_conninfo_utils.py new file mode 100644 index 000000000..8940c937b --- /dev/null +++ b/psycopg/psycopg/_conninfo_utils.py @@ -0,0 +1,127 @@ +""" +Internal utilities to manipulate connection strings +""" + +# Copyright (C) 2024 The Psycopg Team + +from __future__ import annotations + +import os +from typing import Any +from functools import lru_cache +from ipaddress import ip_address +from dataclasses import dataclass +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e + +ConnDict: TypeAlias = "dict[str, Any]" + + +def split_attempts(params: ConnDict) -> list[ConnDict]: + """ + Split connection parameters with a sequence of hosts into separate attempts. + """ + + def split_val(key: str) -> list[str]: + val = get_param(params, key) + return val.split(",") if val else [] + + hosts = split_val("host") + hostaddrs = split_val("hostaddr") + ports = split_val("port") + + if hosts and hostaddrs and len(hosts) != len(hostaddrs): + raise e.OperationalError( + f"could not match {len(hosts)} host names" + f" with {len(hostaddrs)} hostaddr values" + ) + + nhosts = max(len(hosts), len(hostaddrs)) + + if 1 < len(ports) != nhosts: + raise e.OperationalError( + f"could not match {len(ports)} port numbers to {len(hosts)} hosts" + ) + + # A single attempt to make. Don't mangle the conninfo string. + if nhosts <= 1: + return [params] + + if len(ports) == 1: + ports *= nhosts + + # Now all lists are either empty or have the same length + rv = [] + for i in range(nhosts): + attempt = params.copy() + if hosts: + attempt["host"] = hosts[i] + if hostaddrs: + attempt["hostaddr"] = hostaddrs[i] + if ports: + attempt["port"] = ports[i] + rv.append(attempt) + + return rv + + +def get_param(params: ConnDict, name: str) -> str | None: + """ + Return a value from a connection string. + + The value may be also specified in a PG* env var. + """ + if name in params: + return str(params[name]) + + # TODO: check if in service + + paramdef = get_param_def(name) + if not paramdef: + return None + + env = os.environ.get(paramdef.envvar) + if env is not None: + return env + + return None + + +@dataclass +class ParamDef: + """ + Information about defaults and env vars for connection params + """ + + keyword: str + envvar: str + compiled: str | None + + +def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None: + """ + Return the ParamDef of a connection string parameter. + """ + if not _cache: + defs = pq.Conninfo.get_defaults() + for d in defs: + cd = ParamDef( + keyword=d.keyword.decode(), + envvar=d.envvar.decode() if d.envvar else "", + compiled=d.compiled.decode() if d.compiled is not None else None, + ) + _cache[cd.keyword] = cd + + return _cache.get(keyword) + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 9044550cb..82da58822 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -6,30 +6,26 @@ Functions to manipulate conninfo strings from __future__ import annotations -import os import re -import socket -import asyncio -import logging from typing import Any -from random import shuffle -from functools import lru_cache -from ipaddress import ip_address -from dataclasses import dataclass -from typing_extensions import TypeAlias from . import pq from . import errors as e -ConnDict: TypeAlias = "dict[str, Any]" +from . import _conninfo_utils +from . import _conninfo_attempts +from . import _conninfo_attempts_async + +# re-exoprts +ConnDict = _conninfo_utils.ConnDict +conninfo_attempts = _conninfo_attempts.conninfo_attempts +conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async # Default timeout for connection a attempt. # Arbitrary timeout, what applied by the libpq on my computer. # Your mileage won't vary. _DEFAULT_CONNECT_TIMEOUT = 130 -logger = logging.getLogger("psycopg") - def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: """ @@ -127,149 +123,6 @@ def _param_escape(s: str) -> str: return s -def conninfo_attempts(params: ConnDict) -> list[ConnDict]: - """Split a set of connection params on the single attempts to perform. - - A connection param can perform more than one attempt more than one ``host`` - is provided. - - Because the libpq async function doesn't honour the timeout, we need to - reimplement the repeated attempts. - """ - # TODO: we should actually resolve the hosts ourselves. - # If an host resolves to more than one ip, the libpq will make more than - # one attempt and wouldn't get to try the following ones, as before - # fixing #674. - attempts = _split_attempts(params) - if _get_param(params, "load_balance_hosts") == "random": - shuffle(attempts) - return attempts - - -async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: - """Split a set of connection params on the single attempts to perform. - - A connection param can perform more than one attempt more than one ``host`` - is provided. - - Also perform async resolution of the hostname into hostaddr in order to - avoid blocking. Because a host can resolve to more than one address, this - can lead to yield more attempts too. Raise `OperationalError` if no host - could be resolved. - - Because the libpq async function doesn't honour the timeout, we need to - reimplement the repeated attempts. - """ - last_exc = None - attempts = [] - for attempt in _split_attempts(params): - try: - attempts.extend(await _resolve_hostnames(attempt)) - except OSError as ex: - logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) - last_exc = ex - - if not attempts: - assert last_exc - # We couldn't resolve anything - raise e.OperationalError(str(last_exc)) - - if _get_param(params, "load_balance_hosts") == "random": - shuffle(attempts) - - return attempts - - -def _split_attempts(params: ConnDict) -> list[ConnDict]: - """ - Split connection parameters with a sequence of hosts into separate attempts. - """ - - def split_val(key: str) -> list[str]: - val = _get_param(params, key) - return val.split(",") if val else [] - - hosts = split_val("host") - hostaddrs = split_val("hostaddr") - ports = split_val("port") - - if hosts and hostaddrs and len(hosts) != len(hostaddrs): - raise e.OperationalError( - f"could not match {len(hosts)} host names" - f" with {len(hostaddrs)} hostaddr values" - ) - - nhosts = max(len(hosts), len(hostaddrs)) - - if 1 < len(ports) != nhosts: - raise e.OperationalError( - f"could not match {len(ports)} port numbers to {len(hosts)} hosts" - ) - - # A single attempt to make. Don't mangle the conninfo string. - if nhosts <= 1: - return [params] - - if len(ports) == 1: - ports *= nhosts - - # Now all lists are either empty or have the same length - rv = [] - for i in range(nhosts): - attempt = params.copy() - if hosts: - attempt["host"] = hosts[i] - if hostaddrs: - attempt["hostaddr"] = hostaddrs[i] - if ports: - attempt["port"] = ports[i] - rv.append(attempt) - - return rv - - -async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: - """ - Perform async DNS lookup of the hosts and return a new params dict. - - If a ``host`` param is present but not ``hostname``, resolve the host - addresses asynchronously. - - :param params: The input parameters, for instance as returned by - `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most - a single entry for host, hostaddr because it is designed to further - process the input of _split_attempts(). - - :return: A list of attempts to make (to include the case of a hostname - resolving to more than one IP). - """ - host = _get_param(params, "host") - if not host or host.startswith("/") or host[1:2] == ":": - # Local path, or no host to resolve - return [params] - - hostaddr = _get_param(params, "hostaddr") - if hostaddr: - # Already resolved - return [params] - - if is_ip_address(host): - # If the host is already an ip address don't try to resolve it - return [{**params, "hostaddr": host}] - - loop = asyncio.get_running_loop() - - port = _get_param(params, "port") - if not port: - port_def = _get_param_def("port") - port = port_def and port_def.compiled or "5432" - - ans = await loop.getaddrinfo( - host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM - ) - return [{**params, "hostaddr": item[4][0]} for item in ans] - - def timeout_from_conninfo(params: ConnDict) -> int: """ Return the timeout in seconds from the connection parameters. @@ -281,7 +134,7 @@ def timeout_from_conninfo(params: ConnDict) -> int: # - at least 2 seconds. # # See connectDBComplete in fe-connect.c - value: str | int | None = _get_param(params, "connect_timeout") + value: str | int | None = _conninfo_utils.get_param(params, "connect_timeout") if value is None: value = _DEFAULT_CONNECT_TIMEOUT try: @@ -299,63 +152,3 @@ def timeout_from_conninfo(params: ConnDict) -> int: timeout = 2 return timeout - - -def _get_param(params: ConnDict, name: str) -> str | None: - """ - Return a value from a connection string. - - The value may be also specified in a PG* env var. - """ - if name in params: - return str(params[name]) - - # TODO: check if in service - - paramdef = _get_param_def(name) - if not paramdef: - return None - - env = os.environ.get(paramdef.envvar) - if env is not None: - return env - - return None - - -@dataclass -class ParamDef: - """ - Information about defaults and env vars for connection params - """ - - keyword: str - envvar: str - compiled: str | None - - -def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None: - """ - Return the ParamDef of a connection string parameter. - """ - if not _cache: - defs = pq.Conninfo.get_defaults() - for d in defs: - cd = ParamDef( - keyword=d.keyword.decode(), - envvar=d.envvar.decode() if d.envvar else "", - compiled=d.compiled.decode() if d.compiled is not None else None, - ) - _cache[cd.keyword] = cd - - return _cache.get(keyword) - - -@lru_cache() -def is_ip_address(s: str) -> bool: - """Return True if the string represent a valid ip address.""" - try: - ip_address(s) - except ValueError: - return False - return True diff --git a/tests/test_connection.py b/tests/test_connection.py index 9cfa7459f..8456dba45 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -430,17 +430,19 @@ def test_autocommit_unknown(conn): [ ((), {}, ""), (("",), {}, ""), - (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), - (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), + (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"), + (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"), ( ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, "dbname=qux user=joe port=5433", ), - (("dbname=foo",), {"user": None}, "dbname=foo"), + (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"), ], ) -def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want): +def test_connect_args( + conn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want +): got_conninfo: str def fake_connect(conninfo): @@ -857,3 +859,19 @@ def test_connect_context_copy(conn_cls, dsn, conn): def test_cancel_closed(conn): conn.close() conn.cancel() + + +def test_resolve_hostaddr_conn(conn_cls, monkeypatch, fake_resolve): + got = "" + + def fake_connect_gen(conninfo, **kwargs): + nonlocal got + got = conninfo + 1 / 0 + + monkeypatch.setattr(conn_cls, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + conn_cls.connect("host=foo.com") + + assert conninfo_to_dict(got) == {"host": "foo.com", "hostaddr": "1.1.1.1"} diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 754e57b38..2e950aff4 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -428,18 +428,18 @@ async def test_autocommit_unknown(aconn): [ ((), {}, ""), (("",), {}, ""), - (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), - (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), + (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"), + (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"), ( ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, "dbname=qux user=joe port=5433", ), - (("dbname=foo",), {"user": None}, "dbname=foo"), + (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"), ], ) async def test_connect_args( - aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want + aconn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want ): got_conninfo: str @@ -865,3 +865,19 @@ async def test_connect_context_copy(aconn_cls, dsn, aconn): async def test_cancel_closed(aconn): await aconn.close() aconn.cancel() + + +async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch, fake_resolve): + got = "" + + def fake_connect_gen(conninfo, **kwargs): + nonlocal got + got = conninfo + 1 / 0 + + monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + await aconn_cls.connect("host=foo.com") + + assert conninfo_to_dict(got) == {"host": "foo.com", "hostaddr": "1.1.1.1"} diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 8c0208d07..badd5d92c 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -1,9 +1,7 @@ import pytest -import psycopg from psycopg import ProgrammingError from psycopg.conninfo import make_conninfo, conninfo_to_dict -from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT snowman = "\u2603" @@ -89,246 +87,6 @@ def test_no_munging(): assert dsnin == dsnout -@pytest.mark.parametrize( - "conninfo, want, env", - [ - ("", [""], None), - ("service=foo", ["service=foo"], None), - ("host='' user=bar", ["host='' user=bar"], None), - ( - "host=127.0.0.1 user=bar", - ["host=127.0.0.1 user=bar"], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 user=bar", - ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 port=5432", - ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"], - None, - ), - ( - "host=1.1.1.1,1.1.1.1 port=5432,", - ["host=1.1.1.1 port=5432", "host=1.1.1.1 port=''"], - None, - ), - ( - "host=foo.com port=5432", - ["host=foo.com port=5432"], - {"PGHOSTADDR": "1.2.3.4"}, - ), - ], -) -@pytest.mark.anyio -def test_conninfo_attempts(setpgenv, conninfo, want, env): - setpgenv(env) - params = conninfo_to_dict(conninfo) - attempts = conninfo_attempts(params) - want = list(map(conninfo_to_dict, want)) - assert want == attempts - - -@pytest.mark.parametrize( - "conninfo, want, env", - [ - ("", [""], None), - ("host='' user=bar", ["host='' user=bar"], None), - ( - "host=127.0.0.1 user=bar port=''", - ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"], - None, - ), - ( - "host=127.0.0.1 user=bar", - ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 user=bar", - [ - "host=1.1.1.1 user=bar hostaddr=1.1.1.1", - "host=2.2.2.2 user=bar hostaddr=2.2.2.2", - ], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 port=5432", - [ - "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", - "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", - ], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 port=5432,", - [ - "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", - "host=2.2.2.2 port='' hostaddr=2.2.2.2", - ], - None, - ), - ( - "port=5432", - [ - "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", - "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", - ], - {"PGHOST": "1.1.1.1,2.2.2.2"}, - ), - ( - "host=foo.com port=5432", - ["host=foo.com port=5432"], - {"PGHOSTADDR": "1.2.3.4"}, - ), - ], -) -@pytest.mark.anyio -async def test_conninfo_attempts_async_no_resolve( - setpgenv, conninfo, want, env, fail_resolve -): - setpgenv(env) - params = conninfo_to_dict(conninfo) - attempts = await conninfo_attempts_async(params) - want = list(map(conninfo_to_dict, want)) - assert want == attempts - - -@pytest.mark.parametrize( - "conninfo, want, env", - [ - ( - "host=foo.com,qux.com", - ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], - None, - ), - ( - "host=foo.com,qux.com port=5433", - [ - "host=foo.com hostaddr=1.1.1.1 port=5433", - "host=qux.com hostaddr=2.2.2.2 port=5433", - ], - None, - ), - ( - "host=foo.com,qux.com port=5432,5433", - [ - "host=foo.com hostaddr=1.1.1.1 port=5432", - "host=qux.com hostaddr=2.2.2.2 port=5433", - ], - None, - ), - ( - "host=foo.com,foo.com port=5432,", - [ - "host=foo.com hostaddr=1.1.1.1 port=5432", - "host=foo.com hostaddr=1.1.1.1 port=''", - ], - None, - ), - ( - "host=foo.com,nosuchhost.com", - ["host=foo.com hostaddr=1.1.1.1"], - None, - ), - ( - "host=foo.com, port=5432,5433", - ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], - None, - ), - ( - "host=nosuchhost.com,foo.com", - ["host=foo.com hostaddr=1.1.1.1"], - None, - ), - ( - "host=foo.com,qux.com", - ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], - {}, - ), - ( - "host=dup.com", - ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], - None, - ), - ], -) -@pytest.mark.anyio -async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve): - params = conninfo_to_dict(conninfo) - attempts = await conninfo_attempts_async(params) - want = list(map(conninfo_to_dict, want)) - assert want == attempts - - -@pytest.mark.parametrize( - "conninfo, env", - [ - ("host=bad1.com,bad2.com", None), - ("host=foo.com port=1,2", None), - ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), - ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), - ], -) -@pytest.mark.anyio -async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve): - setpgenv(env) - params = conninfo_to_dict(conninfo) - with pytest.raises(psycopg.Error): - await conninfo_attempts_async(params) - - -@pytest.mark.parametrize( - "conninfo, env", - [ - ("host=foo.com port=1,2", None), - ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), - ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), - ], -) -@pytest.mark.anyio -def test_conninfo_attempts_bad(setpgenv, conninfo, env): - setpgenv(env) - params = conninfo_to_dict(conninfo) - with pytest.raises(psycopg.Error): - conninfo_attempts(params) - - -def test_conninfo_random(): - hosts = [f"host{n:02d}" for n in range(50)] - args = {"host": ",".join(hosts)} - ahosts = [att["host"] for att in conninfo_attempts(args)] - assert ahosts == hosts - - args["load_balance_hosts"] = "disable" - ahosts = [att["host"] for att in conninfo_attempts(args)] - assert ahosts == hosts - - args["load_balance_hosts"] = "random" - ahosts = [att["host"] for att in conninfo_attempts(args)] - assert ahosts != hosts - ahosts.sort() - assert ahosts == hosts - - -@pytest.mark.anyio -async def test_conninfo_random_async(fake_resolve): - args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] - assert len(hostaddrs) == 20 - assert hostaddrs == sorted(hostaddrs) - - args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] - assert hostaddrs == sorted(hostaddrs) - - args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] - assert hostaddrs != sorted(hostaddrs) - - @pytest.mark.parametrize( "conninfo, want, env", [ diff --git a/tests/test_conninfo_attempts.py b/tests/test_conninfo_attempts.py new file mode 100644 index 000000000..f7bd141d1 --- /dev/null +++ b/tests/test_conninfo_attempts.py @@ -0,0 +1,181 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict, conninfo_attempts + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", [""], None), + ("service=foo", ["service=foo"], None), + ("host='' user=bar", ["host='' user=bar"], None), + ( + "host=127.0.0.1 user=bar port=''", + ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"], + None, + ), + ( + "host=127.0.0.1 user=bar", + ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + [ + "host=1.1.1.1 user=bar hostaddr=1.1.1.1", + "host=2.2.2.2 user=bar hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432,", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port='' hostaddr=2.2.2.2", + ], + None, + ), + ( + "port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + ["host=foo.com port=5432"], + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +def test_conninfo_attempts_no_resolve(setpgenv, conninfo, want, env, fail_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + attempts = conninfo_attempts(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + None, + ), + ( + "host=foo.com,qux.com port=5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5433", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,foo.com port=5432,", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=foo.com hostaddr=1.1.1.1 port=''", + ], + None, + ), + ( + "host=foo.com,nosuchhost.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com, port=5432,5433", + ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], + None, + ), + ( + "host=nosuchhost.com,foo.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + {}, + ), + ( + "host=dup.com", + ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], + None, + ), + ], +) +def test_conninfo_attempts(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + attempts = conninfo_attempts(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + conninfo_attempts(params) + + +def test_conninfo_random_multi_host(): + hosts = [f"host{n:02d}" for n in range(50)] + args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))} + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "disable" + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "random" + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts != hosts + ahosts.sort() + assert ahosts == hosts + + +def test_conninfo_random_multi_ips(fake_resolve): + args = {"host": "alot.com"} + hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + assert len(hostaddrs) == 20 + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "disable" + hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "random" + hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_conninfo_attempts_async.py b/tests/test_conninfo_attempts_async.py new file mode 100644 index 000000000..bf6da880f --- /dev/null +++ b/tests/test_conninfo_attempts_async.py @@ -0,0 +1,185 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict, conninfo_attempts_async + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", [""], None), + ("service=foo", ["service=foo"], None), + ("host='' user=bar", ["host='' user=bar"], None), + ( + "host=127.0.0.1 user=bar port=''", + ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"], + None, + ), + ( + "host=127.0.0.1 user=bar", + ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + [ + "host=1.1.1.1 user=bar hostaddr=1.1.1.1", + "host=2.2.2.2 user=bar hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432,", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port='' hostaddr=2.2.2.2", + ], + None, + ), + ( + "port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + ["host=foo.com port=5432"], + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +async def test_conninfo_attempts_no_resolve( + setpgenv, conninfo, want, env, fail_resolve +): + setpgenv(env) + params = conninfo_to_dict(conninfo) + attempts = await conninfo_attempts_async(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + None, + ), + ( + "host=foo.com,qux.com port=5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5433", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,foo.com port=5432,", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=foo.com hostaddr=1.1.1.1 port=''", + ], + None, + ), + ( + "host=foo.com,nosuchhost.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com, port=5432,5433", + ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], + None, + ), + ( + "host=nosuchhost.com,foo.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + {}, + ), + ( + "host=dup.com", + ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], + None, + ), + ], +) +async def test_conninfo_attempts(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + attempts = await conninfo_attempts_async(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +async def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + await conninfo_attempts_async(params) + + +async def test_conninfo_random_multi_host(): + hosts = [f"host{n:02d}" for n in range(50)] + args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))} + ahosts = [att["host"] for att in await conninfo_attempts_async(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "disable" + ahosts = [att["host"] for att in await conninfo_attempts_async(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "random" + ahosts = [att["host"] for att in await conninfo_attempts_async(args)] + assert ahosts != hosts + ahosts.sort() + assert ahosts == hosts + + +async def test_conninfo_random_multi_ips(fake_resolve): + args = {"host": "alot.com"} + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + assert len(hostaddrs) == 20 + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "disable" + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "random" + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_dns.py b/tests/test_dns.py index b5dcbaaf4..ded4f8408 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -4,23 +4,6 @@ import psycopg from psycopg.conninfo import conninfo_to_dict -async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch, fake_resolve): - got = [] - - def fake_connect_gen(conninfo, **kwargs): - got.append(conninfo) - 1 / 0 - - monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen) - - with pytest.raises(ZeroDivisionError): - await aconn_cls.connect("host=foo.com") - - assert len(got) == 1 - want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert conninfo_to_dict(got[0]) == want - - @pytest.mark.dns @pytest.mark.anyio async def test_resolve_hostaddr_async_warning(recwarn): diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 2e429eac9..a89344974 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -122,17 +122,17 @@ def test_time_from_ticks(ticks, want): [ ((), {}, ""), (("",), {}, ""), - (("host=foo user=bar",), {}, "host=foo user=bar"), - (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"), + (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"), ( - ("host=foo port=5433",), - {"host": "qux", "user": "joe"}, - "host=qux user=joe port=5433", + ("host=foo.com port=5433",), + {"host": "qux.com", "user": "joe"}, + "host=qux.com user=joe port=5433 hostaddr=2.2.2.2", ), - (("host=foo",), {"user": None}, "host=foo"), + (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"), ], ) -def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv): +def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve): got_conninfo: str def fake_connect(conninfo): -- 2.47.3