From c1fe3c6fac0f5fd57a15dc3bf588cf6cb52fa82c Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Wed, 29 Nov 2023 05:30:03 -0800 Subject: [PATCH] Allow Zones with different map types. (#1015) * Allow Zones with different map types. * Backwards compatibility for python 3.8. --- dns/immutable.py | 14 ++++++++++---- dns/zone.py | 37 +++++++++++++++++++++++++++++-------- tests/test_zone.py | 9 ++++----- 3 files changed, 43 insertions(+), 17 deletions(-) diff --git a/dns/immutable.py b/dns/immutable.py index cab8d6fb..36b0362c 100644 --- a/dns/immutable.py +++ b/dns/immutable.py @@ -1,24 +1,30 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license import collections.abc -from typing import Any +from typing import Any, Callable from dns._immutable_ctx import immutable @immutable class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] - def __init__(self, dictionary: Any, no_copy: bool = False): + def __init__( + self, + dictionary: Any, + no_copy: bool = False, + map_factory: Callable[[], collections.abc.MutableMapping] = dict, + ): """Make an immutable dictionary from the specified dictionary. If *no_copy* is `True`, then *dictionary* will be wrapped instead of copied. Only set this if you are sure there will be no external references to the dictionary. """ - if no_copy and isinstance(dictionary, dict): + if no_copy and isinstance(dictionary, collections.abc.MutableMapping): self._odict = dictionary else: - self._odict = dict(dictionary) + self._odict = map_factory() + self._odict.update(dictionary) self._hash = None def __getitem__(self, key): diff --git a/dns/zone.py b/dns/zone.py index ef54d0a2..9fc82457 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -21,7 +21,19 @@ import contextlib import io import os import struct -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import dns.exception import dns.grange @@ -126,7 +138,10 @@ class Zone(dns.transaction.TransactionManager): the zone. """ - node_factory = dns.node.Node + node_factory: Callable[[], dns.node.Node] = dns.node.Node + map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict + writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None + immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None __slots__ = ["rdclass", "origin", "nodes", "relativize"] @@ -157,7 +172,7 @@ class Zone(dns.transaction.TransactionManager): raise ValueError("origin parameter must be an absolute name") self.origin = origin self.rdclass = rdclass - self.nodes: Dict[dns.name.Name, dns.node.Node] = {} + self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory() self.relativize = relativize def __eq__(self, other): @@ -965,7 +980,7 @@ class Version: self, zone: Zone, id: int, - nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None, + nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None, origin: Optional[dns.name.Name] = None, ): self.zone = zone @@ -973,7 +988,7 @@ class Version: if nodes is not None: self.nodes = nodes else: - self.nodes = {} + self.nodes = zone.map_factory() self.origin = origin def _validate_name(self, name: dns.name.Name) -> dns.name.Name: @@ -1083,7 +1098,7 @@ class ImmutableVersion(Version): version.nodes[name] = ImmutableVersionedNode(node) # We're changing the type of the nodes dictionary here on purpose, so # we ignore the mypy error. - self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore + self.nodes = dns.immutable.Dict(version.nodes, True, self.zone.map_factory) # type: ignore class Transaction(dns.transaction.Transaction): @@ -1099,7 +1114,10 @@ class Transaction(dns.transaction.Transaction): def _setup_version(self): assert self.version is None - self.version = WritableVersion(self.zone, self.replacement) + factory = self.manager.writable_version_factory + if factory is None: + factory = WritableVersion + self.version = factory(self.zone, self.replacement) def _get_rdataset(self, name, rdtype, covers): return self.version.get_rdataset(name, rdtype, covers) @@ -1130,7 +1148,10 @@ class Transaction(dns.transaction.Transaction): self.zone._end_read(self) elif commit and len(self.version.changed) > 0: if self.make_immutable: - version = ImmutableVersion(self.version) + factory = self.manager.immutable_version_factory + if factory is None: + factory = ImmutableVersion + version = factory(self.version) else: version = self.version self.zone._commit_version(self, version, self.version.origin) diff --git a/tests/test_zone.py b/tests/test_zone.py index a9029674..2e590e70 100644 --- a/tests/test_zone.py +++ b/tests/test_zone.py @@ -477,11 +477,10 @@ class ZoneTestCase(unittest.TestCase): def testGenerate(self): z = dns.zone.from_text(example_generate, "example.", relativize=True) f = StringIO() - names = list(z.nodes.keys()) - for n in names: - f.write(z[n].to_text(n)) - f.write("\n") - self.assertEqual(f.getvalue(), example_generate_output) + expected = dns.zone.from_text( + example_generate_output, "example.", relativize=True + ) + self.assertEqual(z, expected) def testTorture1(self): # -- 2.47.3