]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add dns.resolver.resolve_name(). (#903)
authorBrian Wellington <bwelling@xbill.org>
Sat, 11 Mar 2023 02:12:02 +0000 (18:12 -0800)
committerGitHub <noreply@github.com>
Sat, 11 Mar 2023 02:12:02 +0000 (18:12 -0800)
* Add dns.resolver.resolve_name().

* Add missing type annotations.

* Add async resolve_name().

* Replace List[Answer] with HostAnswers.

* Switch addresses_and_families() tuple order

* Fix comment.

dns/asyncresolver.py
dns/resolver.py
tests/test_async.py
tests/test_resolver.py

index 9ba84de07b69cac800b67cc7d787fd1a79ad6b86..55ef0fb75254304c099135635439795fef87d064 100644 (file)
@@ -19,6 +19,7 @@
 
 from typing import Any, Dict, Optional, Union
 
+import socket
 import time
 
 import dns.asyncbackend
@@ -135,6 +136,71 @@ class Resolver(dns.resolver.BaseResolver):
             dns.reversename.from_address(ipaddr), *args, **modified_kwargs
         )
 
+
+    async def resolve_name(
+        self,
+        name: Union[dns.name.Name, str],
+        family: int = socket.AF_UNSPEC,
+        **kwargs: Any
+    ) -> dns.resolver.HostAnswers:
+        """Use an asynchronous resolver to query for address records.
+
+        This utilizes the resolve() method to perform A and/or AAAA lookups on
+        the specified name.
+
+        *qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
+
+        *family*, an ``int``, the address family.  If socket.AF_UNSPEC
+        (the default), both A and AAAA records will be retrieved.
+
+        All other arguments that can be passed to the resolve() function
+        except for rdtype and rdclass are also supported by this
+        function.
+        """
+        # We make a modified kwargs for type checking happiness, as otherwise
+        # we get a legit warning about possibly having rdtype and rdclass
+        # in the kwargs more than once.
+        modified_kwargs: Dict[str, Any] = {}
+        modified_kwargs.update(kwargs)
+        modified_kwargs.pop("rdtype", None)
+        modified_kwargs["rdclass"] = dns.rdataclass.IN
+
+        if family == socket.AF_INET:
+            v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs)
+            return dns.resolver.HostAnswers.make(v4=v4)
+        elif family == socket.AF_INET6:
+            v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
+            return dns.resolver.HostAnswers.make(v6=v6)
+        elif family != socket.AF_UNSPEC:
+            raise NotImplementedError(f"unknown address family {family}")
+
+        raise_on_no_answer = modified_kwargs.pop('raise_on_no_answer', True)
+        lifetime = modified_kwargs.pop('lifetime', None)
+        start = time.time()
+        v6 = await self.resolve(name, dns.rdatatype.AAAA,
+                                raise_on_no_answer=False,
+                                lifetime=self._compute_timeout(start, lifetime),
+                                **modified_kwargs)
+        # Note that setting name ensures we query the same name
+        # for A as we did for AAAA.  (This is just in case search lists
+        # are active by default in the resolver configuration and
+        # we might be talking to a server that says NXDOMAIN when it
+        # wants to say NOERROR no data.
+        name = v6.qname
+        v4 = await self.resolve(name, dns.rdatatype.A,
+                                raise_on_no_answer=False,
+                                lifetime=self._compute_timeout(start, lifetime),
+                                **modified_kwargs)
+        answers = dns.resolver.HostAnswers.make(
+            v6=v6,
+            v4=v4,
+            add_empty=not raise_on_no_answer
+        )
+        if not answers:
+            raise NoAnswer(response=v6.response)
+        return answers
+
+
     # pylint: disable=redefined-outer-name
 
     async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
