From: Daniele Varrazzo Date: Sun, 3 Jul 2022 00:25:19 +0000 (+0100) Subject: feat: don't block on address resolution in async connections X-Git-Tag: 3.1~54^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=96c7b0421baab0466bd8d0456b4498aec9c541f1;p=thirdparty%2Fpsycopg.git feat: don't block on address resolution in async connections Use the same algorithm implemented in the `_dns` module, but based on asyncio getaddrinfo: this avoids the need of the external dns package and works correctly with the /etc/hosts file. There were problems of resource leaking in Python 3.6, but as psycopg 3.1 is 3.7+ only, let's go for it! --- diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index 710113843..cab164012 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -10,8 +10,6 @@ import re from random import randint from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence from typing import TYPE_CHECKING -from functools import lru_cache -from ipaddress import ip_address from collections import defaultdict try: @@ -24,6 +22,7 @@ except ImportError: ) from . import errors as e +from .conninfo import is_ip_address if TYPE_CHECKING: from dns.rdtypes.IN.SRV import SRV @@ -327,13 +326,3 @@ class Rfc2782Resolver: del entries[i] return out - - -@lru_cache() -def is_ip_address(s: str) -> bool: - """Return True if the string represent a valid ip address.""" - try: - ip_address(s) - except ValueError: - return False - return True diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 301ffac9c..8c6571d0c 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -20,7 +20,7 @@ from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .conninfo import make_conninfo, conninfo_to_dict +from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify @@ -181,6 +181,9 @@ class AsyncConnection(BaseConnection[Row]): else: params["connect_timeout"] = None + # Resolve host addresses in non-blocking way + params = await resolve_hostaddr_async(params) + return params async def close(self) -> None: diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index 767952bca..1bc85ad5c 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -4,10 +4,15 @@ Functions to manipulate conninfo strings # Copyright (C) 2020 The Psycopg Team +import os import re +import socket +import asyncio from typing import Any, Dict, List, Optional from pathlib import Path from datetime import tzinfo +from functools import lru_cache +from ipaddress import ip_address from . import pq from . import errors as e @@ -263,3 +268,106 @@ class ConnectionInfo: def _get_pgconn_attr(self, name: str) -> str: value: bytes = getattr(self.pgconn, name) return value.decode(self.encoding) + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + :param params: The input parameters, for instance as returned by + `~psycopg.conninfo.conninfo_to_dict()`. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + The function may change the input ``host``, ``hostname``, ``port`` to allow + connecting without further DNS lookups, eventually removing hosts that are + not resolved, keeping the lists of hosts and ports consistent. + + Raise `~psycopg.OperationalError` if connection is not possible (e.g. no + host resolve, inconsistent lists length). + """ + hostaddr_arg = params.get("hostaddr", os.environ.get("PGHOSTADDR", "")) + if hostaddr_arg: + # Already resolved + return params + + host_arg: str = params.get("host", os.environ.get("PGHOST", "")) + if not host_arg: + # Nothing to resolve + return params + + hosts_in = host_arg.split(",") + port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) + ports_in = port_arg.split(",") + default_port = "5432" + + if len(ports_in) == 1: + # If only one port is specified, the libpq will apply it to all + # the hosts, so don't mangle it. + default_port = ports_in.pop() + + elif len(ports_in) > 1: + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers" + ) + ports_out = [] + + hosts_out = [] + hostaddr_out = [] + loop = asyncio.get_running_loop() + for i, host in enumerate(hosts_in): + if not host or host.startswith("/") or host[1:2] == ":": + # Local path + hosts_out.append(host) + hostaddr_out.append("") + if ports_in: + ports_out.append(ports_in[i]) + continue + + # If the host is already an ip address don't try to resolve it + if is_ip_address(host): + hosts_out.append(host) + hostaddr_out.append(host) + if ports_in: + ports_out.append(ports_in[i]) + continue + + try: + port = ports_in[i] if ports_in else default_port + ans = await loop.getaddrinfo( + host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM + ) + except OSError as ex: + last_exc = ex + else: + for item in ans: + hosts_out.append(host) + hostaddr_out.append(item[4][0]) + if ports_in: + ports_out.append(ports_in[i]) + + # Throw an exception if no host could be resolved + if not hosts_out: + raise e.OperationalError(str(last_exc)) + + out = params.copy() + out["host"] = ",".join(hosts_out) + out["hostaddr"] = ",".join(hostaddr_out) + if ports_in: + out["port"] = ",".join(ports_out) + + return out + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/tests/test_connection.py b/tests/test_connection.py index 58c655634..9922e4adf 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -706,13 +706,13 @@ def test_set_transaction_param_strange(conn): conninfo_params_timeout = [ ( "", - {"host": "localhost", "connect_timeout": None}, - ({"host": "localhost"}, None), + {"dbname": "mydb", "connect_timeout": None}, + ({"dbname": "mydb"}, None), ), ( "", - {"host": "localhost", "connect_timeout": 1}, - ({"host": "localhost", "connect_timeout": "1"}, 1), + {"dbname": "mydb", "connect_timeout": 1}, + ({"dbname": "mydb", "connect_timeout": "1"}, 1), ), ( "dbname=postgres", diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 93ee762f5..a0b4b9935 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -12,6 +12,7 @@ from .utils import gc_collect from .test_cursor import my_row_factory from .test_connection import tx_params, tx_values_map, conninfo_params_timeout from .test_adapt import make_bin_dumper, make_dumper +from .test_conninfo import fake_resolve # noqa: F401 pytestmark = pytest.mark.asyncio @@ -350,14 +351,14 @@ async def test_autocommit_unknown(aconn): [ ((), {}, ""), (("",), {}, ""), - (("host=foo user=bar",), {}, "host=foo user=bar"), - (("host=foo",), {"user": "baz"}, "host=foo user=baz"), + (("dbname=foo user=bar",), {}, "dbname=foo user=bar"), + (("dbname=foo",), {"user": "baz"}, "dbname=foo user=baz"), ( - ("host=foo port=5432",), - {"host": "qux", "user": "joe"}, - "host=qux user=joe port=5432", + ("dbname=foo port=5432",), + {"dbname": "qux", "user": "joe"}, + "dbname=qux user=joe port=5432", ), - (("host=foo",), {"user": None}, "host=foo"), + (("dbname=foo",), {"user": None}, "dbname=foo"), ], ) async def test_connect_args(monkeypatch, pgconn, args, kwargs, want): @@ -721,3 +722,20 @@ async def test_connect_context_copy(dsn, aconn): async def test_cancel_closed(aconn): await aconn.close() aconn.cancel() + + +async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): # noqa: F811 + got = [] + + def fake_connect_gen(conninfo, **kwargs): + got.append(conninfo) + 1 / 0 + + monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen) + + with pytest.raises(ZeroDivisionError): + await psycopg.AsyncConnection.connect("host=foo.com") + + assert len(got) == 1 + want = {"host": "foo.com", "hostaddr": "1.1.1.1"} + assert conninfo_to_dict(got[0]) == want diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index a661fc604..d637ecbdc 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -1,3 +1,5 @@ +import socket +import asyncio import datetime as dt import pytest @@ -5,6 +7,7 @@ import pytest import psycopg from psycopg import ProgrammingError from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from psycopg.conninfo import resolve_hostaddr_async from psycopg._encodings import pg2pyenc snowman = "\u2603" @@ -298,3 +301,140 @@ class TestConnectionInfo: cur.execute("set client_encoding to EUC_TW") with pytest.raises(psycopg.NotSupportedError): cur.execute("select 'x'") + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", "", None), + ("host='' user=bar", "host='' user=bar", None), + ( + "host=127.0.0.1 user=bar", + "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + "host=foo.com port=5432", + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_no_resolve( + monkeypatch, conninfo, want, env, fail_resolve +): + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=foo.com,qux.com port=5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + None, + ), + ( + "host=foo.com,nosuchhost.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com, port=5432,5433", + "host=foo.com, hostaddr=1.1.1.1, port=5432,5433", + None, + ), + ( + "host=nosuchhost.com,foo.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + {}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve): + if env: + for k, v in env.items(): + monkeypatch.setenv(k, v) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + await resolve_hostaddr_async(params) + + +@pytest.fixture +async def fake_resolve(monkeypatch): + fake_hosts = { + "localhost": "127.0.0.1", + "foo.com": "1.1.1.1", + "qux.com": "2.2.2.2", + } + + async def fake_getaddrinfo(host, port, **kwargs): + try: + addr = fake_hosts[host] + except KeyError: + raise OSError(f"unknown test host: {host}") + else: + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))] + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + +@pytest.fixture +async def fail_resolve(monkeypatch): + async def fail_getaddrinfo(host, port, **kwargs): + pytest.fail(f"shouldn't try to resolve {host}") + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo) diff --git a/tests/test_dns.py b/tests/test_dns.py index ca8202c21..66c708518 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -2,7 +2,6 @@ import pytest import psycopg from psycopg.conninfo import conninfo_to_dict -from psycopg.rows import Row pytestmark = [pytest.mark.dns] @@ -121,36 +120,6 @@ async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resol await psycopg._dns.resolve_hostaddr_async(params) # type: ignore[attr-defined] -@pytest.mark.asyncio -async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): - got = [] - - def fake_connect_gen(conninfo, **kwargs): - got.append(conninfo) - 1 / 0 - - monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen) - - # TODO: not enabled by default, but should be usable to make a subclass - class AsyncDnsConnection(psycopg.AsyncConnection[Row]): - @classmethod - async def _get_connection_params(cls, conninfo, **kwargs): - params = await super()._get_connection_params(conninfo, **kwargs) - params = await ( - psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined] - params - ) - ) - return params - - with pytest.raises(ZeroDivisionError): - await AsyncDnsConnection.connect("host=foo.com") - - assert len(got) == 1 - want = {"host": "foo.com", "hostaddr": "1.1.1.1"} - assert conninfo_to_dict(got[0]) == want - - @pytest.fixture def fake_resolve(monkeypatch): import_dnspython()