]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
is_udp is better as Inbound attribute not parameter to process_message(). Increase...
authorBob Halley <halley@dnspython.org>
Sun, 23 Aug 2020 02:11:19 +0000 (19:11 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 23 Aug 2020 02:11:19 +0000 (19:11 -0700)
dns/asyncquery.py
dns/query.py
dns/xfr.py
tests/test_xfr.py

index 702de377df81254cb489a02f016dfab89267673f..89c2622fee643a8b874838abeac207e03c937a2c 100644 (file)
@@ -391,7 +391,8 @@ async def inbound_xfr(where, txn_manager, query=None,
             else:
                 tcpmsg = struct.pack("!H", len(wire)) + wire
                 await s.sendall(tcpmsg, expiration)
-            with dns.xfr.Inbound(txn_manager, rdtype, serial) as inbound:
+            with dns.xfr.Inbound(txn_manager, rdtype, serial,
+                                 is_udp) as inbound:
                 done = False
                 tsig_ctx = None
                 while not done:
@@ -419,7 +420,7 @@ async def inbound_xfr(where, txn_manager, query=None,
                                               multi=(not is_udp),
                                               one_rr_per_rrset=is_ixfr)
                     try:
-                        done = inbound.process_message(r, is_udp)
+                        done = inbound.process_message(r)
                     except dns.xfr.UseTCP:
                         assert is_udp  # should not happen if we used TCP!
                         if udp_mode == UDPMode.ONLY:
index 37c727e3e2ae27c4eade5045af9318f14c3889d0..bd62a7a3963709a3419417087206d05144320e0b 100644 (file)
@@ -1059,7 +1059,8 @@ def inbound_xfr(where, txn_manager, query=None,
             else:
                 tcpmsg = struct.pack("!H", len(wire)) + wire
                 _net_write(s, tcpmsg, expiration)
-            with dns.xfr.Inbound(txn_manager, rdtype, serial) as inbound:
+            with dns.xfr.Inbound(txn_manager, rdtype, serial,
+                                 is_udp) as inbound:
                 done = False
                 tsig_ctx = None
                 while not done:
@@ -1079,7 +1080,7 @@ def inbound_xfr(where, txn_manager, query=None,
                                               multi=(not is_udp),
                                               one_rr_per_rrset=is_ixfr)
                     try:
-                        done = inbound.process_message(r, is_udp)
+                        done = inbound.process_message(r)
                     except dns.xfr.UseTCP:
                         assert is_udp  # should not happen if we used TCP!
                         if udp_mode == UDPMode.ONLY:
index 311e60e63475e25d4210baf899ddcef169600f0b..b07f8b9a3a749f3cb037db6a42e64d21a804cef0 100644 (file)
@@ -47,7 +47,7 @@ class Inbound:
     """
 
     def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR,
-                 serial=None):
+                 serial=None, is_udp=False):
         """Initialize an inbound zone transfer.
 
         *txn_manager* is a :py:class:`dns.transaction.TransactionManager`.
@@ -56,29 +56,33 @@ class Inbound:
 
         *serial* is the base serial number for IXFRs, and is required in
         that case.
+
+        *is_udp*, a ``bool`` indidicates if UDP is being used for this
+        XFR.
         """
         self.txn_manager = txn_manager
         self.txn = None
         self.rdtype = rdtype
-        if rdtype == dns.rdatatype.IXFR and serial is None:
-            raise ValueError('a starting serial must be supplied for IXFRs')
+        if rdtype == dns.rdatatype.IXFR:
+            if serial is None:
+                raise ValueError('a starting serial must be supplied for IXFRs')
+        elif is_udp:
+            raise ValueError('is_udp specified for AXFR')
         self.serial = serial
+        self.is_udp = is_udp
         (_, _, self.origin) = txn_manager.origin_information()
         self.soa_rdataset = None
         self.done = False
         self.expecting_SOA = False
         self.delete_mode = False
 
-    def process_message(self, message, is_udp=False):
+    def process_message(self, message):
         """Process one message in the transfer.
 
         The message should have the same relativization as was specified when
         the `dns.xfr.Inbound` was created.  The message should also have been
         created with `one_rr_per_rrset=True` because order matters.
 
-        *is_udp*, a ``bool`` indidicates if this message was received using
-        UDP.
-
         Returns `True` if the transfer is complete, and `False` otherwise.
         """
         if self.txn is None:
@@ -125,7 +129,7 @@ class Inbound:
                           self.serial)
                     raise SerialWentBackwards
                 else:
-                    if is_udp and len(message.answer[answer_index:]) == 0:
+                    if self.is_udp and len(message.answer[answer_index:]) == 0:
                         #
                         # There are no more records, so this is the
                         # "truncated" response.  Say to use TCP
@@ -216,7 +220,7 @@ class Inbound:
                 self.txn.delete_exact(name, rdataset)
             else:
                 self.txn.add(name, rdataset)
-        if is_udp and not self.done:
+        if self.is_udp and not self.done:
             #
             # This is a UDP IXFR and we didn't get to done, and we didn't
             # get the proper "truncated" response
index fbda5fa8a5ad169ba2b54c632bb2c74bc4b4a7dd..c1a011c66d6beda867bfae30e808ce699dd482d2 100644 (file)
@@ -211,6 +211,20 @@ ns3 3600 IN A 10.0.0.3
 @ 3600 IN SOA foo bar 4 2 3 4 5
 '''
 
+unexpected_end_ixfr_2 = '''id 1
+opcode QUERY
+rcode NOERROR
+flags AA
+;QUESTION
+example. IN IXFR
+;ANSWER
+@ 3600 IN SOA foo bar 4 2 3 4 5
+@ 3600 IN SOA foo bar 1 2 3 4 5
+bar.foo 300 IN MX 0 blaz.foo
+ns2 3600 IN A 10.0.0.2
+@ 3600 IN NS ns2
+'''
+
 bad_serial_ixfr = '''id 1
 opcode QUERY
 rcode NOERROR
@@ -329,9 +343,9 @@ def test_retry_tcp_ixfr():
                            zone_factory=dns.versioned.Zone)
     m = dns.message.from_text(retry_tcp_ixfr, origin=z.origin,
                               one_rr_per_rrset=True)
-    with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr:
+    with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1, is_udp=True) as xfr:
         with pytest.raises(dns.xfr.UseTCP):
-            xfr.process_message(m, True)
+            xfr.process_message(m)
 
 def test_bad_empty_ixfr():
     z = dns.zone.from_text(ixfr_expected, 'example.',
@@ -368,7 +382,9 @@ def test_ixfr_requires_serial():
     with pytest.raises(ValueError):
         dns.xfr.Inbound(z, dns.rdatatype.IXFR)
 
-def test_ixfr_unexpected_end():
+def test_ixfr_unexpected_end_bad_diff_sequence():
+    # This is where we get the end serial, but haven't seen all of
+    # the expected diffs
     z = dns.zone.from_text(base, 'example.',
                            zone_factory=dns.versioned.Zone)
     m = dns.message.from_text(unexpected_end_ixfr, origin=z.origin,
@@ -377,6 +393,17 @@ def test_ixfr_unexpected_end():
         with pytest.raises(dns.exception.FormError):
             xfr.process_message(m)
 
+def test_udp_ixfr_unexpected_end_just_stops():
+    # This is where everything looks good, but the IXFR just stops
+    # in the middle.
+    z = dns.zone.from_text(base, 'example.',
+                           zone_factory=dns.versioned.Zone)
+    m = dns.message.from_text(unexpected_end_ixfr_2, origin=z.origin,
+                              one_rr_per_rrset=True)
+    with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1, is_udp=True) as xfr:
+        with pytest.raises(dns.exception.FormError):
+            xfr.process_message(m)
+
 def test_ixfr_bad_serial():
     z = dns.zone.from_text(base, 'example.',
                            zone_factory=dns.versioned.Zone)
@@ -386,6 +413,12 @@ def test_ixfr_bad_serial():
         with pytest.raises(dns.exception.FormError):
             xfr.process_message(m)
 
+def test_no_udp_with_axfr():
+    z = dns.versioned.Zone('example.')
+    with pytest.raises(ValueError):
+        with dns.xfr.Inbound(z, dns.rdatatype.AXFR, is_udp=True) as xfr:
+            pass
+
 refused = '''id 1
 opcode QUERY
 rcode REFUSED