# Copyright (C) 2021 The Psycopg Team
import os
-from typing import Any, Dict
+import re
+from random import randint
+from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Sequence
+from typing import TYPE_CHECKING
from functools import lru_cache
from ipaddress import ip_address
+from collections import defaultdict
try:
- from dns.resolver import Cache
- from dns.asyncresolver import Resolver
+ from dns.resolver import Resolver, Cache
+ from dns.asyncresolver import Resolver as AsyncResolver
from dns.exception import DNSException
except ImportError:
raise ImportError(
from . import pq
from . import errors as e
-async_resolver = Resolver()
+if TYPE_CHECKING:
+ from dns.rdtypes.IN.SRV import SRV
+
+resolver = Resolver()
+resolver.cache = Cache()
+
+async_resolver = AsyncResolver()
async_resolver.cache = Cache()
return out
+def resolve_srv(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Apply SRV DNS lookup as defined in :RFC:`2782`.
+
+ :param params: The input parameters, for instance as returned by
+ `~psycopg.conninfo.conninfo_to_dict()`.
+ :return: An updated list of connection parameters.
+
+ For every host defined in the ``params["host"]`` list (comma-separated),
+ perform SRV lookup if the host is in the form ``_Service._Proto.Target``.
+ If lookup is successful, return a params dict with hosts and ports replaced
+ with the looked-up entries.
+
+ Raise `~psycopg.OperationalError` if no lookup is successful and no host
+ (looked up or unchanged) could be returned.
+
+ In addition to the rules defined by RFC 2782 about the host name pattern,
+ perform SRV lookup also if the the port is the string ``SRV`` (case
+ insensitive).
+ """
+ return Rfc2782Resolver().resolve(params)
+
+
+async def resolve_srv_async(params: Dict[str, Any]) -> Dict[str, Any]:
+ """Async equivalent of `resolve_srv()`."""
+ return await Rfc2782Resolver().resolve_async(params)
+
+
+class HostPort(NamedTuple):
+ host: str
+ port: str
+ totry: bool = False
+ target: Optional[str] = None
+
+
+class Rfc2782Resolver:
+ """Implement SRV RR Resolution as per RFC 2782
+
+ The class is organised to minimise code duplication between the sync and
+ the async paths.
+ """
+
+ re_srv_rr = re.compile(
+ r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)"
+ )
+
+ def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(self._resolve_srv(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ async def resolve_async(self, params: Dict[str, Any]) -> Dict[str, Any]:
+ """Update the parameters host and port after SRV lookup."""
+ attempts = self._get_attempts(params)
+ if not attempts:
+ return params
+
+ hps = []
+ for hp in attempts:
+ if hp.totry:
+ hps.extend(await self._resolve_srv_async(hp))
+ else:
+ hps.append(hp)
+
+ return self._return_params(params, hps)
+
+ def _get_attempts(self, params: Dict[str, Any]) -> List[HostPort]:
+ """
+ Return the list of host, and for each host if SRV lookup must be tried.
+
+ Return an empty list if no lookup is requested.
+ """
+ # If hostaddr is defined don't do any resolution.
+ if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
+ return []
+
+ host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
+ hosts_in = host_arg.split(",")
+ port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
+ ports_in = port_arg.split(",")
+
+ if len(ports_in) == 1:
+ # If only one port is specified, it applies to all the hosts.
+ ports_in *= len(hosts_in)
+ if len(ports_in) != len(hosts_in):
+ # ProgrammingError would have been more appropriate, but this is
+ # what the raise if the libpq fails connect in the same case.
+ raise e.OperationalError(
+ f"cannot match {len(hosts_in)} hosts with {len(ports_in)}"
+ " port numbers"
+ )
+
+ out = []
+ srv_found = False
+ for host, port in zip(hosts_in, ports_in):
+ m = self.re_srv_rr.match(host)
+ if m or port.lower() == "srv":
+ srv_found = True
+ target = m.group("target") if m else None
+ hp = HostPort(host=host, port=port, totry=True, target=target)
+ else:
+ hp = HostPort(host=host, port=port)
+ out.append(hp)
+
+ if srv_found:
+ return out
+ else:
+ return []
+
+ def _resolve_srv(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ async def _resolve_srv_async(self, hp: HostPort) -> List[HostPort]:
+ try:
+ ans = resolver.resolve(hp.host, "SRV")
+ except DNSException:
+ ans = ()
+ return self._get_solved_entries(hp, ans)
+
+ def _get_solved_entries(
+ self, hp: HostPort, entries: "Sequence[SRV]"
+ ) -> List[HostPort]:
+ if not entries:
+ # No SRV entry found. Delegate the libpq a QNAME=target lookup
+ if hp.target and hp.port.lower() != "srv":
+ return [HostPort(host=hp.target, port=hp.port)]
+ else:
+ return []
+
+ # If there is precisely one SRV RR, and its Target is "." (the root
+ # domain), abort.
+ if len(entries) == 1 and str(entries[0].target) == ".":
+ return []
+
+ return [
+ HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
+ for entry in self.sort_rfc2782(entries)
+ ]
+
+ def _return_params(
+ self, params: Dict[str, Any], hps: List[HostPort]
+ ) -> Dict[str, Any]:
+ if not hps:
+ # Nothing found, we ended up with an empty list
+ raise e.OperationalError("no host found after SRV RR lookup")
+
+ out = params.copy()
+ out["host"] = ",".join(hp.host for hp in hps)
+ out["port"] = ",".join(str(hp.port) for hp in hps)
+ return out
+
+ def sort_rfc2782(self, ans: "Sequence[SRV]") -> "List[SRV]":
+ """
+ Implement the priority/weight ordering defined in RFC 2782.
+ """
+ # Divide the entries by priority:
+ priorities: DefaultDict[int, "List[SRV]"] = defaultdict(list)
+ out: "List[SRV]" = []
+ for entry in ans:
+ priorities[entry.priority].append(entry)
+
+ for pri, entries in sorted(priorities.items()):
+ if len(entries) == 1:
+ out.append(entries[0])
+ continue
+
+ entries.sort(key=lambda ent: ent.weight)
+ total_weight = sum(ent.weight for ent in entries)
+ while entries:
+ r = randint(0, total_weight)
+ csum = 0
+ for i, ent in enumerate(entries):
+ csum += ent.weight
+ if csum >= r:
+ break
+ out.append(ent)
+ total_weight -= ent.weight
+ del entries[i]
+
+ return out
+
+
@lru_cache()
def is_ip_address(s: str) -> bool:
"""Return True if the string represent a valid ip address."""
--- /dev/null
+import pytest
+
+import psycopg
+from psycopg.conninfo import conninfo_to_dict
+
+from .test_dns import import_dnspython
+
+samples_ok = [
+ ("", "", None),
+ ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None),
+ ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}),
+ (
+ "host=foo.com,_pg._tcp.foo.com",
+ "host=foo.com,db1.example.com port=,5432",
+ None,
+ ),
+ (
+ "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com",
+ "host=foo.com,db1.example.com port=,5432",
+ None,
+ ),
+ (
+ "host=_pg._tcp.bar.com",
+ (
+ "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com"
+ " port=5432,5432,5433,5432"
+ ),
+ None,
+ ),
+ (
+ "host=service.foo.com port=srv",
+ ("host=service.example.com port=15432"),
+ None,
+ ),
+ # No resolution
+ (
+ "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+ "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+]
+
+
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+def test_srv(conninfo, want, env, fake_srv, retries, monkeypatch):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+ # retries are needed because weight order is random, although wrong order
+ # is unlikely.
+ for retry in retries:
+ with retry:
+ params = conninfo_to_dict(conninfo)
+ params = psycopg._dns.resolve_srv(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, want, env", samples_ok)
+async def test_srv_async(conninfo, want, env, fake_srv, retries, monkeypatch):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+ async for retry in retries:
+ with retry:
+ params = conninfo_to_dict(conninfo)
+ params = await psycopg._dns.resolve_srv_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+samples_bad = [
+ ("host=_pg._tcp.dot.com", None),
+ ("host=_pg._tcp.foo.com port=1,2", None),
+]
+
+
+@pytest.mark.parametrize("conninfo, env", samples_bad)
+def test_srv_bad(conninfo, env, fake_srv, monkeypatch):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.OperationalError):
+ psycopg._dns.resolve_srv(params)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("conninfo, env", samples_bad)
+async def test_srv_bad_async(conninfo, env, fake_srv, monkeypatch):
+ if env:
+ for k, v in env.items():
+ monkeypatch.setenv(k, v)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.OperationalError):
+ await psycopg._dns.resolve_srv_async(params)
+
+
+@pytest.fixture
+def fake_srv(monkeypatch):
+ import_dnspython()
+
+ from dns.rdtypes.IN.A import A
+ from dns.rdtypes.IN.SRV import SRV
+ from dns.exception import DNSException
+
+ fake_hosts = {
+ ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."],
+ ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."],
+ ("_pg._tcp.bar.com", "SRV"): [
+ "1 0 5432 db2.example.com.",
+ "1 255 5433 db3.example.com.",
+ "0 0 5432 db1.example.com.",
+ "1 65535 5432 db4.example.com.",
+ ],
+ ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."],
+ }
+
+ def fake_srv_(qname, rdtype):
+ try:
+ ans = fake_hosts[qname, rdtype]
+ except KeyError:
+ raise DNSException(f"unknown test host: {qname} {rdtype}")
+ rv = []
+
+ if rdtype == "A":
+ for entry in ans:
+ rv.append(A("IN", "A", entry))
+ else:
+ for entry in ans:
+ pri, w, port, target = entry.split()
+ rv.append(
+ SRV("IN", "SRV", int(pri), int(w), int(port), target)
+ )
+
+ return rv
+
+ async def afake_srv_(qname, rdtype):
+ return fake_srv(qname, rdtype)
+
+ monkeypatch.setattr(psycopg._dns.resolver, "resolve", fake_srv_)
+ monkeypatch.setattr(psycopg._dns.async_resolver, "resolve", afake_srv_)