]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Resolver "nameserver" object support. (#897)
authorBob Halley <halley@dnspython.org>
Sat, 25 Feb 2023 19:43:26 +0000 (11:43 -0800)
committerGitHub <noreply@github.com>
Sat, 25 Feb 2023 19:43:26 +0000 (11:43 -0800)
* Resolver "nameserver" object support.

This turns the list of nameserver strings in the resolver into a tuple
of nameserver objects, which abstract away making queries to a
nameserver of a given type.

The resolver's legacy nameserver list is "enriched" into a tuple of
nameserver objects whenever it is set.  Note that you cannot mutate
the object other than by setting,
e.g. res.nameservers.append("1.2.3.4") will not work.

Error message accumulation has been updated to refer to the
nameservers using a descriptive text form.

* doco fix

* more doco fixes

* do enrichment at Resolution time

* require a later mypy, fix type issues

* add nameserver doc

dns/asyncresolver.py
dns/nameserver.py [new file with mode: 0644]
dns/resolver.py
doc/resolver-class.rst
doc/resolver-nameserver.rst [new file with mode: 0644]
doc/resolver.rst
doc/whatsnew.rst
pyproject.toml
tests/test_resolution.py
tests/test_resolver.py

index 506530e29f157d2b25c7ab5e9f5b78a246a94ce1..9ba84de07b69cac800b67cc7d787fd1a79ad6b86 100644 (file)
@@ -83,37 +83,19 @@ class Resolver(dns.resolver.BaseResolver):
             assert request is not None  # needed for type checking
             done = False
             while not done:
-                (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+                (nameserver, tcp, backoff) = resolution.next_nameserver()
                 if backoff:
                     await backend.sleep(backoff)
                 timeout = self._compute_timeout(start, lifetime, resolution.errors)
                 try:
-                    if dns.inet.is_address(nameserver):
-                        if tcp:
-                            response = await _tcp(
-                                request,
-                                nameserver,
-                                timeout,
-                                port,
-                                source,
-                                source_port,
-                                backend=backend,
-                            )
-                        else:
-                            response = await _udp(
-                                request,
-                                nameserver,
-                                timeout,
-                                port,
-                                source,
-                                source_port,
-                                raise_on_truncation=True,
-                                backend=backend,
-                            )
-                    else:
-                        response = await dns.asyncquery.https(
-                            request, nameserver, timeout=timeout
-                        )
+                    response = await nameserver.async_query(
+                        request,
+                        timeout=timeout,
+                        source=source,
+                        source_port=source_port,
+                        max_size=tcp,
+                        backend=backend,
+                    )
                 except Exception as ex:
                     (_, done) = resolution.query_result(None, ex)
                     continue
diff --git a/dns/nameserver.py b/dns/nameserver.py
new file mode 100644 (file)
index 0000000..7de0abb
--- /dev/null
@@ -0,0 +1,315 @@
+from urllib.parse import urlparse
+
+from typing import Optional, Union
+
+import dns.asyncbackend
+import dns.asyncquery
+import dns.inet
+import dns.message
+import dns.query
+
+
+class Nameserver:
+    def __init__(self):
+        pass
+
+    def __str__(self):
+        raise NotImplementedError
+
+    def is_always_max_size(self) -> bool:
+        raise NotImplementedError
+
+    def answer_nameserver(self) -> str:
+        raise NotImplementedError
+
+    def answer_port(self) -> int:
+        raise NotImplementedError
+
+    def query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        raise NotImplementedError
+
+    async def async_query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        backend: dns.asyncbackend.Backend,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        raise NotImplementedError
+
+
+class AddressAndPortNameserver(Nameserver):
+    def __init__(self, address: str, port: int):
+        super().__init__()
+        self.address = address
+        self.port = port
+
+    def kind(self) -> str:
+        raise NotImplementedError
+
+    def is_always_max_size(self) -> bool:
+        return False
+
+    def __str__(self):
+        ns_kind = self.kind()
+        return f"{ns_kind}:{self.address}@{self.port}"
+
+    def answer_nameserver(self) -> str:
+        return self.address
+
+    def answer_port(self) -> int:
+        return self.port
+
+
+class Do53Nameserver(AddressAndPortNameserver):
+    def __init__(self, address: str, port: int = 53):
+        super().__init__(address, port)
+
+    def kind(self):
+        return "Do53"
+
+    def query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        if max_size:
+            response = dns.query.tcp(
+                request,
+                self.address,
+                timeout=timeout,
+                port=self.port,
+                source=source,
+                source_port=source_port,
+                one_rr_per_rrset=one_rr_per_rrset,
+                ignore_trailing=ignore_trailing,
+            )
+        else:
+            response = dns.query.udp(
+                request,
+                self.address,
+                timeout=timeout,
+                port=self.port,
+                source=source,
+                source_port=source_port,
+                raise_on_truncation=True,
+                one_rr_per_rrset=one_rr_per_rrset,
+                ignore_trailing=ignore_trailing,
+            )
+        return response
+
+    async def async_query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        backend: dns.asyncbackend.Backend,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        if max_size:
+            response = await dns.asyncquery.tcp(
+                request,
+                self.address,
+                timeout=timeout,
+                port=self.port,
+                source=source,
+                source_port=source_port,
+                backend=backend,
+                one_rr_per_rrset=one_rr_per_rrset,
+                ignore_trailing=ignore_trailing,
+            )
+        else:
+            response = await dns.asyncquery.udp(
+                request,
+                self.address,
+                timeout=timeout,
+                port=self.port,
+                source=source,
+                source_port=source_port,
+                raise_on_truncation=True,
+                backend=backend,
+                one_rr_per_rrset=one_rr_per_rrset,
+                ignore_trailing=ignore_trailing,
+            )
+        return response
+
+
+class DoHNameserver(Nameserver):
+    def __init__(self, url: str, bootstrap_address: Optional[str] = None):
+        super().__init__()
+        self.url = url
+        self.bootstrap_address = bootstrap_address
+
+    def is_always_max_size(self) -> bool:
+        return True
+
+    def __str__(self):
+        return self.url
+
+    def answer_nameserver(self) -> str:
+        return self.url
+
+    def answer_port(self) -> int:
+        port = urlparse(self.url).port
+        if port is None:
+            port = 443
+        return port
+
+    def query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool = False,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        return dns.query.https(
+            request,
+            self.url,
+            timeout=timeout,
+            bootstrap_address=self.bootstrap_address,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+        )
+
+    async def async_query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        backend: dns.asyncbackend.Backend,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        return await dns.asyncquery.https(
+            request,
+            self.url,
+            timeout=timeout,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+        )
+
+
+class DoTNameserver(AddressAndPortNameserver):
+    def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
+        super().__init__(address, port)
+        self.hostname = hostname
+
+    def kind(self):
+        return "DoT"
+
+    def query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool = False,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        return dns.query.tls(
+            request,
+            self.address,
+            port=self.port,
+            timeout=timeout,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+            server_hostname=self.hostname,
+        )
+
+    async def async_query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        backend: dns.asyncbackend.Backend,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        return await dns.asyncquery.tls(
+            request,
+            self.address,
+            port=self.port,
+            timeout=timeout,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+            server_hostname=self.hostname,
+        )
+
+
+class DoQNameserver(AddressAndPortNameserver):
+    def __init__(self, address: str, port: int = 853, verify: Union[bool, str] = True):
+        super().__init__(address, port)
+        self.verify = verify
+
+    def kind(self):
+        return "DoQ"
+
+    def query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool = False,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        return dns.query.quic(
+            request,
+            self.address,
+            port=self.port,
+            timeout=timeout,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+            verify=self.verify,
+        )
+
+    async def async_query(
+        self,
+        request: dns.message.QueryMessage,
+        timeout: float,
+        source: Optional[str],
+        source_port: int,
+        max_size: bool,
+        backend: dns.asyncbackend.Backend,
+        one_rr_per_rrset: bool = False,
+        ignore_trailing: bool = False,
+    ) -> dns.message.Message:
+        return await dns.asyncquery.quic(
+            request,
+            self.address,
+            port=self.port,
+            timeout=timeout,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+            verify=self.verify,
+        )
index 5ba8601e029a34a6f0b20330ca83c33173530300..4fc5bfd36b9c03a4edcbc9322cd47e1cb52bc063 100644 (file)
@@ -36,6 +36,7 @@ import dns.ipv4
 import dns.ipv6
 import dns.message
 import dns.name
+import dns.nameserver
 import dns.query
 import dns.rcode
 import dns.rdataclass
@@ -140,7 +141,11 @@ class YXDOMAIN(dns.exception.DNSException):
 
 
 ErrorTuple = Tuple[
-    Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]
+    Optional[str],
+    bool,
+    int,
+    Union[Exception, str],
+    Optional[dns.message.Message],
 ]
 
 
@@ -148,11 +153,7 @@ def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
     """Turn a resolution errors trace into a list of text."""
     texts = []
     for err in errors:
-        texts.append(
-            "Server {} {} port {} answered {}".format(
-                err[0], "TCP" if err[1] else "UDP", err[2], err[3]
-            )
-        )
+        texts.append("Server {} answered {}".format(err[0], err[3]))
     return texts
 
 
@@ -377,7 +378,7 @@ class Cache(CacheBase):
         now = time.time()
         if self.next_cleaning <= now:
             keys_to_delete = []
-            for (k, v) in self.data.items():
+            for k, v in self.data.items():
                 if v.expiration <= now:
                     keys_to_delete.append(k)
             for k in keys_to_delete:
@@ -609,11 +610,10 @@ class _Resolution:
         self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
         # Initialize other things to help analysis tools
         self.qname = dns.name.empty
-        self.nameservers: List[str] = []
-        self.current_nameservers: List[str] = []
+        self.nameservers: List[dns.nameserver.Nameserver] = []
+        self.current_nameservers: List[dns.nameserver.Nameserver] = []
         self.errors: List[ErrorTuple] = []
-        self.nameserver: Optional[str] = None
-        self.port = 0
+        self.nameserver: Optional[dns.nameserver.Nameserver] = None
         self.tcp_attempt = False
         self.retry_with_tcp = False
         self.request: Optional[dns.message.QueryMessage] = None
@@ -670,7 +670,9 @@ class _Resolution:
             if self.resolver.flags is not None:
                 request.flags = self.resolver.flags
 
-            self.nameservers = self.resolver.nameservers[:]
+            self.nameservers = self.resolver._enrich_nameservers(
+                self.resolver._nameservers
+            )
             if self.resolver.rotate:
                 random.shuffle(self.nameservers)
             self.current_nameservers = self.nameservers[:]
@@ -690,12 +692,13 @@ class _Resolution:
         #
         raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses)
 
