]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Allow Zones with different map types. (#1015)
authorBob Halley <halley@dnspython.org>
Wed, 29 Nov 2023 13:30:03 +0000 (05:30 -0800)
committerGitHub <noreply@github.com>
Wed, 29 Nov 2023 13:30:03 +0000 (05:30 -0800)
* Allow Zones with different map types.

* Backwards compatibility for python 3.8.

dns/immutable.py
dns/zone.py
tests/test_zone.py

index cab8d6fb5a03164734bf5af4f97ad45b81c0a9fb..36b0362c75199bc1565ec7a8a2c76bbfa34c3637 100644 (file)
@@ -1,24 +1,30 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
 import collections.abc
-from typing import Any
+from typing import Any, Callable
 
 from dns._immutable_ctx import immutable
 
 
 @immutable
 class Dict(collections.abc.Mapping):  # lgtm[py/missing-equals]
-    def __init__(self, dictionary: Any, no_copy: bool = False):
+    def __init__(
+        self,
+        dictionary: Any,
+        no_copy: bool = False,
+        map_factory: Callable[[], collections.abc.MutableMapping] = dict,
+    ):
         """Make an immutable dictionary from the specified dictionary.
 
         If *no_copy* is `True`, then *dictionary* will be wrapped instead
         of copied.  Only set this if you are sure there will be no external
         references to the dictionary.
         """
-        if no_copy and isinstance(dictionary, dict):
+        if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
             self._odict = dictionary
         else:
-            self._odict = dict(dictionary)
+            self._odict = map_factory()
+            self._odict.update(dictionary)
         self._hash = None
 
     def __getitem__(self, key):
index ef54d0a28c00bdc578830708e6c57c66e90f5b42..9fc824571a1361e5a9f7102274ab34cced50faf9 100644 (file)
@@ -21,7 +21,19 @@ import contextlib
 import io
 import os
 import struct
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    MutableMapping,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
 import dns.exception
 import dns.grange
@@ -126,7 +138,10 @@ class Zone(dns.transaction.TransactionManager):
     the zone.
     """
 
-    node_factory = dns.node.Node
+    node_factory: Callable[[], dns.node.Node] = dns.node.Node
+    map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
+    writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
+    immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
 
     __slots__ = ["rdclass", "origin", "nodes", "relativize"]
 
@@ -157,7 +172,7 @@ class Zone(dns.transaction.TransactionManager):
                 raise ValueError("origin parameter must be an absolute name")
         self.origin = origin
         self.rdclass = rdclass
-        self.nodes: Dict[dns.name.Name, dns.node.Node] = {}
+        self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory()
         self.relativize = relativize
 
     def __eq__(self, other):
@@ -965,7 +980,7 @@ class Version:
         self,
         zone: Zone,
         id: int,
-        nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None,
+        nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None,
         origin: Optional[dns.name.Name] = None,
     ):
         self.zone = zone
@@ -973,7 +988,7 @@ class Version:
         if nodes is not None:
             self.nodes = nodes
         else:
-            self.nodes = {}
+            self.nodes = zone.map_factory()
         self.origin = origin
 
     def _validate_name(self, name: dns.name.Name) -> dns.name.Name:
@@ -1083,7 +1098,7 @@ class ImmutableVersion(Version):
                 version.nodes[name] = ImmutableVersionedNode(node)
         # We're changing the type of the nodes dictionary here on purpose, so
         # we ignore the mypy error.
-        self.nodes = dns.immutable.Dict(version.nodes, True)  # type: ignore
+        self.nodes = dns.immutable.Dict(version.nodes, True, self.zone.map_factory)  # type: ignore
 
 
 class Transaction(dns.transaction.Transaction):
@@ -1099,7 +1114,10 @@ class Transaction(dns.transaction.Transaction):
 
     def _setup_version(self):
         assert self.version is None
-        self.version = WritableVersion(self.zone, self.replacement)
+        factory = self.manager.writable_version_factory
+        if factory is None:
+            factory = WritableVersion
+        self.version = factory(self.zone, self.replacement)
 
     def _get_rdataset(self, name, rdtype, covers):
         return self.version.get_rdataset(name, rdtype, covers)
@@ -1130,7 +1148,10 @@ class Transaction(dns.transaction.Transaction):
             self.zone._end_read(self)
         elif commit and len(self.version.changed) > 0:
             if self.make_immutable:
-                version = ImmutableVersion(self.version)
+                factory = self.manager.immutable_version_factory
+                if factory is None:
+                    factory = ImmutableVersion
+                version = factory(self.version)
             else:
                 version = self.version
             self.zone._commit_version(self, version, self.version.origin)
index a902967411c746bd332b090173a7d3585075cace..2e590e7087865a5ffd4c9bca6dd8597ecfa7180b 100644 (file)
@@ -477,11 +477,10 @@ class ZoneTestCase(unittest.TestCase):
     def testGenerate(self):
         z = dns.zone.from_text(example_generate, "example.", relativize=True)
         f = StringIO()
-        names = list(z.nodes.keys())
-        for n in names:
-            f.write(z[n].to_text(n))
-            f.write("\n")
-        self.assertEqual(f.getvalue(), example_generate_output)
+        expected = dns.zone.from_text(
+            example_generate_output, "example.", relativize=True
+        )
+        self.assertEqual(z, expected)
 
     def testTorture1(self):
         #