]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Eliminate the need for a serial parameter to inbound_xfr()
authorBob Halley <halley@dnspython.org>
Tue, 18 May 2021 13:59:58 +0000 (06:59 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 18 May 2021 13:59:58 +0000 (06:59 -0700)
dns/asyncquery.py
dns/query.py
dns/xfr.py
tests/test_xfr.py

index c02a4789d982135e3b4739f4b15aab22275d1206..deeff2741926cfda1e2523d8736a1a251ac62457 100644 (file)
@@ -356,8 +356,7 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
 
 async def inbound_xfr(where, txn_manager, query=None,
                       port=53, timeout=None, lifetime=None, source=None,
-                      source_port=0, udp_mode=UDPMode.NEVER, serial=0,
-                      backend=None):
+                      source_port=0, udp_mode=UDPMode.NEVER, backend=None):
     """Conduct an inbound transfer and apply it via a transaction from the
     txn_manager.
 
@@ -369,6 +368,8 @@ async def inbound_xfr(where, txn_manager, query=None,
     """
     if query is None:
         (query, serial) = dns.xfr.make_query(txn_manager)
+    else:
+        serial = dns.xfr.extract_serial_from_query(query)
     rdtype = query.question[0].rdtype
     is_ixfr = rdtype == dns.rdatatype.IXFR
     origin = txn_manager.from_wire_origin()
index 7cfa16c10dd003f01f02387659d4fc0b1e2c1719..934bf410a82e82c38edd29f470c9b6e720090fa7 100644 (file)
@@ -997,7 +997,7 @@ class UDPMode(enum.IntEnum):
 
 def inbound_xfr(where, txn_manager, query=None,
                 port=53, timeout=None, lifetime=None, source=None,
-                source_port=0, udp_mode=UDPMode.NEVER, serial=0):
+                source_port=0, udp_mode=UDPMode.NEVER):
     """Conduct an inbound transfer and apply it via a transaction from the
     txn_manager.
 
@@ -1036,6 +1036,8 @@ def inbound_xfr(where, txn_manager, query=None,
     """
     if query is None:
         (query, serial) = dns.xfr.make_query(txn_manager)
+    else:
+        serial = dns.xfr.extract_serial_from_query(query)
     rdtype = query.question[0].rdtype
     is_ixfr = rdtype == dns.rdatatype.IXFR
     origin = txn_manager.from_wire_origin()
index 84059a3a5144df66d4fdacc28624ff7322d03611..5efa6991dae025e84a6f66d3fbe70f7a0bc9c8a3 100644 (file)
@@ -295,3 +295,22 @@ def make_query(txn_manager, serial=0,
     if keyring is not None:
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
     return (q, serial)
+
+def extract_serial_from_query(query):
+    """Extract the SOA serial number from query if it is an IXFR and return
+    it, otherwise return None.
+
+    *query* is a dns.message.QueryMessage that is an IXFR or AXFR request.
+
+    Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have
+    an appropriate SOA RRset in the authority section."""
+
+    question = query.question[0]
+    if question.rdtype == dns.rdatatype.AXFR:
+        return None
+    elif question.rdtype != dns.rdatatype.IXFR:
+        raise ValueError("query is not an AXFR or IXFR")
+    print(question.name, question.rdclass)
+    soa = query.find_rrset(query.authority, question.name, question.rdclass,
+                           dns.rdatatype.SOA)
+    return soa[0].serial
index c1a011c66d6beda867bfae30e808ce699dd482d2..7429fa35a6e2199920b7abed860c0de81594bc89 100644 (file)
@@ -593,6 +593,21 @@ def test_make_query_bad_serial():
     with pytest.raises(ValueError):
         dns.xfr.make_query(z, serial=4294967296)
 
+def test_extract_serial_from_query():
+    z = dns.versioned.Zone('example.')
+    (q, s) = dns.xfr.make_query(z)
+    xs = dns.xfr.extract_serial_from_query(q)
+    assert s is None
+    assert s == xs
+    (q, s) = dns.xfr.make_query(z, serial=10)
+    print(q)
+    xs = dns.xfr.extract_serial_from_query(q)
+    assert s == 10
+    assert s == xs
+    q = dns.message.make_query('example', 'a')
+    with pytest.raises(ValueError):
+        dns.xfr.extract_serial_from_query(q)
+
 
 class XFRNanoNameserver(Server):