]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
unify chaining code
authorBob Halley <halley@dnspython.org>
Tue, 21 Jul 2020 14:32:27 +0000 (07:32 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 21 Jul 2020 14:32:27 +0000 (07:32 -0700)
dns/message.py
dns/resolver.py
tests/test_resolution.py
tests/test_resolver.py

index 7f665722069b2123ea66556fa5e0dc3ecb06751e..87484f518188fa4ab599e0494a4fb53f38a2745d 100644 (file)
@@ -80,6 +80,18 @@ class Truncated(dns.exception.DNSException):
         return self.kwargs['message']
 
 
+class NotQueryResponse(dns.exception.DNSException):
+    """Message is not a response to a query."""
+
+
+class ChainTooLong(dns.exception.DNSException):
+    """The CNAME chain is too long."""
+
+
+class AnswerForNXDOMAIN(dns.exception.DNSException):
+    """The rcode is NXDOMAIN but an answer was found."""
+
+
 class MessageSection(dns.enum.IntEnum):
     """Message sections"""
     QUESTION = 0
@@ -94,6 +106,7 @@ class MessageSection(dns.enum.IntEnum):
 globals().update(MessageSection.__members__)
 
 DEFAULT_EDNS_PAYLOAD = 1232
+MAX_CHAIN = 16
 
 class Message:
     """A DNS message."""
@@ -232,8 +245,10 @@ class Message:
            dns.opcode.from_flags(self.flags) != \
            dns.opcode.from_flags(other.flags):
             return False
-        if dns.rcode.from_flags(other.flags, other.ednsflags) != \
-                dns.rcode.NOERROR:
+        if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL,
+                             dns.rcode.NOTIMP, dns.rcode.REFUSED}:
+            # We don't check the question section in these cases, even
+            # though they still ought to have the same question.
             return True
         if dns.opcode.is_update(self.flags):
             # This is assuming the "sender doesn't include anything
@@ -696,7 +711,96 @@ class Message:
 
 
 class QueryMessage(Message):
-    pass
+    def resolve_chaining(self):
+        """Follow the CNAME chain in the response to determine the answer
+        RRset.
+
+        Raises NotQueryResponse if the message is not a response.
+
+        Raises dns.message.ChainTooLong if the CNAME chain is too long.
+
+        Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was
+        found.
+
+        Raises dns.exception.FormError if the question count is not 1.
+
+        Returns a tuple (dns.name.Name, int, rrset) where the name is the
+        canonical name, the int is the minimized TTL, and rrset is their
+        answer RRset, which may be ``None`` if the chain was dangling or
+        the response is an NXDOMAIN.
+        """
+        if self.flags & dns.flags.QR == 0:
+            raise NotQueryResponse
+        if len(self.question) != 1:
+            raise dns.exception.FormError
+        question = self.question[0]
+        qname = question.name
+        min_ttl = -1
+        rrset = None
+        count = 0
+        while count < MAX_CHAIN:
+            try:
+                rrset = self.find_rrset(self.answer, qname, question.rdclass,
+                                        question.rdtype)
+                if min_ttl == -1 or rrset.ttl < min_ttl:
+                    min_ttl = rrset.ttl
+                break
+            except KeyError:
+                if question.rdtype != dns.rdatatype.CNAME:
+                    try:
+                        crrset = self.find_rrset(self.answer, qname,
+                                                 question.rdclass,
+                                                 dns.rdatatype.CNAME)
+                        if min_ttl == -1 or crrset.ttl < min_ttl:
+                            min_ttl = crrset.ttl
+                        for rd in crrset:
+                            qname = rd.target
+                            break
+                        count += 1
+                        continue
+                    except KeyError:
+                        # Exit the chaining loop
+                        break
+        if count >= MAX_CHAIN:
+            raise ChainTooLong
+        if self.rcode() == dns.rcode.NXDOMAIN and rrset is not None:
+            raise AnswerForNXDOMAIN
+        if rrset is None:
+            # Further minimize the TTL with NCACHE.
+            auname = qname
+            while True:
+                # Look for an SOA RR whose owner name is a superdomain
+                # of qname.
+                try:
+                    srrset = self.find_rrset(self.authority, auname,
+                                             question.rdclass,
+                                             dns.rdatatype.SOA)
+                    if min_ttl == -1 or srrset.ttl < min_ttl:
+                        min_ttl = srrset.ttl
+                    if srrset[0].minimum < min_ttl:
+                        min_ttl = srrset[0].minimum
+                    break
+                except KeyError:
+                    try:
+                        auname = auname.parent()
+                    except dns.name.NoParent:
+                        break
+        return (qname, min_ttl, rrset)
+
+    def canonical_name(self):
+        """Return the canonical name of the first name in the question
+        section.
+
+        Raises dns.message.NotQueryResponse if the message is not a response.
+
+        Raises dns.message.ChainTooLong if the CNAME chain is too long.
+
+        Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was
+        found.
+
+        Raises dns.exception.FormError if the question count is not 1.
+        """
+        return self.resolve_chaining()[0]
 
 
 def _maybe_import_update():
