From 2dab021d03750b56cb4eca5420603f5471054a6e Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Tue, 18 May 2021 06:59:58 -0700 Subject: [PATCH] Eliminate the need for a serial parameter to inbound_xfr() --- dns/asyncquery.py | 5 +++-- dns/query.py | 4 +++- dns/xfr.py | 19 +++++++++++++++++++ tests/test_xfr.py | 15 +++++++++++++++ 4 files changed, 40 insertions(+), 3 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index c02a4789..deeff274 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -356,8 +356,7 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, async def inbound_xfr(where, txn_manager, query=None, port=53, timeout=None, lifetime=None, source=None, - source_port=0, udp_mode=UDPMode.NEVER, serial=0, - backend=None): + source_port=0, udp_mode=UDPMode.NEVER, backend=None): """Conduct an inbound transfer and apply it via a transaction from the txn_manager. @@ -369,6 +368,8 @@ async def inbound_xfr(where, txn_manager, query=None, """ if query is None: (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) rdtype = query.question[0].rdtype is_ixfr = rdtype == dns.rdatatype.IXFR origin = txn_manager.from_wire_origin() diff --git a/dns/query.py b/dns/query.py index 7cfa16c1..934bf410 100644 --- a/dns/query.py +++ b/dns/query.py @@ -997,7 +997,7 @@ class UDPMode(enum.IntEnum): def inbound_xfr(where, txn_manager, query=None, port=53, timeout=None, lifetime=None, source=None, - source_port=0, udp_mode=UDPMode.NEVER, serial=0): + source_port=0, udp_mode=UDPMode.NEVER): """Conduct an inbound transfer and apply it via a transaction from the txn_manager. @@ -1036,6 +1036,8 @@ def inbound_xfr(where, txn_manager, query=None, """ if query is None: (query, serial) = dns.xfr.make_query(txn_manager) + else: + serial = dns.xfr.extract_serial_from_query(query) rdtype = query.question[0].rdtype is_ixfr = rdtype == dns.rdatatype.IXFR origin = txn_manager.from_wire_origin() diff --git a/dns/xfr.py b/dns/xfr.py index 84059a3a..5efa6991 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -295,3 +295,22 @@ def make_query(txn_manager, serial=0, if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) return (q, serial) + +def extract_serial_from_query(query): + """Extract the SOA serial number from query if it is an IXFR and return + it, otherwise return None. + + *query* is a dns.message.QueryMessage that is an IXFR or AXFR request. + + Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have + an appropriate SOA RRset in the authority section.""" + + question = query.question[0] + if question.rdtype == dns.rdatatype.AXFR: + return None + elif question.rdtype != dns.rdatatype.IXFR: + raise ValueError("query is not an AXFR or IXFR") + print(question.name, question.rdclass) + soa = query.find_rrset(query.authority, question.name, question.rdclass, + dns.rdatatype.SOA) + return soa[0].serial diff --git a/tests/test_xfr.py b/tests/test_xfr.py index c1a011c6..7429fa35 100644 --- a/tests/test_xfr.py +++ b/tests/test_xfr.py @@ -593,6 +593,21 @@ def test_make_query_bad_serial(): with pytest.raises(ValueError): dns.xfr.make_query(z, serial=4294967296) +def test_extract_serial_from_query(): + z = dns.versioned.Zone('example.') + (q, s) = dns.xfr.make_query(z) + xs = dns.xfr.extract_serial_from_query(q) + assert s is None + assert s == xs + (q, s) = dns.xfr.make_query(z, serial=10) + print(q) + xs = dns.xfr.extract_serial_from_query(q) + assert s == 10 + assert s == xs + q = dns.message.make_query('example', 'a') + with pytest.raises(ValueError): + dns.xfr.extract_serial_from_query(q) + class XFRNanoNameserver(Server): -- 2.47.3