From: Bob Halley Date: Sun, 23 Aug 2020 02:11:19 +0000 (-0700) Subject: is_udp is better as Inbound attribute not parameter to process_message(). Increase... X-Git-Tag: v2.1.0rc1~49 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ae2b35b0145ae3b75c9494a984b14ff3851e9532;p=thirdparty%2Fdnspython.git is_udp is better as Inbound attribute not parameter to process_message(). Increase coverage. --- diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 702de377..89c2622f 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -391,7 +391,8 @@ async def inbound_xfr(where, txn_manager, query=None, else: tcpmsg = struct.pack("!H", len(wire)) + wire await s.sendall(tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial) as inbound: + with dns.xfr.Inbound(txn_manager, rdtype, serial, + is_udp) as inbound: done = False tsig_ctx = None while not done: @@ -419,7 +420,7 @@ async def inbound_xfr(where, txn_manager, query=None, multi=(not is_udp), one_rr_per_rrset=is_ixfr) try: - done = inbound.process_message(r, is_udp) + done = inbound.process_message(r) except dns.xfr.UseTCP: assert is_udp # should not happen if we used TCP! if udp_mode == UDPMode.ONLY: diff --git a/dns/query.py b/dns/query.py index 37c727e3..bd62a7a3 100644 --- a/dns/query.py +++ b/dns/query.py @@ -1059,7 +1059,8 @@ def inbound_xfr(where, txn_manager, query=None, else: tcpmsg = struct.pack("!H", len(wire)) + wire _net_write(s, tcpmsg, expiration) - with dns.xfr.Inbound(txn_manager, rdtype, serial) as inbound: + with dns.xfr.Inbound(txn_manager, rdtype, serial, + is_udp) as inbound: done = False tsig_ctx = None while not done: @@ -1079,7 +1080,7 @@ def inbound_xfr(where, txn_manager, query=None, multi=(not is_udp), one_rr_per_rrset=is_ixfr) try: - done = inbound.process_message(r, is_udp) + done = inbound.process_message(r) except dns.xfr.UseTCP: assert is_udp # should not happen if we used TCP! if udp_mode == UDPMode.ONLY: diff --git a/dns/xfr.py b/dns/xfr.py index 311e60e6..b07f8b9a 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -47,7 +47,7 @@ class Inbound: """ def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR, - serial=None): + serial=None, is_udp=False): """Initialize an inbound zone transfer. *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. @@ -56,29 +56,33 @@ class Inbound: *serial* is the base serial number for IXFRs, and is required in that case. + + *is_udp*, a ``bool`` indidicates if UDP is being used for this + XFR. """ self.txn_manager = txn_manager self.txn = None self.rdtype = rdtype - if rdtype == dns.rdatatype.IXFR and serial is None: - raise ValueError('a starting serial must be supplied for IXFRs') + if rdtype == dns.rdatatype.IXFR: + if serial is None: + raise ValueError('a starting serial must be supplied for IXFRs') + elif is_udp: + raise ValueError('is_udp specified for AXFR') self.serial = serial + self.is_udp = is_udp (_, _, self.origin) = txn_manager.origin_information() self.soa_rdataset = None self.done = False self.expecting_SOA = False self.delete_mode = False - def process_message(self, message, is_udp=False): + def process_message(self, message): """Process one message in the transfer. The message should have the same relativization as was specified when the `dns.xfr.Inbound` was created. The message should also have been created with `one_rr_per_rrset=True` because order matters. - *is_udp*, a ``bool`` indidicates if this message was received using - UDP. - Returns `True` if the transfer is complete, and `False` otherwise. """ if self.txn is None: @@ -125,7 +129,7 @@ class Inbound: self.serial) raise SerialWentBackwards else: - if is_udp and len(message.answer[answer_index:]) == 0: + if self.is_udp and len(message.answer[answer_index:]) == 0: # # There are no more records, so this is the # "truncated" response. Say to use TCP @@ -216,7 +220,7 @@ class Inbound: self.txn.delete_exact(name, rdataset) else: self.txn.add(name, rdataset) - if is_udp and not self.done: + if self.is_udp and not self.done: # # This is a UDP IXFR and we didn't get to done, and we didn't # get the proper "truncated" response diff --git a/tests/test_xfr.py b/tests/test_xfr.py index fbda5fa8..c1a011c6 100644 --- a/tests/test_xfr.py +++ b/tests/test_xfr.py @@ -211,6 +211,20 @@ ns3 3600 IN A 10.0.0.3 @ 3600 IN SOA foo bar 4 2 3 4 5 ''' +unexpected_end_ixfr_2 = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 1 2 3 4 5 +bar.foo 300 IN MX 0 blaz.foo +ns2 3600 IN A 10.0.0.2 +@ 3600 IN NS ns2 +''' + bad_serial_ixfr = '''id 1 opcode QUERY rcode NOERROR @@ -329,9 +343,9 @@ def test_retry_tcp_ixfr(): zone_factory=dns.versioned.Zone) m = dns.message.from_text(retry_tcp_ixfr, origin=z.origin, one_rr_per_rrset=True) - with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1, is_udp=True) as xfr: with pytest.raises(dns.xfr.UseTCP): - xfr.process_message(m, True) + xfr.process_message(m) def test_bad_empty_ixfr(): z = dns.zone.from_text(ixfr_expected, 'example.', @@ -368,7 +382,9 @@ def test_ixfr_requires_serial(): with pytest.raises(ValueError): dns.xfr.Inbound(z, dns.rdatatype.IXFR) -def test_ixfr_unexpected_end(): +def test_ixfr_unexpected_end_bad_diff_sequence(): + # This is where we get the end serial, but haven't seen all of + # the expected diffs z = dns.zone.from_text(base, 'example.', zone_factory=dns.versioned.Zone) m = dns.message.from_text(unexpected_end_ixfr, origin=z.origin, @@ -377,6 +393,17 @@ def test_ixfr_unexpected_end(): with pytest.raises(dns.exception.FormError): xfr.process_message(m) +def test_udp_ixfr_unexpected_end_just_stops(): + # This is where everything looks good, but the IXFR just stops + # in the middle. + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(unexpected_end_ixfr_2, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1, is_udp=True) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + def test_ixfr_bad_serial(): z = dns.zone.from_text(base, 'example.', zone_factory=dns.versioned.Zone) @@ -386,6 +413,12 @@ def test_ixfr_bad_serial(): with pytest.raises(dns.exception.FormError): xfr.process_message(m) +def test_no_udp_with_axfr(): + z = dns.versioned.Zone('example.') + with pytest.raises(ValueError): + with dns.xfr.Inbound(z, dns.rdatatype.AXFR, is_udp=True) as xfr: + pass + refused = '''id 1 opcode QUERY rcode REFUSED