index a5079d46e6a09ecb58d3fb2bc809a8249aa2fd6f..4e1124712c985a901c69e8212485bdfa62f4a0be 100644 (file)
@@ -78,20 +78,16 @@ class NXDOMAIN(dns.exception.DNSException):
         """Return the unresolved canonical name."""
         if 'qnames' not in self.kwargs:
             raise TypeError("parametrized exception required")
-        IN = dns.rdataclass.IN
-        CNAME = dns.rdatatype.CNAME
-        cname = None
-        # This code assumes the CNAME chain is in proper order, though
-        # the Answer code does not make a similar assumption when
-        # chaining.
         for qname in self.kwargs['qnames']:
             response = self.kwargs['responses'][qname]
-            for answer in response.answer:
-                if answer.rdtype != CNAME or answer.rdclass != IN:
-                    continue
-                cname = answer[0].target
-            if cname is not None:
-                return cname
+            try:
+                cname = response.canonical_name()
+                if cname != qname:
+                    return cname
+            except Exception:
+                # We can just eat this exception as it means there was
+                # something wrong with the response.
+                pass
         return self.kwargs['qnames'][0]
 
     def __add__(self, e_nx):
@@ -209,50 +205,7 @@ class Answer:
         self.response = response
         self.nameserver = nameserver
         self.port = port
-        min_ttl = -1
-        rrset = None
-        for count in range(0, 15):
-            try:
-                rrset = response.find_rrset(response.answer, qname,
-                                            rdclass, rdtype)
-                if min_ttl == -1 or rrset.ttl < min_ttl:
-                    min_ttl = rrset.ttl
-                break
-            except KeyError:
-                if rdtype != dns.rdatatype.CNAME:
-                    try:
-                        crrset = response.find_rrset(response.answer,
-                                                     qname,
-                                                     rdclass,
-                                                     dns.rdatatype.CNAME)
-                        if min_ttl == -1 or crrset.ttl < min_ttl:
-                            min_ttl = crrset.ttl
-                        for rd in crrset:
-                            qname = rd.target
-                            break
-                        continue
-                    except KeyError:
-                        # Exit the chaining loop
-                        break
-        self.canonical_name = qname
-        self.rrset = rrset
-        if rrset is None:
-            while 1:
-                # Look for a SOA RR whose owner name is a superdomain
-                # of qname.
-                try:
-                    srrset = response.find_rrset(response.authority, qname,
-                                                 rdclass, dns.rdatatype.SOA)
-                    if min_ttl == -1 or srrset.ttl < min_ttl:
-                        min_ttl = srrset.ttl
-                    if srrset[0].minimum < min_ttl:
-                        min_ttl = srrset[0].minimum
-                    break
-                except KeyError:
-                    try:
-                        qname = qname.parent()
-                    except dns.name.NoParent:
-                        break
+        (self.canonical_name, min_ttl, self.rrset) = response.resolve_chaining()
         self.expiration = time.time() + min_ttl
 
     def __getattr__(self, attr):  # pragma: no cover
@@ -698,8 +651,13 @@ class _Resolution:
         assert response is not None
         rcode = response.rcode()
         if rcode == dns.rcode.NOERROR:
-            answer = Answer(self.qname, self.rdtype, self.rdclass, response,
-                            self.nameserver, self.port)
+            try:
+                answer = Answer(self.qname, self.rdtype, self.rdclass, response,
+                                self.nameserver, self.port)
+            except Exception:
+                # The nameserver is no good, take it out of the mix.
+                self.nameservers.remove(self.nameserver)
+                return (None, False)
             if self.resolver.cache:
                 self.resolver.cache.put((self.qname, self.rdtype,
                                          self.rdclass), answer)
@@ -707,16 +665,22 @@ class _Resolution:
                 raise NoAnswer(response=answer.response)
             return (answer, True)
         elif rcode == dns.rcode.NXDOMAIN:
-            self.nxdomain_responses[self.qname] = response
-            # Make next_nameserver() return None, so caller breaks its
-            # inner loop and calls next_request().
-            if self.resolver.cache:
+            # Further validate the response by making an Answer, even
+            # if we aren't going to cache it.
+            try:
                 answer = Answer(self.qname, dns.rdatatype.ANY,
                                 dns.rdataclass.IN, response)
