]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add _dns module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 22 Aug 2021 04:03:18 +0000 (06:03 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Aug 2021 17:33:20 +0000 (19:33 +0200)
Add a function to allow async resolution of the `host` entry in the
connection string.

The module is experimental and depends on the `dnspython` external
package, which is currently not installed as a dependency.

psycopg/psycopg/_dns.py [new file with mode: 0644]
tests/test_dns.py [new file with mode: 0644]

diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py
new file mode 100644 (file)
index 0000000..ef2028c
--- /dev/null
@@ -0,0 +1,131 @@
+# type: ignore  # dnspython is currently optional and mypy fails if missing
+"""
+DNS query support
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from typing import Any, Dict
+from functools import lru_cache
+from ipaddress import ip_address
+
+try:
+    from dns.resolver import Cache
+    from dns.asyncresolver import Resolver
+    from dns.exception import DNSException
+except ImportError:
+    raise ImportError(
+        "the module psycopg._dns requires the package 'dnspython' installed"
+    )
+
+from . import pq
+from . import errors as e
+
+async_resolver = Resolver()
+async_resolver.cache = Cache()
+
+
+async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
+    """
+    Perform async DNS lookup of the hosts and return a new params dict.
+
+    If a ``host`` param is present but not ``hostname``, resolve the host
+    addresses dynamically.
+
+    Change ``host``, ``hostname``, ``port`` in place to allow to connect
+    without further DNS lookups (remove hosts that are not resolved, keep the
+    lists consistent).
+
+    Raise `OperationalError` if connection is not possible (e.g. no host
+    resolve, inconsistent lists length).
+
+    See `the PostgreSQL docs`__ for explanation of how these params are used,
+    and how they support multiple entries.
+
+    .. __: https://www.postgresql.org/docs/current/libpq-connect.html
+           #LIBPQ-PARAMKEYWORDS
+
+    .. warning::
+        This function doesn't handle the ``/etc/hosts`` file.
+    """
+    if params.get("hostaddr") or not params.get("host"):
+        return params
+
+    if pq.version() < 100000:
+        # hostaddr not supported
+        return params
+
+    host = params["host"]
+
+    if host.startswith("/") or host[1:2] == ":":
+        # Local path
+        return params
+
+    hosts_in = host.split(",")
+    ports_in = str(params["port"]).split(",") if params.get("port") else []
+    if len(ports_in) <= 1:
+        # If only one port is specified, the libpq will apply it to all
+        # the hosts, so don't mangle it.
+        del ports_in[:]
+    else:
+        if len(ports_in) != len(hosts_in):
+            # ProgrammingError would have been more appropriate, but this is
+            # what the raise if the libpq fails connect in the same case.
+            raise e.OperationalError(
+                f"cannot match {len(hosts_in)} hosts with {len(ports_in)}"
+                " port numbers"
+            )
+        ports_out = []
+
+    hosts_out = []
+    hostaddr_out = []
+    for i, host in enumerate(hosts_in):
+        # If the host is already an ip address don't try to resolve it
+        if is_ip_address(host):
+            hosts_out.append(host)
+            hostaddr_out.append(host)
+            if ports_in:
+                ports_out.append(ports_in[i])
+            continue
+
+        try:
+            ans = await async_resolver.resolve(host)
+        except DNSException as ex:
+            # Special case localhost: on MacOS it doesn't get resolved.
+            # I assue it is just resolved by /etc/hosts, which is not handled
+            # by dnspython.
+            if host == "localhost":
+                hosts_out.append(host)
+                hostaddr_out.append("127.0.0.1")
+                if ports_in:
+                    ports_out.append(ports_in[i])
+            else:
+                last_exc = ex
+        else:
+            for rdata in ans:
+                hosts_out.append(host)
+                hostaddr_out.append(rdata.address)
+                if ports_in:
+                    ports_out.append(ports_in[i])
+
+    # Throw an exception if no host could be resolved
+    if not hosts_out:
+        raise e.OperationalError(str(last_exc))
+
+    out = params.copy()
+    out["host"] = ",".join(hosts_out)
+    out["hostaddr"] = ",".join(hostaddr_out)
+    if ports_in:
+        out["port"] = ",".join(ports_out)
+
+    return out
+
+
+@lru_cache()
+def is_ip_address(s: str) -> bool:
+    """Return True if the string represent a valid ip address."""
+    try:
+        ip_address(s)
+    except ValueError:
+        return False
+    return True
diff --git a/tests/test_dns.py b/tests/test_dns.py
new file mode 100644 (file)
index 0000000..7ecd6e3
--- /dev/null
@@ -0,0 +1,148 @@
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+
+@pytest.mark.parametrize(
+    "conninfo, want",
+    [
+        ("", ""),
+        ("host='' user=bar", "host='' user=bar"),
+        (
+            "host=127.0.0.1 user=bar",
+            "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 user=bar",
+            "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+        ),
+        (
+            "host=1.1.1.1,2.2.2.2 port=5432",
+            "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+        ),
+    ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(conninfo, want, fail_resolve):
+    params = conninfo_to_dict(conninfo)
+    params = await psycopg._dns.resolve_hostaddr_async(params)
+    assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+    "conninfo, want",
+    [
+        (
+            "host=foo.com,qux.com",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+        ),
+        (
+            "host=foo.com,qux.com port=5433",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+        ),
+        (
+            "host=foo.com,qux.com port=5432,5433",
+            "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+        ),
+        (
+            "host=foo.com,nosuchhost.com",
+            "host=foo.com hostaddr=1.1.1.1",
+        ),
+        (
+            "host=nosuchhost.com,foo.com",
+            "host=foo.com hostaddr=1.1.1.1",
+        ),
+    ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, fake_resolve):
+    params = conninfo_to_dict(conninfo)
+    params = await psycopg._dns.resolve_hostaddr_async(params)
+    assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+    "conninfo",
+    [
+        "host=bad1.com,bad2.com",
+        "host=foo.com port=1,2",
+        "host=1.1.1.1,2.2.2.2 port=5432,5433,5434",
+    ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(conninfo, fake_resolve):
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises((TypeError, psycopg.Error)):
+        await psycopg._dns.resolve_hostaddr_async(params)
+
+
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
+    got = []
+
+    def fake_connect_gen(conninfo, **kwargs):
+        got.append(conninfo)
+        1 / 0
+
+    monkeypatch.setattr(
+        psycopg.AsyncConnection, "_connect_gen", fake_connect_gen
+    )
+
+    # TODO: not enabled by default, but should be usable to make a subclass
+    class AsyncDnsConnection(psycopg.AsyncConnection):
+        @classmethod
+        async def _get_connection_params(cls, conninfo, **kwargs):
+            params = await super()._get_connection_params(conninfo, **kwargs)
+            params = await psycopg._dns.resolve_hostaddr_async(params)
+            return params
+
+    with pytest.raises(ZeroDivisionError):
+        await AsyncDnsConnection.connect("host=foo.com")
+
+    assert len(got) == 1
+    want = {"host": "foo.com", "hostaddr": "1.1.1.1"}
+    assert conninfo_to_dict(got[0]) == want
+
+
+@pytest.fixture
+def fake_resolve(monkeypatch):
+    _import_dnspython()
+
+    import dns.rdtypes.IN.A
+    from dns.exception import DNSException
+
+    fake_hosts = {
+        "localhost": "127.0.0.1",
+        "foo.com": "1.1.1.1",
+        "qux.com": "2.2.2.2",
+    }
+
+    async def fake_resolve_(qname):
+        try:
+            addr = fake_hosts[qname]
+        except KeyError:
+            raise DNSException(f"unknown test host: {qname}")
+        else:
+            return [dns.rdtypes.IN.A.A("IN", "A", addr)]
+
+    monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fake_resolve_)
+
+
+@pytest.fixture
+def fail_resolve(monkeypatch):
+    _import_dnspython()
+
+    async def fail_resolve_(qname):
+        pytest.fail(f"shouldn't try to resolve {qname}")
+
+    monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fail_resolve_)
+
+
+def _import_dnspython():
+    try:
+        import dns.rdtypes.IN.A  # noqa: F401
+    except ImportError:
+        pytest.skip("dnspython package not available")
+
+    import psycopg._dns  # noqa: F401