From: Bob Halley Date: Tue, 21 Jul 2020 14:32:27 +0000 (-0700) Subject: unify chaining code X-Git-Tag: v2.1.0rc1~162^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8c63cfd484a70097312acc368f7b10d80cd5ce87;p=thirdparty%2Fdnspython.git unify chaining code --- diff --git a/dns/message.py b/dns/message.py index 7f665722..87484f51 100644 --- a/dns/message.py +++ b/dns/message.py @@ -80,6 +80,18 @@ class Truncated(dns.exception.DNSException): return self.kwargs['message'] +class NotQueryResponse(dns.exception.DNSException): + """Message is not a response to a query.""" + + +class ChainTooLong(dns.exception.DNSException): + """The CNAME chain is too long.""" + + +class AnswerForNXDOMAIN(dns.exception.DNSException): + """The rcode is NXDOMAIN but an answer was found.""" + + class MessageSection(dns.enum.IntEnum): """Message sections""" QUESTION = 0 @@ -94,6 +106,7 @@ class MessageSection(dns.enum.IntEnum): globals().update(MessageSection.__members__) DEFAULT_EDNS_PAYLOAD = 1232 +MAX_CHAIN = 16 class Message: """A DNS message.""" @@ -232,8 +245,10 @@ class Message: dns.opcode.from_flags(self.flags) != \ dns.opcode.from_flags(other.flags): return False - if dns.rcode.from_flags(other.flags, other.ednsflags) != \ - dns.rcode.NOERROR: + if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL, + dns.rcode.NOTIMP, dns.rcode.REFUSED}: + # We don't check the question section in these cases, even + # though they still ought to have the same question. return True if dns.opcode.is_update(self.flags): # This is assuming the "sender doesn't include anything @@ -696,7 +711,96 @@ class Message: class QueryMessage(Message): - pass + def resolve_chaining(self): + """Follow the CNAME chain in the response to determine the answer + RRset. + + Raises NotQueryResponse if the message is not a response. + + Raises dns.message.ChainTooLong if the CNAME chain is too long. + + Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was + found. + + Raises dns.exception.FormError if the question count is not 1. + + Returns a tuple (dns.name.Name, int, rrset) where the name is the + canonical name, the int is the minimized TTL, and rrset is their + answer RRset, which may be ``None`` if the chain was dangling or + the response is an NXDOMAIN. + """ + if self.flags & dns.flags.QR == 0: + raise NotQueryResponse + if len(self.question) != 1: + raise dns.exception.FormError + question = self.question[0] + qname = question.name + min_ttl = -1 + rrset = None + count = 0 + while count < MAX_CHAIN: + try: + rrset = self.find_rrset(self.answer, qname, question.rdclass, + question.rdtype) + if min_ttl == -1 or rrset.ttl < min_ttl: + min_ttl = rrset.ttl + break + except KeyError: + if question.rdtype != dns.rdatatype.CNAME: + try: + crrset = self.find_rrset(self.answer, qname, + question.rdclass, + dns.rdatatype.CNAME) + if min_ttl == -1 or crrset.ttl < min_ttl: + min_ttl = crrset.ttl + for rd in crrset: + qname = rd.target + break + count += 1 + continue + except KeyError: + # Exit the chaining loop + break + if count >= MAX_CHAIN: + raise ChainTooLong + if self.rcode() == dns.rcode.NXDOMAIN and rrset is not None: + raise AnswerForNXDOMAIN + if rrset is None: + # Further minimize the TTL with NCACHE. + auname = qname + while True: + # Look for an SOA RR whose owner name is a superdomain + # of qname. + try: + srrset = self.find_rrset(self.authority, auname, + question.rdclass, + dns.rdatatype.SOA) + if min_ttl == -1 or srrset.ttl < min_ttl: + min_ttl = srrset.ttl + if srrset[0].minimum < min_ttl: + min_ttl = srrset[0].minimum + break + except KeyError: + try: + auname = auname.parent() + except dns.name.NoParent: + break + return (qname, min_ttl, rrset) + + def canonical_name(self): + """Return the canonical name of the first name in the question + section. + + Raises dns.message.NotQueryResponse if the message is not a response. + + Raises dns.message.ChainTooLong if the CNAME chain is too long. + + Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was + found. + + Raises dns.exception.FormError if the question count is not 1. + """ + return self.resolve_chaining()[0] def _maybe_import_update(): diff --git a/dns/resolver.py b/dns/resolver.py index a5079d46..4e112471 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -78,20 +78,16 @@ class NXDOMAIN(dns.exception.DNSException): """Return the unresolved canonical name.""" if 'qnames' not in self.kwargs: raise TypeError("parametrized exception required") - IN = dns.rdataclass.IN - CNAME = dns.rdatatype.CNAME - cname = None - # This code assumes the CNAME chain is in proper order, though - # the Answer code does not make a similar assumption when - # chaining. for qname in self.kwargs['qnames']: response = self.kwargs['responses'][qname] - for answer in response.answer: - if answer.rdtype != CNAME or answer.rdclass != IN: - continue - cname = answer[0].target - if cname is not None: - return cname + try: + cname = response.canonical_name() + if cname != qname: + return cname + except Exception: + # We can just eat this exception as it means there was + # something wrong with the response. + pass return self.kwargs['qnames'][0] def __add__(self, e_nx): @@ -209,50 +205,7 @@ class Answer: self.response = response self.nameserver = nameserver self.port = port - min_ttl = -1 - rrset = None - for count in range(0, 15): - try: - rrset = response.find_rrset(response.answer, qname, - rdclass, rdtype) - if min_ttl == -1 or rrset.ttl < min_ttl: - min_ttl = rrset.ttl - break - except KeyError: - if rdtype != dns.rdatatype.CNAME: - try: - crrset = response.find_rrset(response.answer, - qname, - rdclass, - dns.rdatatype.CNAME) - if min_ttl == -1 or crrset.ttl < min_ttl: - min_ttl = crrset.ttl - for rd in crrset: - qname = rd.target - break - continue - except KeyError: - # Exit the chaining loop - break - self.canonical_name = qname - self.rrset = rrset - if rrset is None: - while 1: - # Look for a SOA RR whose owner name is a superdomain - # of qname. - try: - srrset = response.find_rrset(response.authority, qname, - rdclass, dns.rdatatype.SOA) - if min_ttl == -1 or srrset.ttl < min_ttl: - min_ttl = srrset.ttl - if srrset[0].minimum < min_ttl: - min_ttl = srrset[0].minimum - break - except KeyError: - try: - qname = qname.parent() - except dns.name.NoParent: - break + (self.canonical_name, min_ttl, self.rrset) = response.resolve_chaining() self.expiration = time.time() + min_ttl def __getattr__(self, attr): # pragma: no cover @@ -698,8 +651,13 @@ class _Resolution: assert response is not None rcode = response.rcode() if rcode == dns.rcode.NOERROR: - answer = Answer(self.qname, self.rdtype, self.rdclass, response, - self.nameserver, self.port) + try: + answer = Answer(self.qname, self.rdtype, self.rdclass, response, + self.nameserver, self.port) + except Exception: + # The nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + return (None, False) if self.resolver.cache: self.resolver.cache.put((self.qname, self.rdtype, self.rdclass), answer) @@ -707,16 +665,22 @@ class _Resolution: raise NoAnswer(response=answer.response) return (answer, True) elif rcode == dns.rcode.NXDOMAIN: - self.nxdomain_responses[self.qname] = response - # Make next_nameserver() return None, so caller breaks its - # inner loop and calls next_request(). - if self.resolver.cache: + # Further validate the response by making an Answer, even + # if we aren't going to cache it. + try: answer = Answer(self.qname, dns.rdatatype.ANY, dns.rdataclass.IN, response) + except Exception: + # The nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + return (None, False) + self.nxdomain_responses[self.qname] = response + if self.resolver.cache: self.resolver.cache.put((self.qname, dns.rdatatype.ANY, self.rdclass), answer) - + # Make next_nameserver() return None, so caller breaks its + # inner loop and calls next_request(). return (None, True) elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() diff --git a/tests/test_resolution.py b/tests/test_resolution.py index 9145f167..db42d469 100644 --- a/tests/test_resolution.py +++ b/tests/test_resolution.py @@ -83,6 +83,22 @@ class ResolutionTestCase(unittest.TestCase): r.set_rcode(dns.rcode.NXDOMAIN) return r + def make_long_chain_response(self, q, count): + r = dns.message.make_response(q) + name = self.qname + for i in range(count): + rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN, + dns.rdatatype.CNAME, create=True) + tname = dns.name.from_text(f'target{i}.') + rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME, + str(tname)), 300) + name = tname + rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN, + dns.rdatatype.A, create=True) + rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A, + '10.0.0.1'), 300) + return r + def test_next_request_cache_hit(self): self.resolver.cache = dns.resolver.Cache() q = dns.message.make_query(self.qname, dns.rdatatype.A) @@ -353,6 +369,36 @@ class ResolutionTestCase(unittest.TestCase): self.assertTrue(answer is None) self.assertTrue(done) + def test_query_result_nxdomain_but_has_answer(self): + q = dns.message.make_query(self.qname, dns.rdatatype.A) + r = self.make_address_response(q) + r.set_rcode(dns.rcode.NXDOMAIN) + (_, _) = self.resn.next_request() + (nameserver, _, _, _) = self.resn.next_nameserver() + (answer, done) = self.resn.query_result(r, None) + self.assertIsNone(answer) + self.assertFalse(done) + self.assertTrue(nameserver not in self.resn.nameservers) + + def test_query_result_chain_not_too_long(self): + q = dns.message.make_query(self.qname, dns.rdatatype.A) + r = self.make_long_chain_response(q, 15) + (_, _) = self.resn.next_request() + (_, _, _, _) = self.resn.next_nameserver() + (answer, done) = self.resn.query_result(r, None) + self.assertIsNotNone(answer) + self.assertTrue(done) + + def test_query_result_chain_too_long(self): + q = dns.message.make_query(self.qname, dns.rdatatype.A) + r = self.make_long_chain_response(q, 16) + (_, _) = self.resn.next_request() + (nameserver, _, _, _) = self.resn.next_nameserver() + (answer, done) = self.resn.query_result(r, None) + self.assertIsNone(answer) + self.assertFalse(done) + self.assertTrue(nameserver not in self.resn.nameservers) + def test_query_result_nxdomain_cached(self): self.resolver.cache = dns.resolver.Cache() q = dns.message.make_query(self.qname, dns.rdatatype.A) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 171f319b..1eb8b1eb 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -104,6 +104,18 @@ example. 1 IN A 10.0.0.1 ;ADDITIONAL """ +message_text_mx = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD +;QUESTION +example. IN MX +;ANSWER +example. 1 IN A 10.0.0.1 +;AUTHORITY +;ADDITIONAL +""" + dangling_cname_0_message_text = """id 10000 opcode QUERY rcode NOERROR @@ -222,7 +234,7 @@ class BaseResolverTests(unittest.TestCase): def testIndexErrorOnEmptyRRsetAccess(self): def bad(): - message = dns.message.from_text(message_text) + message = dns.message.from_text(message_text_mx) name = dns.name.from_text('example.') answer = dns.resolver.Answer(name, dns.rdatatype.MX, dns.rdataclass.IN, message, @@ -232,7 +244,7 @@ class BaseResolverTests(unittest.TestCase): def testIndexErrorOnEmptyRRsetDelete(self): def bad(): - message = dns.message.from_text(message_text) + message = dns.message.from_text(message_text_mx) name = dns.name.from_text('example.') answer = dns.resolver.Answer(name, dns.rdatatype.MX, dns.rdataclass.IN, message,