]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Refactor zone transactions to always use versioned CoW code. 734/head
authorBob Halley <halley@dnspython.org>
Wed, 1 Dec 2021 14:48:58 +0000 (06:48 -0800)
committerBob Halley <halley@dnspython.org>
Wed, 1 Dec 2021 14:48:58 +0000 (06:48 -0800)
dns/versioned.py
dns/zone.py

index 686a83b0221583ea10a32fcc6b483f8c88a1de03..42f2c8140bd10181ca0067aaabb3cc0cb13a0458 100644 (file)
@@ -11,12 +11,9 @@ except ImportError:  # pragma: no cover
 import dns.exception
 import dns.immutable
 import dns.name
-import dns.node
 import dns.rdataclass
 import dns.rdatatype
-import dns.rdata
 import dns.rdtypes.ANY.SOA
-import dns.transaction
 import dns.zone
 
 
@@ -24,142 +21,13 @@ class UseTransaction(dns.exception.DNSException):
     """To alter a versioned zone, use a transaction."""
 
 
-class Version:
-    def __init__(self, zone, id):
-        self.zone = zone
-        self.id = id
-        self.nodes = {}
-
-    def _validate_name(self, name):
-        if name.is_absolute():
-            if not name.is_subdomain(self.zone.origin):
-                raise KeyError("name is not a subdomain of the zone origin")
-            if self.zone.relativize:
-                name = name.relativize(self.origin)
-        return name
-
-    def get_node(self, name):
-        name = self._validate_name(name)
-        return self.nodes.get(name)
-
-    def get_rdataset(self, name, rdtype, covers):
-        node = self.get_node(name)
-        if node is None:
-            return None
-        return node.get_rdataset(self.zone.rdclass, rdtype, covers)
-
-    def items(self):
-        return self.nodes.items()  # pylint: disable=dict-items-not-iterating
-
-
-class WritableVersion(Version):
-    def __init__(self, zone, replacement=False):
-        # The zone._versions_lock must be held by our caller.
-        if len(zone._versions) > 0:
-            id = zone._versions[-1].id + 1
-        else:
-            id = 1
-        super().__init__(zone, id)
-        if not replacement:
-            # We copy the map, because that gives us a simple and thread-safe
-            # way of doing versions, and we have a garbage collector to help
-            # us.  We only make new node objects if we actually change the
-            # node.
-            self.nodes.update(zone.nodes)
-        # We have to copy the zone origin as it may be None in the first
-        # version, and we don't want to mutate the zone until we commit.
-        self.origin = zone.origin
-        self.changed = set()
-
-    def _maybe_cow(self, name):
-        name = self._validate_name(name)
-        node = self.nodes.get(name)
-        if node is None or node.id != self.id:
-            new_node = self.zone.node_factory()
-            new_node.id = self.id
-            if node is not None:
-                # moo!  copy on write!
-                new_node.rdatasets.extend(node.rdatasets)
-            self.nodes[name] = new_node
-            self.changed.add(name)
-            return new_node
-        else:
-            return node
-
-    def delete_node(self, name):
-        name = self._validate_name(name)
-        if name in self.nodes:
-            del self.nodes[name]
-            self.changed.add(name)
-
-    def put_rdataset(self, name, rdataset):
-        node = self._maybe_cow(name)
-        node.replace_rdataset(rdataset)
-
-    def delete_rdataset(self, name, rdtype, covers):
-        node = self._maybe_cow(name)
-        node.delete_rdataset(self.zone.rdclass, rdtype, covers)
-        if len(node) == 0:
-            del self.nodes[name]
-
-
-@dns.immutable.immutable
-class ImmutableVersion(Version):
-    def __init__(self, version):
-        # We tell super() that it's a replacement as we don't want it
-        # to copy the nodes, as we're about to do that with an
-        # immutable Dict.
-        super().__init__(version.zone, True)
-        # set the right id!
-        self.id = version.id
-        # Make changed nodes immutable
-        for name in version.changed:
-            node = version.nodes.get(name)
-            # it might not exist if we deleted it in the version
-            if node:
-                version.nodes[name] = ImmutableNode(node)
-        self.nodes = dns.immutable.Dict(version.nodes, True)
-
-
-# A node with a version id.
-
-class Node(dns.node.Node):
-    __slots__ = ['id']
-
-    def __init__(self):
-        super().__init__()
-        # A proper id will get set by the Version
-        self.id = 0
-
-
-@dns.immutable.immutable
-class ImmutableNode(Node):
-    __slots__ = ['id']
-
-    def __init__(self, node):
-        super().__init__()
-        self.id = node.id
-        self.rdatasets = tuple(
-            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
-        )
-
-    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                      create=False):
-        if create:
-            raise TypeError("immutable")
-        return super().find_rdataset(rdclass, rdtype, covers, False)
-
-    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                     create=False):
-        if create:
-            raise TypeError("immutable")
-        return super().get_rdataset(rdclass, rdtype, covers, False)
-
-    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
-        raise TypeError("immutable")
-
-    def replace_rdataset(self, replacement):
-        raise TypeError("immutable")
+# Backwards compatibility
+Node = dns.zone.VersionedNode
+ImmutableNode = dns.zone.ImmutableVersionedNode
+Version = dns.zone.Version
+WritableVersion = dns.zone.WritableVersion
+ImmutableVersion = dns.zone.ImmutableVersion
+Transaction = dns.zone.Transaction
 
 
 class Zone(dns.zone.Zone):
