From: Daniele Varrazzo Date: Sat, 6 Jan 2024 16:24:48 +0000 (+0100) Subject: fix: perform multiple attemps if a host name resolve to multiple hosts X-Git-Tag: 3.1.17~3^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5f77a654e5f4c03b41fa2779f390bfca309b50e5;p=thirdparty%2Fpsycopg.git fix: perform multiple attemps if a host name resolve to multiple hosts We now perform DNS resolution in Python both in the sync and async code. Because the two code paths are now very similar, make sure they pass the same tests. Close #699. --- diff --git a/docs/news.rst b/docs/news.rst index 5fd7c5128..31e3dcd67 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -13,6 +13,8 @@ Future releases Psycopg 3.1.17 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +- Fix multiple connection attempts when a host name resolve to multiple + IP addresses (:ticket:`699`). - Use `typing.Self` as a more correct return value annotation of context managers and other self-returning methods (see :ticket:`708`). diff --git a/psycopg/psycopg/_conninfo_attempts.py b/psycopg/psycopg/_conninfo_attempts.py new file mode 100644 index 000000000..5262ab785 --- /dev/null +++ b/psycopg/psycopg/_conninfo_attempts.py @@ -0,0 +1,90 @@ +""" +Separate connection attempts from a connection string. +""" + +# Copyright (C) 2024 The Psycopg Team + +from __future__ import annotations + +import socket +import logging +from random import shuffle + +from . import errors as e +from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from ._conninfo_utils import split_attempts + +logger = logging.getLogger("psycopg") + + +def conninfo_attempts(params: ConnDict) -> list[ConnDict]: + """Split a set of connection params on the single attempts to perform. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Also perform async resolution of the hostname into hostaddr. Because a host + can resolve to more than one address, this can lead to yield more attempts + too. Raise `OperationalError` if no host could be resolved. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + last_exc = None + attempts = [] + for attempt in split_attempts(params): + try: + attempts.extend(_resolve_hostnames(attempt)) + except OSError as ex: + logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) + last_exc = ex + + if not attempts: + assert last_exc + # We couldn't resolve anything + raise e.OperationalError(str(last_exc)) + + if get_param(params, "load_balance_hosts") == "random": + shuffle(attempts) + + return attempts + + +def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: + """ + Perform DNS lookup of the hosts and return a list of connection attempts. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses asynchronously. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most + a single entry for host, hostaddr because it is designed to further + process the input of split_attempts(). + + :return: A list of attempts to make (to include the case of a hostname + resolving to more than one IP). + """ + host = get_param(params, "host") + if not host or host.startswith("/") or host[1:2] == ":": + # Local path, or no host to resolve + return [params] + + hostaddr = get_param(params, "hostaddr") + if hostaddr: + # Already resolved + return [params] + + if is_ip_address(host): + # If the host is already an ip address don't try to resolve it + return [{**params, "hostaddr": host}] + + port = get_param(params, "port") + if not port: + port_def = get_param_def("port") + port = port_def and port_def.compiled or "5432" + + ans = socket.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + return [{**params, "hostaddr": item[4][0]} for item in ans] diff --git a/psycopg/psycopg/_conninfo_attempts_async.py b/psycopg/psycopg/_conninfo_attempts_async.py new file mode 100644 index 000000000..f0f457250 --- /dev/null +++ b/psycopg/psycopg/_conninfo_attempts_async.py @@ -0,0 +1,92 @@ +""" +Separate connection attempts from a connection string. +""" + +# Copyright (C) 2024 The Psycopg Team + +from __future__ import annotations + +import socket +import asyncio +import logging +from random import shuffle + +from . import errors as e +from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def +from ._conninfo_utils import split_attempts + +logger = logging.getLogger("psycopg") + + +async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: + """Split a set of connection params on the single attempts to perform. + + A connection param can perform more than one attempt more than one ``host`` + is provided. + + Also perform async resolution of the hostname into hostaddr. Because a host + can resolve to more than one address, this can lead to yield more attempts + too. Raise `OperationalError` if no host could be resolved. + + Because the libpq async function doesn't honour the timeout, we need to + reimplement the repeated attempts. + """ + last_exc = None + attempts = [] + for attempt in split_attempts(params): + try: + attempts.extend(await _resolve_hostnames(attempt)) + except OSError as ex: + logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) + last_exc = ex + + if not attempts: + assert last_exc + # We couldn't resolve anything + raise e.OperationalError(str(last_exc)) + + if get_param(params, "load_balance_hosts") == "random": + shuffle(attempts) + + return attempts + + +async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: + """ + Perform async DNS lookup of the hosts and return a list of connection attempts. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses asynchronously. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most + a single entry for host, hostaddr because it is designed to further + process the input of split_attempts(). + + :return: A list of attempts to make (to include the case of a hostname + resolving to more than one IP). + """ + host = get_param(params, "host") + if not host or host.startswith("/") or host[1:2] == ":": + # Local path, or no host to resolve + return [params] + + hostaddr = get_param(params, "hostaddr") + if hostaddr: + # Already resolved + return [params] + + if is_ip_address(host): + # If the host is already an ip address don't try to resolve it + return [{**params, "hostaddr": host}] + + port = get_param(params, "port") + if not port: + port_def = get_param_def("port") + port = port_def and port_def.compiled or "5432" + + loop = asyncio.get_running_loop() + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + return [{**params, "hostaddr": item[4][0]} for item in ans] diff --git a/psycopg/psycopg/_conninfo_utils.py b/psycopg/psycopg/_conninfo_utils.py new file mode 100644 index 000000000..8940c937b --- /dev/null +++ b/psycopg/psycopg/_conninfo_utils.py @@ -0,0 +1,127 @@ +""" +Internal utilities to manipulate connection strings +""" + +# Copyright (C) 2024 The Psycopg Team + +from __future__ import annotations + +import os +from typing import Any +from functools import lru_cache +from ipaddress import ip_address +from dataclasses import dataclass +from typing_extensions import TypeAlias + +from . import pq +from . import errors as e + +ConnDict: TypeAlias = "dict[str, Any]" + + +def split_attempts(params: ConnDict) -> list[ConnDict]: + """ + Split connection parameters with a sequence of hosts into separate attempts. + """ + + def split_val(key: str) -> list[str]: + val = get_param(params, key) + return val.split(",") if val else [] + + hosts = split_val("host") + hostaddrs = split_val("hostaddr") + ports = split_val("port") + + if hosts and hostaddrs and len(hosts) != len(hostaddrs): + raise e.OperationalError( + f"could not match {len(hosts)} host names" + f" with {len(hostaddrs)} hostaddr values" + ) + + nhosts = max(len(hosts), len(hostaddrs)) + + if 1 < len(ports) != nhosts: + raise e.OperationalError( + f"could not match {len(ports)} port numbers to {len(hosts)} hosts" + ) + + # A single attempt to make. Don't mangle the conninfo string. + if nhosts <= 1: + return [params] + + if len(ports) == 1: + ports *= nhosts + + # Now all lists are either empty or have the same length + rv = [] + for i in range(nhosts): + attempt = params.copy() + if hosts: + attempt["host"] = hosts[i] + if hostaddrs: + attempt["hostaddr"] = hostaddrs[i] + if ports: + attempt["port"] = ports[i] + rv.append(attempt) + + return rv + + +def get_param(params: ConnDict, name: str) -> str | None: + """ + Return a value from a connection string. + + The value may be also specified in a PG* env var. + """ + if name in params: + return str(params[name]) + + # TODO: check if in service + + paramdef = get_param_def(name) + if not paramdef: + return None + + env = os.environ.get(paramdef.envvar) + if env is not None: + return env + + return None + + +@dataclass +class ParamDef: + """ + Information about defaults and env vars for connection params + """ + + keyword: str + envvar: str + compiled: str | None + + +def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None: + """ + Return the ParamDef of a connection string parameter. + """ + if not _cache: + defs = pq.Conninfo.get_defaults() + for d in defs: + cd = ParamDef( + keyword=d.keyword.decode(), + envvar=d.envvar.decode() if d.envvar else "", + compiled=d.compiled.decode() if d.compiled is not None else None, + ) + _cache[cd.keyword] = cd + + return _cache.get(keyword) + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 9044550cb..82da58822 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -6,30 +6,26 @@ Functions to manipulate conninfo strings from __future__ import annotations -import os import re -import socket -import asyncio -import logging from typing import Any -from random import shuffle -from functools import lru_cache -from ipaddress import ip_address -from dataclasses import dataclass -from typing_extensions import TypeAlias from . import pq from . import errors as e -ConnDict: TypeAlias = "dict[str, Any]" +from . import _conninfo_utils +from . import _conninfo_attempts +from . import _conninfo_attempts_async + +# re-exoprts +ConnDict = _conninfo_utils.ConnDict +conninfo_attempts = _conninfo_attempts.conninfo_attempts +conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async # Default timeout for connection a attempt. # Arbitrary timeout, what applied by the libpq on my computer. # Your mileage won't vary. _DEFAULT_CONNECT_TIMEOUT = 130 -logger = logging.getLogger("psycopg") - def make_conninfo(conninfo: str = "", **kwargs: Any) -> str: """ @@ -127,149 +123,6 @@ def _param_escape(s: str) -> str: return s -def conninfo_attempts(params: ConnDict) -> list[ConnDict]: - """Split a set of connection params on the single attempts to perform. - - A connection param can perform more than one attempt more than one ``host`` - is provided. - - Because the libpq async function doesn't honour the timeout, we need to - reimplement the repeated attempts. - """ - # TODO: we should actually resolve the hosts ourselves. - # If an host resolves to more than one ip, the libpq will make more than - # one attempt and wouldn't get to try the following ones, as before - # fixing #674. - attempts = _split_attempts(params) - if _get_param(params, "load_balance_hosts") == "random": - shuffle(attempts) - return attempts - - -async def conninfo_attempts_async(params: ConnDict) -> list[ConnDict]: - """Split a set of connection params on the single attempts to perform. - - A connection param can perform more than one attempt more than one ``host`` - is provided. - - Also perform async resolution of the hostname into hostaddr in order to - avoid blocking. Because a host can resolve to more than one address, this - can lead to yield more attempts too. Raise `OperationalError` if no host - could be resolved. - - Because the libpq async function doesn't honour the timeout, we need to - reimplement the repeated attempts. - """ - last_exc = None - attempts = [] - for attempt in _split_attempts(params): - try: - attempts.extend(await _resolve_hostnames(attempt)) - except OSError as ex: - logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) - last_exc = ex - - if not attempts: - assert last_exc - # We couldn't resolve anything - raise e.OperationalError(str(last_exc)) - - if _get_param(params, "load_balance_hosts") == "random": - shuffle(attempts) - - return attempts - - -def _split_attempts(params: ConnDict) -> list[ConnDict]: - """ - Split connection parameters with a sequence of hosts into separate attempts. - """ - - def split_val(key: str) -> list[str]: - val = _get_param(params, key) - return val.split(",") if val else [] - - hosts = split_val("host") - hostaddrs = split_val("hostaddr") - ports = split_val("port") - - if hosts and hostaddrs and len(hosts) != len(hostaddrs): - raise e.OperationalError( - f"could not match {len(hosts)} host names" - f" with {len(hostaddrs)} hostaddr values" - ) - - nhosts = max(len(hosts), len(hostaddrs)) - - if 1 < len(ports) != nhosts: - raise e.OperationalError( - f"could not match {len(ports)} port numbers to {len(hosts)} hosts" - ) - - # A single attempt to make. Don't mangle the conninfo string. - if nhosts <= 1: - return [params] - - if len(ports) == 1: - ports *= nhosts - - # Now all lists are either empty or have the same length - rv = [] - for i in range(nhosts): - attempt = params.copy() - if hosts: - attempt["host"] = hosts[i] - if hostaddrs: - attempt["hostaddr"] = hostaddrs[i] - if ports: - attempt["port"] = ports[i] - rv.append(attempt) - - return rv - - -async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: - """ - Perform async DNS lookup of the hosts and return a new params dict. - - If a ``host`` param is present but not ``hostname``, resolve the host - addresses asynchronously. - - :param params: The input parameters, for instance as returned by - `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most - a single entry for host, hostaddr because it is designed to further - process the input of _split_attempts(). - - :return: A list of attempts to make (to include the case of a hostname - resolving to more than one IP). - """ - host = _get_param(params, "host") - if not host or host.startswith("/") or host[1:2] == ":": - # Local path, or no host to resolve - return [params] - - hostaddr = _get_param(params, "hostaddr") - if hostaddr: - # Already resolved - return [params] - - if is_ip_address(host): - # If the host is already an ip address don't try to resolve it - return [{**params, "hostaddr": host}] - - loop = asyncio.get_running_loop() - - port = _get_param(params, "port") - if not port: - port_def = _get_param_def("port") - port = port_def and port_def.compiled or "5432" - - ans = await loop.getaddrinfo( - host, int(port), proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM - ) - return [{**params, "hostaddr": item[4][0]} for item in ans] - - def timeout_from_conninfo(params: ConnDict) -> int: """ Return the timeout in seconds from the connection parameters. @@ -281,7 +134,7 @@ def timeout_from_conninfo(params: ConnDict) -> int: # - at least 2 seconds. # # See connectDBComplete in fe-connect.c - value: str | int | None = _get_param(params, "connect_timeout") + value: str | int | None = _conninfo_utils.get_param(params, "connect_timeout") if value is None: value = _DEFAULT_CONNECT_TIMEOUT try: @@ -299,63 +152,3 @@ def timeout_from_conninfo(params: ConnDict) -> int: timeout = 2 return timeout - - -def _get_param(params: ConnDict, name: str) -> str | None: - """ - Return a value from a connection string. - - The value may be also specified in a PG* env var. - """ - if name in params: - return str(params[name]) - - # TODO: check if in service - - paramdef = _get_param_def(name) - if not paramdef: - return None - - env = os.environ.get(paramdef.envvar) - if env is not None: - return env - - return None - - -@dataclass -class ParamDef: - """ - Information about defaults and env vars for connection params - """ - - keyword: str - envvar: str - compiled: str | None - - -def _get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None: - """ - Return the ParamDef of a connection string parameter. - """ - if not _cache: - defs = pq.Conninfo.get_defaults() - for d in defs: - cd = ParamDef( - keyword=d.keyword.decode(), - envvar=d.envvar.decode() if d.envvar else "", - compiled=d.compiled.decode() if d.compiled is not None else None, - ) - _cache[cd.keyword] = cd - - return _cache.get(keyword) - - -@lru_cache() -def is_ip_address(s: str) -> bool: - """Return True if the string represent a valid ip address.""" - try: - ip_address(s) - except ValueError: - return False - return True diff --git a/tests/test_connection.py b/tests/test_connection.py index 0ee8aea31..072d467cd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -393,17 +393,19 @@ def test_autocommit_unknown(conn): [ ((), {}, ""), (("",), {}, ""), - (("host=foo user=bar",), {}, "host=foo user=bar"), - (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"), + (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"), ( ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, "dbname=qux user=joe port=5433", ), - (("host=foo",), {"user": None}, "host=foo"), + (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"), ], ) -def test_connect_args(conn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want): +def test_connect_args( + conn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want +): got_conninfo: str def fake_connect(conninfo): @@ -861,3 +863,20 @@ def test_connect_context_copy(conn_cls, dsn, conn): def test_cancel_closed(conn): conn.close() conn.cancel() + + +def test_resolve_hostaddr_conn(conn_cls, monkeypatch, fake_resolve): + got = [] + + def fake_connect_gen(conninfo, **kwargs): + got.append(conninfo) + 1 / 0 + + monkeypatch.setattr(conn_cls, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + conn_cls.connect("host=foo.com") + + assert len(got) == 1 + want = {"host": "foo.com", "hostaddr": "1.1.1.1"} + assert conninfo_to_dict(got[0]) == want diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index ced5db583..1cf040349 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -398,18 +398,18 @@ async def test_autocommit_unknown(aconn): [ ((), {}, ""), (("",), {}, ""), - (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), - (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), + (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"), + (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"), ( ("dbname=foo port=5433",), {"dbname": "qux", "user": "joe"}, "dbname=qux user=joe port=5433", ), - (("dbname=foo",), {"user": None}, "dbname=foo"), + (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"), ], ) async def test_connect_args( - aconn_cls, monkeypatch, setpgenv, pgconn, args, kwargs, want + aconn_cls, monkeypatch, setpgenv, pgconn, fake_resolve, args, kwargs, want ): got_conninfo: str @@ -803,17 +803,17 @@ async def test_cancel_closed(aconn): aconn.cancel() -async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): +async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch, fake_resolve): got = [] def fake_connect_gen(conninfo, **kwargs): got.append(conninfo) 1 / 0 - monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen) + monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen) with pytest.raises(ZeroDivisionError): - await psycopg.AsyncConnection.connect("host=foo.com") + await aconn_cls.connect("host=foo.com") assert len(got) == 1 want = {"host": "foo.com", "hostaddr": "1.1.1.1"} diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 2e8c44822..badd5d92c 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -1,12 +1,7 @@ -import socket -import asyncio - import pytest -import psycopg from psycopg import ProgrammingError from psycopg.conninfo import make_conninfo, conninfo_to_dict -from psycopg.conninfo import conninfo_attempts, conninfo_attempts_async from psycopg.conninfo import timeout_from_conninfo, _DEFAULT_CONNECT_TIMEOUT snowman = "\u2603" @@ -92,246 +87,6 @@ def test_no_munging(): assert dsnin == dsnout -@pytest.mark.parametrize( - "conninfo, want, env", - [ - ("", [""], None), - ("service=foo", ["service=foo"], None), - ("host='' user=bar", ["host='' user=bar"], None), - ( - "host=127.0.0.1 user=bar", - ["host=127.0.0.1 user=bar"], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 user=bar", - ["host=1.1.1.1 user=bar", "host=2.2.2.2 user=bar"], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 port=5432", - ["host=1.1.1.1 port=5432", "host=2.2.2.2 port=5432"], - None, - ), - ( - "host=1.1.1.1,1.1.1.1 port=5432,", - ["host=1.1.1.1 port=5432", "host=1.1.1.1 port=''"], - None, - ), - ( - "host=foo.com port=5432", - ["host=foo.com port=5432"], - {"PGHOSTADDR": "1.2.3.4"}, - ), - ], -) -@pytest.mark.anyio -def test_conninfo_attempts(setpgenv, conninfo, want, env): - setpgenv(env) - params = conninfo_to_dict(conninfo) - attempts = conninfo_attempts(params) - want = list(map(conninfo_to_dict, want)) - assert want == attempts - - -@pytest.mark.parametrize( - "conninfo, want, env", - [ - ("", [""], None), - ("host='' user=bar", ["host='' user=bar"], None), - ( - "host=127.0.0.1 user=bar port=''", - ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"], - None, - ), - ( - "host=127.0.0.1 user=bar", - ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 user=bar", - [ - "host=1.1.1.1 user=bar hostaddr=1.1.1.1", - "host=2.2.2.2 user=bar hostaddr=2.2.2.2", - ], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 port=5432", - [ - "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", - "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", - ], - None, - ), - ( - "host=1.1.1.1,2.2.2.2 port=5432,", - [ - "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", - "host=2.2.2.2 port='' hostaddr=2.2.2.2", - ], - None, - ), - ( - "port=5432", - [ - "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", - "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", - ], - {"PGHOST": "1.1.1.1,2.2.2.2"}, - ), - ( - "host=foo.com port=5432", - ["host=foo.com port=5432"], - {"PGHOSTADDR": "1.2.3.4"}, - ), - ], -) -@pytest.mark.anyio -async def test_conninfo_attempts_async_no_resolve( - setpgenv, conninfo, want, env, fail_resolve -): - setpgenv(env) - params = conninfo_to_dict(conninfo) - attempts = await conninfo_attempts_async(params) - want = list(map(conninfo_to_dict, want)) - assert want == attempts - - -@pytest.mark.parametrize( - "conninfo, want, env", - [ - ( - "host=foo.com,qux.com", - ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], - None, - ), - ( - "host=foo.com,qux.com port=5433", - [ - "host=foo.com hostaddr=1.1.1.1 port=5433", - "host=qux.com hostaddr=2.2.2.2 port=5433", - ], - None, - ), - ( - "host=foo.com,qux.com port=5432,5433", - [ - "host=foo.com hostaddr=1.1.1.1 port=5432", - "host=qux.com hostaddr=2.2.2.2 port=5433", - ], - None, - ), - ( - "host=foo.com,foo.com port=5432,", - [ - "host=foo.com hostaddr=1.1.1.1 port=5432", - "host=foo.com hostaddr=1.1.1.1 port=''", - ], - None, - ), - ( - "host=foo.com,nosuchhost.com", - ["host=foo.com hostaddr=1.1.1.1"], - None, - ), - ( - "host=foo.com, port=5432,5433", - ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], - None, - ), - ( - "host=nosuchhost.com,foo.com", - ["host=foo.com hostaddr=1.1.1.1"], - None, - ), - ( - "host=foo.com,qux.com", - ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], - {}, - ), - ( - "host=dup.com", - ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], - None, - ), - ], -) -@pytest.mark.anyio -async def test_conninfo_attempts_async(conninfo, want, env, fake_resolve): - params = conninfo_to_dict(conninfo) - attempts = await conninfo_attempts_async(params) - want = list(map(conninfo_to_dict, want)) - assert want == attempts - - -@pytest.mark.parametrize( - "conninfo, env", - [ - ("host=bad1.com,bad2.com", None), - ("host=foo.com port=1,2", None), - ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), - ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), - ], -) -@pytest.mark.anyio -async def test_conninfo_attempts_async_bad(setpgenv, conninfo, env, fake_resolve): - setpgenv(env) - params = conninfo_to_dict(conninfo) - with pytest.raises(psycopg.Error): - await conninfo_attempts_async(params) - - -@pytest.mark.parametrize( - "conninfo, env", - [ - ("host=foo.com port=1,2", None), - ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), - ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), - ], -) -@pytest.mark.anyio -def test_conninfo_attempts_bad(setpgenv, conninfo, env): - setpgenv(env) - params = conninfo_to_dict(conninfo) - with pytest.raises(psycopg.Error): - conninfo_attempts(params) - - -def test_conninfo_random(): - hosts = [f"host{n:02d}" for n in range(50)] - args = {"host": ",".join(hosts)} - ahosts = [att["host"] for att in conninfo_attempts(args)] - assert ahosts == hosts - - args["load_balance_hosts"] = "disable" - ahosts = [att["host"] for att in conninfo_attempts(args)] - assert ahosts == hosts - - args["load_balance_hosts"] = "random" - ahosts = [att["host"] for att in conninfo_attempts(args)] - assert ahosts != hosts - ahosts.sort() - assert ahosts == hosts - - -@pytest.mark.anyio -async def test_conninfo_random_async(fake_resolve): - args = {"host": "alot.com"} - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] - assert len(hostaddrs) == 20 - assert hostaddrs == sorted(hostaddrs) - - args["load_balance_hosts"] = "disable" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] - assert hostaddrs == sorted(hostaddrs) - - args["load_balance_hosts"] = "random" - hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] - assert hostaddrs != sorted(hostaddrs) - - @pytest.mark.parametrize( "conninfo, want, env", [ diff --git a/tests/test_conninfo_attempts.py b/tests/test_conninfo_attempts.py new file mode 100644 index 000000000..f7bd141d1 --- /dev/null +++ b/tests/test_conninfo_attempts.py @@ -0,0 +1,181 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict, conninfo_attempts + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", [""], None), + ("service=foo", ["service=foo"], None), + ("host='' user=bar", ["host='' user=bar"], None), + ( + "host=127.0.0.1 user=bar port=''", + ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"], + None, + ), + ( + "host=127.0.0.1 user=bar", + ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + [ + "host=1.1.1.1 user=bar hostaddr=1.1.1.1", + "host=2.2.2.2 user=bar hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432,", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port='' hostaddr=2.2.2.2", + ], + None, + ), + ( + "port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + ["host=foo.com port=5432"], + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +def test_conninfo_attempts_no_resolve(setpgenv, conninfo, want, env, fail_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + attempts = conninfo_attempts(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + None, + ), + ( + "host=foo.com,qux.com port=5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5433", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,foo.com port=5432,", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=foo.com hostaddr=1.1.1.1 port=''", + ], + None, + ), + ( + "host=foo.com,nosuchhost.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com, port=5432,5433", + ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], + None, + ), + ( + "host=nosuchhost.com,foo.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + {}, + ), + ( + "host=dup.com", + ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], + None, + ), + ], +) +def test_conninfo_attempts(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + attempts = conninfo_attempts(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + conninfo_attempts(params) + + +def test_conninfo_random_multi_host(): + hosts = [f"host{n:02d}" for n in range(50)] + args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))} + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "disable" + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "random" + ahosts = [att["host"] for att in conninfo_attempts(args)] + assert ahosts != hosts + ahosts.sort() + assert ahosts == hosts + + +def test_conninfo_random_multi_ips(fake_resolve): + args = {"host": "alot.com"} + hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + assert len(hostaddrs) == 20 + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "disable" + hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "random" + hostaddrs = [att["hostaddr"] for att in conninfo_attempts(args)] + assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_conninfo_attempts_async.py b/tests/test_conninfo_attempts_async.py new file mode 100644 index 000000000..bf6da880f --- /dev/null +++ b/tests/test_conninfo_attempts_async.py @@ -0,0 +1,185 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict, conninfo_attempts_async + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", [""], None), + ("service=foo", ["service=foo"], None), + ("host='' user=bar", ["host='' user=bar"], None), + ( + "host=127.0.0.1 user=bar port=''", + ["host=127.0.0.1 user=bar port='' hostaddr=127.0.0.1"], + None, + ), + ( + "host=127.0.0.1 user=bar", + ["host=127.0.0.1 user=bar hostaddr=127.0.0.1"], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + [ + "host=1.1.1.1 user=bar hostaddr=1.1.1.1", + "host=2.2.2.2 user=bar hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432,", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port='' hostaddr=2.2.2.2", + ], + None, + ), + ( + "port=5432", + [ + "host=1.1.1.1 port=5432 hostaddr=1.1.1.1", + "host=2.2.2.2 port=5432 hostaddr=2.2.2.2", + ], + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + ["host=foo.com port=5432"], + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +async def test_conninfo_attempts_no_resolve( + setpgenv, conninfo, want, env, fail_resolve +): + setpgenv(env) + params = conninfo_to_dict(conninfo) + attempts = await conninfo_attempts_async(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + None, + ), + ( + "host=foo.com,qux.com port=5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5433", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=qux.com hostaddr=2.2.2.2 port=5433", + ], + None, + ), + ( + "host=foo.com,foo.com port=5432,", + [ + "host=foo.com hostaddr=1.1.1.1 port=5432", + "host=foo.com hostaddr=1.1.1.1 port=''", + ], + None, + ), + ( + "host=foo.com,nosuchhost.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com, port=5432,5433", + ["host=foo.com hostaddr=1.1.1.1 port=5432", "host='' port=5433"], + None, + ), + ( + "host=nosuchhost.com,foo.com", + ["host=foo.com hostaddr=1.1.1.1"], + None, + ), + ( + "host=foo.com,qux.com", + ["host=foo.com hostaddr=1.1.1.1", "host=qux.com hostaddr=2.2.2.2"], + {}, + ), + ( + "host=dup.com", + ["host=dup.com hostaddr=3.3.3.3", "host=dup.com hostaddr=3.3.3.4"], + None, + ), + ], +) +async def test_conninfo_attempts(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + attempts = await conninfo_attempts_async(params) + want = list(map(conninfo_to_dict, want)) + assert want == attempts + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +async def test_conninfo_attempts_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + await conninfo_attempts_async(params) + + +async def test_conninfo_random_multi_host(): + hosts = [f"host{n:02d}" for n in range(50)] + args = {"host": ",".join(hosts), "hostaddr": ",".join(["127.0.0.1"] * len(hosts))} + ahosts = [att["host"] for att in await conninfo_attempts_async(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "disable" + ahosts = [att["host"] for att in await conninfo_attempts_async(args)] + assert ahosts == hosts + + args["load_balance_hosts"] = "random" + ahosts = [att["host"] for att in await conninfo_attempts_async(args)] + assert ahosts != hosts + ahosts.sort() + assert ahosts == hosts + + +async def test_conninfo_random_multi_ips(fake_resolve): + args = {"host": "alot.com"} + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + assert len(hostaddrs) == 20 + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "disable" + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + assert hostaddrs == sorted(hostaddrs) + + args["load_balance_hosts"] = "random" + hostaddrs = [att["hostaddr"] for att in await conninfo_attempts_async(args)] + assert hostaddrs != sorted(hostaddrs) diff --git a/tests/test_dns.py b/tests/test_dns.py index 7ea89ad68..ded4f8408 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -4,24 +4,6 @@ import psycopg from psycopg.conninfo import conninfo_to_dict -@pytest.mark.usefixtures("fake_resolve") -async def test_resolve_hostaddr_conn(aconn_cls, monkeypatch): - got = [] - - def fake_connect_gen(conninfo, **kwargs): - got.append(conninfo) - 1 / 0 - - monkeypatch.setattr(aconn_cls, "_connect_gen", fake_connect_gen) - - with pytest.raises(ZeroDivisionError): - await aconn_cls.connect("host=foo.com") - - assert len(got) == 1 - want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert conninfo_to_dict(got[0]) == want - - @pytest.mark.dns @pytest.mark.anyio async def test_resolve_hostaddr_async_warning(recwarn): diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index 3c4ae3ac5..ffd18a978 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -122,17 +122,17 @@ def test_time_from_ticks(ticks, want): [ ((), {}, ""), (("",), {}, ""), - (("host=foo user=bar",), {}, "host=foo user=bar"), - (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + (("host=foo.com user=bar",), {}, "host=foo.com user=bar hostaddr=1.1.1.1"), + (("host=foo.com",), {"user": "baz"}, "host=foo.com user=baz hostaddr=1.1.1.1"), ( - ("host=foo port=5433",), - {"host": "qux", "user": "joe"}, - "host=qux user=joe port=5433", + ("host=foo.com port=5433",), + {"host": "qux.com", "user": "joe"}, + "host=qux.com user=joe port=5433 hostaddr=2.2.2.2", ), - (("host=foo",), {"user": None}, "host=foo"), + (("host=foo.com",), {"user": None}, "host=foo.com hostaddr=1.1.1.1"), ], ) -def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv): +def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve): got_conninfo: str def fake_connect(conninfo):