--- /dev/null
+# type: ignore # dnspython is currently optional and mypy fails if missing
+"""
+DNS query support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from typing import Any, Dict
+from functools import lru_cache
+from ipaddress import ip_address
+
+try:
+ from dns.resolver import Cache
+ from dns.asyncresolver import Resolver
+ from dns.exception import DNSException
+except ImportError:
+ raise ImportError(
+ "the module psycopg._dns requires the package 'dnspython' installed"
+ )
+
+from . import pq
+from . import errors as e
+
+async_resolver = Resolver()
+async_resolver.cache = Cache()
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Perform async DNS lookup of the hosts and return a new params dict.
+
+ If a ``host`` param is present but not ``hostname``, resolve the host
+ addresses dynamically.
+
+ Change ``host``, ``hostname``, ``port`` in place to allow to connect
+ without further DNS lookups (remove hosts that are not resolved, keep the
+ lists consistent).
+
+ Raise `OperationalError` if connection is not possible (e.g. no host
+ resolve, inconsistent lists length).
+
+ See `the PostgreSQL docs`__ for explanation of how these params are used,
+ and how they support multiple entries.
+
+ .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+ #LIBPQ-PARAMKEYWORDS
+
+ .. warning::
+ This function doesn't handle the ``/etc/hosts`` file.
+ """
+ if params.get("hostaddr") or not params.get("host"):
+ return params
+
+ if pq.version() < 100000:
+ # hostaddr not supported
+ return params
+
+ host = params["host"]
+
+ if host.startswith("/") or host[1:2] == ":":
+ # Local path
+ return params
+
+ hosts_in = host.split(",")
+ ports_in = str(params["port"]).split(",") if params.get("port") else []
+ if len(ports_in) <= 1:
+ # If only one port is specified, the libpq will apply it to all
+ # the hosts, so don't mangle it.
+ del ports_in[:]
+ else:
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)}"
+ " port numbers"
+ )
+ ports_out = []
+
+ hosts_out = []
+ hostaddr_out = []
+ for i, host in enumerate(hosts_in):
+ # If the host is already an ip address don't try to resolve it
+ if is_ip_address(host):
+ hosts_out.append(host)
+ hostaddr_out.append(host)
+ if ports_in:
+ ports_out.append(ports_in[i])
+ continue
+
+ try:
+ ans = await async_resolver.resolve(host)
+ except DNSException as ex:
+ # Special case localhost: on MacOS it doesn't get resolved.
+ # I assue it is just resolved by /etc/hosts, which is not handled
+ # by dnspython.
+ if host == "localhost":
+ hosts_out.append(host)
+ hostaddr_out.append("127.0.0.1")
+ if ports_in:
+ ports_out.append(ports_in[i])
+ else:
+ last_exc = ex
+ else:
+ for rdata in ans:
+ hosts_out.append(host)
+ hostaddr_out.append(rdata.address)
+ if ports_in:
+ ports_out.append(ports_in[i])
+
+ # Throw an exception if no host could be resolved
+ if not hosts_out:
+ raise e.OperationalError(str(last_exc))
+
+ out = params.copy()
+ out["host"] = ",".join(hosts_out)
+ out["hostaddr"] = ",".join(hostaddr_out)
+ if ports_in:
+ out["port"] = ",".join(ports_out)
+
+ return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+ """Return True if the string represent a valid ip address."""
+ try:
+ ip_address(s)
+ except ValueError:
+ return False
+ return True
--- /dev/null
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+
+@pytest.mark.parametrize(
+ "conninfo, want",
+ [
+ ("", ""),
+ ("host='' user=bar", "host='' user=bar"),
+ (
+ "host=127.0.0.1 user=bar",
+ "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(conninfo, want, fail_resolve):
+ params = conninfo_to_dict(conninfo)
+ params = await psycopg._dns.resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, want",
+ [
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ params = await psycopg._dns.resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo",
+ [
+ "host=bad1.com,bad2.com",
+ "host=foo.com port=1,2",
+ "host=1.1.1.1,2.2.2.2 port=5432,5433,5434",
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(conninfo, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises((TypeError, psycopg.Error)):
+ await psycopg._dns.resolve_hostaddr_async(params)
+
+
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
+ got = []
+
+ def fake_connect_gen(conninfo, **kwargs):
+ got.append(conninfo)
+ 1 / 0
+
+ monkeypatch.setattr(
+ psycopg.AsyncConnection, "_connect_gen", fake_connect_gen
+ )
+
+ # TODO: not enabled by default, but should be usable to make a subclass
+ class AsyncDnsConnection(psycopg.AsyncConnection):
+ @classmethod
+ async def _get_connection_params(cls, conninfo, **kwargs):
+ params = await super()._get_connection_params(conninfo, **kwargs)
+ params = await psycopg._dns.resolve_hostaddr_async(params)
+ return params
+
+ with pytest.raises(ZeroDivisionError):
+ await AsyncDnsConnection.connect("host=foo.com")
+
+ assert len(got) == 1
+ want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+ assert conninfo_to_dict(got[0]) == want
+
+
+@pytest.fixture
+def fake_resolve(monkeypatch):
+ _import_dnspython()
+
+ import dns.rdtypes.IN.A
+ from dns.exception import DNSException
+
+ fake_hosts = {
+ "localhost": "127.0.0.1",
+ "foo.com": "1.1.1.1",
+ "qux.com": "2.2.2.2",
+ }
+
+ async def fake_resolve_(qname):
+ try:
+ addr = fake_hosts[qname]
+ except KeyError:
+ raise DNSException(f"unknown test host: {qname}")
+ else:
+ return [dns.rdtypes.IN.A.A("IN", "A", addr)]
+
+ monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fake_resolve_)
+
+
+@pytest.fixture
+def fail_resolve(monkeypatch):
+ _import_dnspython()
+
+ async def fail_resolve_(qname):
+ pytest.fail(f"shouldn't try to resolve {qname}")
+
+ monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fail_resolve_)
+
+
+def _import_dnspython():
+ try:
+ import dns.rdtypes.IN.A # noqa: F401
+ except ImportError:
+ pytest.skip("dnspython package not available")
+
+ import psycopg._dns # noqa: F401