@@ -198,7 +66,9 @@ class Zone(dns.zone.Zone):
         self._write_event = None
         self._write_waiters = collections.deque()
         self._readers = set()
-        self._commit_version_unlocked(None, WritableVersion(self), origin)
+        self._commit_version_unlocked(None,
+                                      WritableVersion(self, replacement=True),
+                                      origin)
 
     def reader(self, id=None, serial=None):  # pylint: disable=arguments-differ
         if id is not None and serial is not None:
@@ -247,7 +117,8 @@ class Zone(dns.zone.Zone):
                     # give up the lock, so that we hold the lock as
                     # short a time as possible.  This is why we call
                     # _setup_version() below.
-                    self._write_txn = Transaction(self, replacement)
+                    self._write_txn = Transaction(self, replacement,
+                                                  make_immutable=True)
                     # give up our exclusive right to make a Transaction
                     self._write_event = None
                     break
@@ -367,6 +238,13 @@ class Zone(dns.zone.Zone):
         with self._version_lock:
             self._commit_version_unlocked(txn, version, origin)
 
+    def _get_next_version_id(self):
+        if len(self._versions) > 0:
+            id = self._versions[-1].id + 1
+        else:
+            id = 1
+        return id
+
     def find_node(self, name, create=False):
         if create:
             raise UseTransaction
@@ -394,62 +272,3 @@ class Zone(dns.zone.Zone):
 
     def replace_rdataset(self, name, replacement):
         raise UseTransaction
-
-
-class Transaction(dns.transaction.Transaction):
-
-    def __init__(self, zone, replacement, version=None):
-        read_only = version is not None
-        super().__init__(zone, replacement, read_only)
-        self.version = version
-
-    @property
-    def zone(self):
-        return self.manager
-
-    def _setup_version(self):
-        assert self.version is None
-        self.version = WritableVersion(self.zone, self.replacement)
-
-    def _get_rdataset(self, name, rdtype, covers):
-        return self.version.get_rdataset(name, rdtype, covers)
-
-    def _put_rdataset(self, name, rdataset):
-        assert not self.read_only
-        self.version.put_rdataset(name, rdataset)
-
-    def _delete_name(self, name):
-        assert not self.read_only
-        self.version.delete_node(name)
-
-    def _delete_rdataset(self, name, rdtype, covers):
-        assert not self.read_only
-        self.version.delete_rdataset(name, rdtype, covers)
-
-    def _name_exists(self, name):
-        return self.version.get_node(name) is not None
-
-    def _changed(self):
-        if self.read_only:
-            return False
-        else:
-            return len(self.version.changed) > 0
-
-    def _end_transaction(self, commit):
-        if self.read_only:
-            self.zone._end_read(self)
-        elif commit and len(self.version.changed) > 0:
-            self.zone._commit_version(self, ImmutableVersion(self.version),
-                                      self.version.origin)
-        else:
-            # rollback
-            self.zone._end_write(self)
-
-    def _set_origin(self, origin):
-        if self.version.origin is None:
-            self.version.origin = origin
-
-    def _iterate_rdatasets(self):
-        for (name, node) in self.version.items():
-            for rdataset in node:
-                yield (name, rdataset)
index 2f99b1b7d508525fbcfc3ded68750363e9f4b946..510be2dfcc40d38e3219e29d2be70b84b68ea314 100644 (file)
@@ -24,6 +24,7 @@ import os
 import struct
 
 import dns.exception
+import dns.immutable
 import dns.name
 import dns.node
 import dns.rdataclass
@@ -772,10 +773,13 @@ class Zone(dns.transaction.TransactionManager):
     # TransactionManager methods
 
     def reader(self):
-        return Transaction(self, False, True)
+        return Transaction(self, False,
+                           Version(self, 1, self.nodes, self.origin))
 
     def writer(self, replacement=False):
-        return Transaction(self, replacement, False)
+        txn = Transaction(self, replacement)
+        txn._setup_version()
+        return txn
 
     def origin_information(self):
         if self.relativize:
