]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
set_serial() -> update_serial()
authorBob Halley <halley@dnspython.org>
Thu, 13 Aug 2020 13:31:34 +0000 (06:31 -0700)
committerBob Halley <halley@dnspython.org>
Thu, 13 Aug 2020 13:31:34 +0000 (06:31 -0700)
dns/transaction.py
tests/test_transaction.py

index 20d6939607e9ee6b6bcd6542edcb7eadb5b9bda6..630f6303fd559f9fd46a102747917b20d5f51004 100644 (file)
@@ -8,6 +8,7 @@ import dns.rdataclass
 import dns.rdataset
 import dns.rdatatype
 import dns.rrset
+import dns.serial
 import dns.ttl
 
 
@@ -148,20 +149,33 @@ class Transaction:
             name = dns.name.from_text(name, None)
         return self._name_exists(name)
 
-    def set_serial(self, increment=1, value=None, name=dns.name.empty,
-                   rdclass=dns.rdataclass.IN):
+    def update_serial(self, value=1, relative=True, name=dns.name.empty,
+                      rdclass=dns.rdataclass.IN):
+        """Update the serial number.
+
+        *value*, an `int`, is an increment if *relative* is `True`, or the
+        actual value to set if *relative* is `False`.
+
+        Raises `KeyError` if there is no SOA rdataset at *name*.
+
+        Raises `ValueError` if *value* is negative or if the increment is
+        so large that it would cause the new serial to be less than the
+        prior value.
+        """
+        if value < 0:
+            raise ValueError('negative update_serial() value')
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
         rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA,
                                       dns.rdatatype.NONE)
         if rdataset is None or len(rdataset) == 0:
             raise KeyError
-        if value is not None:
-            serial = value
+        if relative:
+            serial = dns.serial.Serial(rdataset[0].serial) + value
         else:
-            serial = rdataset[0].serial
-        serial += increment
-        if serial > 0xffffffff or serial < 1:
+            serial = dns.serial.Serial(value)
+        serial = serial.value  # convert back to int
+        if serial == 0:
             serial = 1
         rdata = rdataset[0].replace(serial=serial)
         new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata)
index d782f21e948c9dba2ca16d21adab609e2fd10a72..93705f2eadf5dfe6f1ba05914a0f2798ce38ba22 100644 (file)
@@ -334,30 +334,36 @@ def test_zone_bad_class(zone):
             txn.delete(dns.name.empty, dns.rdataclass.CH,
                        dns.rdatatype.NS, dns.rdatatype.NONE)
 
-def test_set_serial(zone):
+def test_update_serial(zone):
     # basic
     with zone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     rdataset = zone.find_rdataset('@', 'soa')
     assert rdataset[0].serial == 2
     # max
     with zone.writer() as txn:
-        txn.set_serial(0, 0xffffffff)
+        txn.update_serial(0xffffffff, False)
     rdataset = zone.find_rdataset('@', 'soa')
     assert rdataset[0].serial == 0xffffffff
     # wraparound to 1
     with zone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     rdataset = zone.find_rdataset('@', 'soa')
     assert rdataset[0].serial == 1
     # trying to set to zero sets to 1
     with zone.writer() as txn:
-        txn.set_serial(0, 0)
+        txn.update_serial(0, False)
     rdataset = zone.find_rdataset('@', 'soa')
     assert rdataset[0].serial == 1
     with pytest.raises(KeyError):
         with zone.writer() as txn:
-            txn.set_serial(name=dns.name.from_text('unknown', None))
+            txn.update_serial(name=dns.name.from_text('unknown', None))
+    with pytest.raises(ValueError):
+        with zone.writer() as txn:
+            txn.update_serial(-1)
+    with pytest.raises(ValueError):
+        with zone.writer() as txn:
+            txn.update_serial(2**31)
 
 class ExpectedException(Exception):
     pass
@@ -407,11 +413,11 @@ def test_vzone_multiple_versions(vzone):
     assert len(vzone._versions) == 1
     vzone.set_max_versions(None)  # unlimited!
     with vzone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     with vzone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     with vzone.writer() as txn:
-        txn.set_serial(increment=0, value=1000)
+        txn.update_serial(1000, False)
     rdataset = vzone.find_rdataset('@', 'soa')
     assert rdataset[0].serial == 1000
     assert len(vzone._versions) == 4
@@ -449,11 +455,11 @@ def test_vzone_open_txn_pins_versions(vzone):
     assert len(vzone._versions) == 1
     vzone.set_max_versions(None)  # unlimited!
     with vzone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     with vzone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     with vzone.writer() as txn:
-        txn.set_serial()
+        txn.update_serial()
     with vzone.reader(id=2) as txn:
         vzone.set_max_versions(1)
         with vzone.reader(id=3) as txn: