]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add immutable module.
authorBob Halley <halley@dnspython.org>
Sat, 8 Aug 2020 14:18:10 +0000 (07:18 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 8 Aug 2020 14:18:10 +0000 (07:18 -0700)
dns/immutable.py [new file with mode: 0644]
dns/rdata.py
tests/test_immutable.py [new file with mode: 0644]

diff --git a/dns/immutable.py b/dns/immutable.py
new file mode 100644 (file)
index 0000000..dc48fe8
--- /dev/null
@@ -0,0 +1,62 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import collections.abc
+import sys
+
+if sys.version_info >= (3, 7):
+    odict = dict
+else:
+    from collections import OrderedDict as odict  # pragma: no cover
+
+
+class ImmutableDict(collections.abc.Mapping):
+    def __init__(self, dictionary, no_copy=False):
+        """Make an immutable dictionary from the specified dictionary.
+
+        If *no_copy* is `True`, then *dictionary* will be wrapped instead
+        of copied.  Only set this if you are sure there will be no external
+        references to the dictionary.
+        """
+        if no_copy and isinstance(dictionary, odict):
+            self._odict = dictionary
+        else:
+            self._odict = odict(dictionary)
+        self._hash = None
+
+    def __getitem__(self, key):
+        return self._odict.__getitem__(key)
+
+    def __hash__(self):
+        if self._hash is None:
+            self._hash = 0
+            for key in sorted(self._odict.keys()):
+                self._hash ^= hash(key)
+        return self._hash
+
+    def __len__(self):
+        return len(self._odict)
+
+    def __iter__(self):
+        return iter(self._odict)
+
+
+def constify(o):
+    """
+    Convert mutable types to immutable types.
+    """
+    if isinstance(o, bytearray):
+        return bytes(o)
+    if isinstance(o, tuple):
+        try:
+            hash(o)
+            return o
+        except Exception:
+            return tuple(constify(elt) for elt in o)
+    if isinstance(o, list):
+        return tuple(constify(elt) for elt in o)
+    if isinstance(o, dict):
+        cdict = odict()
+        for k, v in o.items():
+            cdict[k] = constify(v)
+        return ImmutableDict(cdict, True)
+    return o
index ea2f1d249a93e0a2b0742e56f564400338154401..3f0b6d2120e582f92973f9dbb9091557debf4b2a 100644 (file)
@@ -26,6 +26,7 @@ import itertools
 
 import dns.wire
 import dns.exception
+import dns.immutable
 import dns.name
 import dns.rdataclass
 import dns.rdatatype
@@ -92,21 +93,9 @@ def _truncate_bitmap(what):
             return what[0: i + 1]
     return what[0:1]
 
-def _constify(o):
-    """
-    Convert mutable types to immutable types.
-    """
-    if isinstance(o, bytearray):
-        return bytes(o)
-    if isinstance(o, tuple):
-        try:
-            hash(o)
-            return o
-        except Exception:
-            return tuple(_constify(elt) for elt in o)
-    if isinstance(o, list):
-        return tuple(_constify(elt) for elt in o)
-    return o
+# So we don't have to edit all the rdata classes...
+_constify = dns.immutable.constify
+
 
 class Rdata:
     """Base class for all DNS rdata types."""
diff --git a/tests/test_immutable.py b/tests/test_immutable.py
new file mode 100644 (file)
index 0000000..0385fc9
--- /dev/null
@@ -0,0 +1,40 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import unittest
+
+import dns.immutable
+
+
+class ImmutableTestCase(unittest.TestCase):
+
+    def test_ImmutableDict_hash(self):
+        d1 = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
+        d2 = dns.immutable.ImmutableDict({'b': 2, 'a': 1})
+        d3 = {'b': 2, 'a': 1}
+        self.assertEqual(d1, d2)
+        self.assertEqual(d2, d3)
+        self.assertEqual(hash(d1), hash(d2))
+
+    def test_ImmutableDict_hash_cache(self):
+        d = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
+        self.assertEqual(d._hash, None)
+        h1 = hash(d)
+        self.assertEqual(d._hash, h1)
+        h2 = hash(d)
+        self.assertEqual(h1, h2)
+
+    def test_constify(self):
+        items = (
+            (bytearray([1, 2, 3]), b'\x01\x02\x03'),
+            ((1, 2, 3), (1, 2, 3)),
+            ((1, [2], 3), (1, (2,), 3)),
+            ([1, 2, 3], (1, 2, 3)),
+            ([1, {'a': [1, 2]}],
+             (1, dns.immutable.ImmutableDict({'a': (1, 2)}))),
+            ('hi', 'hi'),
+            (b'hi', b'hi'),
+        )
+        for input, expected in items:
+            self.assertEqual(dns.immutable.constify(input), expected)
+        self.assertIsInstance(dns.immutable.constify({'a': 1}),
+                              dns.immutable.ImmutableDict)