@@ -787,107 +791,238 @@ class Zone(dns.transaction.TransactionManager):
     def get_class(self):
         return self.rdclass
 
+    # Transaction methods
 
-class Transaction(dns.transaction.Transaction):
+    def _end_read(self, txn):
+        pass
+
+    def _end_write(self, txn):
+        pass
+
+    def _commit_version(self, txn, version, origin):
+        self.nodes = version.nodes
+        if self.origin is None:
+            self.origin = origin
+
+    def _get_next_version_id(self):
+        # Versions are ephemeral and all have id 1
+        return 1
+
+
+# These classes used to be in dns.versioned, but have moved here so we can use
+# the copy-on-write transaction mechanism for both kinds of zones.  In a
+# regular zone, the version only exists during the transaction, and the nodes
+# are regular dns.node.Nodes.
+
+# A node with a version id.
+
+class VersionedNode(dns.node.Node):
+    __slots__ = ['id']
+
+    def __init__(self):
+        super().__init__()
+        # A proper id will get set by the Version
+        self.id = 0
+
+
+@dns.immutable.immutable
+class ImmutableVersionedNode(VersionedNode):
+    __slots__ = ['id']
+
+    def __init__(self, node):
+        super().__init__()
+        self.id = node.id
+        self.rdatasets = tuple(
+            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+        )
+
+    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+                      create=False):
+        if create:
+            raise TypeError("immutable")
+        return super().find_rdataset(rdclass, rdtype, covers, False)
+
+    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+                     create=False):
+        if create:
+            raise TypeError("immutable")
+        return super().get_rdataset(rdclass, rdtype, covers, False)
 
-    _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY,
-                                              dns.rdatatype.ANY)
+    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+        raise TypeError("immutable")
+
+    def replace_rdataset(self, replacement):
+        raise TypeError("immutable")
+
+
+class Version:
+    def __init__(self, zone, id, nodes=None, origin=None):
+        self.zone = zone
+        self.id = id
+        if nodes is not None:
+            self.nodes = nodes
+        else:
+            self.nodes = {}
+        self.origin = origin
+
+    def _validate_name(self, name):
+        if name.is_absolute():
+            if not name.is_subdomain(self.zone.origin):
+                raise KeyError("name is not a subdomain of the zone origin")
+            if self.zone.relativize:
+                # XXXRTH should it be an error if self.origin is still None?
+                name = name.relativize(self.origin)
+        return name
+
+    def get_node(self, name):
+        name = self._validate_name(name)
+        return self.nodes.get(name)
+
+    def get_rdataset(self, name, rdtype, covers):
+        node = self.get_node(name)
+        if node is None:
+            return None
+        return node.get_rdataset(self.zone.rdclass, rdtype, covers)
 
-    def __init__(self, zone, replacement, read_only):
+    def items(self):
+        return self.nodes.items()  # pylint: disable=dict-items-not-iterating
+
+
+class WritableVersion(Version):
+    def __init__(self, zone, replacement=False):
+        # The zone._versions_lock must be held by our caller in a versioned
+        # zone.
+        id = zone._get_next_version_id()
+        super().__init__(zone, id)
+        if not replacement:
+            # We copy the map, because that gives us a simple and thread-safe
+            # way of doing versions, and we have a garbage collector to help
+            # us.  We only make new node objects if we actually change the
+            # node.
+            self.nodes.update(zone.nodes)
+        # We have to copy the zone origin as it may be None in the first
+        # version, and we don't want to mutate the zone until we commit.
+        self.origin = zone.origin
+        self.changed = set()
+
+    def _maybe_cow(self, name):
+        name = self._validate_name(name)
+        node = self.nodes.get(name)
+        if node is None or name not in self.changed:
+            new_node = self.zone.node_factory()
+            if hasattr(new_node, 'id'):
+                # We keep doing this for backwards compatibility, as earlier
+                # code used new_node.id != self.id for the "do we need to CoW?"
+                # test.  Now we use the changed set as this works with both
+                # regular zones and versioned zones.
+                new_node.id = self.id
+            if node is not None:
+                # moo!  copy on write!
+                new_node.rdatasets.extend(node.rdatasets)
+            self.nodes[name] = new_node
+            self.changed.add(name)
+            return new_node
+        else:
+            return node
+
+    def delete_node(self, name):
+        name = self._validate_name(name)
+        if name in self.nodes:
+            del self.nodes[name]
+            self.changed.add(name)
+
+    def put_rdataset(self, name, rdataset):
+        node = self._maybe_cow(name)
+        node.replace_rdataset(rdataset)
+
+    def delete_rdataset(self, name, rdtype, covers):
+        node = self._maybe_cow(name)
+        node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+        if len(node) == 0:
+            del self.nodes[name]
+
+
+@dns.immutable.immutable
+class ImmutableVersion(Version):
+    def __init__(self, version):
+        # We tell super() that it's a replacement as we don't want it
+        # to copy the nodes, as we're about to do that with an
+        # immutable Dict.
+        super().__init__(version.zone, True)
+        # set the right id!
+        self.id = version.id
+        # keep the origin
+        self.origin = version.origin
+        # Make changed nodes immutable
+        for name in version.changed:
+            node = version.nodes.get(name)
+            # it might not exist if we deleted it in the version
+            if node:
+                version.nodes[name] = ImmutableVersionedNode(node)
+        self.nodes = dns.immutable.Dict(version.nodes, True)
+
+
+class Transaction(dns.transaction.Transaction):
+
+    def __init__(self, zone, replacement, version=None, make_immutable=False):
+        read_only = version is not None
         super().__init__(zone, replacement, read_only)
