]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add function to perform SRV DNS resolution
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Aug 2021 03:21:53 +0000 (05:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 30 Aug 2021 03:33:29 +0000 (05:33 +0200)
docs/api/dns.rst
psycopg/psycopg/_dns.py
tests/test_dns.py
tests/test_dns_srv.py [new file with mode: 0644]

index 14c1e4026d5553a66c26b176ad3a227f4908b823..af09a344bd47fdb686869b6411c3b6d44e696d27 100644 (file)
@@ -40,6 +40,32 @@ server before performing a connection.
                    return params
 
 
+.. autofunction:: resolve_srv
+
+   .. warning::
+       This is an experimental functionality.
+
+   .. note::
+       One possible way to use this function automatically is to subclass
+       `~psycopg.Connection`, extending the
+       `~psycopg.Connection._get_connection_params()` method::
+
+           import psycopg._dns  # not imported automatically
+
+           class SrvCognizantConnection(psycopg.Connection):
+               @classmethod
+               def _get_connection_params(cls, conninfo, **kwargs):
+                   params = super()._get_connection_params(conninfo, **kwargs)
+                   params = psycopg._dns.resolve_srv(params)
+                   return params
+
+           # The name will be resolved to db1.example.com
+           cnn = SrvCognizantConnection.connect("host=_postgres._tcp.db.psycopg.org")
+
+
+.. autofunction:: resolve_srv_async
+
+
 .. automethod:: psycopg.Connection._get_connection_params
 
     .. warning::
index c77f294d1a7a402cda6e4b4038680244b4a97a8d..689809415cf7e1326c4cae7083bcb0910ade9088 100644 (file)
@@ -6,13 +6,17 @@ DNS query support
 # Copyright (C) 2021 The Psycopg Team
 
 import os
-from typing import Any, Dict
+import re
+from random import randint
+from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
+from typing import TYPE_CHECKING
 from functools import lru_cache
 from ipaddress import ip_address
+from collections import defaultdict
 
 try:
-    from dns.resolver import Cache
-    from dns.asyncresolver import Resolver
+    from dns.resolver import Resolver, Cache
+    from dns.asyncresolver import Resolver as AsyncResolver
     from dns.exception import DNSException
 except ImportError:
     raise ImportError(
@@ -22,7 +26,13 @@ except ImportError:
 from . import pq
 from . import errors as e
 
-async_resolver = Resolver()
+if TYPE_CHECKING:
+    from dns.rdtypes.IN.SRV import SRV
+
+resolver = Resolver()
+resolver.cache = Cache()
+
+async_resolver = AsyncResolver()
 async_resolver.cache = Cache()
 
 
@@ -136,6 +146,201 @@ async def resolve_hostaddr_async(params: Dict[str, Any]) -> Dict[str, Any]:
     return out
 
 
+def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
+    """Apply SRV DNS lookup as defined in :RFC:`2782`.
+
+    :param params: The input parameters, for instance as returned by
+        `~psycopg.conninfo.conninfo_to_dict()`.
+    :return: An updated list of connection parameters.
+
+    For every host defined in the ``params["host"]`` list (comma-separated),
+    perform SRV lookup if the host is in the form ``_Service._Proto.Target``.
+    If lookup is successful, return a params dict with hosts and ports replaced
+    with the looked-up entries.
+
+    Raise `~psycopg.OperationalError` if no lookup is successful and no host
+    (looked up or unchanged) could be returned.
+
+    In addition to the rules defined by RFC 2782 about the host name pattern,
+    perform SRV lookup also if the the port is the string ``SRV`` (case
+    insensitive).
+    """
+    return Rfc2782Resolver().resolve(params)
+
+
+async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
+    """Async equivalent of `resolve_srv()`."""
+    return await Rfc2782Resolver().resolve_async(params)
+
+
+class HostPort(NamedTuple):
+    host: str
+    port: str
+    totry: bool = False
+    target: Optional[str] = None
+
+
+class Rfc2782Resolver:
+    """Implement SRV RR Resolution as per RFC 2782
+
+    The class is organised to minimise code duplication between the sync and
+    the async paths.
+    """
+
+    re_srv_rr = re.compile(
+        r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)"
+    )
+
+    def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
+        """Update the parameters host and port after SRV lookup."""
+        attempts = self._get_attempts(params)
+        if not attempts:
+            return params
+
+        hps = []
+        for hp in attempts:
+            if hp.totry:
+                hps.extend(self._resolve_srv(hp))
+            else:
+                hps.append(hp)
+
+        return self._return_params(params, hps)
+
+    async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
+        """Update the parameters host and port after SRV lookup."""
+        attempts = self._get_attempts(params)
+        if not attempts:
+            return params
+
+        hps = []
+        for hp in attempts:
+            if hp.totry:
+                hps.extend(await self._resolve_srv_async(hp))
+            else:
+                hps.append(hp)
+
+        return self._return_params(params, hps)
+
+    def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
+        """
+        Return the list of host, and for each host if SRV lookup must be tried.
+
+        Return an empty list if no lookup is requested.
+        """
+        # If hostaddr is defined don't do any resolution.
+        if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
+            return []
+
+        host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+        hosts_in = host_arg.split(",")
+        port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+        ports_in = port_arg.split(",")
+
+        if len(ports_in) == 1:
+            # If only one port is specified, it applies to all the hosts.
+            ports_in *= len(hosts_in)
+        if len(ports_in) != len(hosts_in):
+            # ProgrammingError would have been more appropriate, but this is
+            # what the raise if the libpq fails connect in the same case.
+            raise e.OperationalError(
+                f"cannot match {len(hosts_in)} hosts with {len(ports_in)}"
+                " port numbers"
+            )
+
+        out = []
+        srv_found = False
+        for host, port in zip(hosts_in, ports_in):
+            m = self.re_srv_rr.match(host)
+            if m or port.lower() == "srv":
+                srv_found = True
+                target = m.group("target") if m else None
+                hp = HostPort(host=host, port=port, totry=True, target=target)
+            else:
+                hp = HostPort(host=host, port=port)
+            out.append(hp)
+
+        if srv_found:
+            return out
+        else:
+            return []
+
+    def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
+        try:
+            ans = resolver.resolve(hp.host, "SRV")
+        except DNSException:
+            ans = ()
+        return self._get_solved_entries(hp, ans)
+
+    async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
+        try:
+            ans = resolver.resolve(hp.host, "SRV")
+        except DNSException:
+            ans = ()
+        return self._get_solved_entries(hp, ans)
+
+    def _get_solved_entries(
+        self, hp: HostPort, entries: "Sequence[SRV]"
+    ) -> List[HostPort]:
+        if not entries:
+            # No SRV entry found. Delegate the libpq a QNAME=target lookup
+            if hp.target and hp.port.lower() != "srv":
+                return [HostPort(host=hp.target, port=hp.port)]
+            else:
+                return []
+
+        # If there is precisely one SRV RR, and its Target is "." (the root
+        # domain), abort.
+        if len(entries) == 1 and str(entries[0].target) == ".":
+            return []
+
+        return [
+            HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
+            for entry in self.sort_rfc2782(entries)
+        ]
+
+    def _return_params(
+        self, params: Dict[str, Any], hps: List[HostPort]
+    ) -> Dict[str, Any]:
+        if not hps:
+            # Nothing found, we ended up with an empty list
+            raise e.OperationalError("no host found after SRV RR lookup")
+
+        out = params.copy()
+        out["host"] = ",".join(hp.host for hp in hps)
+        out["port"] = ",".join(str(hp.port) for hp in hps)
+        return out
+
+    def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
+        """
+        Implement the priority/weight ordering defined in RFC 2782.
+        """
+        # Divide the entries by priority:
+        priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
+        out: "List[SRV]" = []
+        for entry in ans:
+            priorities[entry.priority].append(entry)
+
+        for pri, entries in sorted(priorities.items()):
+            if len(entries) == 1:
+                out.append(entries[0])
+                continue
+
+            entries.sort(key=lambda ent: ent.weight)
+            total_weight = sum(ent.weight for ent in entries)
+            while entries:
+                r = randint(0, total_weight)
+                csum = 0
+                for i, ent in enumerate(entries):
+                    csum += ent.weight
+                    if csum >= r:
+                        break
+                out.append(ent)
+                total_weight -= ent.weight
+                del entries[i]
+
+        return out
+
+
 @lru_cache()
 def is_ip_address(s: str) -> bool:
     """Return True if the string represent a valid ip address."""
index d76fe60db5afaef69175aab2f3e2d51487075a9e..cf3e3e3ed7bb9e3b316b668a0303a4e1f98603e0 100644 (file)
@@ -146,7 +146,7 @@ async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
 
 @pytest.fixture
 def fake_resolve(monkeypatch):
-    _import_dnspython()
+    import_dnspython()
 
     import dns.rdtypes.IN.A
     from dns.exception import DNSException
@@ -170,7 +170,7 @@ def fake_resolve(monkeypatch):
 
 @pytest.fixture
 def fail_resolve(monkeypatch):
-    _import_dnspython()
+    import_dnspython()
 
     async def fail_resolve_(qname):
         pytest.fail(f"shouldn't try to resolve {qname}")
@@ -178,7 +178,7 @@ def fail_resolve(monkeypatch):
     monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", fail_resolve_)
 
 
-def _import_dnspython():
+def import_dnspython():
     try:
         import dns.rdtypes.IN.A  # noqa: F401
     except ImportError:
diff --git a/tests/test_dns_srv.py b/tests/test_dns_srv.py
new file mode 100644 (file)
index 0000000..e642a50
--- /dev/null
@@ -0,0 +1,141 @@
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from .test_dns import import_dnspython
+
+samples_ok = [
+    ("", "", None),
+    ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None),
+    ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}),
+    (
+        "host=foo.com,_pg._tcp.foo.com",
+        "host=foo.com,db1.example.com port=,5432",
+        None,
+    ),
+    (
+        "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com",
+        "host=foo.com,db1.example.com port=,5432",
+        None,
+    ),
+    (
+        "host=_pg._tcp.bar.com",
+        (
+            "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com"
+            " port=5432,5432,5433,5432"
+        ),
+        None,
+    ),
+    (
+        "host=service.foo.com port=srv",
+        ("host=service.example.com port=15432"),
+        None,
+    ),
+    # No resolution
+    (
+        "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+        "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+        None,
+    ),
+]
+
+
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+def test_srv(conninfo, want, env, fake_srv, retries, monkeypatch):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
+    # retries are needed because weight order is random, although wrong order
+    # is unlikely.
+    for retry in retries:
+        with retry:
+            params = conninfo_to_dict(conninfo)
+            params = psycopg._dns.resolve_srv(params)
+            assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+async def test_srv_async(conninfo, want, env, fake_srv, retries, monkeypatch):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
+    async for retry in retries:
+        with retry:
+            params = conninfo_to_dict(conninfo)
+            params = await psycopg._dns.resolve_srv_async(params)
+            assert conninfo_to_dict(want) == params
+
+
+samples_bad = [
+    ("host=_pg._tcp.dot.com", None),
+    ("host=_pg._tcp.foo.com port=1,2", None),
+]
+
+
+@pytest.mark.parametrize("conninfo,  env", samples_bad)
+def test_srv_bad(conninfo, env, fake_srv, monkeypatch):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.OperationalError):
+        psycopg._dns.resolve_srv(params)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo,  env", samples_bad)
+async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch):
+    if env:
+        for k, v in env.items():
+            monkeypatch.setenv(k, v)
+    params = conninfo_to_dict(conninfo)
+    with pytest.raises(psycopg.OperationalError):
+        await psycopg._dns.resolve_srv_async(params)
+
+
+@pytest.fixture
+def fake_srv(monkeypatch):
+    import_dnspython()
+
+    from dns.rdtypes.IN.A import A
+    from dns.rdtypes.IN.SRV import SRV
+    from dns.exception import DNSException
+
+    fake_hosts = {
+        ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."],
+        ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."],
+        ("_pg._tcp.bar.com", "SRV"): [
+            "1 0 5432 db2.example.com.",
+            "1 255 5433 db3.example.com.",
+            "0 0 5432 db1.example.com.",
+            "1 65535 5432 db4.example.com.",
+        ],
+        ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."],
+    }
+
+    def fake_srv_(qname, rdtype):
+        try:
+            ans = fake_hosts[qname, rdtype]
+        except KeyError:
+            raise DNSException(f"unknown test host: {qname} {rdtype}")
+        rv = []
+
+        if rdtype == "A":
+            for entry in ans:
+                rv.append(A("IN", "A", entry))
+        else:
+            for entry in ans:
+                pri, w, port, target = entry.split()
+                rv.append(
+                    SRV("IN", "SRV", int(pri), int(w), int(port), target)
+                )
+
+        return rv
+
+    async def afake_srv_(qname, rdtype):
+        return fake_srv(qname, rdtype)
+
+    monkeypatch.setattr(psycopg._dns.resolver, "resolve", fake_srv_)
+    monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", afake_srv_)