]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
refactor resolver, extracting all business logic
authorBob Halley <halley@dnspython.org>
Mon, 18 May 2020 15:08:05 +0000 (08:08 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 18 May 2020 15:08:05 +0000 (08:08 -0700)
dns/resolver.py
dns/trio/resolver.py

index 3af35f44d31eebbd0913b9c8f91ab3a223dbce4f..474219a015593d2fcc6de86569a6909cee8ad9d7 100644 (file)
@@ -490,6 +490,169 @@ class LRUCache(object):
                     node = next
                 self.data = {}
 
+class _Resolution(object):
+    """Helper class for dns.resolver.Resolver.resolve().
+
+    All of the "business logic" of resolution is encapsulated in this
+    class, allowing us to have multiple resolve() implementations
+    using different I/O schemes without copying all of the
+    complicated logic.
+
+    This class is a "friend" to dns.resolver.Resolver and manipulates
+    resolver data structures directly.
+    """
+
+    def __init__(self, resolver, qname, rdtype, rdclass, tcp,
+                 raise_on_no_answer, search):
+        if isinstance(qname, str):
+            qname = dns.name.from_text(qname, None)
+        if isinstance(rdtype, str):
+            rdtype = dns.rdatatype.from_text(rdtype)
+        if dns.rdatatype.is_metatype(rdtype):
+            raise NoMetaqueries
+        if isinstance(rdclass, str):
+            rdclass = dns.rdataclass.from_text(rdclass)
+        if dns.rdataclass.is_metaclass(rdclass):
+            raise NoMetaqueries
+        self.resolver = resolver
+        self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
+        self.qnames = self.qnames_to_try[:]
+        self.rdtype = rdtype
+        self.rdclass = rdclass
+        self.tcp = tcp
+        self.raise_on_no_answer = raise_on_no_answer
+        self.nxdomain_responses = {}
+
+    def next_request(self):
+        """Get the next request to send, and check the cache.
+
+        Returns a (request, answer) tuple.  At most one of request or
+        answer will not be None.
+        """
+
+        # We return a tuple instead of Union[Message,Answer] as it lets
+        # the caller avoid isinstance.
+
+        if len(self.qnames) == 0:
+            #
+            # We've tried everything and only gotten NXDOMAINs.  (We know
+            # it's only NXDOMAINs as anything else would have returned
+            # before now.)
+            #
+            raise NXDOMAIN(qnames=self.qnames_to_try,
+                           responses=self.nxdomain_responses)
+
+        self.qname = self.qnames.pop()
+
+        # Do we know the answer?
+        if self.resolver.cache:
+            answer = self.resolver.cache.get((self.qname, self.rdtype,
+                                              self.rdclass))
+            if answer is not None:
+                if answer.rrset is None and self.raise_on_no_answer:
+                    raise NoAnswer(response=answer.response)
+                else:
+                    return (None, answer)
+
+        # Build the request
+        request = dns.message.make_query(self.qname, self.rdtype, self.rdclass)
+        if self.resolver.keyname is not None:
+            request.use_tsig(self.resolver.keyring, self.resolver.keyname,
+                             algorithm=self.resolver.keyalgorithm)
+        request.use_edns(self.resolver.edns, self.resolver.ednsflags,
+                         self.resolver.payload)
+        if self.resolver.flags is not None:
+            request.flags = self.resolver.flags
+
+        self.nameservers = self.resolver.nameservers[:]
+        if self.resolver.rotate:
+            random.shuffle(self.nameservers)
+        self.current_nameservers = self.nameservers[:]
+        self.errors = []
+        self.nameserver = None
+        self.tcp_attempt = False
+        self.retry_with_tcp = False
+        self.request = request
+        self.backoff = 0.10
+
+        return (request, None)
+
+    def next_nameserver(self):
+        if self.retry_with_tcp:
+            assert self.nameserver is not None
+            self.tcp_attempt = True
+            self.retry_with_tcp = False
+            return (self.nameserver, self.port, True)
+
+        backoff = 0
+        if not self.current_nameservers:
+            if len(self.nameservers) == 0:
+                # Out of things to try!
+                raise NoNameservers(request=self.request, errors=self.errors)
+            self.current_nameservers = self.nameservers[:]
+            backoff = self.backoff
+            self.backoff = min(self.backoff * 2, 2)
+
+        self.nameserver = self.current_nameservers.pop()
+        self.port = self.resolver.nameserver_ports.get(self.nameserver,
+                                                       self.resolver.port)
+        self.tcp_attempt = self.tcp
+        return (self.nameserver, self.port, self.tcp_attempt, backoff)
+
+    def query_result(self, response, ex):
+        #
+        # returns an (answer: Answer, end_loop: bool) tuple.
+        #
+        if ex:
+            # Exception during I/O or from_wire()
+            assert response is None
+            self.errors.append((self.nameserver, self.tcp_attempt, self.port,
+                                ex, response))
+            if isinstance(ex, dns.exception.FormError) or \
+               isinstance(ex, EOFError) or \
+               isinstance(ex, NotImplementedError):
+                # This nameserver is no good, take it out of the mix.
+                self.nameservers.remove(self.nameserver)
+            elif isinstance(ex, dns.message.Truncated):
+                if self.tcp_attempt:
+                    # Truncation with TCP is no good!
+                    self.nameservers.remove(self.nameserver)
+                else:
+                    self.retry_with_tcp = True
+            return (None, False)
+        # We got an answer!
+        assert response is not None
+        rcode = response.rcode()
+        if rcode == dns.rcode.NOERROR:
+            answer = Answer(self.qname, self.rdtype, self.rdclass, response,
+                            self.raise_on_no_answer, self.nameserver,
+                            self.port)
+            if self.resolver.cache:
+                self.resolver.cache.put((self.qname, self.rdtype,
+                                         self.rdclass), answer)
+            return (answer, True)
+        elif rcode == dns.rcode.NXDOMAIN:
+            self.nxdomain_responses[self.qname] = response
+            # Make next_nameserver() return None, so caller breaks its
+            # inner loop and calls next_request().
+            return (None, True)
+        elif rcode == dns.rcode.YXDOMAIN:
+            yex = YXDOMAIN()
+            self.errors.append((self.nameserver, self.tcp_attempt,
+                                self.port, yex, response))
+            raise yex
+        else:
+            #
+            # We got a response, but we're not happy with the
+            # rcode in it.  Remove the server from the mix if
+            # the rcode isn't SERVFAIL.
+            #
+            if rcode != dns.rcode.SERVFAIL or not self.retry_servfail:
+                self.nameservers.remove(self.nameserver)
+            self.errors.append((self.nameserver, self.tcp_attempt, self.port,
+                                dns.rcode.to_text(rcode), response))
+            return (None, False)
+
 class Resolver(object):
     """DNS stub resolver."""
 
@@ -862,179 +1025,47 @@ class Resolver(object):
 
         """
 
-        if isinstance(qname, str):
-            qname = dns.name.from_text(qname, None)
-        if isinstance(rdtype, str):
-            rdtype = dns.rdatatype.from_text(rdtype)
-        if dns.rdatatype.is_metatype(rdtype):
-            raise NoMetaqueries
-        if isinstance(rdclass, str):
-            rdclass = dns.rdataclass.from_text(rdclass)
-        if dns.rdataclass.is_metaclass(rdclass):
-            raise NoMetaqueries
-        qnames_to_try = self._get_qnames_to_try(qname, search)
-        all_nxdomain = True
-        nxdomain_responses = {}
+        resolution = _Resolution(self, qname, rdtype, rdclass, tcp,
+                                 raise_on_no_answer, search)
         start = time.time()
-        _qname = None  # make pylint happy
-        for _qname in qnames_to_try:
-            if self.cache:
-                answer = self.cache.get((_qname, rdtype, rdclass))
-                if answer is not None:
-                    if answer.rrset is None and raise_on_no_answer:
-                        raise NoAnswer(response=answer.response)
+        while True:
+            (request, answer) = resolution.next_request()
+            if answer:
+                # cache hit!
+                return answer
+            done = False
+            while not done:
+                (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+                if backoff:
+                    time.sleep(backoff)
+                timeout = self._compute_timeout(start, lifetime)
+                try:
+                    if dns.inet.is_address(nameserver):
+                        if tcp:
+                            response = dns.query.tcp(request, nameserver,
+                                                     timeout=timeout,
+                                                     port=port,
+                                                     source=source,
+                                                     source_port=source_port)
+                        else:
+                            response = dns.query.udp(request,
+                                                     nameserver,
+                                                     timeout=timeout,
+                                                     port=port,
+                                                     source=source,
+                                                     source_port=source_port)
                     else:
-                        return answer
-            request = dns.message.make_query(_qname, rdtype, rdclass)
-            if self.keyname is not None:
-                request.use_tsig(self.keyring, self.keyname,
-                                 algorithm=self.keyalgorithm)
-            request.use_edns(self.edns, self.ednsflags, self.payload)
-            if self.flags is not None:
-                request.flags = self.flags
-            response = None
-            #
-            # make a copy of the servers list so we can alter it later.
-            #
-            nameservers = self.nameservers[:]
-            errors = []
-            if self.rotate:
-                random.shuffle(nameservers)
-            backoff = 0.10
-            # keep track of nameserver and port
-            # to include them in Answer
-            nameserver_answered = None
-            port_answered = None
-            while response is None:
-                if len(nameservers) == 0:
-                    raise NoNameservers(request=request, errors=errors)
-                for nameserver in nameservers[:]:
-                    timeout = self._compute_timeout(start, lifetime)
-                    port = self.nameserver_ports.get(nameserver, self.port)
-                    protocol = urlparse(nameserver).scheme
-                    try:
+                        protocol = urlparse(nameserver).scheme
                         if protocol == 'https':
-                            tcp_attempt = True
                             response = dns.query.https(request, nameserver,
                                                        timeout=timeout)
                         elif protocol:
                             continue
-                        else:
-                            tcp_attempt = tcp
-                            if tcp:
-                                response = \
-                                    dns.query.tcp(request, nameserver,
-                                                  timeout=timeout,
-                                                  port=port,
-                                                  source=source,
-                                                  source_port=source_port)
-                            else:
-                                try:
-                                    response = \
-                                        dns.query.udp(request,
-                                                      nameserver,
-                                                      timeout=timeout,
-                                                      port=port,
-                                                      source=source,
-                                                      source_port=source_port)
-                                except dns.message.Truncated:
-                                    # Response truncated; retry with TCP.
-                                    tcp_attempt = True
-                                    timeout = self._compute_timeout(start,
-                                                                    lifetime)
-                                    response = \
-                                        dns.query.tcp(request, nameserver,
-                                                      timeout=timeout,
-                                                      port=port,
-                                                      source=source,
-                                                      source_port=source_port)
-                    except (socket.error, dns.exception.Timeout) as ex:
-                        #
-                        # Communication failure or timeout.  Go to the
-                        # next server
-                        #
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    except dns.query.UnexpectedSource as ex:
-                        #
-                        # Who knows?  Keep going.
-                        #
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    except dns.exception.FormError as ex:
-                        #
-                        # We don't understand what this server is
-                        # saying.  Take it out of the mix and
-                        # continue.
-                        #
-                        nameservers.remove(nameserver)
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    except EOFError as ex:
-                        #
-                        # We're using TCP and they hung up on us.
-                        # Probably they don't support TCP (though
-                        # they're supposed to!).  Take it out of the
-                        # mix and continue.
-                        #
-                        nameservers.remove(nameserver)
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    nameserver_answered = nameserver
-                    port_answered = port
-                    rcode = response.rcode()
-                    if rcode == dns.rcode.YXDOMAIN:
-                        yex = YXDOMAIN()
-                        errors.append((nameserver, tcp_attempt, port, yex,
-                                       response))
-                        raise yex
-                    if rcode == dns.rcode.NOERROR or \
-                            rcode == dns.rcode.NXDOMAIN:
-                        break
-                    #
-                    # We got a response, but we're not happy with the
-                    # rcode in it.  Remove the server from the mix if
-                    # the rcode isn't SERVFAIL.
-                    #
-                    if rcode != dns.rcode.SERVFAIL or not self.retry_servfail:
-                        nameservers.remove(nameserver)
-                    errors.append((nameserver, tcp_attempt, port,
-                                   dns.rcode.to_text(rcode), response))
-                    response = None
-                if response is not None:
-                    break
-                #
-                # All nameservers failed!
-                #
-                if len(nameservers) > 0:
-                    #
-                    # But we still have servers to try.  Sleep a bit
-                    # so we don't pound them!
-                    #
-                    timeout = self._compute_timeout(start, lifetime)
-                    sleep_time = min(timeout, backoff)
-                    backoff *= 2
-                    time.sleep(sleep_time)
-            if response.rcode() == dns.rcode.NXDOMAIN:
-                nxdomain_responses[_qname] = response
-                continue
-            all_nxdomain = False
-            break
-        if all_nxdomain:
-            raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses)
-        answer = Answer(_qname, rdtype, rdclass, response,
-                        raise_on_no_answer, nameserver_answered, port_answered)
-        if self.cache:
-            self.cache.put((_qname, rdtype, rdclass), answer)
-        return answer
+                    (answer, done) = resolution.query_result(response, None)
+                    if answer:
+                        return answer
+                except Exception as ex:
+                    (_, done) = resolution.query_result(None, ex)
 
     def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
               tcp=False, source=None, raise_on_no_answer=True, source_port=0,
index 2f96a52361265851371eaaf024dee71e181dc428..785fde671ccb1b6ba5ff323b0a743605aa8051dc 100644 (file)
 
 """trio async I/O library DNS stub resolver."""
 
-import random
-import socket
 import trio
-from urllib.parse import urlparse
 
 import dns.exception
 import dns.query
 import dns.resolver
 import dns.trio.query
 
-# import resolver symbols for compatibility and brevity
-from dns.resolver import NXDOMAIN, YXDOMAIN, NoAnswer, NoNameservers, \
-    NotAbsolute, NoRootSOA, NoMetaqueries, Answer
+# import some resolver symbols for brevity
+from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
 
 # we do this for indentation reasons below
 _udp = dns.trio.query.udp
@@ -87,62 +83,25 @@ class Resolver(dns.resolver.Resolver):
 
         """
 
-        if isinstance(qname, str):
-            qname = dns.name.from_text(qname, None)
-        if isinstance(rdtype, str):
-            rdtype = dns.rdatatype.from_text(rdtype)
-        if dns.rdatatype.is_metatype(rdtype):
-            raise NoMetaqueries
-        if isinstance(rdclass, str):
-            rdclass = dns.rdataclass.from_text(rdclass)
-        if dns.rdataclass.is_metaclass(rdclass):
-            raise NoMetaqueries
-        qnames_to_try = self._get_qnames_to_try(qname, search)
-        all_nxdomain = True
-        nxdomain_responses = {}
-        _qname = None  # make pylint happy
-        for _qname in qnames_to_try:
-            if self.cache:
-                answer = self.cache.get((_qname, rdtype, rdclass))
-                if answer is not None:
-                    if answer.rrset is None and raise_on_no_answer:
-                        raise NoAnswer(response=answer.response)
-                    else:
-                        return answer
-            request = dns.message.make_query(_qname, rdtype, rdclass)
-            if self.keyname is not None:
-                request.use_tsig(self.keyring, self.keyname,
-                                 algorithm=self.keyalgorithm)
-            request.use_edns(self.edns, self.ednsflags, self.payload)
-            if self.flags is not None:
-                request.flags = self.flags
-            response = None
-            #
-            # make a copy of the servers list so we can alter it later.
-            #
-            nameservers = self.nameservers[:]
-            errors = []
-            if self.rotate:
-                random.shuffle(nameservers)
-            backoff = 0.10
-            # keep track of nameserver and port
-            # to include them in Answer
-            nameserver_answered = None
-            port_answered = None
-            loops = 0
-            while response is None:
-                if len(nameservers) == 0:
-                    raise NoNameservers(request=request, errors=errors)
-                for nameserver in nameservers[:]:
-                    port = self.nameserver_ports.get(nameserver, self.port)
-                    protocol = urlparse(nameserver).scheme
-                    try:
-                        with trio.fail_after(self.timeout):
-                            if protocol == 'https':
-                                raise NotImplementedError
-                            elif protocol:
-                                continue
-                            tcp_attempt = tcp
+        resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
+                                              raise_on_no_answer, search)
+        while True:
+            (request, answer) = resolution.next_request()
+            if answer:
+                # cache hit!
+                return answer
+            loops = 1
+            done = False
+            while not done:
+                (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+                if backoff:
+                    loops += 1
+                    if loops >= 5:
+                        raise TooManyAttempts
+                    await trio.sleep(backoff)
+                try:
+                    with trio.fail_after(self.timeout):
+                        if dns.inet.is_address(nameserver):
                             if tcp:
                                 response = await \
                                     _stream(request, nameserver,
@@ -150,113 +109,20 @@ class Resolver(dns.resolver.Resolver):
                                             source=source,
                                             source_port=source_port)
                             else:
-                                try:
-                                    response = await \
-                                        _udp(request,
-                                             nameserver,
-                                             port=port,
-                                             source=source,
-                                             source_port=source_port)
-                                except dns.message.Truncated:
-                                    # Response truncated; retry with TCP.
-                                    tcp_attempt = True
-                                    response = await \
-                                        _stream(request, nameserver,
-                                                port=port,
-                                                source=source,
-                                                source_port=source_port)
-                    except (socket.error, trio.TooSlowError) as ex:
-                        #
-                        # Communication failure or timeout.  Go to the
-                        # next server
-                        #
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    except dns.query.UnexpectedSource as ex:
-                        #
-                        # Who knows?  Keep going.
-                        #
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    except dns.exception.FormError as ex:
-                        #
-                        # We don't understand what this server is
-                        # saying.  Take it out of the mix and
-                        # continue.
-                        #
-                        nameservers.remove(nameserver)
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    except EOFError as ex:
-                        #
-                        # We're using TCP and they hung up on us.
-                        # Probably they don't support TCP (though
-                        # they're supposed to!).  Take it out of the
-                        # mix and continue.
-                        #
-                        nameservers.remove(nameserver)
-                        errors.append((nameserver, tcp_attempt, port, ex,
-                                       response))
-                        response = None
-                        continue
-                    nameserver_answered = nameserver
-                    port_answered = port
-                    rcode = response.rcode()
-                    if rcode == dns.rcode.YXDOMAIN:
-                        yex = YXDOMAIN()
-                        errors.append((nameserver, tcp_attempt, port, yex,
-                                       response))
-                        raise yex
-                    if rcode == dns.rcode.NOERROR or \
-                       rcode == dns.rcode.NXDOMAIN:
-                        break
-                    #
-                    # We got a response, but we're not happy with the
-                    # rcode in it.  Remove the server from the mix if
-                    # the rcode isn't SERVFAIL.
-                    #
-                    if rcode != dns.rcode.SERVFAIL or not self.retry_servfail:
-                        nameservers.remove(nameserver)
-                    errors.append((nameserver, tcp_attempt, port,
-                                   dns.rcode.to_text(rcode), response))
-                    response = None
-                if response is not None:
-                    break
-                #
-                # All nameservers failed!
-                #
-                # Do not loop forever if caller hasn't used a timeout
-                # scope.
-                loops += 1
-                if loops >= 5:
-                    raise TooManyAttempts
-                if len(nameservers) > 0:
-                    #
-                    # But we still have servers to try.  Sleep a bit
-                    # so we don't pound them!
-                    #
-                    await trio.sleep(backoff)
-                    backoff *= 2
-                    if backoff > 2:
-                        backoff = 2
-            if response.rcode() == dns.rcode.NXDOMAIN:
-                nxdomain_responses[_qname] = response
-                continue
-            all_nxdomain = False
-            break
-        if all_nxdomain:
-            raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses)
-        answer = Answer(_qname, rdtype, rdclass, response, raise_on_no_answer,
-                        nameserver_answered, port_answered)
-        if self.cache:
-            self.cache.put((_qname, rdtype, rdclass), answer)
-        return answer
+                                response = await \
+                                    _udp(request,
+                                         nameserver,
+                                         port=port,
+                                         source=source,
+                                         source_port=source_port)
+                        else:
+                            # We don't do DoH yet.
+                            raise NotImplementedError
+                    (answer, done) = resolution.query_result(response, None)
+                    if answer:
+                        return answer
+                except Exception as ex:
+                    (_, done) = resolution.query_result(None, ex)
 
     async def query(self, *args, **kwargs):
         # We have to define something here as we don't want to inherit the