]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
open versions by id or serial; cleanups
authorBob Halley <halley@dnspython.org>
Tue, 11 Aug 2020 14:38:26 +0000 (07:38 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 11 Aug 2020 14:38:26 +0000 (07:38 -0700)
dns/versioned.py
tests/test_transaction.py

index 6f911e1d6c4dc753cac72cbbd1ec075c0529dfe0..45ede79b3fd829e72a27ecb2aa8976b7ad3d93c0 100644 (file)
@@ -51,13 +51,6 @@ class Version:
     def items(self):
         return self.nodes.items()  # pylint: disable=dict-items-not-iterating
 
-    def _print(self):  # pragma: no cover
-        # XXXRTH  This is for debugging
-        print('VERSION', self.id)
-        for (name, node) in self.nodes.items():
-            for rdataset in node:
-                print(rdataset.to_text(name))
-
 
 class WritableVersion(Version):
     def __init__(self, zone, replacement=False):
@@ -77,14 +70,6 @@ class WritableVersion(Version):
         self.origin = zone.origin
         self.changed = set()
 
-    def _validate_name(self, name):
-        if name.is_absolute():
-            if not name.is_subdomain(self.origin):
-                raise KeyError("name is not a subdomain of the zone origin")
-            if self.zone.relativize:
-                name = name.relativize(self.origin)
-        return name
-
     def _maybe_cow(self, name):
         name = self._validate_name(name)
         node = self.nodes.get(name)
@@ -150,17 +135,34 @@ class Node(dns.node.Node):
         self.id = 0
 
 
-# It would be nice if this were a subclass of Node (just above) but it's
-# less code duplication this way as we inherit all of the method disabling
-# code.
-
 @dns.immutable.immutable
-class ImmutableNode(dns.node.ImmutableNode):
+class ImmutableNode(Node):
     __slots__ = ['id']
 
     def __init__(self, node):
-        super().__init__(node)
+        super().__init__()
         self.id = node.id
+        self.rdatasets = tuple(
+            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+        )
+
+    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+                      create=False):
+        if create:
+            raise TypeError("immutable")
+        return super().find_rdataset(rdclass, rdtype, covers, False)
+
+    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+                     create=False):
+        if create:
+            raise TypeError("immutable")
+        return super().get_rdataset(rdclass, rdtype, covers, False)
+
+    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+        raise TypeError("immutable")
+
+    def replace_rdataset(self, replacement):
+        raise TypeError("immutable")
 
 
 class Zone(dns.zone.Zone):
@@ -199,9 +201,36 @@ class Zone(dns.zone.Zone):
         self._write_waiters = collections.deque()
         self._commit_version_unlocked(WritableVersion(self), origin)
 
-    def reader(self):
+    def reader(self, id=None, serial=None):
+        if id is not None and serial is not None:
+            raise ValueError('cannot specify both id and serial')
         with self.version_lock:
-            return Transaction(False, self, self.versions[-1])
+            if id is not None:
+                version = None
+                for v in reversed(self.versions):
+                    if v.id == id:
+                        version = v
+                        break
+                if version is None:
+                    raise KeyError('version not found')
+            elif serial is not None:
+                if self.relativize:
+                    oname = dns.name.empty
+                else:
+                    oname = self.origin
+                version = None
+                for v in reversed(self.versions):
+                    n = v.nodes.get(oname)
+                    if n:
+                        rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
+                        if rds and rds[0].serial == serial:
+                            version = v
+                            break
+                if version is None:
+                    raise KeyError('serial not found')
+            else:
+                version = self.versions[-1]
+            return Transaction(False, self, version)
 
     def writer(self, replacement=False):
         event = None
index ed154fc5db228b7da1c060d50ac86eb0e01dbc2a..64705ed4352ea02c12f512a3b13c1ddba1c4acd1 100644 (file)
@@ -398,19 +398,27 @@ def test_vzone_multiple_versions(vzone):
     with vzone.writer() as txn:
         txn.set_serial()
     with vzone.writer() as txn:
-        txn.set_serial()
+        txn.set_serial(increment=0, value=1000)
     rdataset = vzone.find_rdataset('@', 'soa')
-    assert rdataset[0].serial == 4
+    assert rdataset[0].serial == 1000
     assert len(vzone.versions) == 4
+    with vzone.reader(id=5) as txn:
+        assert txn.version.id == 5
+        rdataset = txn.get('@', 'in', 'soa')
+        assert rdataset[0].serial == 1000
+    with vzone.reader(serial=1000) as txn:
+        assert txn.version.id == 5
+        rdataset = txn.get('@', 'in', 'soa')
+        assert rdataset[0].serial == 1000
     vzone.set_max_versions(2)
     assert len(vzone.versions) == 2
-    # The ones that survived should be 3 and 4
+    # The ones that survived should be 3 and 1000
     rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
                                               dns.rdatatype.NONE)
     assert rdataset[0].serial == 3
     rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
                                               dns.rdatatype.NONE)
-    assert rdataset[0].serial == 4
+    assert rdataset[0].serial == 1000
     with pytest.raises(ValueError):
         vzone.set_max_versions(0)