@@ -228,6 +294,20 @@ async def resolve_address(
     return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
 
 
+async def resolve_name(
+    name: Union[dns.name.Name, str],
+    family: int = socket.AF_UNSPEC,
+    **kwargs: Any
+) -> dns.resolver.HostAnswers:
+    """Use a resolver to asynchronously query for address records.
+
+    See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more
+    information on the parameters.
+    """
+
+    return await get_default_resolver().resolve_name(name, family, **kwargs)
+
+
 async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
     """Determine the canonical name of *name*.
 
index 63a57eee3ac29df18c03dbe16806d7fec7f0320d..2675b49900b36b2af91b26c5604dccd9d4ed8161 100644 (file)
@@ -17,7 +17,7 @@
 
 """DNS stub resolver."""
 
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
 
 from urllib.parse import urlparse
 import contextlib
@@ -310,6 +310,61 @@ class Answer:
         del self.rrset[i]
 
 
+class Answers(dict):
+    """A dict of DNS stub resolver answers, indexed by type."""
+    pass
+
+class HostAnswers(Answers):
+    """A dict of DNS stub resolver answers to a host name lookup, indexed by
+    type.
+    """
+
+    @classmethod
+    def make(
+        self,
+        v6 : Optional[Answer] = None,
+        v4 : Optional[Answer] = None,
+        add_empty : bool = True
+    ) -> 'HostAnswers':
+        answers = HostAnswers()
+        if v6 is not None and (add_empty or v6.rrset):
+            answers[dns.rdatatype.AAAA] = v6
+        if v4 is not None and (add_empty or v4.rrset):
+            answers[dns.rdatatype.A] = v4
+        return answers
+
+
+    # Returns pairs of (address, family) from this result, potentiallys
+    # filtering by address family.
+    def addresses_and_families(
+        self,
+        family : int = socket.AF_UNSPEC
+    ) -> Iterator[Tuple[str, int]]:
+        if family == socket.AF_UNSPEC:
+            yield from self.addresses_and_families(socket.AF_INET6)
+            yield from self.addresses_and_families(socket.AF_INET)
+            return
+        elif family == socket.AF_INET6:
+            answer = self.get(dns.rdatatype.AAAA)
+        elif family == socket.AF_INET:
+            answer = self.get(dns.rdatatype.A)
+        else:
+            raise NotImplementedError(f"unknown address family {family}")
+        if answer:
+            for rdata in answer:
+                yield (rdata.address, family)
+
+    # Returns addresses from this result, potentially filtering by
+    # address family.
+    def addresses(self, family : int = socket.AF_UNSPEC) -> Iterator[str]:
+        return (pair[0] for pair in self.addresses_and_families(family))
+
+    # Returns the canonical name from this result.
+    def canonical_name(self) -> dns.name.Name:
+        answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A))
+        return answer.canonical_name
+
+
 class CacheStatistics:
     """Cache Statistics"""
 
@@ -1343,6 +1398,66 @@ class Resolver(BaseResolver):
             dns.reversename.from_address(ipaddr), *args, **modified_kwargs
         )
 
+    def resolve_name(
+        self,
+        name: Union[dns.name.Name, str],
+        family: int = socket.AF_UNSPEC,
+        **kwargs: Any
+    ) -> HostAnswers:
+        """Use a resolver to query for address records.
+
+        This utilizes the resolve() method to perform A and/or AAAA lookups on
+        the specified name.
+
+        *qname*, a ``dns.name.Name`` or ``str``, the name to resolve.
+
+        *family*, an ``int``, the address family.  If socket.AF_UNSPEC
+        (the default), both A and AAAA records will be retrieved.
+
+        All other arguments that can be passed to the resolve() function
+        except for rdtype and rdclass are also supported by this
+        function.
+        """
+        # We make a modified kwargs for type checking happiness, as otherwise
+        # we get a legit warning about possibly having rdtype and rdclass
+        # in the kwargs more than once.
+        modified_kwargs: Dict[str, Any] = {}
+        modified_kwargs.update(kwargs)
+        modified_kwargs.pop("rdtype", None)
+        modified_kwargs["rdclass"] = dns.rdataclass.IN
+
+        if family == socket.AF_INET:
+            v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs)
+            return HostAnswers.make(v4=v4)
+        elif family == socket.AF_INET6:
+            v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs)
+            return HostAnswers.make(v6=v6)
+        elif family != socket.AF_UNSPEC:
+            raise NotImplementedError(f"unknown address family {family}")
+
+        raise_on_no_answer = modified_kwargs.pop('raise_on_no_answer', True)
+        lifetime = modified_kwargs.pop('lifetime', None)
+        start = time.time()
+        v6 = self.resolve(name, dns.rdatatype.AAAA,
+                          raise_on_no_answer=False,
+                          lifetime=self._compute_timeout(start, lifetime),
+                          **modified_kwargs)
+        # Note that setting name ensures we query the same name
+        # for A as we did for AAAA.  (This is just in case search lists
+        # are active by default in the resolver configuration and
+        # we might be talking to a server that says NXDOMAIN when it
+        # wants to say NOERROR no data.
+        name = v6.qname
+        v4 = self.resolve(name, dns.rdatatype.A,
+                          raise_on_no_answer=False,
+                          lifetime=self._compute_timeout(start, lifetime),
+                          **modified_kwargs)
+        answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer)
+        if not answers:
+            raise NoAnswer(response=v6.response)
+        return answers
+
+
     # pylint: disable=redefined-outer-name
 
     def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name:
@@ -1468,6 +1583,20 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer:
     return get_default_resolver().resolve_address(ipaddr, *args, **kwargs)
 
 
+def resolve_name(
+    name: Union[dns.name.Name, str],
+    family: int = socket.AF_UNSPEC,
+    **kwargs: Any
+) -> HostAnswers:
+    """Use a resolver to query for address records.
+
+    See ``dns.resolver.Resolver.resolve_name`` for more information on the
+    parameters.
+    """
+
+    return get_default_resolver().resolve_name(name, family, **kwargs)
+
+
 def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
     """Determine the canonical name of *name*.
 
@@ -1606,8 +1735,7 @@ def _getaddrinfo(
         )
     if host is None and service is None:
         raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
-    v6addrs = []
-    v4addrs = []
+    addrs = []
     canonical_name = None  # pylint: disable=redefined-outer-name
     # Is host None or an address literal?  If so, use the system's
     # getaddrinfo().
@@ -1623,24 +1751,9 @@ def _getaddrinfo(
         pass
     # Something needs resolution!
     try:
-        if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
-            v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False)
-            # Note that setting host ensures we query the same name
-            # for A as we did for AAAA.  (This is just in case search lists
-            # are active by default in the resolver configuration and
-            # we might be talking to a server that says NXDOMAIN when it
-            # wants to say NOERROR no data.
-            host = v6.qname
-            canonical_name = v6.canonical_name.to_text(True)
-            if v6.rrset is not None:
-                for rdata in v6.rrset:
-                    v6addrs.append(rdata.address)
-        if family == socket.AF_INET or family == socket.AF_UNSPEC:
-            v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False)
-            canonical_name = v4.canonical_name.to_text(True)
-            if v4.rrset is not None:
-                for rdata in v4.rrset:
-                    v4addrs.append(rdata.address)
+        answers = _resolver.resolve_name(host, family)
+        addrs = answers.addresses_and_families()
+        canonical_name = answers.canonical_name().to_text(True)
     except dns.resolver.NXDOMAIN:
         raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
     except Exception:
