# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import collections.abc
-from typing import Any
+from typing import Any, Callable
from dns._immutable_ctx import immutable
@immutable
class Dict(collections.abc.Mapping): # lgtm[py/missing-equals]
- def __init__(self, dictionary: Any, no_copy: bool = False):
+ def __init__(
+ self,
+ dictionary: Any,
+ no_copy: bool = False,
+ map_factory: Callable[[], collections.abc.MutableMapping] = dict,
+ ):
"""Make an immutable dictionary from the specified dictionary.
If *no_copy* is `True`, then *dictionary* will be wrapped instead
of copied. Only set this if you are sure there will be no external
references to the dictionary.
"""
- if no_copy and isinstance(dictionary, dict):
+ if no_copy and isinstance(dictionary, collections.abc.MutableMapping):
self._odict = dictionary
else:
- self._odict = dict(dictionary)
+ self._odict = map_factory()
+ self._odict.update(dictionary)
self._hash = None
def __getitem__(self, key):
import io
import os
import struct
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ MutableMapping,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
import dns.exception
import dns.grange
the zone.
"""
- node_factory = dns.node.Node
+ node_factory: Callable[[], dns.node.Node] = dns.node.Node
+ map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
+ writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
+ immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
__slots__ = ["rdclass", "origin", "nodes", "relativize"]
raise ValueError("origin parameter must be an absolute name")
self.origin = origin
self.rdclass = rdclass
- self.nodes: Dict[dns.name.Name, dns.node.Node] = {}
+ self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory()
self.relativize = relativize
def __eq__(self, other):
self,
zone: Zone,
id: int,
- nodes: Optional[Dict[dns.name.Name, dns.node.Node]] = None,
+ nodes: Optional[MutableMapping[dns.name.Name, dns.node.Node]] = None,
origin: Optional[dns.name.Name] = None,
):
self.zone = zone
if nodes is not None:
self.nodes = nodes
else:
- self.nodes = {}
+ self.nodes = zone.map_factory()
self.origin = origin
def _validate_name(self, name: dns.name.Name) -> dns.name.Name:
version.nodes[name] = ImmutableVersionedNode(node)
# We're changing the type of the nodes dictionary here on purpose, so
# we ignore the mypy error.
- self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore
+ self.nodes = dns.immutable.Dict(version.nodes, True, self.zone.map_factory) # type: ignore
class Transaction(dns.transaction.Transaction):
def _setup_version(self):
assert self.version is None
- self.version = WritableVersion(self.zone, self.replacement)
+ factory = self.manager.writable_version_factory
+ if factory is None:
+ factory = WritableVersion
+ self.version = factory(self.zone, self.replacement)
def _get_rdataset(self, name, rdtype, covers):
return self.version.get_rdataset(name, rdtype, covers)
self.zone._end_read(self)
elif commit and len(self.version.changed) > 0:
if self.make_immutable:
- version = ImmutableVersion(self.version)
+ factory = self.manager.immutable_version_factory
+ if factory is None:
+ factory = ImmutableVersion
+ version = factory(self.version)
else:
version = self.version
self.zone._commit_version(self, version, self.version.origin)