import dns.message
import dns.name
+import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.resolver
(request, answer) = self.resn.next_request()
self.assertRaises(dns.resolver.NXDOMAIN, bad)
- def test_next_request_cache_hit(self):
- self.resolver.cache = dns.resolver.Cache()
- q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ def make_address_response(self, q):
r = dns.message.make_response(q)
rrs = r.get_rrset(r.answer, self.qname, dns.rdataclass.IN,
dns.rdatatype.A, create=True)
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.A,
'10.0.0.1'), 300)
+ return r
+
+ def make_negative_response(self, q, nxdomain=False):
+ r = dns.message.make_response(q)
+ rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
+ dns.rdatatype.SOA, create=True)
+ rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
+ '. . 1 2 3 4 300'), 300)
+ if nxdomain:
+ r.set_rcode(dns.rcode.NXDOMAIN)
+ return r
+
+ def test_next_request_cache_hit(self):
+ self.resolver.cache = dns.resolver.Cache()
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A,
dns.rdataclass.IN, r)
self.resolver.cache.put((self.qname, dns.rdatatype.A,
# In default mode, we should raise on a no-answer hit
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
- r = dns.message.make_response(q)
- # We need an SOA so the cache doesn't expire the answer immediately.
- rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
- dns.rdatatype.SOA, create=True)
- rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
- '. . 1 2 3 4 300'), 300)
+ # Note we need an SOA so the cache doesn't expire the answer
+ # immediately, but our negative response code does that.
+ r = self.make_negative_response(q)
cache_answer = dns.resolver.Answer(self.qname, dns.rdatatype.A,
dns.rdataclass.IN, r, False)
self.resolver.cache.put((self.qname, dns.rdatatype.A,
self.assertTrue(answer is cache_answer)
def test_next_nameserver_udp(self):
- nameservers = {'10.0.0.1', '10.0.0.2'}
(request, answer) = self.resn.next_request()
(nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver1 in nameservers)
+ self.assertTrue(nameserver1 in self.resolver.nameservers)
self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
(nameserver2, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver2 in nameservers)
+ self.assertTrue(nameserver2 in self.resolver.nameservers)
self.assertTrue(nameserver2 != nameserver1)
self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.2)
def test_next_nameserver_retry_with_tcp(self):
- nameservers = {'10.0.0.1', '10.0.0.2'}
(request, answer) = self.resn.next_request()
(nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver1 in nameservers)
+ self.assertTrue(nameserver1 in self.resolver.nameservers)
self.assertEqual(port, 53)
self.assertFalse(tcp)
self.assertEqual(backoff, 0.0)
self.assertTrue(tcp)
self.assertEqual(backoff, 0.0)
(nameserver3, port, tcp, backoff) = self.resn.next_nameserver()
- self.assertTrue(nameserver3 in nameservers)
+ self.assertTrue(nameserver3 in self.resolver.nameservers)
self.assertTrue(nameserver3 != nameserver1)
self.assertEqual(port, 53)
self.assertFalse(tcp)
def bad():
(nameserver, _, _, _) = self.resn.next_nameserver()
self.assertRaises(dns.resolver.NoNameservers, bad)
+
+ def test_query_result_nameserver_removing_exceptions(self):
+ # add some nameservers so we have enough to remove :)
+ self.resolver.nameservers.extend(['10.0.0.3', '10.0.0.4'])
+ (request, _) = self.resn.next_request()
+ exceptions = [dns.exception.FormError(), EOFError(),
+ NotImplementedError(), dns.message.Truncated()]
+ for i in range(4):
+ (nameserver, _, _, _) = self.resn.next_nameserver()
+ if i == 3:
+ # Truncated is only bad if we're doing TCP, make it look
+ # like that's the case
+ self.resn.tcp_attempt = True
+ self.assertTrue(nameserver in self.resn.nameservers)
+ (answer, done) = self.resn.query_result(None, exceptions[i])
+ self.assertTrue(answer is None)
+ self.assertFalse(done)
+ self.assertFalse(nameserver in self.resn.nameservers)
+ self.assertEqual(len(self.resn.nameservers), 0)
+
+ def test_query_result_nameserver_continuing_exception(self):
+ # except for the exceptions tested in
+ # test_query_result_nameserver_removing_exceptions(), we should
+ # not remove any nameservers and just continue resolving.
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ nameservers = self.resn.nameservers[:]
+ (answer, done) = self.resn.query_result(None, dns.exception.Timeout())
+ self.assertTrue(answer is None)
+ self.assertFalse(done)
+ self.assertEqual(nameservers, self.resn.nameservers)
+
+ def test_query_result_retry_with_tcp(self):
+ (request, _) = self.resn.next_request()
+ (nameserver, _, tcp, _) = self.resn.next_nameserver()
+ self.assertFalse(tcp)
+ (answer, done) = self.resn.query_result(None, dns.message.Truncated())
+ self.assertTrue(answer is None)
+ self.assertFalse(done)
+ self.assertTrue(self.resn.retry_with_tcp)
+ # The rest of TCP retry logic was tested above in
+ # test_next_nameserver_retry_with_tcp(), so we do not repeat
+ # it.
+
+ def test_query_result_no_error_with_data(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertFalse(answer is None)
+ self.assertTrue(done)
+ self.assertEqual(answer.qname, self.qname)
+ self.assertEqual(answer.rdtype, dns.rdatatype.A)
+
+ def test_query_result_no_error_with_data_cached(self):
+ self.resolver.cache = dns.resolver.Cache()
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertFalse(answer is None)
+ cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.A,
+ dns.rdataclass.IN))
+ self.assertTrue(answer is cache_answer)
+
+ def test_query_result_no_error_no_data(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_negative_response(q)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ def bad():
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertRaises(dns.resolver.NoAnswer, bad)
+
+ def test_query_result_nxdomain(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_negative_response(q, True)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertTrue(answer is None)
+ self.assertTrue(done)
+
+ def test_query_result_yxdomain(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ r.set_rcode(dns.rcode.YXDOMAIN)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ def bad():
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertRaises(dns.resolver.YXDOMAIN, bad)
+
+ def test_query_result_servfail_no_retry(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ r.set_rcode(dns.rcode.SERVFAIL)
+ (_, _) = self.resn.next_request()
+ (nameserver, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertTrue(answer is None)
+ self.assertFalse(done)
+ self.assertTrue(nameserver not in self.resn.nameservers)
+
+ def test_query_result_servfail_with_retry(self):
+ self.resolver.retry_servfail = True
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ r.set_rcode(dns.rcode.SERVFAIL)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ nameservers = self.resn.nameservers[:]
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertTrue(answer is None)
+ self.assertFalse(done)
+ self.assertEqual(nameservers, self.resn.nameservers)
+
+ def test_query_result_other_unhappy_rcode(self):
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_address_response(q)
+ r.set_rcode(dns.rcode.REFUSED)
+ (_, _) = self.resn.next_request()
+ (nameserver, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertTrue(answer is None)
+ self.assertFalse(done)
+ self.assertTrue(nameserver not in self.resn.nameservers)