]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add some test coverage for dns.query.xfr()
authorBob Halley <halley@dnspython.org>
Sun, 14 Jun 2020 19:44:23 +0000 (12:44 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 14 Jun 2020 19:44:23 +0000 (12:44 -0700)
tests/test_query.py

index b9699d2738480413ce0f81b684f20a928046bdd6..90fee03f9d235838115d70bf50bfe343ada0f4e5 100644 (file)
@@ -24,11 +24,13 @@ try:
 except Exception:
     have_ssl = False
 
+import dns.exception
 import dns.message
 import dns.name
 import dns.rdataclass
 import dns.rdatatype
 import dns.query
+import dns.zone
 
 # Some tests require the internet to be available to run, so let's
 # skip those if it's not there.
@@ -38,6 +40,17 @@ try:
 except socket.gaierror:
     _network_available = False
 
+# Some tests use a "nano nameserver" for testing.  It requires trio
+# and threading, so try to import it and if it doesn't work, skip
+# those tests.
+try:
+    from .nanonameserver import Server
+    _nanonameserver_available = True
+except ImportError:
+    _nanonameserver_available = False
+    class Server(object):
+        pass
+
 @unittest.skipIf(not _network_available, "Internet not reachable")
 class QueryTests(unittest.TestCase):
 
@@ -142,3 +155,210 @@ class QueryTests(unittest.TestCase):
         q = dns.message.make_query(qname, dns.rdatatype.A)
         (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
         self.assertFalse(tcp)
+
+
+axfr_zone = '''
+$ORIGIN example.
+$TTL 300
+@ SOA ns1 root 1 7200 900 1209600 86400
+@ NS ns1
+@ NS ns2
+ns1 A 10.0.0.1
+ns2 A 10.0.0.1
+'''
+
+class AXFRNanoNameserver(Server):
+
+    def handle(self, message, peer, connection_type):
+        self.zone = dns.zone.from_text(axfr_zone, origin='example')
+        self.origin = self.zone.origin
+        items = []
+        soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
+        response = dns.message.make_response(message)
+        response.flags |= dns.flags.AA
+        response.answer.append(soa)
+        items.append(response)
+        response = dns.message.make_response(message)
+        response.question = []
+        response.flags |= dns.flags.AA
+        for (name, rdataset) in self.zone.iterate_rdatasets():
+            if rdataset.rdtype == dns.rdatatype.SOA and \
+               name == dns.name.empty:
+                continue
+            rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype,
+                                    rdataset.covers)
+            rrset.update(rdataset)
+            response.answer.append(rrset)
+        items.append(response)
+        response = dns.message.make_response(message)
+        response.question = []
+        response.flags |= dns.flags.AA
+        response.answer.append(soa)
+        items.append(response)
+        return items
+
+ixfr_message = '''id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
+deleted.example. 300 IN A 10.0.0.1
+changed.example. 300 IN A 10.0.0.2
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+changed.example. 300 IN A 10.0.0.4
+added.example. 300 IN A 10.0.0.3
+example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+added2.example. 300 IN A 10.0.0.5
+example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
+'''
+
+ixfr_trailing_junk = ixfr_message + 'junk.example. 300 IN A 10.0.0.6'
+
+ixfr_up_to_date_message = '''id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
+'''
+
+axfr_trailing_junk = '''id 12345
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN AXFR
+;ANSWER
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+added.example. 300 IN A 10.0.0.3
+added2.example. 300 IN A 10.0.0.5
+changed.example. 300 IN A 10.0.0.4
+example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
+junk.example. 300 IN A 10.0.0.6
+'''
+
+class IXFRNanoNameserver(Server):
+
+    def __init__(self, response_text):
+        super().__init__()
+        self.response_text = response_text
+
+    def handle(self, message, peer, connection_type):
+        try:
+            r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
+            r.id = message.id
+            return r
+        except Exception:
+            pass
+
+@unittest.skipIf(not _nanonameserver_available,
+                 "Internet and nanonameserver required")
+class XfrTests(unittest.TestCase):
+
+    def test_axfr(self):
+        expected = dns.zone.from_text(axfr_zone, origin='example')
+        with AXFRNanoNameserver() as ns:
+            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                port=ns.tcp_address[1])
+            zone = dns.zone.from_xfr(xfr)
+            self.assertEqual(zone, expected)
+
+    def test_axfr_udp(self):
+        def bad():
+            with AXFRNanoNameserver() as ns:
+                xfr = dns.query.xfr(ns.udp_address[0], 'example',
+                                    port=ns.udp_address[1], use_udp=True)
+                l = list(xfr)
+        self.assertRaises(ValueError, bad)
+
+    def test_axfr_bad_rcode(self):
+        def bad():
+            # We just use Server here as by default it will refuse.
+            with Server() as ns:
+                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                    port=ns.tcp_address[1])
+                l = list(xfr)
+        self.assertRaises(dns.query.TransferError, bad)
+
+    def test_axfr_trailing_junk(self):
+        # we use the IXFR server here as it returns messages
+        def bad():
+            with IXFRNanoNameserver(axfr_trailing_junk) as ns:
+                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                    dns.rdatatype.AXFR,
+                                    port=ns.tcp_address[1])
+                l = list(xfr)
+        self.assertRaises(dns.exception.FormError, bad)
+
+    def test_ixfr_tcp(self):
+        with IXFRNanoNameserver(ixfr_message) as ns:
+            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                dns.rdatatype.IXFR,
+                                port=ns.tcp_address[1],
+                                serial=2,
+                                relativize=False)
+            l = list(xfr)
+            self.assertEqual(len(l), 1)
+            expected = dns.message.from_text(ixfr_message,
+                                             one_rr_per_rrset=True)
+            expected.id = l[0].id
+            self.assertEqual(l[0], expected)
+
+    def test_ixfr_udp(self):
+        with IXFRNanoNameserver(ixfr_message) as ns:
+            xfr = dns.query.xfr(ns.udp_address[0], 'example',
+                                dns.rdatatype.IXFR,
+                                port=ns.udp_address[1],
+                                serial=2,
+                                relativize=False, use_udp=True)
+            l = list(xfr)
+            self.assertEqual(len(l), 1)
+            expected = dns.message.from_text(ixfr_message,
+                                             one_rr_per_rrset=True)
+            expected.id = l[0].id
+            self.assertEqual(l[0], expected)
+
+    def test_ixfr_up_to_date(self):
+        with IXFRNanoNameserver(ixfr_up_to_date_message) as ns:
+            xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                dns.rdatatype.IXFR,
+                                port=ns.tcp_address[1],
+                                serial=2,
+                                relativize=False)
+            l = list(xfr)
+            self.assertEqual(len(l), 1)
+            expected = dns.message.from_text(ixfr_up_to_date_message,
+                                             one_rr_per_rrset=True)
+            expected.id = l[0].id
+            print(expected)
+            print(l[0])
+            self.assertEqual(l[0], expected)
+
+    def test_ixfr_trailing_junk(self):
+        def bad():
+            with IXFRNanoNameserver(ixfr_trailing_junk) as ns:
+                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                    dns.rdatatype.IXFR,
+                                    port=ns.tcp_address[1],
+                                    serial=2,
+                                    relativize=False)
+                l = list(xfr)
+        self.assertRaises(dns.exception.FormError, bad)
+
+    def test_ixfr_base_serial_mismatch(self):
+        def bad():
+            with IXFRNanoNameserver(ixfr_message) as ns:
+                xfr = dns.query.xfr(ns.tcp_address[0], 'example',
+                                    dns.rdatatype.IXFR,
+                                    port=ns.tcp_address[1],
+                                    serial=1,
+                                    relativize=False)
+                l = list(xfr)
+        self.assertRaises(dns.exception.FormError, bad)