-        self.rdatasets = {}
+        self.version = version
+        self.make_immutable = make_immutable
 
     @property
     def zone(self):
         return self.manager
 
+    def _setup_version(self):
+        assert self.version is None
+        self.version = WritableVersion(self.zone, self.replacement)
+
     def _get_rdataset(self, name, rdtype, covers):
-        rdataset = self.rdatasets.get((name, rdtype, covers))
-        if rdataset is self._deleted_rdataset:
-            return None
-        elif rdataset is None and not self.replacement:
-            rdataset = self.zone.get_rdataset(name, rdtype, covers)
-        return rdataset
+        return self.version.get_rdataset(name, rdtype, covers)
 
     def _put_rdataset(self, name, rdataset):
         assert not self.read_only
-        self.zone._validate_name(name)
-        self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+        self.version.put_rdataset(name, rdataset)
 
     def _delete_name(self, name):
         assert not self.read_only
-        # First remove any changes involving the name
-        remove = []
-        for key in self.rdatasets:
-            if key[0] == name:
-                remove.append(key)
-        if len(remove) > 0:
-            for key in remove:
-                del self.rdatasets[key]
-        # Next add deletion records for any rdatasets matching the
-        # name in the zone
-        node = self.zone.get_node(name)
-        if node is not None:
-            for rdataset in node.rdatasets:
-                self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
-                    self._deleted_rdataset
+        self.version.delete_node(name)
 
     def _delete_rdataset(self, name, rdtype, covers):
         assert not self.read_only
-        try:
-            del self.rdatasets[(name, rdtype, covers)]
-        except KeyError:
-            pass
-        rdataset = self.zone.get_rdataset(name, rdtype, covers)
-        if rdataset is not None:
-            self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
-                self._deleted_rdataset
+        self.version.delete_rdataset(name, rdtype, covers)
 
     def _name_exists(self, name):
-        for key, rdataset in self.rdatasets.items():
-            if key[0] == name:
-                if rdataset != self._deleted_rdataset:
-                    return True
-                else:
-                    return None
-        self.zone._validate_name(name)
-        if self.zone.get_node(name):
-            return True
-        return False
+        return self.version.get_node(name) is not None
 
     def _changed(self):
         if self.read_only:
             return False
         else:
-            return len(self.rdatasets) > 0
+            return len(self.version.changed) > 0
 
     def _end_transaction(self, commit):
-        if commit and self._changed():
-            if self.replacement:
-                self.zone.nodes = {}
-            for (name, rdtype, covers), rdataset in \
-                self.rdatasets.items():
-                if rdataset is self._deleted_rdataset:
-                    self.zone.delete_rdataset(name, rdtype, covers)
-                else:
-                    self.zone.replace_rdataset(name, rdataset)
+        if self.read_only:
+            self.zone._end_read(self)
+        elif commit and len(self.version.changed) > 0:
+            if self.make_immutable:
+                version = ImmutableVersion(self.version)
+            else:
+                version = self.version
+            self.zone._commit_version(self, version, self.version.origin)
+        else:
+            # rollback
+            self.zone._end_write(self)
 
     def _set_origin(self, origin):
-        if self.zone.origin is None:
-            self.zone.origin = origin
+        if self.version.origin is None:
+            self.version.origin = origin
 
     def _iterate_rdatasets(self):
-        # Expensive but simple!  Use a versioned zone for efficient txn
-        # iteration.
-        if self.replacement:
-            rdatasets = self.rdatasets
-        else:
-            rdatasets = {}
-            for (name, rdataset) in self.zone.iterate_rdatasets():
-                rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
-            rdatasets.update(self.rdatasets)
-        for (name, _, _), rdataset in rdatasets.items():
-            yield (name, rdataset)
+        for (name, node) in self.version.items():
+            for rdataset in node:
+                yield (name, rdataset)
 
 
 def from_text(text, origin=None, rdclass=dns.rdataclass.IN,