+            except Exception:
+                # The nameserver is no good, take it out of the mix.
+                self.nameservers.remove(self.nameserver)
+                return (None, False)
+            self.nxdomain_responses[self.qname] = response
+            if self.resolver.cache:
                 self.resolver.cache.put((self.qname,
                                          dns.rdatatype.ANY,
                                          self.rdclass), answer)
-
+            # Make next_nameserver() return None, so caller breaks its
+            # inner loop and calls next_request().
             return (None, True)
         elif rcode == dns.rcode.YXDOMAIN:
             yex = YXDOMAIN()
index 9145f167a355dcbec298097557c4c7ef575f2312..db42d469d0b59f268fd51ed20dbe4a58c6656041 100644 (file)
@@ -83,6 +83,22 @@ class ResolutionTestCase(unittest.TestCase):
             r.set_rcode(dns.rcode.NXDOMAIN)
         return r
 
+    def make_long_chain_response(self, q, count):
+        r = dns.message.make_response(q)
+        name = self.qname
+        for i in range(count):
+            rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN,
+                              dns.rdatatype.CNAME, create=True)
+            tname = dns.name.from_text(f'target{i}.')
+            rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME,
+                                        str(tname)), 300)
+            name = tname
+        rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN,
+                          dns.rdatatype.A, create=True)
+        rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
+                                    '10.0.0.1'), 300)
+        return r
+
     def test_next_request_cache_hit(self):
         self.resolver.cache = dns.resolver.Cache()
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
@@ -353,6 +369,36 @@ class ResolutionTestCase(unittest.TestCase):
         self.assertTrue(answer is None)
         self.assertTrue(done)
 
+    def test_query_result_nxdomain_but_has_answer(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_address_response(q)
+        r.set_rcode(dns.rcode.NXDOMAIN)
+        (_, _) = self.resn.next_request()
+        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertIsNone(answer)
+        self.assertFalse(done)
+        self.assertTrue(nameserver not in self.resn.nameservers)
+
+    def test_query_result_chain_not_too_long(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_long_chain_response(q, 15)
+        (_, _) = self.resn.next_request()
+        (_, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertIsNotNone(answer)
+        self.assertTrue(done)
+
+    def test_query_result_chain_too_long(self):
+        q = dns.message.make_query(self.qname, dns.rdatatype.A)
+        r = self.make_long_chain_response(q, 16)
+        (_, _) = self.resn.next_request()
+        (nameserver, _, _, _) = self.resn.next_nameserver()
+        (answer, done) = self.resn.query_result(r, None)
+        self.assertIsNone(answer)
+        self.assertFalse(done)
+        self.assertTrue(nameserver not in self.resn.nameservers)
+
     def test_query_result_nxdomain_cached(self):
         self.resolver.cache = dns.resolver.Cache()
         q = dns.message.make_query(self.qname, dns.rdatatype.A)
index 171f319b49a12bdcf54ba6943fd438aee425e5b5..1eb8b1eb39ce716e3b4a35de193f8a1073795e01 100644 (file)
@@ -104,6 +104,18 @@ example. 1 IN A 10.0.0.1
 ;ADDITIONAL
 """
 
+message_text_mx = """id 1234
+opcode QUERY
+rcode NOERROR
+flags QR AA RD
+;QUESTION
+example. IN MX
+;ANSWER
+example. 1 IN A 10.0.0.1
+;AUTHORITY
+;ADDITIONAL
+"""
+
 dangling_cname_0_message_text = """id 10000
 opcode QUERY
 rcode NOERROR
@@ -222,7 +234,7 @@ class BaseResolverTests(unittest.TestCase):
 
     def testIndexErrorOnEmptyRRsetAccess(self):
         def bad():
-            message = dns.message.from_text(message_text)
+            message = dns.message.from_text(message_text_mx)
             name = dns.name.from_text('example.')
             answer = dns.resolver.Answer(name, dns.rdatatype.MX,
                                          dns.rdataclass.IN, message,
@@ -232,7 +244,7 @@ class BaseResolverTests(unittest.TestCase):
 
     def testIndexErrorOnEmptyRRsetDelete(self):
         def bad():
-            message = dns.message.from_text(message_text)
+            message = dns.message.from_text(message_text_mx)
             name = dns.name.from_text('example.')
             answer = dns.resolver.Answer(name, dns.rdatatype.MX,
                                          dns.rdataclass.IN, message,