@@ -1672,20 +1785,11 @@ def _getaddrinfo(
         cname = canonical_name
     else:
         cname = ""
-    if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
-        for addr in v6addrs:
-            for socktype in socktypes:
-                for proto in _protocols_for_socktype[socktype]:
-                    tuples.append(
-                        (socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0))
-                    )
-    if family == socket.AF_INET or family == socket.AF_UNSPEC:
-        for addr in v4addrs:
-            for socktype in socktypes:
-                for proto in _protocols_for_socktype[socktype]:
-                    tuples.append(
-                        (socket.AF_INET, socktype, proto, cname, (addr, port))
-                    )
+    for addr, af in addrs:
+        for socktype in socktypes:
+            for proto in _protocols_for_socktype[socktype]:
+                addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
+                tuples.append((af, socktype, proto, cname, addr_tuple))
     if len(tuples) == 0:
         raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
     return tuples
index 65f1ed9bfca20effce2f71906c45e68e1f68368c..52ba2e28c09419d7193a0fc0e08091506595b0de 100644 (file)
@@ -187,6 +187,48 @@ class AsyncTests(unittest.TestCase):
         dnsgoogle = dns.name.from_text("dns.google.")
         self.assertEqual(answer[0].target, dnsgoogle)
 
+    def testResolveName(self):
+        async def run1():
+            return await dns.asyncresolver.resolve_name("dns.google.")
+
+        answers = self.async_run(run1)
+        seen = set(answers.addresses())
+        self.assertEqual(len(seen), 4)
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+        self.assertIn("2001:4860:4860::8844", seen)
+        self.assertIn("2001:4860:4860::8888", seen)
+
+        async def run2():
+            return await dns.asyncresolver.resolve_name("dns.google.", socket.AF_INET)
+
+        answers = self.async_run(run2)
+        seen = set(answers.addresses())
+        self.assertEqual(len(seen), 2)
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+
+        async def run3():
+            return await dns.asyncresolver.resolve_name("dns.google.", socket.AF_INET6)
+
+        answers = self.async_run(run3)
+        seen = set(answers.addresses())
+        self.assertEqual(len(seen), 2)
+        self.assertIn("2001:4860:4860::8844", seen)
+        self.assertIn("2001:4860:4860::8888", seen)
+
+        async def run4():
+            await dns.asyncresolver.resolve_name("nxdomain.dnspython.org")
+
+        with self.assertRaises(dns.resolver.NXDOMAIN):
+            self.async_run(run4)
+
+        async def run5():
+            await dns.asyncresolver.resolve_name(dns.reversename.from_address("8.8.8.8"))
+
+        with self.assertRaises(dns.resolver.NoAnswer):
+            self.async_run(run5)
+
     def testCanonicalNameNoCNAME(self):
         cname = dns.name.from_text("www.google.com")
 
