From: Bob Halley Date: Tue, 11 Aug 2020 14:38:26 +0000 (-0700) Subject: open versions by id or serial; cleanups X-Git-Tag: v2.1.0rc1~90 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ca553435093b0dfae0eefc8d634529fe61ac8721;p=thirdparty%2Fdnspython.git open versions by id or serial; cleanups --- diff --git a/dns/versioned.py b/dns/versioned.py index 6f911e1d..45ede79b 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -51,13 +51,6 @@ class Version: def items(self): return self.nodes.items() # pylint: disable=dict-items-not-iterating - def _print(self): # pragma: no cover - # XXXRTH This is for debugging - print('VERSION', self.id) - for (name, node) in self.nodes.items(): - for rdataset in node: - print(rdataset.to_text(name)) - class WritableVersion(Version): def __init__(self, zone, replacement=False): @@ -77,14 +70,6 @@ class WritableVersion(Version): self.origin = zone.origin self.changed = set() - def _validate_name(self, name): - if name.is_absolute(): - if not name.is_subdomain(self.origin): - raise KeyError("name is not a subdomain of the zone origin") - if self.zone.relativize: - name = name.relativize(self.origin) - return name - def _maybe_cow(self, name): name = self._validate_name(name) node = self.nodes.get(name) @@ -150,17 +135,34 @@ class Node(dns.node.Node): self.id = 0 -# It would be nice if this were a subclass of Node (just above) but it's -# less code duplication this way as we inherit all of the method disabling -# code. - @dns.immutable.immutable -class ImmutableNode(dns.node.ImmutableNode): +class ImmutableNode(Node): __slots__ = ['id'] def __init__(self, node): - super().__init__(node) + super().__init__() self.id = node.id + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + raise TypeError("immutable") + + def replace_rdataset(self, replacement): + raise TypeError("immutable") class Zone(dns.zone.Zone): @@ -199,9 +201,36 @@ class Zone(dns.zone.Zone): self._write_waiters = collections.deque() self._commit_version_unlocked(WritableVersion(self), origin) - def reader(self): + def reader(self, id=None, serial=None): + if id is not None and serial is not None: + raise ValueError('cannot specify both id and serial') with self.version_lock: - return Transaction(False, self, self.versions[-1]) + if id is not None: + version = None + for v in reversed(self.versions): + if v.id == id: + version = v + break + if version is None: + raise KeyError('version not found') + elif serial is not None: + if self.relativize: + oname = dns.name.empty + else: + oname = self.origin + version = None + for v in reversed(self.versions): + n = v.nodes.get(oname) + if n: + rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) + if rds and rds[0].serial == serial: + version = v + break + if version is None: + raise KeyError('serial not found') + else: + version = self.versions[-1] + return Transaction(False, self, version) def writer(self, replacement=False): event = None diff --git a/tests/test_transaction.py b/tests/test_transaction.py index ed154fc5..64705ed4 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -398,19 +398,27 @@ def test_vzone_multiple_versions(vzone): with vzone.writer() as txn: txn.set_serial() with vzone.writer() as txn: - txn.set_serial() + txn.set_serial(increment=0, value=1000) rdataset = vzone.find_rdataset('@', 'soa') - assert rdataset[0].serial == 4 + assert rdataset[0].serial == 1000 assert len(vzone.versions) == 4 + with vzone.reader(id=5) as txn: + assert txn.version.id == 5 + rdataset = txn.get('@', 'in', 'soa') + assert rdataset[0].serial == 1000 + with vzone.reader(serial=1000) as txn: + assert txn.version.id == 5 + rdataset = txn.get('@', 'in', 'soa') + assert rdataset[0].serial == 1000 vzone.set_max_versions(2) assert len(vzone.versions) == 2 - # The ones that survived should be 3 and 4 + # The ones that survived should be 3 and 1000 rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA, dns.rdatatype.NONE) assert rdataset[0].serial == 3 rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA, dns.rdatatype.NONE) - assert rdataset[0].serial == 4 + assert rdataset[0].serial == 1000 with pytest.raises(ValueError): vzone.set_max_versions(0)