From: Bob Halley Date: Sat, 8 Aug 2020 14:18:10 +0000 (-0700) Subject: Add immutable module. X-Git-Tag: v2.1.0rc1~105 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=71d57aa6d122ef78b743642d04738d3d58cf43e0;p=thirdparty%2Fdnspython.git Add immutable module. --- diff --git a/dns/immutable.py b/dns/immutable.py new file mode 100644 index 00000000..dc48fe85 --- /dev/null +++ b/dns/immutable.py @@ -0,0 +1,62 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import collections.abc +import sys + +if sys.version_info >= (3, 7): + odict = dict +else: + from collections import OrderedDict as odict # pragma: no cover + + +class ImmutableDict(collections.abc.Mapping): + def __init__(self, dictionary, no_copy=False): + """Make an immutable dictionary from the specified dictionary. + + If *no_copy* is `True`, then *dictionary* will be wrapped instead + of copied. Only set this if you are sure there will be no external + references to the dictionary. + """ + if no_copy and isinstance(dictionary, odict): + self._odict = dictionary + else: + self._odict = odict(dictionary) + self._hash = None + + def __getitem__(self, key): + return self._odict.__getitem__(key) + + def __hash__(self): + if self._hash is None: + self._hash = 0 + for key in sorted(self._odict.keys()): + self._hash ^= hash(key) + return self._hash + + def __len__(self): + return len(self._odict) + + def __iter__(self): + return iter(self._odict) + + +def constify(o): + """ + Convert mutable types to immutable types. + """ + if isinstance(o, bytearray): + return bytes(o) + if isinstance(o, tuple): + try: + hash(o) + return o + except Exception: + return tuple(constify(elt) for elt in o) + if isinstance(o, list): + return tuple(constify(elt) for elt in o) + if isinstance(o, dict): + cdict = odict() + for k, v in o.items(): + cdict[k] = constify(v) + return ImmutableDict(cdict, True) + return o diff --git a/dns/rdata.py b/dns/rdata.py index ea2f1d24..3f0b6d21 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -26,6 +26,7 @@ import itertools import dns.wire import dns.exception +import dns.immutable import dns.name import dns.rdataclass import dns.rdatatype @@ -92,21 +93,9 @@ def _truncate_bitmap(what): return what[0: i + 1] return what[0:1] -def _constify(o): - """ - Convert mutable types to immutable types. - """ - if isinstance(o, bytearray): - return bytes(o) - if isinstance(o, tuple): - try: - hash(o) - return o - except Exception: - return tuple(_constify(elt) for elt in o) - if isinstance(o, list): - return tuple(_constify(elt) for elt in o) - return o +# So we don't have to edit all the rdata classes... +_constify = dns.immutable.constify + class Rdata: """Base class for all DNS rdata types.""" diff --git a/tests/test_immutable.py b/tests/test_immutable.py new file mode 100644 index 00000000..0385fc91 --- /dev/null +++ b/tests/test_immutable.py @@ -0,0 +1,40 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import unittest + +import dns.immutable + + +class ImmutableTestCase(unittest.TestCase): + + def test_ImmutableDict_hash(self): + d1 = dns.immutable.ImmutableDict({'a': 1, 'b': 2}) + d2 = dns.immutable.ImmutableDict({'b': 2, 'a': 1}) + d3 = {'b': 2, 'a': 1} + self.assertEqual(d1, d2) + self.assertEqual(d2, d3) + self.assertEqual(hash(d1), hash(d2)) + + def test_ImmutableDict_hash_cache(self): + d = dns.immutable.ImmutableDict({'a': 1, 'b': 2}) + self.assertEqual(d._hash, None) + h1 = hash(d) + self.assertEqual(d._hash, h1) + h2 = hash(d) + self.assertEqual(h1, h2) + + def test_constify(self): + items = ( + (bytearray([1, 2, 3]), b'\x01\x02\x03'), + ((1, 2, 3), (1, 2, 3)), + ((1, [2], 3), (1, (2,), 3)), + ([1, 2, 3], (1, 2, 3)), + ([1, {'a': [1, 2]}], + (1, dns.immutable.ImmutableDict({'a': (1, 2)}))), + ('hi', 'hi'), + (b'hi', b'hi'), + ) + for input, expected in items: + self.assertEqual(dns.immutable.constify(input), expected) + self.assertIsInstance(dns.immutable.constify({'a': 1}), + dns.immutable.ImmutableDict)