From dda44ecfc8ff710adab8e3fe1cf35c293cefee9b Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 22 Aug 2021 06:03:18 +0200 Subject: [PATCH] Add _dns module Add a function to allow async resolution of the `host` entry in the connection string. The module is experimental and depends on the `dnspython` external package, which is currently not installed as a dependency. --- psycopg/psycopg/_dns.py | 131 +++++++++++++++++++++++++++++++++++ tests/test_dns.py | 148 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 279 insertions(+) create mode 100644 psycopg/psycopg/_dns.py create mode 100644 tests/test_dns.py diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py new file mode 100644 index 000000000..ef2028c92 --- /dev/null +++ b/psycopg/psycopg/_dns.py @@ -0,0 +1,131 @@ +# type: ignore # dnspython is currently optional and mypy fails if missing +""" +DNS query support +""" + +# Copyright (C) 2021 The Psycopg Team + +from typing import Any, Dict +from functools import lru_cache +from ipaddress import ip_address + +try: + from dns.resolver import Cache + from dns.asyncresolver import Resolver + from dns.exception import DNSException +except ImportError: + raise ImportError( + "the module psycopg._dns requires the package 'dnspython' installed" + ) + +from . import pq +from . import errors as e + +async_resolver = Resolver() +async_resolver.cache = Cache() + + +async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform async DNS lookup of the hosts and return a new params dict. + + If a ``host`` param is present but not ``hostname``, resolve the host + addresses dynamically. + + Change ``host``, ``hostname``, ``port`` in place to allow to connect + without further DNS lookups (remove hosts that are not resolved, keep the + lists consistent). + + Raise `OperationalError` if connection is not possible (e.g. no host + resolve, inconsistent lists length). + + See `the PostgreSQL docs`__ for explanation of how these params are used, + and how they support multiple entries. + + .. __: https://www.postgresql.org/docs/current/libpq-connect.html + #LIBPQ-PARAMKEYWORDS + + .. warning:: + This function doesn't handle the ``/etc/hosts`` file. + """ + if params.get("hostaddr") or not params.get("host"): + return params + + if pq.version() < 100000: + # hostaddr not supported + return params + + host = params["host"] + + if host.startswith("/") or host[1:2] == ":": + # Local path + return params + + hosts_in = host.split(",") + ports_in = str(params["port"]).split(",") if params.get("port") else [] + if len(ports_in) <= 1: + # If only one port is specified, the libpq will apply it to all + # the hosts, so don't mangle it. + del ports_in[:] + else: + if len(ports_in) != len(hosts_in): + # ProgrammingError would have been more appropriate, but this is + # what the raise if the libpq fails connect in the same case. + raise e.OperationalError( + f"cannot match {len(hosts_in)} hosts with {len(ports_in)}" + " port numbers" + ) + ports_out = [] + + hosts_out = [] + hostaddr_out = [] + for i, host in enumerate(hosts_in): + # If the host is already an ip address don't try to resolve it + if is_ip_address(host): + hosts_out.append(host) + hostaddr_out.append(host) + if ports_in: + ports_out.append(ports_in[i]) + continue + + try: + ans = await async_resolver.resolve(host) + except DNSException as ex: + # Special case localhost: on MacOS it doesn't get resolved. + # I assue it is just resolved by /etc/hosts, which is not handled + # by dnspython. + if host == "localhost": + hosts_out.append(host) + hostaddr_out.append("127.0.0.1") + if ports_in: + ports_out.append(ports_in[i]) + else: + last_exc = ex + else: + for rdata in ans: + hosts_out.append(host) + hostaddr_out.append(rdata.address) + if ports_in: + ports_out.append(ports_in[i]) + + # Throw an exception if no host could be resolved + if not hosts_out: + raise e.OperationalError(str(last_exc)) + + out = params.copy() + out["host"] = ",".join(hosts_out) + out["hostaddr"] = ",".join(hostaddr_out) + if ports_in: + out["port"] = ",".join(ports_out) + + return out + + +@lru_cache() +def is_ip_address(s: str) -> bool: + """Return True if the string represent a valid ip address.""" + try: + ip_address(s) + except ValueError: + return False + return True diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 000000000..7ecd6e375 --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,148 @@ +import pytest + +import psycopg +from psycopg.conninfo import conninfo_to_dict + + +@pytest.mark.parametrize( + "conninfo, want", + [ + ("", ""), + ("host='' user=bar", "host='' user=bar"), + ( + "host=127.0.0.1 user=bar", + "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_no_resolve(conninfo, want, fail_resolve): + params = conninfo_to_dict(conninfo) + params = await psycopg._dns.resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, want", + [ + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + ), + ( + "host=foo.com,qux.com port=5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + ), + ( + "host=foo.com,qux.com port=5432,5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + ), + ( + "host=foo.com,nosuchhost.com", + "host=foo.com hostaddr=1.1.1.1", + ), + ( + "host=nosuchhost.com,foo.com", + "host=foo.com hostaddr=1.1.1.1", + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async(conninfo, want, fake_resolve): + params = conninfo_to_dict(conninfo) + params = await psycopg._dns.resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo", + [ + "host=bad1.com,bad2.com", + "host=foo.com port=1,2", + "host=1.1.1.1,2.2.2.2 port=5432,5433,5434", + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_bad(conninfo, fake_resolve): + params = conninfo_to_dict(conninfo) + with pytest.raises((TypeError, psycopg.Error)): + await psycopg._dns.resolve_hostaddr_async(params) + + +@pytest.mark.asyncio +async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve): + got = [] + + def fake_connect_gen(conninfo, **kwargs): + got.append(conninfo) + 1 / 0 + + monkeypatch.setattr( + psycopg.AsyncConnection, "_connect_gen", fake_connect_gen + ) + + # TODO: not enabled by default, but should be usable to make a subclass + class AsyncDnsConnection(psycopg.AsyncConnection): + @classmethod + async def _get_connection_params(cls, conninfo, **kwargs): + params = await super()._get_connection_params(conninfo, **kwargs) + params = await psycopg._dns.resolve_hostaddr_async(params) + return params + + with pytest.raises(ZeroDivisionError): + await AsyncDnsConnection.connect("host=foo.com") + + assert len(got) == 1 + want = {"host": "foo.com", "hostaddr": "1.1.1.1"} + assert conninfo_to_dict(got[0]) == want + + +@pytest.fixture +def fake_resolve(monkeypatch): + _import_dnspython() + + import dns.rdtypes.IN.A + from dns.exception import DNSException + + fake_hosts = { + "localhost": "127.0.0.1", + "foo.com": "1.1.1.1", + "qux.com": "2.2.2.2", + } + + async def fake_resolve_(qname): + try: + addr = fake_hosts[qname] + except KeyError: + raise DNSException(f"unknown test host: {qname}") + else: + return [dns.rdtypes.IN.A.A("IN", "A", addr)] + + monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fake_resolve_) + + +@pytest.fixture +def fail_resolve(monkeypatch): + _import_dnspython() + + async def fail_resolve_(qname): + pytest.fail(f"shouldn't try to resolve {qname}") + + monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fail_resolve_) + + +def _import_dnspython(): + try: + import dns.rdtypes.IN.A # noqa: F401 + except ImportError: + pytest.skip("dnspython package not available") + + import psycopg._dns # noqa: F401 -- 2.47.2