--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import collections.abc
+import sys
+
+if sys.version_info >= (3, 7):
+ odict = dict
+else:
+ from collections import OrderedDict as odict # pragma: no cover
+
+
+class ImmutableDict(collections.abc.Mapping):
+ def __init__(self, dictionary, no_copy=False):
+ """Make an immutable dictionary from the specified dictionary.
+
+ If *no_copy* is `True`, then *dictionary* will be wrapped instead
+ of copied. Only set this if you are sure there will be no external
+ references to the dictionary.
+ """
+ if no_copy and isinstance(dictionary, odict):
+ self._odict = dictionary
+ else:
+ self._odict = odict(dictionary)
+ self._hash = None
+
+ def __getitem__(self, key):
+ return self._odict.__getitem__(key)
+
+ def __hash__(self):
+ if self._hash is None:
+ self._hash = 0
+ for key in sorted(self._odict.keys()):
+ self._hash ^= hash(key)
+ return self._hash
+
+ def __len__(self):
+ return len(self._odict)
+
+ def __iter__(self):
+ return iter(self._odict)
+
+
+def constify(o):
+ """
+ Convert mutable types to immutable types.
+ """
+ if isinstance(o, bytearray):
+ return bytes(o)
+ if isinstance(o, tuple):
+ try:
+ hash(o)
+ return o
+ except Exception:
+ return tuple(constify(elt) for elt in o)
+ if isinstance(o, list):
+ return tuple(constify(elt) for elt in o)
+ if isinstance(o, dict):
+ cdict = odict()
+ for k, v in o.items():
+ cdict[k] = constify(v)
+ return ImmutableDict(cdict, True)
+ return o
import dns.wire
import dns.exception
+import dns.immutable
import dns.name
import dns.rdataclass
import dns.rdatatype
return what[0: i + 1]
return what[0:1]
-def _constify(o):
- """
- Convert mutable types to immutable types.
- """
- if isinstance(o, bytearray):
- return bytes(o)
- if isinstance(o, tuple):
- try:
- hash(o)
- return o
- except Exception:
- return tuple(_constify(elt) for elt in o)
- if isinstance(o, list):
- return tuple(_constify(elt) for elt in o)
- return o
+# So we don't have to edit all the rdata classes...
+_constify = dns.immutable.constify
+
class Rdata:
"""Base class for all DNS rdata types."""
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import unittest
+
+import dns.immutable
+
+
+class ImmutableTestCase(unittest.TestCase):
+
+ def test_ImmutableDict_hash(self):
+ d1 = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
+ d2 = dns.immutable.ImmutableDict({'b': 2, 'a': 1})
+ d3 = {'b': 2, 'a': 1}
+ self.assertEqual(d1, d2)
+ self.assertEqual(d2, d3)
+ self.assertEqual(hash(d1), hash(d2))
+
+ def test_ImmutableDict_hash_cache(self):
+ d = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
+ self.assertEqual(d._hash, None)
+ h1 = hash(d)
+ self.assertEqual(d._hash, h1)
+ h2 = hash(d)
+ self.assertEqual(h1, h2)
+
+ def test_constify(self):
+ items = (
+ (bytearray([1, 2, 3]), b'\x01\x02\x03'),
+ ((1, 2, 3), (1, 2, 3)),
+ ((1, [2], 3), (1, (2,), 3)),
+ ([1, 2, 3], (1, 2, 3)),
+ ([1, {'a': [1, 2]}],
+ (1, dns.immutable.ImmutableDict({'a': (1, 2)}))),
+ ('hi', 'hi'),
+ (b'hi', b'hi'),
+ )
+ for input, expected in items:
+ self.assertEqual(dns.immutable.constify(input), expected)
+ self.assertIsInstance(dns.immutable.constify({'a': 1}),
+ dns.immutable.ImmutableDict)