index c1a97bf8992b8a2b13f49a52ede3ee024aa203b5..c73451e65ef1754f0b3ab92f6328a67ec225f1ce 100644 (file)
@@ -31,6 +31,7 @@ import dns.quic
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
+import dns.reversename
 import dns.tsig
 import dns.tsigkeyring
 import tests.util
@@ -665,6 +666,33 @@ class LiveResolverTests(unittest.TestCase):
         dnsgoogle = dns.name.from_text("dns.google.")
         self.assertEqual(answer[0].target, dnsgoogle)
 
+    def testResolveName(self):
+        answers = dns.resolver.resolve_name("dns.google.")
+        seen = set(answers.addresses())
+        self.assertEqual(len(seen), 4)
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+        self.assertIn("2001:4860:4860::8844", seen)
+        self.assertIn("2001:4860:4860::8888", seen)
+
+        answers = dns.resolver.resolve_name("dns.google.", socket.AF_INET)
+        seen = set(answers.addresses())
+        self.assertEqual(len(seen), 2)
+        self.assertIn("8.8.8.8", seen)
+        self.assertIn("8.8.4.4", seen)
+
+        answers = dns.resolver.resolve_name("dns.google.", socket.AF_INET6)
+        seen = set(answers.addresses())
+        self.assertEqual(len(seen), 2)
+        self.assertIn("2001:4860:4860::8844", seen)
+        self.assertIn("2001:4860:4860::8888", seen)
+
+        with self.assertRaises(dns.resolver.NXDOMAIN):
+            dns.resolver.resolve_name("nxdomain.dnspython.org")
+
+        with self.assertRaises(dns.resolver.NoAnswer):
+            dns.resolver.resolve_name(dns.reversename.from_address("8.8.8.8"))
+
     @patch.object(dns.message.Message, "use_edns")
     def testResolveEdnsOptions(self, message_use_edns_mock):
         resolver = dns.resolver.Resolver()