-    def next_nameserver(self) -> Tuple[str, int, bool, float]:
+    def next_nameserver(self) -> Tuple[dns.nameserver.Nameserver, bool, float]:
         if self.retry_with_tcp:
             assert self.nameserver is not None
+            assert not self.nameserver.is_always_max_size()
             self.tcp_attempt = True
             self.retry_with_tcp = False
-            return (self.nameserver, self.port, True, 0)
+            return (self.nameserver, True, 0)
 
         backoff = 0.0
         if not self.current_nameservers:
@@ -707,11 +710,8 @@ class _Resolution:
             self.backoff = min(self.backoff * 2, 2)
 
         self.nameserver = self.current_nameservers.pop(0)
-        self.port = self.resolver.nameserver_ports.get(
-            self.nameserver, self.resolver.port
-        )
-        self.tcp_attempt = self.tcp
-        return (self.nameserver, self.port, self.tcp_attempt, backoff)
+        self.tcp_attempt = self.tcp or self.nameserver.is_always_max_size()
+        return (self.nameserver, self.tcp_attempt, backoff)
 
     def query_result(
         self, response: Optional[dns.message.Message], ex: Optional[Exception]
@@ -724,7 +724,13 @@ class _Resolution:
             # Exception during I/O or from_wire()
             assert response is None
             self.errors.append(
-                (self.nameserver, self.tcp_attempt, self.port, ex, response)
+                (
+                    str(self.nameserver),
+                    self.tcp_attempt,
+                    self.nameserver.answer_port(),
+                    ex,
+                    response,
+                )
             )
             if (
                 isinstance(ex, dns.exception.FormError)
@@ -752,12 +758,18 @@ class _Resolution:
                     self.rdtype,
                     self.rdclass,
                     response,
-                    self.nameserver,
-                    self.port,
+                    self.nameserver.answer_nameserver(),
+                    self.nameserver.answer_port(),
                 )
             except Exception as e:
                 self.errors.append(
-                    (self.nameserver, self.tcp_attempt, self.port, e, response)
+                    (
+                        str(self.nameserver),
+                        self.tcp_attempt,
+                        self.nameserver.answer_port(),
+                        e,
+                        response,
+                    )
                 )
                 # The nameserver is no good, take it out of the mix.
                 self.nameservers.remove(self.nameserver)
@@ -776,7 +788,13 @@ class _Resolution:
                 )
             except Exception as e:
                 self.errors.append(
-                    (self.nameserver, self.tcp_attempt, self.port, e, response)
+                    (
+                        str(self.nameserver),
+                        self.tcp_attempt,
+                        self.nameserver.answer_port(),
+                        e,
+                        response,
+                    )
                 )
                 # The nameserver is no good, take it out of the mix.
                 self.nameservers.remove(self.nameserver)
