]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add NXDOMAIN caching
authorBob Halley <halley@dnspython.org>
Thu, 21 May 2020 21:35:40 +0000 (14:35 -0700)
committerBob Halley <halley@dnspython.org>
Thu, 21 May 2020 21:35:40 +0000 (14:35 -0700)
dns/resolver.py
tests/test_resolution.py

index e50eab8da6087c035deeabd52f41712d1f8debba..cc1f78b305ff4ec1564341214212da27a84e9aa3 100644 (file)
@@ -536,51 +536,61 @@ class _Resolution(object):
         """
 
         # We return a tuple instead of Union[Message,Answer] as it lets
-        # the caller avoid isinstance.
+        # the caller avoid isinstance().
 
-        if len(self.qnames) == 0:
-            #
-            # We've tried everything and only gotten NXDOMAINs.  (We know
-            # it's only NXDOMAINs as anything else would have returned
-            # before now.)
-            #
-            raise NXDOMAIN(qnames=self.qnames_to_try,
-                           responses=self.nxdomain_responses)
-
-        self.qname = self.qnames.pop(0)
-
-        # Do we know the answer?
-        if self.resolver.cache:
-            answer = self.resolver.cache.get((self.qname, self.rdtype,
-                                              self.rdclass))
-            if answer is not None:
-                if answer.rrset is None and self.raise_on_no_answer:
-                    raise NoAnswer(response=answer.response)
-                else:
-                    return (None, answer)
-
-        # Build the request
-        request = dns.message.make_query(self.qname, self.rdtype, self.rdclass)
-        if self.resolver.keyname is not None:
-            request.use_tsig(self.resolver.keyring, self.resolver.keyname,
-                             algorithm=self.resolver.keyalgorithm)
-        request.use_edns(self.resolver.edns, self.resolver.ednsflags,
-                         self.resolver.payload)
-        if self.resolver.flags is not None:
-            request.flags = self.resolver.flags
-
-        self.nameservers = self.resolver.nameservers[:]
-        if self.resolver.rotate:
-            random.shuffle(self.nameservers)
-        self.current_nameservers = self.nameservers[:]
-        self.errors = []
-        self.nameserver = None
-        self.tcp_attempt = False
-        self.retry_with_tcp = False
-        self.request = request
-        self.backoff = 0.10
+        while len(self.qnames) > 0:
+            self.qname = self.qnames.pop(0)
+
+            # Do we know the answer?
+            if self.resolver.cache:
+                answer = self.resolver.cache.get((self.qname, self.rdtype,
+                                                  self.rdclass))
+                if answer is not None:
+                    if answer.rrset is None and self.raise_on_no_answer:
+                        raise NoAnswer(response=answer.response)
+                    else:
+                        return (None, answer)
+                answer = self.resolver.cache.get((self.qname,
+                                                  dns.rdatatype.ANY,
+                                                  self.rdclass))
+                if answer is not None and \
+                   answer.response.rcode() == dns.rcode.NXDOMAIN:
+                    # cached NXDOMAIN; record it and continue to next
+                    # name.
+                    self.nxdomain_responses[self.qname] = answer.response
+                    continue
+
+            # Build the request
+            request = dns.message.make_query(self.qname, self.rdtype,
+                                             self.rdclass)
+            if self.resolver.keyname is not None:
+                request.use_tsig(self.resolver.keyring, self.resolver.keyname,
+                                 algorithm=self.resolver.keyalgorithm)
+            request.use_edns(self.resolver.edns, self.resolver.ednsflags,
+                             self.resolver.payload)
+            if self.resolver.flags is not None:
+                request.flags = self.resolver.flags
+
+            self.nameservers = self.resolver.nameservers[:]
+            if self.resolver.rotate:
+                random.shuffle(self.nameservers)
+            self.current_nameservers = self.nameservers[:]
+            self.errors = []
+            self.nameserver = None
+            self.tcp_attempt = False
+            self.retry_with_tcp = False
+            self.request = request
+            self.backoff = 0.10
+
+            return (request, None)
 
-        return (request, None)
+        #
+        # We've tried everything and only gotten NXDOMAINs.  (We know
+        # it's only NXDOMAINs as anything else would have returned
+        # before now.)
+        #
+        raise NXDOMAIN(qnames=self.qnames_to_try,
+                       responses=self.nxdomain_responses)
 
     def next_nameserver(self):
         if self.retry_with_tcp:
@@ -641,6 +651,13 @@ class _Resolution(object):
             self.nxdomain_responses[self.qname] = response
             # Make next_nameserver() return None, so caller breaks its
             # inner loop and calls next_request().
+            if self.resolver.cache:
+                answer = Answer(self.qname, dns.rdatatype.ANY,
+                                dns.rdataclass.IN, response)
+                self.resolver.cache.put((self.qname,
+                                         dns.rdatatype.ANY,
+                                         self.rdclass), answer)
+
             return (None, True)
         elif rcode == dns.rcode.YXDOMAIN:
             yex = YXDOMAIN()
index 95dd9ae99d336b7613e43c8c1aadcb6c60cf231b..bb1c4b136619296889ba61b060daf874867b85d8 100644 (file)
@@ -56,7 +56,7 @@ class ResolutionTestCase(unittest.TestCase):
 
     def make_negative_response(self, q, nxdomain=False):
         r = dns.message.make_response(q)
-        rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
+        rrs = r.get_rrset(r.authority, q.question[0].name, dns.rdataclass.IN,
                           dns.rdatatype.SOA, create=True)
         rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
                                     '. . 1 2 3 4 300'), 300)
@@ -76,7 +76,7 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertTrue(request is None)
         self.assertTrue(answer is cache_answer)
 
-    def test_next_request_no_answer(self):
+    def test_next_request_cached_no_answer(self):
         # In default mode, we should raise on a no-answer hit
         self.resolver.cache = dns.resolver.Cache()
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
@@ -98,6 +98,35 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertTrue(request is None)
         self.assertTrue(answer is cache_answer)
 
+    def test_next_request_cached_nxdomain(self):
+        # use a relative qname so we have two qnames to try
+        qname = dns.name.from_text('www.dnspython.org', None)
+        self.resn = dns.resolver._Resolution(self.resolver, qname,
+                                             'A', 'IN',
+                                             False, True, False)
+        qname1 = dns.name.from_text('www.dnspython.org.example.')
+        qname2 = dns.name.from_text('www.dnspython.org.')
+        # Arrange to get NXDOMAIN hits on both of those qnames.
+        self.resolver.cache = dns.resolver.Cache()
+        q1 = dns.message.make_query(qname1, dns.rdatatype.A)
+        r1 = self.make_negative_response(q1, True)
+        cache_answer = dns.resolver.Answer(qname1, dns.rdatatype.ANY,
+                                           dns.rdataclass.IN, r1)
+        self.resolver.cache.put((qname1, dns.rdatatype.ANY,
+                                 dns.rdataclass.IN), cache_answer)
+        q2 = dns.message.make_query(qname2, dns.rdatatype.A)
+        r2 = self.make_negative_response(q2, True)
+        cache_answer = dns.resolver.Answer(qname2, dns.rdatatype.ANY,
+                                           dns.rdataclass.IN, r2)
+        self.resolver.cache.put((qname2, dns.rdatatype.ANY,
+                                 dns.rdataclass.IN), cache_answer)
+        try:
+            (request, answer) = self.resn.next_request()
+            self.assertTrue(False)  # should not happen!
+        except dns.resolver.NXDOMAIN as nx:
+            self.assertTrue(nx.response(qname1) is r1)
+            self.assertTrue(nx.response(qname2) is r2)
+
     def test_next_nameserver_udp(self):
         (request, answer) = self.resn.next_request()
         (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
@@ -241,6 +270,19 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertTrue(answer is None)
         self.assertTrue(done)
 
+    def test_query_result_nxdomain_cached(self):
+        self.resolver.cache = dns.resolver.Cache()
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_negative_response(q, True)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertTrue(answer is None)
+        self.assertTrue(done)
+        cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.ANY,
+                                                dns.rdataclass.IN))
+        self.assertTrue(cache_answer.response is r)
+
     def test_query_result_yxdomain(self):
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
         r = self.make_address_response(q)