]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Transaction support. 569/head
authorBob Halley <halley@dnspython.org>
Thu, 30 Jul 2020 16:21:03 +0000 (09:21 -0700)
committerBob Halley <halley@dnspython.org>
Tue, 11 Aug 2020 01:14:30 +0000 (18:14 -0700)
14 files changed:
dns/__init__.py
dns/_immutable_attr.py [new file with mode: 0644]
dns/_immutable_ctx.py [new file with mode: 0644]
dns/exception.py
dns/immutable.py
dns/masterfile.py [new file with mode: 0644]
dns/node.py
dns/rdataset.py
dns/transaction.py [new file with mode: 0644]
dns/versioned.py [new file with mode: 0644]
dns/zone.py
tests/test_immutable.py
tests/test_rdataset.py
tests/test_transaction.py [new file with mode: 0644]

index b944701d3c84e3e4080e676ab75e1bc11d72ba81..eafdcc4d539d2b6fdca4ee8d48c011262e7d3a16 100644 (file)
@@ -27,9 +27,11 @@ __all__ = [
     'entropy',
     'exception',
     'flags',
+    'immutable',
     'inet',
     'ipv4',
     'ipv6',
+    'masterfile',
     'message',
     'name',
     'namedict',
@@ -48,12 +50,14 @@ __all__ = [
     'serial',
     'set',
     'tokenizer',
+    'transaction',
     'tsig',
     'tsigkeyring',
     'ttl',
     'rdtypes',
     'update',
     'version',
+    'versioned',
     'wire',
     'zone',
 ]
diff --git a/dns/_immutable_attr.py b/dns/_immutable_attr.py
new file mode 100644 (file)
index 0000000..e858337
--- /dev/null
@@ -0,0 +1,66 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# This implementation of the immutable decorator is for python 3.6,
+# which doesn't have Context Variables.  This implementation is somewhat
+# costly for classes with slots, as it adds a __dict__ to them.
+
+class _Immutable:
+    """Immutable mixin class"""
+
+    # Note we MUST NOT have __slots__ as that causes
+    #
+    #    TypeError: multiple bases have instance lay-out conflict
+    #
+    # when we get mixed in with another class with slots.  When we
+    # get mixed into something with slots, it effectively adds __dict__ to
+    # the slots of the other class, which allows attribute setting to work,
+    # albeit at the cost of the dictionary.
+
+    def __setattr__(self, name, value):
+        if not hasattr(self, '_immutable_init') or \
+           self._immutable_init is not self:
+            raise TypeError("object doesn't support attribute assignment")
+        else:
+            super().__setattr__(name, value)
+
+    def __delattr__(self, name):
+        if not hasattr(self, '_immutable_init') or \
+           self._immutable_init is not self:
+            raise TypeError("object doesn't support attribute assignment")
+        else:
+            super().__delattr__(name)
+
+
+def _immutable_init(f):
+    def nf(*args, **kwargs):
+        try:
+            # Are we already initializing an immutable class?
+            previous = args[0]._immutable_init
+        except AttributeError:
+            # We are the first!
+            previous = None
+            object.__setattr__(args[0], '_immutable_init', args[0])
+        try:
+            # call the actual __init__
+            f(*args, **kwargs)
+        finally:
+            if not previous:
+                # If we started the initialzation, establish immutability
+                # by removing the attribute that allows mutation
+                object.__delattr__(args[0], '_immutable_init')
+    return nf
+
+
+def immutable(cls):
+    if _Immutable in cls.__mro__:
+        # Some ancestor already has the mixin, so just make sure we keep
+        # following the __init__ protocol.
+        cls.__init__ = _immutable_init(cls.__init__)
+        ncls = cls
+    else:
+        # Mixin the Immutable class and follow the __init__ protocol.
+        class ncls(_Immutable, cls):
+            @_immutable_init
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+    return ncls
diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py
new file mode 100644 (file)
index 0000000..babdde0
--- /dev/null
@@ -0,0 +1,60 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# This implementation of the immutable decorator requires python >=
+# 3.7, and is significantly more storage efficient when making classes
+# with slots immutable.  It's also faster.
+
+import contextvars
+
+_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
+
+
+class _Immutable:
+    """Immutable mixin class"""
+
+    # We set slots to the empty list to say "we don't have any attributes".
+    # We do this so that if we're mixed in with a class with __slots__, we
+    # don't cause a __dict__ to be added which would waste space.
+
+    __slots__ = ()
+
+    def __setattr__(self, name, value):
+        if _in__init__.get() is not self:
+            raise TypeError("object doesn't support attribute assignment")
+        else:
+            super().__setattr__(name, value)
+
+    def __delattr__(self, name):
+        if _in__init__.get() is not self:
+            raise TypeError("object doesn't support attribute assignment")
+        else:
+            super().__delattr__(name)
+
+
+def _immutable_init(f):
+    def nf(*args, **kwargs):
+        previous = _in__init__.set(args[0])
+        try:
+            # call the actual __init__
+            f(*args, **kwargs)
+        finally:
+            _in__init__.reset(previous)
+    return nf
+
+
+def immutable(cls):
+    if _Immutable in cls.__mro__:
+        # Some ancestor already has the mixin, so just make sure we keep
+        # following the __init__ protocol.
+        cls.__init__ = _immutable_init(cls.__init__)
+        ncls = cls
+    else:
+        # Mixin the Immutable class and follow the __init__ protocol.
+        class ncls(_Immutable, cls):
+            # We have to do the __slots__ declaration here too!
+            __slots__ = ()
+
+            @_immutable_init
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+    return ncls
index 9486f4507421ce115d32e9e2dd0c42cb50035bdd..9392373492e05f86d43c190dec5d37231276b6ce 100644 (file)
@@ -138,5 +138,5 @@ class ExceptionWrapper:
     def __exit__(self, exc_type, exc_val, exc_tb):
         if exc_type is not None and not isinstance(exc_val,
                                                    self.exception_class):
-            raise self.exception_class() from exc_val
+            raise self.exception_class(str(exc_val)) from exc_val
         return False
index dc48fe852af628dcc7b7780c5fa17d1462f38afb..7cc39dd0a34a2164e9c978902c370bbe68da231b 100644 (file)
@@ -3,13 +3,19 @@
 import collections.abc
 import sys
 
+# pylint: disable=unused-import
 if sys.version_info >= (3, 7):
     odict = dict
+    from dns._immutable_ctx import immutable
 else:
-    from collections import OrderedDict as odict  # pragma: no cover
+    # pragma: no cover
+    from collections import OrderedDict as odict
+    from dns._immutable_attr import immutable  # noqa
+# pylint: enable=unused-import
 
 
-class ImmutableDict(collections.abc.Mapping):
+@immutable
+class Dict(collections.abc.Mapping):
     def __init__(self, dictionary, no_copy=False):
         """Make an immutable dictionary from the specified dictionary.
 
@@ -28,9 +34,10 @@ class ImmutableDict(collections.abc.Mapping):
 
     def __hash__(self):
         if self._hash is None:
-            self._hash = 0
+            h = 0
             for key in sorted(self._odict.keys()):
-                self._hash ^= hash(key)
+                h ^= hash(key)
+            object.__setattr__(self, '_hash', h)
         return self._hash
 
     def __len__(self):
@@ -58,5 +65,5 @@ def constify(o):
         cdict = odict()
         for k, v in o.items():
             cdict[k] = constify(v)
-        return ImmutableDict(cdict, True)
+        return Dict(cdict, True)
     return o
diff --git a/dns/masterfile.py b/dns/masterfile.py
new file mode 100644 (file)
index 0000000..30553b5
--- /dev/null
@@ -0,0 +1,404 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+"""DNS Zones."""
+
+import re
+import sys
+
+import dns.exception
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdata
+import dns.rdtypes.ANY.SOA
+import dns.rrset
+import dns.tokenizer
+import dns.transaction
+import dns.ttl
+import dns.grange
+
+
+class UnknownOrigin(dns.exception.DNSException):
+    """Unknown origin"""
+
+
+class Reader:
+
+    """Read a DNS master file into a transaction."""
+
+    def __init__(self, tok, origin, rdclass, relativize, txn,
+                 allow_include=False):
+        if isinstance(origin, str):
+            origin = dns.name.from_text(origin)
+        self.tok = tok
+        self.current_origin = origin
+        self.relativize = relativize
+        self.last_ttl = 0
+        self.last_ttl_known = False
+        self.default_ttl = 0
+        self.default_ttl_known = False
+        self.last_name = self.current_origin
+        self.zone_origin = origin
+        self.zone_rdclass = rdclass
+        self.txn = txn
+        self.saved_state = []
+        self.current_file = None
+        self.allow_include = allow_include
+
+    def _eat_line(self):
+        while 1:
+            token = self.tok.get()
+            if token.is_eol_or_eof():
+                break
+
+    def _rr_line(self):
+        """Process one line from a DNS master file."""
+        # Name
+        if self.current_origin is None:
+            raise UnknownOrigin
+        token = self.tok.get(want_leading=True)
+        if not token.is_whitespace():
+            self.last_name = self.tok.as_name(token, self.current_origin)
+        else:
+            token = self.tok.get()
+            if token.is_eol_or_eof():
+                # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+                return
+            self.tok.unget(token)
+        name = self.last_name
+        if not name.is_subdomain(self.zone_origin):
+            self._eat_line()
+            return
+        if self.relativize:
+            name = name.relativize(self.zone_origin)
+        token = self.tok.get()
+        if not token.is_identifier():
+            raise dns.exception.SyntaxError
+
+        # TTL
+        ttl = None
+        try:
+            ttl = dns.ttl.from_text(token.value)
+            self.last_ttl = ttl
+            self.last_ttl_known = True
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except dns.ttl.BadTTL:
+            if self.default_ttl_known:
+                ttl = self.default_ttl
+            elif self.last_ttl_known:
+                ttl = self.last_ttl
+
+        # Class
+        try:
+            rdclass = dns.rdataclass.from_text(token.value)
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except dns.exception.SyntaxError:
+            raise
+        except Exception:
+            rdclass = self.zone_rdclass
+        if rdclass != self.zone_rdclass:
+            raise dns.exception.SyntaxError("RR class is not zone's class")
+        # Type
+        try:
+            rdtype = dns.rdatatype.from_text(token.value)
+        except Exception:
+            raise dns.exception.SyntaxError(
+                "unknown rdatatype '%s'" % token.value)
+        try:
+            rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
+                                     self.current_origin, self.relativize,
+                                     self.zone_origin)
+        except dns.exception.SyntaxError:
+            # Catch and reraise.
+            raise
+        except Exception:
+            # All exceptions that occur in the processing of rdata
+            # are treated as syntax errors.  This is not strictly
+            # correct, but it is correct almost all of the time.
+            # We convert them to syntax errors so that we can emit
+            # helpful filename:line info.
+            (ty, va) = sys.exc_info()[:2]
+            raise dns.exception.SyntaxError(
+                "caught exception {}: {}".format(str(ty), str(va)))
+
+        if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
+            # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
+            # TTL from the SOA minttl if no $TTL statement is present before the
+            # SOA is parsed.
+            self.default_ttl = rd.minimum
+            self.default_ttl_known = True
+            if ttl is None:
+                # if we didn't have a TTL on the SOA, set it!
+                ttl = rd.minimum
+
+        # TTL check.  We had to wait until now to do this as the SOA RR's
+        # own TTL can be inferred from its minimum.
+        if ttl is None:
+            raise dns.exception.SyntaxError("Missing default TTL value")
+
+        self.txn.add(name, ttl, rd)
+
+    def _parse_modify(self, side):
+        # Here we catch everything in '{' '}' in a group so we can replace it
+        # with ''.
+        is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$")
+        is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$")
+        is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$")
+        # Sometimes there are modifiers in the hostname. These come after
+        # the dollar sign. They are in the form: ${offset[,width[,base]]}.
+        # Make names
+        g1 = is_generate1.match(side)
+        if g1:
+            mod, sign, offset, width, base = g1.groups()
+            if sign == '':
+                sign = '+'
+        g2 = is_generate2.match(side)
+        if g2:
+            mod, sign, offset = g2.groups()
+            if sign == '':
+                sign = '+'
+            width = 0
+            base = 'd'
+        g3 = is_generate3.match(side)
+        if g3:
+            mod, sign, offset, width = g3.groups()
+            if sign == '':
+                sign = '+'
+            base = 'd'
+
+        if not (g1 or g2 or g3):
+            mod = ''
+            sign = '+'
+            offset = 0
+            width = 0
+            base = 'd'
+
+        if base != 'd':
+            raise NotImplementedError()
+
+        return mod, sign, offset, width, base
+
+    def _generate_line(self):
+        # range lhs [ttl] [class] type rhs [ comment ]
+        """Process one line containing the GENERATE statement from a DNS
+        master file."""
+        if self.current_origin is None:
+            raise UnknownOrigin
+
+        token = self.tok.get()
+        # Range (required)
+        try:
+            start, stop, step = dns.grange.from_text(token.value)
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except Exception:
+            raise dns.exception.SyntaxError
+
+        # lhs (required)
+        try:
+            lhs = token.value
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except Exception:
+            raise dns.exception.SyntaxError
+
+        # TTL
+        try:
+            ttl = dns.ttl.from_text(token.value)
+            self.last_ttl = ttl
+            self.last_ttl_known = True
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except dns.ttl.BadTTL:
+            if not (self.last_ttl_known or self.default_ttl_known):
+                raise dns.exception.SyntaxError("Missing default TTL value")
+            if self.default_ttl_known:
+                ttl = self.default_ttl
+            elif self.last_ttl_known:
+                ttl = self.last_ttl
+        # Class
+        try:
+            rdclass = dns.rdataclass.from_text(token.value)
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except dns.exception.SyntaxError:
+            raise dns.exception.SyntaxError
+        except Exception:
+            rdclass = self.zone_rdclass
+        if rdclass != self.zone_rdclass:
+            raise dns.exception.SyntaxError("RR class is not zone's class")
+        # Type
+        try:
+            rdtype = dns.rdatatype.from_text(token.value)
+            token = self.tok.get()
+            if not token.is_identifier():
+                raise dns.exception.SyntaxError
+        except Exception:
+            raise dns.exception.SyntaxError("unknown rdatatype '%s'" %
+                                            token.value)
+
+        # rhs (required)
+        rhs = token.value
+
+        # The code currently only supports base 'd', so the last value
+        # in the tuple _parse_modify returns is ignored
+        lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs)
+        rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs)
+        for i in range(start, stop + 1, step):
+            # +1 because bind is inclusive and python is exclusive
+
+            if lsign == '+':
+                lindex = i + int(loffset)
+            elif lsign == '-':
+                lindex = i - int(loffset)
+
+            if rsign == '-':
+                rindex = i - int(roffset)
+            elif rsign == '+':
+                rindex = i + int(roffset)
+
+            lzfindex = str(lindex).zfill(int(lwidth))
+            rzfindex = str(rindex).zfill(int(rwidth))
+
+            name = lhs.replace('$%s' % (lmod), lzfindex)
+            rdata = rhs.replace('$%s' % (rmod), rzfindex)
+
+            self.last_name = dns.name.from_text(name, self.current_origin,
+                                                self.tok.idna_codec)
+            name = self.last_name
+            if not name.is_subdomain(self.zone_origin):
+                self._eat_line()
+                return
+            if self.relativize:
+                name = name.relativize(self.zone_origin)
+
+            try:
+                rd = dns.rdata.from_text(rdclass, rdtype, rdata,
+                                         self.current_origin, self.relativize,
+                                         self.zone_origin)
+            except dns.exception.SyntaxError:
+                # Catch and reraise.
+                raise
+            except Exception:
+                # All exceptions that occur in the processing of rdata
+                # are treated as syntax errors.  This is not strictly
+                # correct, but it is correct almost all of the time.
+                # We convert them to syntax errors so that we can emit
+                # helpful filename:line info.
+                (ty, va) = sys.exc_info()[:2]
+                raise dns.exception.SyntaxError("caught exception %s: %s" %
+                                                (str(ty), str(va)))
+
+            self.txn.add(name, ttl, rd)
+
+    def read(self):
+        """Read a DNS master file and build a zone object.
+
+        @raises dns.zone.NoSOA: No SOA RR was found at the zone origin
+        @raises dns.zone.NoNS: No NS RRset was found at the zone origin
+        """
+
+        try:
+            while 1:
+                token = self.tok.get(True, True)
+                if token.is_eof():
+                    if self.current_file is not None:
+                        self.current_file.close()
+                    if len(self.saved_state) > 0:
+                        (self.tok,
+                         self.current_origin,
+                         self.last_name,
+                         self.current_file,
+                         self.last_ttl,
+                         self.last_ttl_known,
+                         self.default_ttl,
+                         self.default_ttl_known) = self.saved_state.pop(-1)
+                        continue
+                    break
+                elif token.is_eol():
+                    continue
+                elif token.is_comment():
+                    self.tok.get_eol()
+                    continue
+                elif token.value[0] == '$':
+                    c = token.value.upper()
+                    if c == '$TTL':
+                        token = self.tok.get()
+                        if not token.is_identifier():
+                            raise dns.exception.SyntaxError("bad $TTL")
+                        self.default_ttl = dns.ttl.from_text(token.value)
+                        self.default_ttl_known = True
+                        self.tok.get_eol()
+                    elif c == '$ORIGIN':
+                        self.current_origin = self.tok.get_name()
+                        self.tok.get_eol()
+                        if self.zone_origin is None:
+                            self.zone_origin = self.current_origin
+                        self.txn._set_origin(self.current_origin)
+                    elif c == '$INCLUDE' and self.allow_include:
+                        token = self.tok.get()
+                        filename = token.value
+                        token = self.tok.get()
+                        if token.is_identifier():
+                            new_origin =\
+                                dns.name.from_text(token.value,
+                                                   self.current_origin,
+                                                   self.tok.idna_codec)
+                            self.tok.get_eol()
+                        elif not token.is_eol_or_eof():
+                            raise dns.exception.SyntaxError(
+                                "bad origin in $INCLUDE")
+                        else:
+                            new_origin = self.current_origin
+                        self.saved_state.append((self.tok,
+                                                 self.current_origin,
+                                                 self.last_name,
+                                                 self.current_file,
+                                                 self.last_ttl,
+                                                 self.last_ttl_known,
+                                                 self.default_ttl,
+                                                 self.default_ttl_known))
+                        self.current_file = open(filename, 'r')
+                        self.tok = dns.tokenizer.Tokenizer(self.current_file,
+                                                           filename)
+                        self.current_origin = new_origin
+                    elif c == '$GENERATE':
+                        self._generate_line()
+                    else:
+                        raise dns.exception.SyntaxError(
+                            "Unknown master file directive '" + c + "'")
+                    continue
+                self.tok.unget(token)
+                self._rr_line()
+        except dns.exception.SyntaxError as detail:
+            (filename, line_number) = self.tok.where()
+            if detail is None:
+                detail = "syntax error"
+            ex = dns.exception.SyntaxError(
+                "%s:%d: %s" % (filename, line_number, detail))
+            tb = sys.exc_info()[2]
+            raise ex.with_traceback(tb) from None
index b7e21b54678e5a486c5ca01854829e507354e68b..8e1451f3d2e2def8fbb939c7ab727db31c569186 100644 (file)
@@ -183,3 +183,33 @@ class Node:
         self.delete_rdataset(replacement.rdclass, replacement.rdtype,
                              replacement.covers)
         self.rdatasets.append(replacement)
+
+
+@dns.immutable.immutable
+class ImmutableNode(Node):
+
+    """An ImmutableNode is an immutable set of rdatasets."""
+
+    def __init__(self, node):
+        super().__init__()
+        self.rdatasets = tuple(
+            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+        )
+
+    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+                      create=False):
+        if create:
+            raise TypeError("immutable")
+        return super().find_rdataset(rdclass, rdtype, covers, False)
+
+    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+                     create=False):
+        if create:
+            raise TypeError("immutable")
+        return super().get_rdataset(rdclass, rdtype, covers, False)
+
+    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+        raise TypeError("immutable")
+
+    def replace_rdataset(self, replacement):
+        raise TypeError("immutable")
index ba93ab4365619505af2cc75d82314365a2ead8f6..1f372cd61242a80de8267d2d3fa1385cb135b402 100644 (file)
@@ -22,6 +22,7 @@ import random
 import struct
 
 import dns.exception
+import dns.immutable
 import dns.rdatatype
 import dns.rdataclass
 import dns.rdata
@@ -306,6 +307,52 @@ class Rdataset(dns.set.Set):
         return False
 
 
+@dns.immutable.immutable
+class ImmutableRdataset(Rdataset):
+
+    """An immutable DNS rdataset."""
+
+    def __init__(self, rdataset):
+        """Create an immutable rdataset from the specified rdataset."""
+
+        super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
+                         rdataset.ttl)
+        self.items = dns.immutable.Dict(rdataset.items)
+
+    def update_ttl(self, ttl):
+        raise TypeError('immutable')
+
+    def add(self, rd, ttl=None):
+        raise TypeError('immutable')
+
+    def union_update(self, other):
+        raise TypeError('immutable')
+
+    def intersection_update(self, other):
+        raise TypeError('immutable')
+
+    def update(self, other):
+        raise TypeError('immutable')
+
+    def __delitem__(self, i):
+        raise TypeError('immutable')
+
+    def __ior__(self, other):
+        raise TypeError('immutable')
+
+    def __iand__(self, other):
+        raise TypeError('immutable')
+
+    def __iadd__(self, other):
+        raise TypeError('immutable')
+
+    def __isub__(self, other):
+        raise TypeError('immutable')
+
+    def clear(self):
+        raise TypeError('immutable')
+
+
 def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
                    origin=None, relativize=True, relativize_to=None):
     """Create an rdataset with the specified class, type, and TTL, and with
diff --git a/dns/transaction.py b/dns/transaction.py
new file mode 100644 (file)
index 0000000..20d6939
--- /dev/null
@@ -0,0 +1,383 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import collections
+
+import dns.exception
+import dns.name
+import dns.rdataclass
+import dns.rdataset
+import dns.rdatatype
+import dns.rrset
+import dns.ttl
+
+
+class TransactionManager:
+    def reader(self):
+        """Begin a read-only transaction."""
+        raise NotImplementedError  # pragma: no cover
+
+    def writer(self, replacement=False):
+        """Begin a writable transaction.
+
+        *replacement*, a `bool`.  If `True`, the content of the
+        transaction completely replaces any prior content.  If False,
+        the default, then the content of the transaction updates the
+        existing content.
+        """
+        raise NotImplementedError  # pragma: no cover
+
+
+class DeleteNotExact(dns.exception.DNSException):
+    """Existing data did not match data specified by an exact delete."""
+
+
+class ReadOnly(dns.exception.DNSException):
+    """Tried to write to a read-only transaction."""
+
+
+class Transaction:
+
+    def __init__(self, replacement=False, read_only=False):
+        self.replacement = replacement
+        self.read_only = read_only
+
+    #
+    # This is the high level API
+    #
+
+    def get(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE):
+        """Return the rdataset associated with *name*, *rdclass*, *rdtype*,
+        and *covers*, or `None` if not found.
+
+        Note that the returned rdataset is immutable.
+        """
+        if isinstance(name, str):
+            name = dns.name.from_text(name, None)
+        rdclass = dns.rdataclass.RdataClass.make(rdclass)
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
+        rdataset = self._get_rdataset(name, rdclass, rdtype, covers)
+        if rdataset is not None and \
+           not isinstance(rdataset, dns.rdataset.ImmutableRdataset):
+            rdataset = dns.rdataset.ImmutableRdataset(rdataset)
+        return rdataset
+
+    def _check_read_only(self):
+        if self.read_only:
+            raise ReadOnly
+
+    def add(self, *args):
+        """Add records.
+
+        The arguments may be:
+
+            - rrset
+
+            - name, rdataset...
+
+            - name, ttl, rdata...
+        """
+        self._check_read_only()
+        return self._add(False, args)
+
+    def replace(self, *args):
+        """Replace the existing rdataset at the name with the specified
+        rdataset, or add the specified rdataset if there was no existing
+        rdataset.
+
+        The arguments may be:
+
+            - rrset
+
+            - name, rdataset...
+
+            - name, ttl, rdata...
+
+        Note that if you want to replace the entire node, you should do
+        a delete of the name followed by one or more calls to add() or
+        replace().
+        """
+        self._check_read_only()
+        return self._add(True, args)
+
+    def delete(self, *args):
+        """Delete records.
+
+        It is not an error if some of the records are not in the existing
+        set.
+
+        The arguments may be:
+
+            - rrset
+
+            - name
+
+            - name, rdataclass, rdatatype, [covers]
+
+            - name, rdataset...
+
+            - name, rdata...
+        """
+        self._check_read_only()
+        return self._delete(False, args)
+
+    def delete_exact(self, *args):
+        """Delete records.
+
+        The arguments may be:
+
+            - rrset
+
+            - name
+
+            - name, rdataclass, rdatatype, [covers]
+
+            - name, rdataset...
+
+            - name, rdata...
+
+        Raises dns.transaction.DeleteNotExact if some of the records
+        are not in the existing set.
+
+        """
+        self._check_read_only()
+        return self._delete(True, args)
+
+    def name_exists(self, name):
+        """Does the specified name exist?"""
+        if isinstance(name, str):
+            name = dns.name.from_text(name, None)
+        return self._name_exists(name)
+
+    def set_serial(self, increment=1, value=None, name=dns.name.empty,
+                   rdclass=dns.rdataclass.IN):
+        if isinstance(name, str):
+            name = dns.name.from_text(name, None)
+        rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA,
+                                      dns.rdatatype.NONE)
+        if rdataset is None or len(rdataset) == 0:
+            raise KeyError
+        if value is not None:
+            serial = value
+        else:
+            serial = rdataset[0].serial
+        serial += increment
+        if serial > 0xffffffff or serial < 1:
+            serial = 1
+        rdata = rdataset[0].replace(serial=serial)
+        new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata)
+        self.replace(name, new_rdataset)
+
+    def __iter__(self):
+        return self._iterate_rdatasets()
+
+    #
+    # Helper methods
+    #
+
+    def _raise_if_not_empty(self, method, args):
+        if len(args) != 0:
+            raise TypeError(f'extra parameters to {method}')
+
+    def _rdataset_from_args(self, method, deleting, args):
+        try:
+            arg = args.popleft()
+            if isinstance(arg, dns.rdataset.Rdataset):
+                rdataset = arg
+            else:
+                if deleting:
+                    ttl = 0
+                else:
+                    if isinstance(arg, int):
+                        ttl = arg
+                        if ttl > dns.ttl.MAX_TTL:
+                            raise ValueError(f'{method}: TTL value too big')
+                    else:
+                        raise TypeError(f'{method}: expected a TTL')
+                    arg = args.popleft()
+                if isinstance(arg, dns.rdata.Rdata):
+                    rdataset = dns.rdataset.from_rdata(ttl, arg)
+                else:
+                    raise TypeError(f'{method}: expected an Rdata')
+            return rdataset
+        except IndexError:
+            if deleting:
+                return None
+            else:
+                # reraise
+                raise TypeError(f'{method}: expected more arguments')
+
+    def _add(self, replace, args):
+        try:
+            args = collections.deque(args)
+            if replace:
+                method = 'replace()'
+            else:
+                method = 'add()'
+            arg = args.popleft()
+            if isinstance(arg, str):
+                arg = dns.name.from_text(arg, None)
+            if isinstance(arg, dns.name.Name):
+                name = arg
+                rdataset = self._rdataset_from_args(method, False, args)
+            elif isinstance(arg, dns.rrset.RRset):
+                rrset = arg
+                name = rrset.name
+                # rrsets are also rdatasets, but they don't print the
+                # same, so convert.
+                rdataset = dns.rdataset.Rdataset(rrset.rdclass, rrset.rdtype,
+                                                 rrset.covers, rrset.ttl)
+                rdataset.union_update(rrset)
+            else:
+                raise TypeError(f'{method} requires a name or RRset ' +
+                                'as the first argument')
+            self._raise_if_not_empty(method, args)
+            if not replace:
+                existing = self._get_rdataset(name, rdataset.rdclass,
+                                              rdataset.rdtype, rdataset.covers)
+                if existing is not None:
+                    if isinstance(existing, dns.rdataset.ImmutableRdataset):
+                        trds = dns.rdataset.Rdataset(existing.rdclass,
+                                                     existing.rdtype,
+                                                     existing.covers)
+                        trds.update(existing)
+                        existing = trds
+                    rdataset = existing.union(rdataset)
+            self._put_rdataset(name, rdataset)
+        except IndexError:
+            raise TypeError(f'not enough parameters to {method}')
+
+    def _delete(self, exact, args):
+        try:
+            args = collections.deque(args)
+            if exact:
+                method = 'delete_exact()'
+            else:
+                method = 'delete()'
+            arg = args.popleft()
+            if isinstance(arg, str):
+                arg = dns.name.from_text(arg, None)
+            if isinstance(arg, dns.name.Name):
+                name = arg
+                if len(args) > 0 and isinstance(args[0], int):
+                    # deleting by type and class
+                    rdclass = dns.rdataclass.RdataClass.make(args.popleft())
+                    rdtype = dns.rdatatype.RdataType.make(args.popleft())
+                    if len(args) > 0:
+                        covers = dns.rdatatype.RdataType.make(args.popleft())
+                    else:
+                        covers = dns.rdatatype.NONE
+                    self._raise_if_not_empty(method, args)
+                    existing = self._get_rdataset(name, rdclass, rdtype, covers)
+                    if existing is None:
+                        if exact:
+                            raise DeleteNotExact(f'{method}: missing rdataset')
+                    else:
+                        self._delete_rdataset(name, rdclass, rdtype, covers)
+                    return
+                else:
+                    rdataset = self._rdataset_from_args(method, True, args)
+            elif isinstance(arg, dns.rrset.RRset):
+                rdataset = arg  # rrsets are also rdatasets
+                name = rdataset.name
+            else:
+                raise TypeError(f'{method} requires a name or RRset ' +
+                                'as the first argument')
+            self._raise_if_not_empty(method, args)
+            if rdataset:
+                existing = self._get_rdataset(name, rdataset.rdclass,
+                                              rdataset.rdtype, rdataset.covers)
+                if existing is not None:
+                    if exact:
+                        intersection = existing.intersection(rdataset)
+                        if intersection != rdataset:
+                            raise DeleteNotExact(f'{method}: missing rdatas')
+                    rdataset = existing.difference(rdataset)
+                    if len(rdataset) == 0:
+                        self._delete_rdataset(name, rdataset.rdclass,
+                                              rdataset.rdtype, rdataset.covers)
+                    else:
+                        self._put_rdataset(name, rdataset)
+                elif exact:
+                    raise DeleteNotExact(f'{method}: missing rdataset')
+            else:
+                if exact and not self._name_exists(name):
+                    raise DeleteNotExact(f'{method}: name not known')
+                self._delete_name(name)
+        except IndexError:
+            raise TypeError(f'not enough parameters to {method}')
+
+    #
+    # Transactions are context managers.
+    #
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_type is None:
+            self._end_transaction(True)
+        else:
+            self._end_transaction(False)
+        return False
+
+    #
+    # This is the low level API, which must be implemented by subclasses
+    # of Transaction.
+    #
+
+    def _get_rdataset(self, name, rdclass, rdtype, covers):
+        """Return the rdataset associated with *name*, *rdclass*, *rdtype*,
+        and *covers*, or `None` if not found."""
+        raise NotImplementedError  # pragma: no cover
+
+    def _put_rdataset(self, name, rdataset):
+        """Store the rdataset."""
+        raise NotImplementedError  # pragma: no cover
+
+    def _delete_name(self, name):
+        """Delete all data associated with *name*.
+
+        It is not an error if the rdataset does not exist.
+        """
+        raise NotImplementedError  # pragma: no cover
+
+    def _delete_rdataset(self, name, rdclass, rdtype, covers):
+        """Delete all data associated with *name*, *rdclass*, *rdtype*, and
+        *covers*.
+
+        It is not an error if the rdataset does not exist.
+        """
+        raise NotImplementedError  # pragma: no cover
+
+    def _name_exists(self, name):
+        """Does name exist?
+
+        Returns a bool.
+        """
+        raise NotImplementedError  # pragma: no cover
+
+    def _end_transaction(self, commit):
+        """End the transaction.
+
+        *commit*, a bool.  If ``True``, commit the transaction, otherwise
+        roll it back.
+
+        Raises an exception if committing failed.
+        """
+        raise NotImplementedError  # pragma: no cover
+
+    def _set_origin(self, origin):
+        """Set the origin.
+
+        This method is called when reading a possibly relativized
+        source, and an origin setting operation occurs (e.g. $ORIGIN
+        in a masterfile).
+        """
+        raise NotImplementedError  # pragma: no cover
+
+    def _iterate_rdatasets(self):
+        """Return an iterator that yields (name, rdataset) tuples.
+
+        Not all Transaction subclasses implement this.
+        """
+        raise NotImplementedError  # pragma: no cover
diff --git a/dns/versioned.py b/dns/versioned.py
new file mode 100644 (file)
index 0000000..6f911e1
--- /dev/null
@@ -0,0 +1,392 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""DNS Versioned Zones."""
+
+import collections
+try:
+    import threading as _threading
+except ImportError:  # pragma: no cover
+    import dummy_threading as _threading    # type: ignore
+
+import dns.exception
+import dns.immutable
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdata
+import dns.rdtypes.ANY.SOA
+import dns.transaction
+import dns.zone
+
+
+class UseTransaction(dns.exception.DNSException):
+    """To alter a versioned zone, use a transaction."""
+
+
+class Version:
+    def __init__(self, zone, id):
+        self.zone = zone
+        self.id = id
+        self.nodes = {}
+
+    def _validate_name(self, name):
+        if name.is_absolute():
+            if not name.is_subdomain(self.zone.origin):
+                raise KeyError("name is not a subdomain of the zone origin")
+            if self.zone.relativize:
+                name = name.relativize(self.origin)
+        return name
+
+    def get_node(self, name):
+        name = self._validate_name(name)
+        return self.nodes.get(name)
+
+    def get_rdataset(self, name, rdtype, covers):
+        node = self.get_node(name)
+        if node is None:
+            return None
+        return node.get_rdataset(self.zone.rdclass, rdtype, covers)
+
+    def items(self):
+        return self.nodes.items()  # pylint: disable=dict-items-not-iterating
+
+    def _print(self):  # pragma: no cover
+        # XXXRTH  This is for debugging
+        print('VERSION', self.id)
+        for (name, node) in self.nodes.items():
+            for rdataset in node:
+                print(rdataset.to_text(name))
+
+
+class WritableVersion(Version):
+    def __init__(self, zone, replacement=False):
+        if len(zone.versions) > 0:
+            id = zone.versions[-1].id + 1
+        else:
+            id = 1
+        super().__init__(zone, id)
+        if not replacement:
+            # We copy the map, because that gives us a simple and thread-safe
+            # way of doing versions, and we have a garbage collector to help
+            # us.  We only make new node objects if we actually change the
+            # node.
+            self.nodes.update(zone.nodes)
+        # We have to copy the zone origin as it may be None in the first
+        # version, and we don't want to mutate the zone until we commit.
+        self.origin = zone.origin
+        self.changed = set()
+
+    def _validate_name(self, name):
+        if name.is_absolute():
+            if not name.is_subdomain(self.origin):
+                raise KeyError("name is not a subdomain of the zone origin")
+            if self.zone.relativize:
+                name = name.relativize(self.origin)
+        return name
+
+    def _maybe_cow(self, name):
+        name = self._validate_name(name)
+        node = self.nodes.get(name)
+        if node is None or node.id != self.id:
+            new_node = self.zone.node_factory()
+            new_node.id = self.id
+            if node is not None:
+                # moo!  copy on write!
+                new_node.rdatasets.extend(node.rdatasets)
+            self.nodes[name] = new_node
+            self.changed.add(name)
+            return new_node
+        else:
+            return node
+
+    def delete_node(self, name):
+        name = self._validate_name(name)
+        if name in self.nodes:
+            del self.nodes[name]
+            return True
+        return False
+
+    def put_rdataset(self, name, rdataset):
+        node = self._maybe_cow(name)
+        node.replace_rdataset(rdataset)
+
+    def delete_rdataset(self, name, rdtype, covers):
+        node = self._maybe_cow(name)
+        if not node.get_rdataset(self.zone.rdclass, rdtype, covers):
+            return False
+        node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+        if len(node) == 0:
+            del self.nodes[name]
+        return True
+
+
+@dns.immutable.immutable
+class ImmutableVersion(Version):
+    def __init__(self, version):
+        # We tell super() that it's a replacement as we don't want it
+        # to copy the nodes, as we're about to do that with an
+        # immutable Dict.
+        super().__init__(version.zone, True)
+        # set the right id!
+        self.id = version.id
+        # Make changed nodes immutable
+        for name in version.changed:
+            node = version.nodes.get(name)
+            # it might not exist if we deleted it in the version
+            if node:
+                version.nodes[name] = ImmutableNode(node)
+        self.nodes = dns.immutable.Dict(version.nodes, True)
+
+
+# A node with a version id.
+
+class Node(dns.node.Node):
+    __slots__ = ['id']
+
+    def __init__(self):
+        super().__init__()
+        # A proper id will get set by the Version
+        self.id = 0
+
+
+# It would be nice if this were a subclass of Node (just above) but it's
+# less code duplication this way as we inherit all of the method disabling
+# code.
+
+@dns.immutable.immutable
+class ImmutableNode(dns.node.ImmutableNode):
+    __slots__ = ['id']
+
+    def __init__(self, node):
+        super().__init__(node)
+        self.id = node.id
+
+
+class Zone(dns.zone.Zone):
+
+    __slots__ = ['versions', '_write_txn', '_write_waiters', '_write_event',
+                 '_pruning_policy']
+
+    node_factory = Node
+
+    def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
+                 pruning_policy=None):
+        """Initialize a versioned zone object.
+
+        *origin* is the origin of the zone.  It may be a ``dns.name.Name``,
+        a ``str``, or ``None``.  If ``None``, then the zone's origin will
+        be set by the first ``$ORIGIN`` line in a masterfile.
+
+        *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
+
+        *relativize*, a ``bool``, determine's whether domain names are
+        relativized to the zone's origin.  The default is ``True``.
+
+        *pruning policy*, a function taking a `Version` and returning
+        a `bool`, or `None`.  Should the version be pruned?  If `None`,
+        the default policy, which retains one version is used.
+        """
+        super().__init__(origin, rdclass, relativize)
+        self.versions = collections.deque()
+        self.version_lock = _threading.Lock()
+        if pruning_policy is None:
+            self._pruning_policy = self._default_pruning_policy
+        else:
+            self._pruning_policy = pruning_policy
+        self._write_txn = None
+        self._write_event = None
+        self._write_waiters = collections.deque()
+        self._commit_version_unlocked(WritableVersion(self), origin)
+
+    def reader(self):
+        with self.version_lock:
+            return Transaction(False, self, self.versions[-1])
+
+    def writer(self, replacement=False):
+        event = None
+        while True:
+            with self.version_lock:
+                # Checking event == self._write_event ensures that either
+                # no one was waiting before we got lucky and found no write
+                # txn, or we were the one who was waiting and got woken up.
+                # This prevents "taking cuts" when creating a write txn.
+                if self._write_txn is None and event == self._write_event:
+                    # Creating the transaction defers version setup
+                    # (i.e.  copying the nodes dictionary) until we
+                    # give up the lock, so that we hold the lock as
+                    # short a time as possible.  This is why we call
+                    # _setup_version() below.
+                    self._write_txn = Transaction(replacement, self)
+                    # give up our exclusive right to make a Transaction
+                    self._write_event = None
+                    break
+                # Someone else is writing already, so we will have to
+                # wait, but we want to do the actual wait outside the
+                # lock.
+                event = _threading.Event()
+                self._write_waiters.append(event)
+            # wait (note we gave up the lock!)
+            #
+            # We only wake one sleeper at a time, so it's important
+            # that no event waiter can exit this method (e.g. via
+            # cancelation) without returning a transaction or waking
+            # someone else up.
+            #
+            # This is not a problem with Threading module threads as
+            # they cannot be canceled, but could be an issue with trio
+            # or curio tasks when we do the async version of writer().
+            # I.e. we'd need to do something like:
+            #
+            # try:
+            #     event.wait()
+            # except trio.Cancelled:
+            #     with self.version_lock:
+            #         self._maybe_wakeup_one_waiter_unlocked()
+            #     raise
+            #
+            event.wait()
+        # Do the deferred version setup.
+        self._write_txn._setup_version()
+        return self._write_txn
+
+    def _maybe_wakeup_one_waiter_unlocked(self):
+        if len(self._write_waiters) > 0:
+            self._write_event = self._write_waiters.popleft()
+            self._write_event.set()
+
+    # pylint: disable=unused-argument
+    def _default_pruning_policy(self, zone, version):
+        return True
+    # pylint: enable=unused-argument
+
+    def _prune_versions_unlocked(self):
+        while len(self.versions) > 1 and \
+              self._pruning_policy(self, self.versions[0]):
+            self.versions.popleft()
+
+    def set_max_versions(self, max_versions):
+        """Set a pruning policy that retains up to the specified number
+        of versions
+        """
+        if max_versions is not None and max_versions < 1:
+            raise ValueError('max versions must be at least 1')
+        if max_versions is None:
+            def policy(*_):
+                return False
+        else:
+            def policy(zone, _):
+                return len(zone.versions) > max_versions
+        self.set_pruning_policy(policy)
+
+    def set_pruning_policy(self, policy):
+        """Set the pruning policy for the zone.
+
+        The *policy* function takes a `Version` and returns `True` if
+        the version should be pruned, and `False` otherwise.  `None`
+        may also be specified for policy, in which case the default policy
+        is used.
+
+        Pruning checking proceeds from the least version and the first
+        time the function returns `False`, the checking stops.  I.e. the
+        retained versions are always a consecutive sequence.
+        """
+        if policy is None:
+            policy = self._default_pruning_policy
+        with self.version_lock:
+            self._pruning_policy = policy
+            self._prune_versions_unlocked()
+
+    def _commit_version_unlocked(self, version, origin):
+        self.versions.append(version)
+        self._prune_versions_unlocked()
+        self.nodes = version.nodes
+        if self.origin is None:
+            self.origin = origin
+        self._write_txn = None
+        self._maybe_wakeup_one_waiter_unlocked()
+
+    def _commit_version(self, version, origin):
+        with self.version_lock:
+            self._commit_version_unlocked(version, origin)
+
+    def find_node(self, name, create=False):
+        if create:
+            raise UseTransaction
+        return super().find_node(name)
+
+    def delete_node(self, name):
+        raise UseTransaction
+
+    def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
+                      create=False):
+        if create:
+            raise UseTransaction
+        rdataset = super().find_rdataset(name, rdtype, covers)
+        return dns.rdataset.ImmutableRdataset(rdataset)
+
+    def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
+                     create=False):
+        if create:
+            raise UseTransaction
+        rdataset = super().get_rdataset(name, rdtype, covers)
+        return dns.rdataset.ImmutableRdataset(rdataset)
+
+    def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
+        raise UseTransaction
+
+    def replace_rdataset(self, name, replacement):
+        raise UseTransaction
+
+
+class Transaction(dns.transaction.Transaction):
+
+    def __init__(self, replacement, zone, version=None):
+        read_only = version is not None
+        super().__init__(replacement, read_only)
+        self.zone = zone
+        self.version = version
+
+    def _setup_version(self):
+        assert self.version is None
+        self.version = WritableVersion(self.zone, self.replacement)
+
+    def _get_rdataset(self, name, rdclass, rdtype, covers):
+        if rdclass != self.zone.rdclass:
+            raise ValueError(f'class {rdclass} != ' +
+                             f'zone class {self.zone.rdclass}')
+        return self.version.get_rdataset(name, rdtype, covers)
+
+    def _put_rdataset(self, name, rdataset):
+        assert not self.read_only
+        if rdataset.rdclass != self.zone.rdclass:
+            raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
+                             f'zone class {self.zone.rdclass}')
+        self.version.put_rdataset(name, rdataset)
+
+    def _delete_name(self, name):
+        assert not self.read_only
+        self.version.delete_node(name)
+
+    def _delete_rdataset(self, name, rdclass, rdtype, covers):
+        assert not self.read_only
+        self.version.delete_rdataset(name, rdtype, covers)
+
+    def _name_exists(self, name):
+        return self.version.get_node(name) is not None
+
+    def _end_transaction(self, commit):
+        if self.read_only:
+            return
+        if commit and len(self.version.changed) > 0:
+            self.zone._commit_version(ImmutableVersion(self.version),
+                                      self.version.origin)
+
+    def _set_origin(self, origin):
+        if self.version.origin is None:
+            self.version.origin = origin
+
+    def _iterate_rdatasets(self):
+        for (name, node) in self.version.items():
+            for rdataset in node:
+                yield (name, rdataset)
index d5bb30583f94d4f1e73725768d10edd3eaaa775a..2ca9bc212ffe6462de5ef098155333b8410df226 100644 (file)
 import contextlib
 import io
 import os
-import re
-import sys
 
 import dns.exception
+import dns.masterfile
 import dns.name
 import dns.node
 import dns.rdataclass
@@ -32,6 +31,7 @@ import dns.rdata
 import dns.rdtypes.ANY.SOA
 import dns.rrset
 import dns.tokenizer
+import dns.transaction
 import dns.ttl
 import dns.grange
 
@@ -56,7 +56,7 @@ class UnknownOrigin(BadZone):
     """The DNS zone's origin is unknown."""
 
 
