From 5ee0d6714a0bc0303491a253d86146b342c27918 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 10 Mar 2023 18:12:02 -0800 Subject: [PATCH] Add dns.resolver.resolve_name(). (#903) * Add dns.resolver.resolve_name(). * Add missing type annotations. * Add async resolve_name(). * Replace List[Answer] with HostAnswers. * Switch addresses_and_families() tuple order * Fix comment. --- dns/asyncresolver.py | 80 +++++++++++++++++++ dns/resolver.py | 174 ++++++++++++++++++++++++++++++++--------- tests/test_async.py | 42 ++++++++++ tests/test_resolver.py | 28 +++++++ 4 files changed, 289 insertions(+), 35 deletions(-) diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py index 9ba84de0..55ef0fb7 100644 --- a/dns/asyncresolver.py +++ b/dns/asyncresolver.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, Union +import socket import time import dns.asyncbackend @@ -135,6 +136,71 @@ class Resolver(dns.resolver.BaseResolver): dns.reversename.from_address(ipaddr), *args, **modified_kwargs ) + + async def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any + ) -> dns.resolver.HostAnswers: + """Use an asynchronous resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = await self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return dns.resolver.HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = await self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return dns.resolver.HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop('raise_on_no_answer', True) + lifetime = modified_kwargs.pop('lifetime', None) + start = time.time() + v6 = await self.resolve(name, dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = await self.resolve(name, dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs) + answers = dns.resolver.HostAnswers.make( + v6=v6, + v4=v4, + add_empty=not raise_on_no_answer + ) + if not answers: + raise NoAnswer(response=v6.response) + return answers + + # pylint: disable=redefined-outer-name async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: @@ -228,6 +294,20 @@ async def resolve_address( return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +async def resolve_name( + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any +) -> dns.resolver.HostAnswers: + """Use a resolver to asynchronously query for address records. + + See :py:func:`dns.asyncresolver.Resolver.resolve_name` for more + information on the parameters. + """ + + return await get_default_resolver().resolve_name(name, family, **kwargs) + + async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. diff --git a/dns/resolver.py b/dns/resolver.py index 63a57eee..2675b499 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -17,7 +17,7 @@ """DNS stub resolver.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from urllib.parse import urlparse import contextlib @@ -310,6 +310,61 @@ class Answer: del self.rrset[i] +class Answers(dict): + """A dict of DNS stub resolver answers, indexed by type.""" + pass + +class HostAnswers(Answers): + """A dict of DNS stub resolver answers to a host name lookup, indexed by + type. + """ + + @classmethod + def make( + self, + v6 : Optional[Answer] = None, + v4 : Optional[Answer] = None, + add_empty : bool = True + ) -> 'HostAnswers': + answers = HostAnswers() + if v6 is not None and (add_empty or v6.rrset): + answers[dns.rdatatype.AAAA] = v6 + if v4 is not None and (add_empty or v4.rrset): + answers[dns.rdatatype.A] = v4 + return answers + + + # Returns pairs of (address, family) from this result, potentiallys + # filtering by address family. + def addresses_and_families( + self, + family : int = socket.AF_UNSPEC + ) -> Iterator[Tuple[str, int]]: + if family == socket.AF_UNSPEC: + yield from self.addresses_and_families(socket.AF_INET6) + yield from self.addresses_and_families(socket.AF_INET) + return + elif family == socket.AF_INET6: + answer = self.get(dns.rdatatype.AAAA) + elif family == socket.AF_INET: + answer = self.get(dns.rdatatype.A) + else: + raise NotImplementedError(f"unknown address family {family}") + if answer: + for rdata in answer: + yield (rdata.address, family) + + # Returns addresses from this result, potentially filtering by + # address family. + def addresses(self, family : int = socket.AF_UNSPEC) -> Iterator[str]: + return (pair[0] for pair in self.addresses_and_families(family)) + + # Returns the canonical name from this result. + def canonical_name(self) -> dns.name.Name: + answer = self.get(dns.rdatatype.AAAA, self.get(dns.rdatatype.A)) + return answer.canonical_name + + class CacheStatistics: """Cache Statistics""" @@ -1343,6 +1398,66 @@ class Resolver(BaseResolver): dns.reversename.from_address(ipaddr), *args, **modified_kwargs ) + def resolve_name( + self, + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any + ) -> HostAnswers: + """Use a resolver to query for address records. + + This utilizes the resolve() method to perform A and/or AAAA lookups on + the specified name. + + *qname*, a ``dns.name.Name`` or ``str``, the name to resolve. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC + (the default), both A and AAAA records will be retrieved. + + All other arguments that can be passed to the resolve() function + except for rdtype and rdclass are also supported by this + function. + """ + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs: Dict[str, Any] = {} + modified_kwargs.update(kwargs) + modified_kwargs.pop("rdtype", None) + modified_kwargs["rdclass"] = dns.rdataclass.IN + + if family == socket.AF_INET: + v4 = self.resolve(name, dns.rdatatype.A, **modified_kwargs) + return HostAnswers.make(v4=v4) + elif family == socket.AF_INET6: + v6 = self.resolve(name, dns.rdatatype.AAAA, **modified_kwargs) + return HostAnswers.make(v6=v6) + elif family != socket.AF_UNSPEC: + raise NotImplementedError(f"unknown address family {family}") + + raise_on_no_answer = modified_kwargs.pop('raise_on_no_answer', True) + lifetime = modified_kwargs.pop('lifetime', None) + start = time.time() + v6 = self.resolve(name, dns.rdatatype.AAAA, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs) + # Note that setting name ensures we query the same name + # for A as we did for AAAA. (This is just in case search lists + # are active by default in the resolver configuration and + # we might be talking to a server that says NXDOMAIN when it + # wants to say NOERROR no data. + name = v6.qname + v4 = self.resolve(name, dns.rdatatype.A, + raise_on_no_answer=False, + lifetime=self._compute_timeout(start, lifetime), + **modified_kwargs) + answers = HostAnswers.make(v6=v6, v4=v4, add_empty=not raise_on_no_answer) + if not answers: + raise NoAnswer(response=v6.response) + return answers + + # pylint: disable=redefined-outer-name def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: @@ -1468,6 +1583,20 @@ def resolve_address(ipaddr: str, *args: Any, **kwargs: Any) -> Answer: return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) +def resolve_name( + name: Union[dns.name.Name, str], + family: int = socket.AF_UNSPEC, + **kwargs: Any +) -> HostAnswers: + """Use a resolver to query for address records. + + See ``dns.resolver.Resolver.resolve_name`` for more information on the + parameters. + """ + + return get_default_resolver().resolve_name(name, family, **kwargs) + + def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. @@ -1606,8 +1735,7 @@ def _getaddrinfo( ) if host is None and service is None: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") - v6addrs = [] - v4addrs = [] + addrs = [] canonical_name = None # pylint: disable=redefined-outer-name # Is host None or an address literal? If so, use the system's # getaddrinfo(). @@ -1623,24 +1751,9 @@ def _getaddrinfo( pass # Something needs resolution! try: - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - v6 = _resolver.resolve(host, dns.rdatatype.AAAA, raise_on_no_answer=False) - # Note that setting host ensures we query the same name - # for A as we did for AAAA. (This is just in case search lists - # are active by default in the resolver configuration and - # we might be talking to a server that says NXDOMAIN when it - # wants to say NOERROR no data. - host = v6.qname - canonical_name = v6.canonical_name.to_text(True) - if v6.rrset is not None: - for rdata in v6.rrset: - v6addrs.append(rdata.address) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - v4 = _resolver.resolve(host, dns.rdatatype.A, raise_on_no_answer=False) - canonical_name = v4.canonical_name.to_text(True) - if v4.rrset is not None: - for rdata in v4.rrset: - v4addrs.append(rdata.address) + answers = _resolver.resolve_name(host, family) + addrs = answers.addresses_and_families() + canonical_name = answers.canonical_name().to_text(True) except dns.resolver.NXDOMAIN: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") except Exception: @@ -1672,20 +1785,11 @@ def _getaddrinfo( cname = canonical_name else: cname = "" - if family == socket.AF_INET6 or family == socket.AF_UNSPEC: - for addr in v6addrs: - for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: - tuples.append( - (socket.AF_INET6, socktype, proto, cname, (addr, port, 0, 0)) - ) - if family == socket.AF_INET or family == socket.AF_UNSPEC: - for addr in v4addrs: - for socktype in socktypes: - for proto in _protocols_for_socktype[socktype]: - tuples.append( - (socket.AF_INET, socktype, proto, cname, (addr, port)) - ) + for addr, af in addrs: + for socktype in socktypes: + for proto in _protocols_for_socktype[socktype]: + addr_tuple = dns.inet.low_level_address_tuple((addr, port), af) + tuples.append((af, socktype, proto, cname, addr_tuple)) if len(tuples) == 0: raise socket.gaierror(socket.EAI_NONAME, "Name or service not known") return tuples diff --git a/tests/test_async.py b/tests/test_async.py index 65f1ed9b..52ba2e28 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -187,6 +187,48 @@ class AsyncTests(unittest.TestCase): dnsgoogle = dns.name.from_text("dns.google.") self.assertEqual(answer[0].target, dnsgoogle) + def testResolveName(self): + async def run1(): + return await dns.asyncresolver.resolve_name("dns.google.") + + answers = self.async_run(run1) + seen = set(answers.addresses()) + self.assertEqual(len(seen), 4) + self.assertIn("8.8.8.8", seen) + self.assertIn("8.8.4.4", seen) + self.assertIn("2001:4860:4860::8844", seen) + self.assertIn("2001:4860:4860::8888", seen) + + async def run2(): + return await dns.asyncresolver.resolve_name("dns.google.", socket.AF_INET) + + answers = self.async_run(run2) + seen = set(answers.addresses()) + self.assertEqual(len(seen), 2) + self.assertIn("8.8.8.8", seen) + self.assertIn("8.8.4.4", seen) + + async def run3(): + return await dns.asyncresolver.resolve_name("dns.google.", socket.AF_INET6) + + answers = self.async_run(run3) + seen = set(answers.addresses()) + self.assertEqual(len(seen), 2) + self.assertIn("2001:4860:4860::8844", seen) + self.assertIn("2001:4860:4860::8888", seen) + + async def run4(): + await dns.asyncresolver.resolve_name("nxdomain.dnspython.org") + + with self.assertRaises(dns.resolver.NXDOMAIN): + self.async_run(run4) + + async def run5(): + await dns.asyncresolver.resolve_name(dns.reversename.from_address("8.8.8.8")) + + with self.assertRaises(dns.resolver.NoAnswer): + self.async_run(run5) + def testCanonicalNameNoCNAME(self): cname = dns.name.from_text("www.google.com") diff --git a/tests/test_resolver.py b/tests/test_resolver.py index c1a97bf8..c73451e6 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -31,6 +31,7 @@ import dns.quic import dns.rdataclass import dns.rdatatype import dns.resolver +import dns.reversename import dns.tsig import dns.tsigkeyring import tests.util @@ -665,6 +666,33 @@ class LiveResolverTests(unittest.TestCase): dnsgoogle = dns.name.from_text("dns.google.") self.assertEqual(answer[0].target, dnsgoogle) + def testResolveName(self): + answers = dns.resolver.resolve_name("dns.google.") + seen = set(answers.addresses()) + self.assertEqual(len(seen), 4) + self.assertIn("8.8.8.8", seen) + self.assertIn("8.8.4.4", seen) + self.assertIn("2001:4860:4860::8844", seen) + self.assertIn("2001:4860:4860::8888", seen) + + answers = dns.resolver.resolve_name("dns.google.", socket.AF_INET) + seen = set(answers.addresses()) + self.assertEqual(len(seen), 2) + self.assertIn("8.8.8.8", seen) + self.assertIn("8.8.4.4", seen) + + answers = dns.resolver.resolve_name("dns.google.", socket.AF_INET6) + seen = set(answers.addresses()) + self.assertEqual(len(seen), 2) + self.assertIn("2001:4860:4860::8844", seen) + self.assertIn("2001:4860:4860::8888", seen) + + with self.assertRaises(dns.resolver.NXDOMAIN): + dns.resolver.resolve_name("nxdomain.dnspython.org") + + with self.assertRaises(dns.resolver.NoAnswer): + dns.resolver.resolve_name(dns.reversename.from_address("8.8.8.8")) + @patch.object(dns.message.Message, "use_edns") def testResolveEdnsOptions(self, message_use_edns_mock): resolver = dns.resolver.Resolver() -- 2.47.3