return self.kwargs['message']
+class NotQueryResponse(dns.exception.DNSException):
+ """Message is not a response to a query."""
+
+
+class ChainTooLong(dns.exception.DNSException):
+ """The CNAME chain is too long."""
+
+
+class AnswerForNXDOMAIN(dns.exception.DNSException):
+ """The rcode is NXDOMAIN but an answer was found."""
+
+
class MessageSection(dns.enum.IntEnum):
"""Message sections"""
QUESTION = 0
globals().update(MessageSection.__members__)
DEFAULT_EDNS_PAYLOAD = 1232
+MAX_CHAIN = 16
class Message:
"""A DNS message."""
dns.opcode.from_flags(self.flags) != \
dns.opcode.from_flags(other.flags):
return False
- if dns.rcode.from_flags(other.flags, other.ednsflags) != \
- dns.rcode.NOERROR:
+ if other.rcode() in {dns.rcode.FORMERR, dns.rcode.SERVFAIL,
+ dns.rcode.NOTIMP, dns.rcode.REFUSED}:
+ # We don't check the question section in these cases, even
+ # though they still ought to have the same question.
return True
if dns.opcode.is_update(self.flags):
# This is assuming the "sender doesn't include anything
class QueryMessage(Message):
- pass
+ def resolve_chaining(self):
+ """Follow the CNAME chain in the response to determine the answer
+ RRset.
+
+ Raises NotQueryResponse if the message is not a response.
+
+ Raises dns.message.ChainTooLong if the CNAME chain is too long.
+
+ Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was
+ found.
+
+ Raises dns.exception.FormError if the question count is not 1.
+
+ Returns a tuple (dns.name.Name, int, rrset) where the name is the
+ canonical name, the int is the minimized TTL, and rrset is their
+ answer RRset, which may be ``None`` if the chain was dangling or
+ the response is an NXDOMAIN.
+ """
+ if self.flags & dns.flags.QR == 0:
+ raise NotQueryResponse
+ if len(self.question) != 1:
+ raise dns.exception.FormError
+ question = self.question[0]
+ qname = question.name
+ min_ttl = -1
+ rrset = None
+ count = 0
+ while count < MAX_CHAIN:
+ try:
+ rrset = self.find_rrset(self.answer, qname, question.rdclass,
+ question.rdtype)
+ if min_ttl == -1 or rrset.ttl < min_ttl:
+ min_ttl = rrset.ttl
+ break
+ except KeyError:
+ if question.rdtype != dns.rdatatype.CNAME:
+ try:
+ crrset = self.find_rrset(self.answer, qname,
+ question.rdclass,
+ dns.rdatatype.CNAME)
+ if min_ttl == -1 or crrset.ttl < min_ttl:
+ min_ttl = crrset.ttl
+ for rd in crrset:
+ qname = rd.target
+ break
+ count += 1
+ continue
+ except KeyError:
+ # Exit the chaining loop
+ break
+ if count >= MAX_CHAIN:
+ raise ChainTooLong
+ if self.rcode() == dns.rcode.NXDOMAIN and rrset is not None:
+ raise AnswerForNXDOMAIN
+ if rrset is None:
+ # Further minimize the TTL with NCACHE.
+ auname = qname
+ while True:
+ # Look for an SOA RR whose owner name is a superdomain
+ # of qname.
+ try:
+ srrset = self.find_rrset(self.authority, auname,
+ question.rdclass,
+ dns.rdatatype.SOA)
+ if min_ttl == -1 or srrset.ttl < min_ttl:
+ min_ttl = srrset.ttl
+ if srrset[0].minimum < min_ttl:
+ min_ttl = srrset[0].minimum
+ break
+ except KeyError:
+ try:
+ auname = auname.parent()
+ except dns.name.NoParent:
+ break
+ return (qname, min_ttl, rrset)
+
+ def canonical_name(self):
+ """Return the canonical name of the first name in the question
+ section.
+
+ Raises dns.message.NotQueryResponse if the message is not a response.
+
+ Raises dns.message.ChainTooLong if the CNAME chain is too long.
+
+ Raises AnswerForNXDOMAIN if the rcode is NXDOMAIN but an answer was
+ found.
+
+ Raises dns.exception.FormError if the question count is not 1.
+ """
+ return self.resolve_chaining()[0]
def _maybe_import_update():
"""Return the unresolved canonical name."""
if 'qnames' not in self.kwargs:
raise TypeError("parametrized exception required")
- IN = dns.rdataclass.IN
- CNAME = dns.rdatatype.CNAME
- cname = None
- # This code assumes the CNAME chain is in proper order, though
- # the Answer code does not make a similar assumption when
- # chaining.
for qname in self.kwargs['qnames']:
response = self.kwargs['responses'][qname]
- for answer in response.answer:
- if answer.rdtype != CNAME or answer.rdclass != IN:
- continue
- cname = answer[0].target
- if cname is not None:
- return cname
+ try:
+ cname = response.canonical_name()
+ if cname != qname:
+ return cname
+ except Exception:
+ # We can just eat this exception as it means there was
+ # something wrong with the response.
+ pass
return self.kwargs['qnames'][0]
def __add__(self, e_nx):
self.response = response
self.nameserver = nameserver
self.port = port
- min_ttl = -1
- rrset = None
- for count in range(0, 15):
- try:
- rrset = response.find_rrset(response.answer, qname,
- rdclass, rdtype)
- if min_ttl == -1 or rrset.ttl < min_ttl:
- min_ttl = rrset.ttl
- break
- except KeyError:
- if rdtype != dns.rdatatype.CNAME:
- try:
- crrset = response.find_rrset(response.answer,
- qname,
- rdclass,
- dns.rdatatype.CNAME)
- if min_ttl == -1 or crrset.ttl < min_ttl:
- min_ttl = crrset.ttl
- for rd in crrset:
- qname = rd.target
- break
- continue
- except KeyError:
- # Exit the chaining loop
- break
- self.canonical_name = qname
- self.rrset = rrset
- if rrset is None:
- while 1:
- # Look for a SOA RR whose owner name is a superdomain
- # of qname.
- try:
- srrset = response.find_rrset(response.authority, qname,
- rdclass, dns.rdatatype.SOA)
- if min_ttl == -1 or srrset.ttl < min_ttl:
- min_ttl = srrset.ttl
- if srrset[0].minimum < min_ttl:
- min_ttl = srrset[0].minimum
- break
- except KeyError:
- try:
- qname = qname.parent()
- except dns.name.NoParent:
- break
+ (self.canonical_name, min_ttl, self.rrset) = response.resolve_chaining()
self.expiration = time.time() + min_ttl
def __getattr__(self, attr): # pragma: no cover
assert response is not None
rcode = response.rcode()
if rcode == dns.rcode.NOERROR:
- answer = Answer(self.qname, self.rdtype, self.rdclass, response,
- self.nameserver, self.port)
+ try:
+ answer = Answer(self.qname, self.rdtype, self.rdclass, response,
+ self.nameserver, self.port)
+ except Exception:
+ # The nameserver is no good, take it out of the mix.
+ self.nameservers.remove(self.nameserver)
+ return (None, False)
if self.resolver.cache:
self.resolver.cache.put((self.qname, self.rdtype,
self.rdclass), answer)
raise NoAnswer(response=answer.response)
return (answer, True)
elif rcode == dns.rcode.NXDOMAIN:
- self.nxdomain_responses[self.qname] = response
- # Make next_nameserver() return None, so caller breaks its
- # inner loop and calls next_request().
- if self.resolver.cache:
+ # Further validate the response by making an Answer, even
+ # if we aren't going to cache it.
+ try:
answer = Answer(self.qname, dns.rdatatype.ANY,
dns.rdataclass.IN, response)
+ except Exception:
+ # The nameserver is no good, take it out of the mix.
+ self.nameservers.remove(self.nameserver)
+ return (None, False)
+ self.nxdomain_responses[self.qname] = response
+ if self.resolver.cache:
self.resolver.cache.put((self.qname,
dns.rdatatype.ANY,
self.rdclass), answer)
-
+ # Make next_nameserver() return None, so caller breaks its
+ # inner loop and calls next_request().
return (None, True)
elif rcode == dns.rcode.YXDOMAIN:
yex = YXDOMAIN()
r.set_rcode(dns.rcode.NXDOMAIN)
return r
+ def make_long_chain_response(self, q, count):
+ r = dns.message.make_response(q)
+ name = self.qname
+ for i in range(count):
+ rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN,
+ dns.rdatatype.CNAME, create=True)
+ tname = dns.name.from_text(f'target{i}.')
+ rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CNAME,
+ str(tname)), 300)
+ name = tname
+ rrs = r.get_rrset(r.answer, name, dns.rdataclass.IN,
+ dns.rdatatype.A, create=True)
+ rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
+ '10.0.0.1'), 300)
+ return r
+
def test_next_request_cache_hit(self):
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
self.assertTrue(answer is None)
self.assertTrue(done)
+ def test_query_result_nxdomain_but_has_answer(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ r.set_rcode(dns.rcode.NXDOMAIN)
+ (_, _) = self.resn.next_request()
+ (nameserver, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertIsNone(answer)
+ self.assertFalse(done)
+ self.assertTrue(nameserver not in self.resn.nameservers)
+
+ def test_query_result_chain_not_too_long(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_long_chain_response(q, 15)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertIsNotNone(answer)
+ self.assertTrue(done)
+
+ def test_query_result_chain_too_long(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_long_chain_response(q, 16)
+ (_, _) = self.resn.next_request()
+ (nameserver, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertIsNone(answer)
+ self.assertFalse(done)
+ self.assertTrue(nameserver not in self.resn.nameservers)
+
def test_query_result_nxdomain_cached(self):
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
;ADDITIONAL
"""
+message_text_mx = """id 1234
+opcode QUERY
+rcode NOERROR
+flags QR AA RD
+;QUESTION
+example. IN MX
+;ANSWER
+example. 1 IN A 10.0.0.1
+;AUTHORITY
+;ADDITIONAL
+"""
+
dangling_cname_0_message_text = """id 10000
opcode QUERY
rcode NOERROR
def testIndexErrorOnEmptyRRsetAccess(self):
def bad():
- message = dns.message.from_text(message_text)
+ message = dns.message.from_text(message_text_mx)
name = dns.name.from_text('example.')
answer = dns.resolver.Answer(name, dns.rdatatype.MX,
dns.rdataclass.IN, message,
def testIndexErrorOnEmptyRRsetDelete(self):
def bad():
- message = dns.message.from_text(message_text)
+ message = dns.message.from_text(message_text_mx)
name = dns.name.from_text('example.')
answer = dns.resolver.Answer(name, dns.rdatatype.MX,
dns.rdataclass.IN, message,