From: Bob Halley Date: Thu, 13 Aug 2020 13:31:34 +0000 (-0700) Subject: set_serial() -> update_serial() X-Git-Tag: v2.1.0rc1~84 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=060ad0db51fbb26a6f809548d884ef468ad06beb;p=thirdparty%2Fdnspython.git set_serial() -> update_serial() --- diff --git a/dns/transaction.py b/dns/transaction.py index 20d69396..630f6303 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -8,6 +8,7 @@ import dns.rdataclass import dns.rdataset import dns.rdatatype import dns.rrset +import dns.serial import dns.ttl @@ -148,20 +149,33 @@ class Transaction: name = dns.name.from_text(name, None) return self._name_exists(name) - def set_serial(self, increment=1, value=None, name=dns.name.empty, - rdclass=dns.rdataclass.IN): + def update_serial(self, value=1, relative=True, name=dns.name.empty, + rdclass=dns.rdataclass.IN): + """Update the serial number. + + *value*, an `int`, is an increment if *relative* is `True`, or the + actual value to set if *relative* is `False`. + + Raises `KeyError` if there is no SOA rdataset at *name*. + + Raises `ValueError` if *value* is negative or if the increment is + so large that it would cause the new serial to be less than the + prior value. + """ + if value < 0: + raise ValueError('negative update_serial() value') if isinstance(name, str): name = dns.name.from_text(name, None) rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA, dns.rdatatype.NONE) if rdataset is None or len(rdataset) == 0: raise KeyError - if value is not None: - serial = value + if relative: + serial = dns.serial.Serial(rdataset[0].serial) + value else: - serial = rdataset[0].serial - serial += increment - if serial > 0xffffffff or serial < 1: + serial = dns.serial.Serial(value) + serial = serial.value # convert back to int + if serial == 0: serial = 1 rdata = rdataset[0].replace(serial=serial) new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index d782f21e..93705f2e 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -334,30 +334,36 @@ def test_zone_bad_class(zone): txn.delete(dns.name.empty, dns.rdataclass.CH, dns.rdatatype.NS, dns.rdatatype.NONE) -def test_set_serial(zone): +def test_update_serial(zone): # basic with zone.writer() as txn: - txn.set_serial() + txn.update_serial() rdataset = zone.find_rdataset('@', 'soa') assert rdataset[0].serial == 2 # max with zone.writer() as txn: - txn.set_serial(0, 0xffffffff) + txn.update_serial(0xffffffff, False) rdataset = zone.find_rdataset('@', 'soa') assert rdataset[0].serial == 0xffffffff # wraparound to 1 with zone.writer() as txn: - txn.set_serial() + txn.update_serial() rdataset = zone.find_rdataset('@', 'soa') assert rdataset[0].serial == 1 # trying to set to zero sets to 1 with zone.writer() as txn: - txn.set_serial(0, 0) + txn.update_serial(0, False) rdataset = zone.find_rdataset('@', 'soa') assert rdataset[0].serial == 1 with pytest.raises(KeyError): with zone.writer() as txn: - txn.set_serial(name=dns.name.from_text('unknown', None)) + txn.update_serial(name=dns.name.from_text('unknown', None)) + with pytest.raises(ValueError): + with zone.writer() as txn: + txn.update_serial(-1) + with pytest.raises(ValueError): + with zone.writer() as txn: + txn.update_serial(2**31) class ExpectedException(Exception): pass @@ -407,11 +413,11 @@ def test_vzone_multiple_versions(vzone): assert len(vzone._versions) == 1 vzone.set_max_versions(None) # unlimited! with vzone.writer() as txn: - txn.set_serial() + txn.update_serial() with vzone.writer() as txn: - txn.set_serial() + txn.update_serial() with vzone.writer() as txn: - txn.set_serial(increment=0, value=1000) + txn.update_serial(1000, False) rdataset = vzone.find_rdataset('@', 'soa') assert rdataset[0].serial == 1000 assert len(vzone._versions) == 4 @@ -449,11 +455,11 @@ def test_vzone_open_txn_pins_versions(vzone): assert len(vzone._versions) == 1 vzone.set_max_versions(None) # unlimited! with vzone.writer() as txn: - txn.set_serial() + txn.update_serial() with vzone.writer() as txn: - txn.set_serial() + txn.update_serial() with vzone.writer() as txn: - txn.set_serial() + txn.update_serial() with vzone.reader(id=2) as txn: vzone.set_max_versions(1) with vzone.reader(id=3) as txn: