]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
finish testing of resolver business logic
authorBob Halley <halley@dnspython.org>
Wed, 20 May 2020 02:01:40 +0000 (19:01 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 20 May 2020 02:01:40 +0000 (19:01 -0700)
tests/test_resolution.py

index 2819842724423d8ccc80d3240404f333c51ab2cc..92223495ed8497248d4ab99546f4703f644618eb 100644 (file)
@@ -2,6 +2,7 @@ import unittest
 
 import dns.message
 import dns.name
+import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver
@@ -45,14 +46,28 @@ class ResolutionTestCase(unittest.TestCase):
         (request, answer) = self.resn.next_request()
         self.assertRaises(dns.resolver.NXDOMAIN, bad)
 
-    def test_next_request_cache_hit(self):
-        self.resolver.cache = dns.resolver.Cache()
-        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+    def make_address_response(self, q):
         r = dns.message.make_response(q)
         rrs = r.get_rrset(r.answer, self.qname, dns.rdataclass.IN,
                           dns.rdatatype.A, create=True)
         rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
                                     '10.0.0.1'), 300)
+        return r
+
+    def make_negative_response(self, q, nxdomain=False):
+        r = dns.message.make_response(q)
+        rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
+                          dns.rdatatype.SOA, create=True)
+        rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
+                                    '. . 1 2 3 4 300'), 300)
+        if nxdomain:
+            r.set_rcode(dns.rcode.NXDOMAIN)
+        return r
+
+    def test_next_request_cache_hit(self):
+        self.resolver.cache = dns.resolver.Cache()
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
         cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A,
                                            dns.rdataclass.IN, r)
         self.resolver.cache.put((self.qname, dns.rdatatype.A,
@@ -65,12 +80,9 @@ class ResolutionTestCase(unittest.TestCase):
         # In default mode, we should raise on a no-answer hit
         self.resolver.cache = dns.resolver.Cache()
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
-        r = dns.message.make_response(q)
-        # We need an SOA so the cache doesn't expire the answer immediately.
-        rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
-                          dns.rdatatype.SOA, create=True)
-        rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
-                                    '. . 1 2 3 4 300'), 300)
+        # Note we need an SOA so the cache doesn't expire the answer
+        # immediately, but our negative response code does that.
+        r = self.make_negative_response(q)
         cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A,
                                            dns.rdataclass.IN, r, False)
         self.resolver.cache.put((self.qname, dns.rdatatype.A,
@@ -87,15 +99,14 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertTrue(answer is cache_answer)
 
     def test_next_nameserver_udp(self):
-        nameservers = {'10.0.0.1', '10.0.0.2'}
         (request, answer) = self.resn.next_request()
         (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver1 in nameservers)
+        self.assertTrue(nameserver1 in self.resolver.nameservers)
         self.assertEqual(port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
         (nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver2 in nameservers)
+        self.assertTrue(nameserver2 in self.resolver.nameservers)
         self.assertTrue(nameserver2 != nameserver1)
         self.assertEqual(port, 53)
         self.assertFalse(tcp)
@@ -117,10 +128,9 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertEqual(backoff, 0.2)
 
     def test_next_nameserver_retry_with_tcp(self):
-        nameservers = {'10.0.0.1', '10.0.0.2'}
         (request, answer) = self.resn.next_request()
         (nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver1 in nameservers)
+        self.assertTrue(nameserver1 in self.resolver.nameservers)
         self.assertEqual(port, 53)
         self.assertFalse(tcp)
         self.assertEqual(backoff, 0.0)
@@ -131,7 +141,7 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertTrue(tcp)
         self.assertEqual(backoff, 0.0)
         (nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
-        self.assertTrue(nameserver3 in nameservers)
+        self.assertTrue(nameserver3 in self.resolver.nameservers)
         self.assertTrue(nameserver3 != nameserver1)
         self.assertEqual(port, 53)
         self.assertFalse(tcp)
@@ -146,3 +156,132 @@ class ResolutionTestCase(unittest.TestCase):
         def bad():
             (nameserver, _, _, _) = self.resn.next_nameserver()
         self.assertRaises(dns.resolver.NoNameservers, bad)
+
+    def test_query_result_nameserver_removing_exceptions(self):
+        # add some nameservers so we have enough to remove :)
+        self.resolver.nameservers.extend(['10.0.0.3', '10.0.0.4'])
+        (request, _) = self.resn.next_request()
+        exceptions = [dns.exception.FormError(), EOFError(),
+                      NotImplementedError(), dns.message.Truncated()]
+        for i in range(4):
+            (nameserver, _, _, _) = self.resn.next_nameserver()
+            if i == 3:
+                # Truncated is only bad if we're doing TCP, make it look
+                # like that's the case
+                self.resn.tcp_attempt = True
+            self.assertTrue(nameserver in self.resn.nameservers)
+            (answer, done) = self.resn.query_result(None, exceptions[i])
+            self.assertTrue(answer is None)
+            self.assertFalse(done)
+            self.assertFalse(nameserver in self.resn.nameservers)
+        self.assertEqual(len(self.resn.nameservers), 0)
+
+    def test_query_result_nameserver_continuing_exception(self):
+        # except for the exceptions tested in
+        # test_query_result_nameserver_removing_exceptions(), we should
+        # not remove any nameservers and just continue resolving.
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        nameservers = self.resn.nameservers[:]
+        (answer, done) = self.resn.query_result(None, dns.exception.Timeout())
+        self.assertTrue(answer is None)
+        self.assertFalse(done)
+        self.assertEqual(nameservers, self.resn.nameservers)
+
+    def test_query_result_retry_with_tcp(self):
+        (request, _) = self.resn.next_request()
+        (nameserver, _, tcp, _) = self.resn.next_nameserver()
+        self.assertFalse(tcp)
+        (answer, done) = self.resn.query_result(None, dns.message.Truncated())
+        self.assertTrue(answer is None)
+        self.assertFalse(done)
+        self.assertTrue(self.resn.retry_with_tcp)
+        # The rest of TCP retry logic was tested above in
+        # test_next_nameserver_retry_with_tcp(), so we do not repeat
+        # it.
+
+    def test_query_result_no_error_with_data(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertFalse(answer is None)
+        self.assertTrue(done)
+        self.assertEqual(answer.qname, self.qname)
+        self.assertEqual(answer.rdtype, dns.rdatatype.A)
+
+    def test_query_result_no_error_with_data_cached(self):
+        self.resolver.cache = dns.resolver.Cache()
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertFalse(answer is None)
+        cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.A,
+                                                dns.rdataclass.IN))
+        self.assertTrue(answer is cache_answer)
+
+    def test_query_result_no_error_no_data(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_negative_response(q)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        def bad():
+            (answer, done) = self.resn.query_result(r, None)
+        self.assertRaises(dns.resolver.NoAnswer, bad)
+
+    def test_query_result_nxdomain(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_negative_response(q, True)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertTrue(answer is None)
+        self.assertTrue(done)
+
+    def test_query_result_yxdomain(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        r.set_rcode(dns.rcode.YXDOMAIN)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        def bad():
+            (answer, done) = self.resn.query_result(r, None)
+        self.assertRaises(dns.resolver.YXDOMAIN, bad)
+
+    def test_query_result_servfail_no_retry(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        r.set_rcode(dns.rcode.SERVFAIL)
+        (_, _) = self.resn.next_request()
+        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertTrue(answer is None)
+        self.assertFalse(done)
+        self.assertTrue(nameserver not in self.resn.nameservers)
+
+    def test_query_result_servfail_with_retry(self):
+        self.resolver.retry_servfail = True
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        r.set_rcode(dns.rcode.SERVFAIL)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        nameservers = self.resn.nameservers[:]
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertTrue(answer is None)
+        self.assertFalse(done)
+        self.assertEqual(nameservers, self.resn.nameservers)
+
+    def test_query_result_other_unhappy_rcode(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        r.set_rcode(dns.rcode.REFUSED)
+        (_, _) = self.resn.next_request()
+        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertTrue(answer is None)
+        self.assertFalse(done)
+        self.assertTrue(nameserver not in self.resn.nameservers)