@@ -792,7 +810,13 @@ class _Resolution:
         elif rcode == dns.rcode.YXDOMAIN:
             yex = YXDOMAIN()
             self.errors.append(
-                (self.nameserver, self.tcp_attempt, self.port, yex, response)
+                (
+                    str(self.nameserver),
+                    self.tcp_attempt,
+                    self.nameserver.answer_port(),
+                    yex,
+                    response,
+                )
             )
             raise yex
         else:
@@ -804,9 +828,9 @@ class _Resolution:
                 self.nameservers.remove(self.nameserver)
             self.errors.append(
                 (
-                    self.nameserver,
+                    str(self.nameserver),
                     self.tcp_attempt,
-                    self.port,
+                    self.nameserver.answer_port(),
                     dns.rcode.to_text(rcode),
                     response,
                 )
@@ -840,6 +864,7 @@ class BaseResolver:
     retry_servfail: bool
     rotate: bool
     ndots: Optional[int]
+    _nameservers: List[Union[str, dns.nameserver.Nameserver]]
 
     def __init__(
         self, filename: str = "/etc/resolv.conf", configure: bool = True
@@ -868,7 +893,7 @@ class BaseResolver:
         self.domain = dns.name.Name(dns.name.from_text(socket.gethostname())[1:])
         if len(self.domain) == 0:
             self.domain = dns.name.root
-        self.nameservers = []
+        self._nameservers = []
         self.nameserver_ports = {}
         self.port = 53
         self.search = []
@@ -905,6 +930,7 @@ class BaseResolver:
 
         """
 
+        nameservers = []
         if isinstance(f, str):
             try:
                 cm: contextlib.AbstractContextManager = open(f)
@@ -924,7 +950,7 @@ class BaseResolver:
                     continue
 
                 if tokens[0] == "nameserver":
-                    self.nameservers.append(tokens[1])
+                    nameservers.append(tokens[1])
                 elif tokens[0] == "domain":
                     self.domain = dns.name.from_text(tokens[1])
                     # domain and search are exclusive
@@ -952,8 +978,11 @@ class BaseResolver:
                                 self.ndots = int(opt.split(":")[1])
                             except (ValueError, IndexError):
                                 pass
-        if len(self.nameservers) == 0:
+        if len(nameservers) == 0:
             raise NoResolverConfiguration("no nameservers")
+        # Assigning directly instead of appending means we invoke the
+        # setter logic, with additonal checking and enrichment.
+        self.nameservers = nameservers
 
     def read_registry(self) -> None:
         """Extract resolver configuration from the Windows registry."""
@@ -1088,34 +1117,60 @@ class BaseResolver:
 
         self.flags = flags
 
-    @property
-    def nameservers(self) -> List[str]:
-        return self._nameservers
-
-    @nameservers.setter
-    def nameservers(self, nameservers: List[str]) -> None:
-        """
-        *nameservers*, a ``list`` of nameservers.
-
-        Raises ``ValueError`` if *nameservers* is anything other than a
-        ``list``.
-        """
+    def _enrich_nameservers(
+        self, nameservers: List[Union[str, dns.nameserver.Nameserver]]
+    ) -> List[dns.nameserver.Nameserver]:
+        enriched_nameservers = []
         if isinstance(nameservers, list):
             for nameserver in nameservers:
-                if not dns.inet.is_address(nameserver):
+                enriched_nameserver: dns.nameserver.Nameserver
+                if isinstance(nameserver, dns.nameserver.Nameserver):
+                    enriched_nameserver = nameserver
+                elif dns.inet.is_address(nameserver):
+                    port = self.nameserver_ports.get(nameserver, self.port)
+                    enriched_nameserver = dns.nameserver.Do53Nameserver(
+                        nameserver, port
+                    )
+                else:
                     try:
                         if urlparse(nameserver).scheme != "https":
                             raise NotImplementedError
                     except Exception:
                         raise ValueError(
-                            f"nameserver {nameserver} is not an "
-                            "IP address or valid https URL"
+                            f"nameserver {nameserver} is not a "
+                            "dns.nameserver.Nameserver instance or text form, "
+                            "IP address, nor a valid https URL"
                         )
-            self._nameservers = nameservers
+                    enriched_nameserver = dns.nameserver.DoHNameserver(nameserver)
+                enriched_nameservers.append(enriched_nameserver)
         else:
             raise ValueError(
-                "nameservers must be a list (not a {})".format(type(nameservers))
+                "nameservers must be a list or tuple (not a {})".format(
+                    type(nameservers)
+                )
             )
+        return enriched_nameservers
+
+    @property
+    def nameservers(
+        self,
+    ) -> List[Union[str, dns.nameserver.Nameserver]]:
+        return self._nameservers
+
+    @nameservers.setter
+    def nameservers(
+        self, nameservers: List[Union[str, dns.nameserver.Nameserver]]
+    ) -> None:
+        """
+        *nameservers*, a ``list`` of nameservers, where a nameserver is either
+        a string interpretable as a nameserver, or a ``dns.nameserver.Nameserver``
+        instance.
+
+        Raises ``ValueError`` if *nameservers* is not a list of nameservers.
+        """
+        # We just call _enrich_nameservers() for checking
+        self._enrich_nameservers(nameservers)
+        self._nameservers = nameservers
 
 
 class Resolver(BaseResolver):
@@ -1200,33 +1255,18 @@ class Resolver(BaseResolver):
             assert request is not None  # needed for type checking
             done = False
             while not done:
-                (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+                (nameserver, tcp, backoff) = resolution.next_nameserver()
                 if backoff:
                     time.sleep(backoff)
                 timeout = self._compute_timeout(start, lifetime, resolution.errors)
                 try:
-                    if dns.inet.is_address(nameserver):
-                        if tcp:
-                            response = dns.query.tcp(
-                                request,
-                                nameserver,
-                                timeout=timeout,
-                                port=port,
-                                source=source,
-                                source_port=source_port,
-                            )
-                        else:
-                            response = dns.query.udp(
-                                request,
-                                nameserver,
-                                timeout=timeout,
-                                port=port,
-                                source=source,
-                                source_port=source_port,
-                                raise_on_truncation=True,
-                            )
-                    else:
-                        response = dns.query.https(request, nameserver, timeout=timeout)
+                    response = nameserver.query(
+                        request,
+                        timeout=timeout,
+                        source=source,
+                        source_port=source_port,
+                        max_size=tcp,
+                    )
                 except Exception as ex:
                     (_, done) = resolution.query_result(None, ex)
                     continue
@@ -1357,7 +1397,6 @@ def resolve(
     lifetime: Optional[float] = None,
     search: Optional[bool] = None,
 ) -> Answer:  # pragma: no cover
-
     """Query nameservers to find the answer to the question.
 
     This is a convenience function that uses the default resolver
index 5bf01e379b448830d3087070108fa829c25eddad..21c6c466b402f1cb3ee0252f5d5391e3df1568ea 100644 (file)
@@ -12,11 +12,12 @@ The dns.resolver.Resolver and dns.resolver.Answer Classes
 
    .. attribute:: nameservers
 
-      A ``list`` of ``str``, each item containing an IPv4 or IPv6 address.
+      A ``list`` of ``str`` or ``dns.nameserver.Nameserver``.  A string may be
+      an IPv4 or IPv6 address, or an https URL.
 
-      This field is planned to become a property in dnspython 2.4.  Writing to this
-      field other than by direct assignment is deprecated, and so is depending on the
-      mutability and form of the iterable returned when it is read.
+      This field is actually a property, and returns a tuple as of dnspython 2.4.
+      Assigning this this field converts any strings into
+      ``dns.nameserver.Nameserver`` instances.
 
    .. attribute:: search
 
diff --git a/doc/resolver-nameserver.rst b/doc/resolver-nameserver.rst
new file mode 100644 (file)
index 0000000..06f4a1b
--- /dev/null
@@ -0,0 +1,46 @@
+.. _resolver-nameserver:
+
+The dns.nameserver.Nameserver Classes
+-------------------------------------
+
+The ``dns.nameserver.Nameserver`` abstract class represents a remote recursive resolver,
+and is used by the stub resolver to answer queries.
+
+.. autoclass:: dns.nameserver.Nameserver
+   :members:
+
+The dns.nameserver.Do53Nameserver Class
+---------------------------------------
+
+The ``dns.nameserver.Do53Nameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make regular port 53 (Do53) DNS queries to a recursive server.
+
+.. autoclass:: dns.nameserver.Do53Nameserver
+   :members:
+
+The dns.nameserver.DoTNameserver Class
+---------------------------------------
+
+The ``dns.nameserver.DoTNameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make DNS-over-TLS (DoT) queries to a recursive server.
+
+.. autoclass:: dns.nameserver.DoTNameserver
+   :members:
+
+The dns.nameserver.DoHNameserver Class
+---------------------------------------
+
+The ``dns.nameserver.DoHNameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make DNS-over-HTTPS (DoH) queries to a recursive server.
+
+.. autoclass:: dns.nameserver.DoHNameserver
+   :members:
+
+The dns.nameserver.DoQNameserver Class
+---------------------------------------
+
+The ``dns.nameserver.DoQNameserver`` class is a ``dns.nameserver.Nameserver`` class use
+to make DNS-over-QUIC (DoQ) queries to a recursive server.
+
+.. autoclass:: dns.nameserver.DoQNameserver
+   :members:
index e9cf7b2e6c57f65916152c0f0ccf204109229809..138ac3ef551f07e8d79797228d5ff7d3eef2d286 100644 (file)
@@ -13,6 +13,7 @@ be used simply by setting the *nameservers* attribute.
 .. toctree::
 
    resolver-class
+   resolver-nameserver
    resolver-functions
    resolver-caching
    resolver-override
index 54d3847bb879602ee5c9c7b882bf28fdda43fe72..95fa691e44a149eb52ca0252c9ebad6f46cca72e 100644 (file)
@@ -6,6 +6,12 @@ What's New in dnspython
 2.4.0 (in development)
 ----------------------
 
+* The stub resolver now uses instances of ``dns.nameserver.Nameserver`` to represent
+  remote recursive resolvers, and can communicate using
+  DNS over port 53, HTTPS, TLS, and QUIC.  In additional to being able to specify
+  an IPv4, IPv6, or HTTPS URL as a nameserver, instances of ``dns.nameserver.Nameserver``
+  are now permitted.
+
 2.3.0
 -----
 
index 1703a7fb9bafe03d48a92b5bb7ec100f46b45b40..deb52f7a1d17404ec0f5891f7edc91692e303881 100644 (file)
@@ -60,7 +60,7 @@ coverage = "^7.0"
 twine = "^4.0.0"
 wheel = "^0.38.1"
 pylint = "^2.7.4"
-mypy = ">=0.940"
+mypy = ">=1.0.1"
 black = "^23.1.0"
 
 [tool.poetry.extras]
index d2819a122b62e5c7893a7155eaa06b48589656f9..d8bdb2c9152b1421a9fa9b4c0b879332ea17bcae 100644 (file)
@@ -222,8 +222,8 @@ class ResolutionTestCase(unittest.TestCase):
 
     def test_next_request_rotate(self):
         self.resolver.rotate = True
-        order1 = ["10.0.0.1", "10.0.0.2"]
-        order2 = ["10.0.0.2", "10.0.0.1"]
+        order1 = ["Do53:10.0.0.1@53", "Do53:10.0.0.2@53"]
+        order2 = ["Do53:10.0.0.2@53", "Do53:10.0.0.1@53"]
         seen1 = False
         seen2 = False
         # We're not interested in testing the randomness, but we'd
@@ -235,9 +235,11 @@ class ResolutionTestCase(unittest.TestCase):
                 self.resolver, self.qname, "A", "IN", False, True, False
             )
             self.resn.next_request()
-            if self.resn.nameservers == order1:
+            text_form = [str(n) for n in self.resn.nameservers]
+            print(text_form)
+            if text_form == order1:
                 seen1 = True
-            elif self.resn.nameservers == order2:
+            elif text_form == order2:
                 seen2 = True
             else:
                 raise ValueError  # should not happen!
@@ -264,68 +266,71 @@ class ResolutionTestCase(unittest.TestCase):
 
     def test_next_nameserver_udp(self):
         (request, answer) = self.resn.next_request()
-        (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver1 in self.resolver.nameservers)
-        self.assertEqual(port, 53)
+        (nameserver1, tcp, backoff) = self.resn.next_nameserver()
+        self.assertEqual(nameserver1.port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
-        (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver2 in self.resolver.nameservers)
+        (nameserver2, tcp, backoff) = self.resn.next_nameserver()
         self.assertTrue(nameserver2 != nameserver1)
-        self.assertEqual(port, 53)
+        self.assertEqual(nameserver2.port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
-        (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
+        (nameserver3, tcp, backoff) = self.resn.next_nameserver()
         self.assertTrue(nameserver3 is nameserver1)
-        self.assertEqual(port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.1)
-        (nameserver4, port, tcp, backoff) = self.resn.next_nameserver()
+        (nameserver4, tcp, backoff) = self.resn.next_nameserver()
         self.assertTrue(nameserver4 is nameserver2)
-        self.assertEqual(port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
-        (nameserver5, port, tcp, backoff) = self.resn.next_nameserver()
+        (nameserver5, tcp, backoff) = self.resn.next_nameserver()
         self.assertTrue(nameserver5 is nameserver1)
-        self.assertEqual(port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.2)
 
     def test_next_nameserver_retry_with_tcp(self):
         (request, answer) = self.resn.next_request()
-        (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver1 in self.resolver.nameservers)
-        self.assertEqual(port, 53)
+        (nameserver1, tcp, backoff) = self.resn.next_nameserver()
+        self.assertEqual(nameserver1.port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
         self.resn.retry_with_tcp = True
-        (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
+        (nameserver2, tcp, backoff) = self.resn.next_nameserver()
         self.assertTrue(nameserver2 is nameserver1)
-        self.assertEqual(port, 53)
         self.assertTrue(tcp)
         self.assertEqual(backoff, 0.0)
-        (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver3 in self.resolver.nameservers)
+        (nameserver3, tcp, backoff) = self.resn.next_nameserver()
         self.assertTrue(nameserver3 != nameserver1)
-        self.assertEqual(port, 53)
+        self.assertEqual(nameserver3.port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
 
     def test_next_nameserver_no_nameservers(self):
         (request, answer) = self.resn.next_request()
-        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (nameserver, _, _) = self.resn.next_nameserver()
         self.resn.nameservers.remove(nameserver)
-        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (nameserver, _, _) = self.resn.next_nameserver()
         self.resn.nameservers.remove(nameserver)
 
         def bad():
-            (nameserver, _, _, _) = self.resn.next_nameserver()
+            (nameserver, _, _) = self.resn.next_nameserver()
 
         self.assertRaises(dns.resolver.NoNameservers, bad)
 
+    def test_next_nameserver_max_size_nameserver(self):
+        # A query to a nameserver that always supports a maximum size query
+        # always counts as a "tcp attempt" for the state machine
+        self.resolver.nameservers = ["https://127.0.0.1:443/bogus"]
+        (_, _) = self.resn.next_request()
+        (nameserver, tcp_attempt, _) = self.resn.next_nameserver()
+        print(nameserver)
+        assert tcp_attempt
+
     def test_query_result_nameserver_removing_exceptions(self):
         # add some nameservers so we have enough to remove :)
-        self.resolver.nameservers.extend(["10.0.0.3", "10.0.0.4"])
+        new_nameservers = list(self.resolver.nameservers[:])
+        new_nameservers.extend(["10.0.0.3", "10.0.0.4"])
+        self.resolver.nameservers = new_nameservers
         (request, _) = self.resn.next_request()
         exceptions = [
             dns.exception.FormError(),
@@ -334,7 +339,7 @@ class ResolutionTestCase(unittest.TestCase):
             dns.message.Truncated(),
         ]
         for i in range(4):
-            (nameserver, _, _, _) = self.resn.next_nameserver()
+            (nameserver, _, _) = self.resn.next_nameserver()
             if i == 3:
                 # Truncated is only bad if we're doing TCP, make it look
                 # like that's the case
@@ -351,7 +356,7 @@ class ResolutionTestCase(unittest.TestCase):
         # test_query_result_nameserver_removing_exceptions(), we should
         # not remove any nameservers and just continue resolving.
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         nameservers = self.resn.nameservers[:]
         (answer, done) = self.resn.query_result(None, dns.exception.Timeout())
         self.assertTrue(answer is None)
@@ -360,7 +365,7 @@ class ResolutionTestCase(unittest.TestCase):
 
     def test_query_result_retry_with_tcp(self):
         (request, _) = self.resn.next_request()
-        (nameserver, _, tcp, _) = self.resn.next_nameserver()
+        (nameserver, tcp, _) = self.resn.next_nameserver()
         self.assertFalse(tcp)
         (answer, done) = self.resn.query_result(None, dns.message.Truncated())
         self.assertTrue(answer is None)
@@ -374,7 +379,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_address_response(q)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertFalse(answer is None)
         self.assertTrue(done)
@@ -386,7 +391,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_address_response(q)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertFalse(answer is None)
         cache_answer = self.resolver.cache.get(
@@ -398,7 +403,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_negative_response(q)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
 
         def bad():
             (answer, done) = self.resn.query_result(r, None)
@@ -409,7 +414,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_negative_response(q, True)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertTrue(answer is None)
         self.assertTrue(done)
@@ -419,7 +424,7 @@ class ResolutionTestCase(unittest.TestCase):
         r = self.make_address_response(q)
         r.set_rcode(dns.rcode.NXDOMAIN)
         (_, _) = self.resn.next_request()
-        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (nameserver, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertIsNone(answer)
         self.assertFalse(done)
@@ -429,7 +434,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_long_chain_response(q, 15)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertIsNotNone(answer)
         self.assertTrue(done)
@@ -438,7 +443,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_long_chain_response(q, 16)
         (_, _) = self.resn.next_request()
-        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (nameserver, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertIsNone(answer)
         self.assertFalse(done)
@@ -449,7 +454,7 @@ class ResolutionTestCase(unittest.TestCase):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_negative_response(q, True)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertTrue(answer is None)
         self.assertTrue(done)
@@ -463,7 +468,7 @@ class ResolutionTestCase(unittest.TestCase):
         r = self.make_address_response(q)
         r.set_rcode(dns.rcode.YXDOMAIN)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
 
         def bad():
             (answer, done) = self.resn.query_result(r, None)
@@ -475,7 +480,7 @@ class ResolutionTestCase(unittest.TestCase):
         r = self.make_address_response(q)
         r.set_rcode(dns.rcode.SERVFAIL)
         (_, _) = self.resn.next_request()
-        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (nameserver, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertTrue(answer is None)
         self.assertFalse(done)
@@ -487,7 +492,7 @@ class ResolutionTestCase(unittest.TestCase):
         r = self.make_address_response(q)
         r.set_rcode(dns.rcode.SERVFAIL)
         (_, _) = self.resn.next_request()
-        (_, _, _, _) = self.resn.next_nameserver()
+        (_, _, _) = self.resn.next_nameserver()
         nameservers = self.resn.nameservers[:]
         (answer, done) = self.resn.query_result(r, None)
         self.assertTrue(answer is None)
@@ -499,7 +504,7 @@ class ResolutionTestCase(unittest.TestCase):
         r = self.make_address_response(q)
         r.set_rcode(dns.rcode.REFUSED)
         (_, _) = self.resn.next_request()
-        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (nameserver, _, _) = self.resn.next_nameserver()
         (answer, done) = self.resn.query_result(r, None)
         self.assertTrue(answer is None)
         self.assertFalse(done)
index d21127d79d48f4ca8c6d120b57ee381b04c93d53..c1a97bf8992b8a2b13f49a52ede3ee024aa203b5 100644 (file)
@@ -27,6 +27,7 @@ from unittest.mock import patch
 import dns.e164
 import dns.message
 import dns.name
+import dns.quic
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
@@ -717,6 +718,27 @@ class LiveResolverTests(unittest.TestCase):
         answer2 = res.resolve("dns.google.", "A")
         self.assertIs(answer2, answer1)
 
+    @unittest.skipIf(not tests.util.have_ipv4(), "IPv4 not reachable")
+    def testTLSNameserver(self):
+        res = dns.resolver.Resolver(configure=False)
+        res.nameservers = [dns.nameserver.DoTNameserver("8.8.8.8", 853)]
+        answer = res.resolve("dns.google.", "A")
+        seen = set([rdata.address for rdata in answer])
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+
+    @unittest.skipIf(
+        not (tests.util.have_ipv4() and dns.quic.have_quic),
+        "IPv4 not reachable or QUIC not available",
+    )
+    def testQuicNameserver(self):
+        res = dns.resolver.Resolver(configure=False)
+        res.nameservers = [dns.nameserver.DoQNameserver("94.140.14.14", 784)]
+        answer = res.resolve("dns.adguard.com.", "A")
+        seen = set([rdata.address for rdata in answer])
+        self.assertIn("94.140.14.14", seen)
+        self.assertIn("94.140.15.15", seen)
+
     def testCanonicalNameNoCNAME(self):
         cname = dns.name.from_text("www.google.com")
         self.assertEqual(dns.resolver.canonical_name("www.google.com"), cname)
@@ -772,7 +794,6 @@ if hasattr(selectors, "PollSelector"):
 
 
 class NXDOMAINExceptionTestCase(unittest.TestCase):
-
     # pylint: disable=broad-except
 
     def test_nxdomain_compatible(self):
@@ -951,6 +972,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
             "1.2.3.4",
             1234,
             (1, 2, 3, 4),
+            (),
             {"invalid": "nameserver"},
         ]
         for invalid_nameserver in invalid_nameservers:
@@ -1123,7 +1145,7 @@ def testResolverTimeout():
             errors = e.kwargs["errors"]
             assert len(errors) > 1
             for error in errors:
-                assert error[0] == na.udp_address[0]  # address
+                assert str(error[0]) == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}"
                 assert not error[1]  # not TCP
                 assert error[2] == na.udp_address[1]  # port
                 assert isinstance(error[3], dns.exception.Timeout)  # exception
@@ -1145,7 +1167,7 @@ def testResolverNoNameservers():
             errors = e.kwargs["errors"]
             assert len(errors) == 1
             for error in errors:
-                assert error[0] == na.udp_address[0]  # address
+                assert error[0] == f"Do53:{na.udp_address[0]}@{na.udp_address[1]}"
                 assert not error[1]  # not TCP
                 assert error[2] == na.udp_address[1]  # port
                 assert error[3] == "FORMERR"