-class Zone:
+class Zone(dns.transaction.TransactionManager):
 
     """A DNS zone.
 
@@ -642,415 +642,108 @@ class Zone:
         if self.get_rdataset(name, dns.rdatatype.NS) is None:
             raise NoNS
 
+    def reader(self):
+        return Transaction(False, True, self)
 
-class _MasterReader:
-
-    """Read a DNS master file
-
-    @ivar tok: The tokenizer
-    @type tok: dns.tokenizer.Tokenizer object
-    @ivar last_ttl: The last seen explicit TTL for an RR
-    @type last_ttl: int
-    @ivar last_ttl_known: Has last TTL been detected
-    @type last_ttl_known: bool
-    @ivar default_ttl: The default TTL from a $TTL directive or SOA RR
-    @type default_ttl: int
-    @ivar default_ttl_known: Has default TTL been detected
-    @type default_ttl_known: bool
-    @ivar last_name: The last name read
-    @type last_name: dns.name.Name object
-    @ivar current_origin: The current origin
-    @type current_origin: dns.name.Name object
-    @ivar relativize: should names in the zone be relativized?
-    @type relativize: bool
-    @ivar zone: the zone
-    @type zone: dns.zone.Zone object
-    @ivar saved_state: saved reader state (used when processing $INCLUDE)
-    @type saved_state: list of (tokenizer, current_origin, last_name, file,
-    last_ttl, last_ttl_known, default_ttl, default_ttl_known) tuples.
-    @ivar current_file: the file object of the $INCLUDed file being parsed
-    (None if no $INCLUDE is active).
-    @ivar allow_include: is $INCLUDE allowed?
-    @type allow_include: bool
-    @ivar check_origin: should sanity checks of the origin node be done?
-    The default is True.
-    @type check_origin: bool
-    """
+    def writer(self, replacement=False):
+        return Transaction(replacement, False, self)
 
-    def __init__(self, tok, origin, rdclass, relativize, zone_factory=Zone,
-                 allow_include=False, check_origin=True):
-        if isinstance(origin, str):
-            origin = dns.name.from_text(origin)
-        self.tok = tok
-        self.current_origin = origin
-        self.relativize = relativize
-        self.last_ttl = 0
-        self.last_ttl_known = False
-        self.default_ttl = 0
-        self.default_ttl_known = False
-        self.last_name = self.current_origin
-        self.zone = zone_factory(origin, rdclass, relativize=relativize)
-        self.saved_state = []
-        self.current_file = None
-        self.allow_include = allow_include
-        self.check_origin = check_origin
-
-    def _eat_line(self):
-        while 1:
-            token = self.tok.get()
-            if token.is_eol_or_eof():
-                break
-
-    def _rr_line(self):
-        """Process one line from a DNS master file."""
-        # Name
-        if self.current_origin is None:
-            raise UnknownOrigin
-        token = self.tok.get(want_leading=True)
-        if not token.is_whitespace():
-            self.last_name = self.tok.as_name(token, self.current_origin)
-        else:
-            token = self.tok.get()
-            if token.is_eol_or_eof():
-                # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
-                return
-            self.tok.unget(token)
-        name = self.last_name
-        if not name.is_subdomain(self.zone.origin):
-            self._eat_line()
-            return
-        if self.relativize:
-            name = name.relativize(self.zone.origin)
-        token = self.tok.get()
-        if not token.is_identifier():
-            raise dns.exception.SyntaxError
 
-        # TTL
-        ttl = None
-        try:
-            ttl = dns.ttl.from_text(token.value)
-            self.last_ttl = ttl
-            self.last_ttl_known = True
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except dns.ttl.BadTTL:
-            if self.default_ttl_known:
-                ttl = self.default_ttl
-            elif self.last_ttl_known:
-                ttl = self.last_ttl
-
-        # Class
-        try:
-            rdclass = dns.rdataclass.from_text(token.value)
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except dns.exception.SyntaxError:
-            raise
-        except Exception:
-            rdclass = self.zone.rdclass
-        if rdclass != self.zone.rdclass:
-            raise dns.exception.SyntaxError("RR class is not zone's class")
-        # Type
-        try:
-            rdtype = dns.rdatatype.from_text(token.value)
-        except Exception:
-            raise dns.exception.SyntaxError(
-                "unknown rdatatype '%s'" % token.value)
-        n = self.zone.nodes.get(name)
-        if n is None:
-            n = self.zone.node_factory()
-            self.zone.nodes[name] = n
-        try:
-            rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
-                                     self.current_origin, self.relativize,
-                                     self.zone.origin)
-        except dns.exception.SyntaxError:
-            # Catch and reraise.
-            raise
-        except Exception:
-            # All exceptions that occur in the processing of rdata
-            # are treated as syntax errors.  This is not strictly
-            # correct, but it is correct almost all of the time.
-            # We convert them to syntax errors so that we can emit
-            # helpful filename:line info.
-            (ty, va) = sys.exc_info()[:2]
-            raise dns.exception.SyntaxError(
-                "caught exception {}: {}".format(str(ty), str(va)))
-
-        if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
-            # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
-            # TTL from the SOA minttl if no $TTL statement is present before the
-            # SOA is parsed.
-            self.default_ttl = rd.minimum
-            self.default_ttl_known = True
-            if ttl is None:
-                # if we didn't have a TTL on the SOA, set it!
-                ttl = rd.minimum
-
-        # TTL check.  We had to wait until now to do this as the SOA RR's
-        # own TTL can be inferred from its minimum.
-        if ttl is None:
-            raise dns.exception.SyntaxError("Missing default TTL value")
-
-        covers = rd.covers()
-        rds = n.find_rdataset(rdclass, rdtype, covers, True)
-        rds.add(rd, ttl)
-
-    def _parse_modify(self, side):
-        # Here we catch everything in '{' '}' in a group so we can replace it
-        # with ''.
-        is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$")
-        is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$")
-        is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$")
-        # Sometimes there are modifiers in the hostname. These come after
-        # the dollar sign. They are in the form: ${offset[,width[,base]]}.
-        # Make names
-        g1 = is_generate1.match(side)
-        if g1:
-            mod, sign, offset, width, base = g1.groups()
-            if sign == '':
-                sign = '+'
-        g2 = is_generate2.match(side)
-        if g2:
-            mod, sign, offset = g2.groups()
-            if sign == '':
-                sign = '+'
-            width = 0
-            base = 'd'
-        g3 = is_generate3.match(side)
-        if g3:
-            mod, sign, offset, width = g3.groups()
-            if sign == '':
-                sign = '+'
-            base = 'd'
-
-        if not (g1 or g2 or g3):
-            mod = ''
-            sign = '+'
-            offset = 0
-            width = 0
-            base = 'd'
-
-        if base != 'd':
-            raise NotImplementedError()
-
-        return mod, sign, offset, width, base
-
-    def _generate_line(self):
-        # range lhs [ttl] [class] type rhs [ comment ]
-        """Process one line containing the GENERATE statement from a DNS
-        master file."""
-        if self.current_origin is None:
-            raise UnknownOrigin
-
-        token = self.tok.get()
-        # Range (required)
-        try:
-            start, stop, step = dns.grange.from_text(token.value)
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except Exception:
-            raise dns.exception.SyntaxError
-
-        # lhs (required)
-        try:
-            lhs = token.value
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except Exception:
-            raise dns.exception.SyntaxError
-
-        # TTL
-        try:
-            ttl = dns.ttl.from_text(token.value)
-            self.last_ttl = ttl
-            self.last_ttl_known = True
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except dns.ttl.BadTTL:
-            if not (self.last_ttl_known or self.default_ttl_known):
-                raise dns.exception.SyntaxError("Missing default TTL value")
-            if self.default_ttl_known:
-                ttl = self.default_ttl
-            elif self.last_ttl_known:
-                ttl = self.last_ttl
-        # Class
-        try:
-            rdclass = dns.rdataclass.from_text(token.value)
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except dns.exception.SyntaxError:
-            raise dns.exception.SyntaxError
-        except Exception:
-            rdclass = self.zone.rdclass
+class Transaction(dns.transaction.Transaction):
+
+    _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY,
+                                              dns.rdatatype.ANY)
+
+    def __init__(self, replacement, read_only, zone):
+        super().__init__(replacement, read_only)
+        self.zone = zone
+        self.rdatasets = {}
+
+    def _get_rdataset(self, name, rdclass, rdtype, covers):
         if rdclass != self.zone.rdclass:
-            raise dns.exception.SyntaxError("RR class is not zone's class")
-        # Type
-        try:
-            rdtype = dns.rdatatype.from_text(token.value)
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except Exception:
-            raise dns.exception.SyntaxError("unknown rdatatype '%s'" %
-                                            token.value)
-
-        # rhs (required)
-        rhs = token.value
-
-        # The code currently only supports base 'd', so the last value
-        # in the tuple _parse_modify returns is ignored
-        lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs)
-        rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs)
-        for i in range(start, stop + 1, step):
-            # +1 because bind is inclusive and python is exclusive
-
-            if lsign == '+':
-                lindex = i + int(loffset)
-            elif lsign == '-':
-                lindex = i - int(loffset)
-
-            if rsign == '-':
-                rindex = i - int(roffset)
-            elif rsign == '+':
-                rindex = i + int(roffset)
-
-            lzfindex = str(lindex).zfill(int(lwidth))
-            rzfindex = str(rindex).zfill(int(rwidth))
-
-            name = lhs.replace('$%s' % (lmod), lzfindex)
-            rdata = rhs.replace('$%s' % (rmod), rzfindex)
-
-            self.last_name = dns.name.from_text(name, self.current_origin,
-                                                self.tok.idna_codec)
-            name = self.last_name
-            if not name.is_subdomain(self.zone.origin):
-                self._eat_line()
-                return
-            if self.relativize:
-                name = name.relativize(self.zone.origin)
-
-            n = self.zone.nodes.get(name)
-            if n is None:
-                n = self.zone.node_factory()
-                self.zone.nodes[name] = n
-            try:
-                rd = dns.rdata.from_text(rdclass, rdtype, rdata,
-                                         self.current_origin, self.relativize,
-                                         self.zone.origin)
-            except dns.exception.SyntaxError:
-                # Catch and reraise.
-                raise
-            except Exception:
-                # All exceptions that occur in the processing of rdata
-                # are treated as syntax errors.  This is not strictly
-                # correct, but it is correct almost all of the time.
-                # We convert them to syntax errors so that we can emit
-                # helpful filename:line info.
-                (ty, va) = sys.exc_info()[:2]
-                raise dns.exception.SyntaxError("caught exception %s: %s" %
-                                                (str(ty), str(va)))
-
-            covers = rd.covers()
-            rds = n.find_rdataset(rdclass, rdtype, covers, True)
-            rds.add(rd, ttl)
-
-    def read(self):
-        """Read a DNS master file and build a zone object.
-
-        @raises dns.zone.NoSOA: No SOA RR was found at the zone origin
-        @raises dns.zone.NoNS: No NS RRset was found at the zone origin
-        """
+            raise ValueError(f'class {rdclass} != ' +
+                             f'zone class {self.zone.rdclass}')
+        rdataset = self.rdatasets.get((name, rdtype, covers))
+        if rdataset is self._deleted_rdataset:
+            return None
+        elif rdataset is None:
+            rdataset = self.zone.get_rdataset(name, rdtype, covers)
+        return rdataset
 
+    def _put_rdataset(self, name, rdataset):
+        assert not self.read_only
+        self.zone._validate_name(name)
+        if rdataset.rdclass != self.zone.rdclass:
+            raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
+                             f'zone class {self.zone.rdclass}')
+        self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+
+    def _delete_name(self, name):
+        assert not self.read_only
+        # First remove any changes involving the name
+        remove = []
+        for key in self.rdatasets:
+            if key[0] == name:
+                remove.append(key)
+        if len(remove) > 0:
+            for key in remove:
+                del self.rdatasets[key]
+        # Next add deletion records for any rdatasets matching the
+        # name in the zone
+        node = self.zone.get_node(name)
+        if node is not None:
+            for rdataset in node.rdatasets:
+                self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
+                    self._deleted_rdataset
+
+    def _delete_rdataset(self, name, rdclass, rdtype, covers):
+        assert not self.read_only
+        # The high-level code always does a _get_rdataset() before any
+        # situation where it would call _delete_rdataset(), so we don't
+        # need to check if rdclass != self.zone.rdclass.
         try:
-            while 1:
-                token = self.tok.get(True, True)
-                if token.is_eof():
-                    if self.current_file is not None:
-                        self.current_file.close()
-                    if len(self.saved_state) > 0:
-                        (self.tok,
-                         self.current_origin,
-                         self.last_name,
-                         self.current_file,
-                         self.last_ttl,
-                         self.last_ttl_known,
-                         self.default_ttl,
-                         self.default_ttl_known) = self.saved_state.pop(-1)
-                        continue
-                    break
-                elif token.is_eol():
-                    continue
-                elif token.is_comment():
-                    self.tok.get_eol()
-                    continue
-                elif token.value[0] == '$':
-                    c = token.value.upper()
-                    if c == '$TTL':
-                        token = self.tok.get()
-                        if not token.is_identifier():
-                            raise dns.exception.SyntaxError("bad $TTL")
-                        self.default_ttl = dns.ttl.from_text(token.value)
-                        self.default_ttl_known = True
-                        self.tok.get_eol()
-                    elif c == '$ORIGIN':
-                        self.current_origin = self.tok.get_name()
-                        self.tok.get_eol()
-                        if self.zone.origin is None:
-                            self.zone.origin = self.current_origin
-                    elif c == '$INCLUDE' and self.allow_include:
-                        token = self.tok.get()
-                        filename = token.value
-                        token = self.tok.get()
-                        if token.is_identifier():
-                            new_origin =\
-                                dns.name.from_text(token.value,
-                                                   self.current_origin,
-                                                   self.tok.idna_codec)
-                            self.tok.get_eol()
-                        elif not token.is_eol_or_eof():
-                            raise dns.exception.SyntaxError(
-                                "bad origin in $INCLUDE")
-                        else:
-                            new_origin = self.current_origin
-                        self.saved_state.append((self.tok,
-                                                 self.current_origin,
-                                                 self.last_name,
-                                                 self.current_file,
-                                                 self.last_ttl,
-                                                 self.last_ttl_known,
-                                                 self.default_ttl,
-                                                 self.default_ttl_known))
-                        self.current_file = open(filename, 'r')
-                        self.tok = dns.tokenizer.Tokenizer(self.current_file,
-                                                           filename)
-                        self.current_origin = new_origin
-                    elif c == '$GENERATE':
-                        self._generate_line()
-                    else:
-                        raise dns.exception.SyntaxError(
-                            "Unknown master file directive '" + c + "'")
-                    continue
-                self.tok.unget(token)
-                self._rr_line()
-        except dns.exception.SyntaxError as detail:
-            (filename, line_number) = self.tok.where()
-            if detail is None:
-                detail = "syntax error"
-            ex = dns.exception.SyntaxError(
-                "%s:%d: %s" % (filename, line_number, detail))
-            tb = sys.exc_info()[2]
-            raise ex.with_traceback(tb) from None
-
-        # Now that we're done reading, do some basic checking of the zone.
-        if self.check_origin:
-            self.zone.check_origin()
+            del self.rdatasets[(name, rdtype, covers)]
+        except KeyError:
+            pass
+        rdataset = self.zone.get_rdataset(name, rdtype, covers)
+        if rdataset is not None:
+            self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
+                self._deleted_rdataset
+
+    def _name_exists(self, name):
+        for key, rdataset in self.rdatasets.items():
+            if key[0] == name:
+                if rdataset != self._deleted_rdataset:
+                    return True
+                else:
+                    return None
+        self.zone._validate_name(name)
+        if self.zone.get_node(name):
+            return True
+        return False
+
+    def _end_transaction(self, commit):
+        if commit and not self.read_only:
+            for (name, rdtype, covers), rdataset in \
+                self.rdatasets.items():
+                if rdataset is self._deleted_rdataset:
+                    self.zone.delete_rdataset(name, rdtype, covers)
+                else:
+                    self.zone.replace_rdataset(name, rdataset)
+
+    def _set_origin(self, origin):
+        if self.zone.origin is None:
+            self.zone.origin = origin
+
+    def _iterate_rdatasets(self):
+        # Expensive but simple!  Use a versioned zone for efficient txn
+        # iteration.
+        rdatasets = {}
+        for (name, rdataset) in self.zone.iterate_rdatasets():
+            rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+        rdatasets.update(self.rdatasets)
+        for (name, _, _), rdataset in rdatasets.items():
+            yield (name, rdataset)
 
 
 def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
@@ -1103,12 +796,20 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
 
     if filename is None:
         filename = '<string>'
-    tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
-    reader = _MasterReader(tok, origin, rdclass, relativize, zone_factory,
-                           allow_include=allow_include,
-                           check_origin=check_origin)
-    reader.read()
-    return reader.zone
+    zone = zone_factory(origin, rdclass, relativize=relativize)
+    with zone.writer(True) as txn:
+        tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
+        reader = dns.masterfile.Reader(tok, origin, rdclass, relativize, txn,
+                                       allow_include=allow_include)
+        try:
+            reader.read()
+        except dns.masterfile.UnknownOrigin:
+            # for backwards compatibility
+            raise dns.zone.UnknownOrigin
+    # Now that we're done reading, do some basic checking of the zone.
+    if check_origin:
+        zone.check_origin()
+    return zone
 
 
 def from_file(f, origin=None, rdclass=dns.rdataclass.IN,
index 0385fc9175eced3814a922107351d581e2b8844f..1a70e3d41147d895ffecd71aaef4cac25b4755b0 100644 (file)
@@ -3,20 +3,30 @@
 import unittest
 
 import dns.immutable
+import dns._immutable_attr
+
+try:
+    import dns._immutable_ctx as immutable_ctx
+    _have_contextvars = True
+except ImportError:
+    _have_contextvars = False
+
+    class immutable_ctx:
+        pass
 
 
 class ImmutableTestCase(unittest.TestCase):
 
-    def test_ImmutableDict_hash(self):
-        d1 = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
-        d2 = dns.immutable.ImmutableDict({'b': 2, 'a': 1})
+    def test_immutable_dict_hash(self):
+        d1 = dns.immutable.Dict({'a': 1, 'b': 2})
+        d2 = dns.immutable.Dict({'b': 2, 'a': 1})
         d3 = {'b': 2, 'a': 1}
         self.assertEqual(d1, d2)
         self.assertEqual(d2, d3)
         self.assertEqual(hash(d1), hash(d2))
 
-    def test_ImmutableDict_hash_cache(self):
-        d = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
+    def test_immutable_dict_hash_cache(self):
+        d = dns.immutable.Dict({'a': 1, 'b': 2})
         self.assertEqual(d._hash, None)
         h1 = hash(d)
         self.assertEqual(d._hash, h1)
@@ -30,11 +40,121 @@ class ImmutableTestCase(unittest.TestCase):
             ((1, [2], 3), (1, (2,), 3)),
             ([1, 2, 3], (1, 2, 3)),
             ([1, {'a': [1, 2]}],
-             (1, dns.immutable.ImmutableDict({'a': (1, 2)}))),
+             (1, dns.immutable.Dict({'a': (1, 2)}))),
             ('hi', 'hi'),
             (b'hi', b'hi'),
         )
         for input, expected in items:
             self.assertEqual(dns.immutable.constify(input), expected)
         self.assertIsInstance(dns.immutable.constify({'a': 1}),
-                              dns.immutable.ImmutableDict)
+                              dns.immutable.Dict)
+
+
+class DecoratorTestCase(unittest.TestCase):
+
+    immutable_module = dns._immutable_attr
+
+    def make_classes(self):
+        class A:
+            def __init__(self, a, akw=10):
+                self.a = a
+                self.akw = akw
+
+        class B(A):
+            def __init__(self, a, b):
+                super().__init__(a, akw=20)
+                self.b = b
+        B = self.immutable_module.immutable(B)
+
+        # note C is immutable by inheritance
+        class C(B):
+            def __init__(self, a, b, c):
+                super().__init__(a, b)
+                self.c = c
+        C = self.immutable_module.immutable(C)
+
+        class SA:
+            __slots__ = ('a', 'akw')
+            def __init__(self, a, akw=10):
+                self.a = a
+                self.akw = akw
+
+        class SB(A):
+            __slots__ = ('b')
+            def __init__(self, a, b):
+                super().__init__(a, akw=20)
+                self.b = b
+        SB = self.immutable_module.immutable(SB)
+
+        # note SC is immutable by inheritance and has no slots of its own
+        class SC(SB):
+            def __init__(self, a, b, c):
+                super().__init__(a, b)
+                self.c = c
+        SC = self.immutable_module.immutable(SC)
+
+        return ((A, B, C), (SA, SB, SC))
+
+    def test_basic(self):
+        for A, B, C in self.make_classes():
+            a = A(1)
+            self.assertEqual(a.a, 1)
+            self.assertEqual(a.akw, 10)
+            b = B(11, 21)
+            self.assertEqual(b.a, 11)
+            self.assertEqual(b.akw, 20)
+            self.assertEqual(b.b, 21)
+            c = C(111, 211, 311)
+            self.assertEqual(c.a, 111)
+            self.assertEqual(c.akw, 20)
+            self.assertEqual(c.b, 211)
+            self.assertEqual(c.c, 311)
+            # changing A is ok!
+            a.a = 11
+            self.assertEqual(a.a, 11)
+            # changing B is not!
+            with self.assertRaises(TypeError):
+                b.a = 11
+            with self.assertRaises(TypeError):
+                del b.a
+
+    def test_constructor_deletes_attribute(self):
+        class A:
+            def __init__(self, a):
+                self.a = a
+                self.b = a
+                del self.b
+        A = self.immutable_module.immutable(A)
+        a = A(10)
+        self.assertEqual(a.a, 10)
+        self.assertFalse(hasattr(a, 'b'))
+
+    def test_no_collateral_damage(self):
+
+        # A and B are immutable but not related.  The magic that lets
+        # us write to immutable things while initializing B should not let
+        # B mess with A.
+
+        class A:
+            def __init__(self, a):
+                self.a = a
+        A = self.immutable_module.immutable(A)
+
+        class B:
+            def __init__(self, a, b):
+                self.b = a.a + b
+                # rudely attempt to mutate innocent immutable bystander 'a'
+                a.a = 1000
+        B = self.immutable_module.immutable(B)
+
+        a = A(10)
+        self.assertEqual(a.a, 10)
+        with self.assertRaises(TypeError):
+            B(a, 20)
+        self.assertEqual(a.a, 10)
+
+
+@unittest.skipIf(not _have_contextvars, "contextvars not available")
+class CtxDecoratorTestCase(DecoratorTestCase):
+
+    immutable_module = immutable_ctx
index a80d65042d0f50901d88930720cdf056308727d5..88b48400bd2cf19a571c83d2cc133330568ba76f 100644 (file)
@@ -122,5 +122,34 @@ class RdatasetTestCase(unittest.TestCase):
             '<DNS IN RRSIG(NSEC) rdataset:'))
 
 
+class ImmutableRdatasetTestCase(unittest.TestCase):
+
+    def test_basic(self):
+        rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2')
+        rd = dns.rdata.from_text('in', 'a', '10.0.0.3')
+        irds = dns.rdataset.ImmutableRdataset(rds)
+        with self.assertRaises(TypeError):
+            irds.update_ttl(100)
+        with self.assertRaises(TypeError):
+            irds.add(rd, 300)
+        with self.assertRaises(TypeError):
+            irds.union_update(rds)
+        with self.assertRaises(TypeError):
+            irds.intersection_update(rds)
+        with self.assertRaises(TypeError):
+            irds.update(rds)
+        with self.assertRaises(TypeError):
+            irds += rds
+        with self.assertRaises(TypeError):
+            irds -= rds
+        with self.assertRaises(TypeError):
+            irds &= rds
+        with self.assertRaises(TypeError):
+            irds |= rds
+        with self.assertRaises(TypeError):
+            del irds[0]
+        with self.assertRaises(TypeError):
+            irds.clear()
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
new file mode 100644 (file)
index 0000000..ed154fc
--- /dev/null
@@ -0,0 +1,451 @@
+import time
+
+import pytest
+
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdataset
+import dns.rrset
+import dns.transaction
+import dns.versioned
+import dns.zone
+
+class DB(dns.transaction.TransactionManager):
+    def __init__(self):
+        self.rdatasets = {}
+
+    def reader(self):
+        return Transaction(False, True, self)
+
+    def writer(self, replacement=False):
+        return Transaction(replacement, False, self)
+
+
+class Transaction(dns.transaction.Transaction):
+    def __init__(self, replacement, read_only, db):
+        super().__init__(replacement)
+        self.db = db
+        self.rdatasets = {}
+        self.read_only = read_only
+        if not replacement:
+            self.rdatasets.update(db.rdatasets)
+
+    def _get_rdataset(self, name, rdclass, rdtype, covers):
+        return self.rdatasets.get((name, rdclass, rdtype, covers))
+
+    def _put_rdataset(self, name, rdataset):
+        self.rdatasets[(name, rdataset.rdclass, rdataset.rdtype,
+                        rdataset.covers)] = rdataset
+
+    def _delete_name(self, name):
+        remove = []
+        for key in self.rdatasets.keys():
+            if key[0] == name:
+                remove.append(key)
+        if len(remove) > 0:
+            for key in remove:
+                del self.rdatasets[key]
+
+    def _delete_rdataset(self, name, rdclass, rdtype, covers):
+        del self.rdatasets[(name, rdclass, rdtype, covers)]
+
+    def _name_exists(self, name):
+        for key in self.rdatasets.keys():
+            if key[0] == name:
+                return True
+        return False
+
+    def _end_transaction(self, commit):
+        if commit:
+            self.db.rdatasets = self.rdatasets
+
+    def _set_origin(self, origin):
+        pass
+
+@pytest.fixture
+def db():
+    db = DB()
+    rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
+    db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] = rrset
+    return db
+
+def test_basic(db):
+    # successful txn
+    with db.writer() as txn:
+        rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                    '10.0.0.1', '10.0.0.2')
+        txn.add(rrset)
+        assert txn.name_exists(rrset.name)
+    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+        rrset
+    # rollback
+    with pytest.raises(Exception):
+        with db.writer() as txn:
+            rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                         '10.0.0.3', '10.0.0.4')
+            txn.add(rrset2)
+            raise Exception()
+    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+        rrset
+    with db.writer() as txn:
+        txn.delete(rrset.name)
+    assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+        is None
+
+def test_get(db):
+    with db.writer() as txn:
+        content = dns.name.from_text('content', None)
+        rdataset = txn.get(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+        assert rdataset is not None
+        assert rdataset[0].strings == (b'content',)
+        assert isinstance(rdataset, dns.rdataset.ImmutableRdataset)
+
+def test_add(db):
+    with db.writer() as txn:
+        rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                    '10.0.0.1', '10.0.0.2')
+        txn.add(rrset)
+        rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                     '10.0.0.3', '10.0.0.4')
+        txn.add(rrset2)
+    expected = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                   '10.0.0.1', '10.0.0.2',
+                                   '10.0.0.3', '10.0.0.4')
+    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+        expected
+
+def test_replacement(db):
+    with db.writer() as txn:
+        rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                    '10.0.0.1', '10.0.0.2')
+        txn.add(rrset)
+        rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+                                     '10.0.0.3', '10.0.0.4')
+        txn.replace(rrset2)
+    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+        rrset2
+
+def test_delete(db):
+    with db.writer() as txn:
+        txn.delete(dns.name.from_text('nonexistent', None))
+        content = dns.name.from_text('content', None)
+        content2 = dns.name.from_text('content2', None)
+        txn.delete(content)
+        assert not txn.name_exists(content)
+        txn.delete(content2, dns.rdataclass.IN, dns.rdatatype.TXT)
+        rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content')
+        txn.add(rrset)
+        assert txn.name_exists(content)
+        txn.delete(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+        assert not txn.name_exists(content)
+        rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content')
+        txn.delete(rrset)
+    content_keys = [k for k in db.rdatasets if k[0] == content]
+    assert len(content_keys) == 0
+
+def test_delete_exact(db):
+    with db.writer() as txn:
+        rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'bad-content')
+        with pytest.raises(dns.transaction.DeleteNotExact):
+            txn.delete_exact(rrset)
+        rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'bad-content')
+        with pytest.raises(dns.transaction.DeleteNotExact):
+            txn.delete_exact(rrset)
+        with pytest.raises(dns.transaction.DeleteNotExact):
+            txn.delete_exact(rrset.name)
+        with pytest.raises(dns.transaction.DeleteNotExact):
+            txn.delete_exact(rrset.name, dns.rdataclass.IN, dns.rdatatype.TXT)
+        rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
+        txn.delete_exact(rrset)
+    assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+        is None
+
+def test_parameter_forms(db):
+    with db.writer() as txn:
+        foo = dns.name.from_text('foo', None)
+        rdataset = dns.rdataset.from_text('in', 'a', 300,
+                                          '10.0.0.1', '10.0.0.2')
+        rdata1 = dns.rdata.from_text('in', 'a', '10.0.0.3')
+        rdata2 = dns.rdata.from_text('in', 'a', '10.0.0.4')
+        txn.add(foo, rdataset)
+        txn.add(foo, 100, rdata1)
+        txn.add(foo, 30, rdata2)
+    expected = dns.rrset.from_text('foo', 30, 'in', 'a',
+                                   '10.0.0.1', '10.0.0.2',
+                                   '10.0.0.3', '10.0.0.4')
+    assert db.rdatasets[(foo, rdataset.rdclass, rdataset.rdtype, 0)] == \
+        expected
+    with db.writer() as txn:
+        txn.delete(foo, rdataset)
+        txn.delete(foo, rdata1)
+        txn.delete(foo, rdata2)
+    assert db.rdatasets.get((foo, rdataset.rdclass, rdataset.rdtype, 0)) \
+        is None
+
+def test_bad_parameters(db):
+    with db.writer() as txn:
+        with pytest.raises(TypeError):
+            txn.add(1)
+        with pytest.raises(TypeError):
+            rrset = dns.rrset.from_text('bar', 300, 'in', 'txt', 'bar')
+            txn.add(rrset, 1)
+        with pytest.raises(ValueError):
+            foo = dns.name.from_text('foo', None)
+            rdata = dns.rdata.from_text('in', 'a', '10.0.0.3')
+            txn.add(foo, 0x80000000, rdata)
+        with pytest.raises(TypeError):
+            txn.add(foo)
+        with pytest.raises(TypeError):
+            txn.add()
+        with pytest.raises(TypeError):
+            txn.add(foo, 300)
+        with pytest.raises(TypeError):
+            txn.add(foo, 300, 'hi')
+        with pytest.raises(TypeError):
+            txn.add(foo, 'hi')
+        with pytest.raises(TypeError):
+            txn.delete()
+        with pytest.raises(TypeError):
+            txn.delete(1)
+
+example_text = """$TTL 3600
+$ORIGIN example.
+@ soa foo bar 1 2 3 4 5
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+$TTL 300
+$ORIGIN foo.example.
+bar mx 0 blaz
+"""
+
+example_text_output = """@ 3600 IN SOA foo bar 1 2 3 4 5
+@ 3600 IN NS ns1
+@ 3600 IN NS ns2
+@ 3600 IN NS ns3
+ns1 3600 IN A 10.0.0.1
+ns2 3600 IN A 10.0.0.2
+ns3 3600 IN A 10.0.0.3
+"""
+
+@pytest.fixture(params=[dns.zone.Zone, dns.versioned.Zone])
+def zone(request):
+    return dns.zone.from_text(example_text, zone_factory=request.param)
+
+def test_zone_basic(zone):
+    with zone.writer() as txn:
+        txn.delete(dns.name.from_text('bar.foo', None))
+        rd = dns.rdata.from_text('in', 'ns', 'ns3')
+        txn.add(dns.name.empty, 3600, rd)
+        rd = dns.rdata.from_text('in', 'a', '10.0.0.3')
+        txn.add(dns.name.from_text('ns3', None), 3600, rd)
+    output = zone.to_text()
+    assert output == example_text_output
+
+def test_zone_base_layer(zone):
+    with zone.writer() as txn:
+        # Get a set from the zone layer
+        rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+                           dns.rdatatype.NS, dns.rdatatype.NONE)
+        expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
+        assert rdataset == expected
+
+def test_zone_transaction_layer(zone):
+    with zone.writer() as txn:
+        # Make a change
+        rd = dns.rdata.from_text('in', 'ns', 'ns3')
+        txn.add(dns.name.empty, 3600, rd)
+        # Get a set from the transaction layer
+        expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3')
+        rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+                           dns.rdatatype.NS, dns.rdatatype.NONE)
+        assert rdataset == expected
+        assert txn.name_exists(dns.name.empty)
+        ns1 = dns.name.from_text('ns1', None)
+        assert txn.name_exists(ns1)
+        ns99 = dns.name.from_text('ns99', None)
+        assert not txn.name_exists(ns99)
+
+def test_zone_add_and_delete(zone):
+    with zone.writer() as txn:
+        a99 = dns.name.from_text('a99', None)
+        a100 = dns.name.from_text('a100', None)
+        a101 = dns.name.from_text('a101', None)
+        rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+        txn.add(a99, rds)
+        txn.delete(a99, dns.rdataclass.IN, dns.rdatatype.A)
+        txn.delete(a100, dns.rdataclass.IN, dns.rdatatype.A)
+        txn.delete(a101)
+        assert not txn.name_exists(a99)
+        assert not txn.name_exists(a100)
+        assert not txn.name_exists(a101)
+        ns1 = dns.name.from_text('ns1', None)
+        txn.delete(ns1, dns.rdataclass.IN, dns.rdatatype.A)
+        assert not txn.name_exists(ns1)
+    with zone.writer() as txn:
+        txn.add(a99, rds)
+        txn.delete(a99)
+        assert not txn.name_exists(a99)
+    with zone.writer() as txn:
+        txn.add(a100, rds)
+        txn.delete(a99)
+        assert not txn.name_exists(a99)
+        assert txn.name_exists(a100)
+
+def test_zone_get_deleted(zone):
+    with zone.writer() as txn:
+        print(zone.to_text())
+        ns1 = dns.name.from_text('ns1', None)
+        assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is not None
+        txn.delete(ns1)
+        assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is None
+        ns2 = dns.name.from_text('ns2', None)
+        txn.delete(ns2, dns.rdataclass.IN, dns.rdatatype.A)
+        assert txn.get(ns2, dns.rdataclass.IN, dns.rdatatype.A) is None
+
+def test_zone_bad_class(zone):
+    with zone.writer() as txn:
+        with pytest.raises(ValueError):
+            txn.get(dns.name.empty, dns.rdataclass.CH,
+                    dns.rdatatype.NS, dns.rdatatype.NONE)
+        rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2')
+        with pytest.raises(ValueError):
+            txn.add(dns.name.empty, rds)
+        with pytest.raises(ValueError):
+            txn.replace(dns.name.empty, rds)
+        with pytest.raises(ValueError):
+            txn.delete(dns.name.empty, rds)
+        with pytest.raises(ValueError):
+            txn.delete(dns.name.empty, dns.rdataclass.CH,
+                       dns.rdatatype.NS, dns.rdatatype.NONE)
+
+def test_set_serial(zone):
+    # basic
+    with zone.writer() as txn:
+        txn.set_serial()
+    rdataset = zone.find_rdataset('@', 'soa')
+    assert rdataset[0].serial == 2
+    # max
+    with zone.writer() as txn:
+        txn.set_serial(0, 0xffffffff)
+    rdataset = zone.find_rdataset('@', 'soa')
+    assert rdataset[0].serial == 0xffffffff
+    # wraparound to 1
+    with zone.writer() as txn:
+        txn.set_serial()
+    rdataset = zone.find_rdataset('@', 'soa')
+    assert rdataset[0].serial == 1
+    # trying to set to zero sets to 1
+    with zone.writer() as txn:
+        txn.set_serial(0, 0)
+    rdataset = zone.find_rdataset('@', 'soa')
+    assert rdataset[0].serial == 1
+    with pytest.raises(KeyError):
+        with zone.writer() as txn:
+            txn.set_serial(name=dns.name.from_text('unknown', None))
+
+class ExpectedException(Exception):
+    pass
+
+def test_zone_rollback(zone):
+    try:
+        with zone.writer() as txn:
+            a99 = dns.name.from_text('a99.example.')
+            rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+            txn.add(a99, rds)
+            assert txn.name_exists(a99)
+            raise ExpectedException
+    except ExpectedException:
+        pass
+    assert not zone.get_node(a99)
+
+def test_zone_ooz_name(zone):
+    with zone.writer() as txn:
+        with pytest.raises(KeyError):
+            a99 = dns.name.from_text('a99.not-example.')
+            assert txn.name_exists(a99)
+
+def test_zone_iteration(zone):
+    expected = {}
+    for (name, rdataset) in zone.iterate_rdatasets():
+        expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+    with zone.writer() as txn:
+        actual = {}
+        for (name, rdataset) in txn:
+            actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+    assert actual == expected
+
+@pytest.fixture
+def vzone():
+    return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone)
+
+def test_vzone_read_only(vzone):
+    with vzone.reader() as txn:
+        rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+                           dns.rdatatype.NS, dns.rdatatype.NONE)
+        expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
+        assert rdataset == expected
+        with pytest.raises(dns.transaction.ReadOnly):
+            txn.replace(dns.name.empty, expected)
+
+def test_vzone_multiple_versions(vzone):
+    assert len(vzone.versions) == 1
+    vzone.set_max_versions(None)  # unlimited!
+    with vzone.writer() as txn:
+        txn.set_serial()
+    with vzone.writer() as txn:
+        txn.set_serial()
+    with vzone.writer() as txn:
+        txn.set_serial()
+    rdataset = vzone.find_rdataset('@', 'soa')
+    assert rdataset[0].serial == 4
+    assert len(vzone.versions) == 4
+    vzone.set_max_versions(2)
+    assert len(vzone.versions) == 2
+    # The ones that survived should be 3 and 4
+    rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+                                              dns.rdatatype.NONE)
+    assert rdataset[0].serial == 3
+    rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+                                              dns.rdatatype.NONE)
+    assert rdataset[0].serial == 4
+    with pytest.raises(ValueError):
+        vzone.set_max_versions(0)
+
+try:
+    import threading
+
+    one_got_lock = threading.Event()
+
+    def run_one(zone):
+        with zone.writer() as txn:
+            one_got_lock.set()
+            # wait until two blocks
+            while len(zone._write_waiters) == 0:
+                time.sleep(0.01)
+            rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.98')
+            txn.add('a98', rds)
+
+    def run_two(zone):
+        # wait until one has the lock so we know we will block if we
+        # get the call done before the sleep in one completes
+        one_got_lock.wait()
+        with zone.writer() as txn:
+            rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+            txn.add('a99', rds)
+
+    def test_vzone_concurrency(vzone):
+        t1 = threading.Thread(target=run_one, args=(vzone,))
+        t1.start()
+        t2 = threading.Thread(target=run_two, args=(vzone,))
+        t2.start()
+        t1.join()
+        t2.join()
+        with vzone.reader() as txn:
+            assert txn.name_exists('a98')
+            assert txn.name_exists('a99')
+
+except ImportError:  # pragma: no cover
+    pass