]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Btree Zones (#1215)
authorBob Halley <halley@dnspython.org>
Sun, 10 Aug 2025 21:16:34 +0000 (14:16 -0700)
committerGitHub <noreply@github.com>
Sun, 10 Aug 2025 21:16:34 +0000 (14:16 -0700)
* Add BTree zone

dns/btree.py [new file with mode: 0644]
dns/btreezone.py [new file with mode: 0644]
dns/versioned.py
dns/zone.py
tests/test_btree.py [new file with mode: 0644]
tests/test_btreezone.py [new file with mode: 0644]
tests/test_zone.py

diff --git a/dns/btree.py b/dns/btree.py
new file mode 100644 (file)
index 0000000..f544edd
--- /dev/null
@@ -0,0 +1,850 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""
+A BTree in the style of Cormen, Leiserson, and Rivest's "Algorithms" book, with
+copy-on-write node updates, cursors, and optional space optimization for mostly-in-order
+insertion.
+"""
+
+from collections.abc import MutableMapping, MutableSet
+from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, cast
+
+DEFAULT_T = 127
+
+KT = TypeVar("KT")  # the type of a key in Element
+
+
+class Element(Generic[KT]):
+    """All items stored in the BTree are Elements."""
+
+    def key(self) -> KT:
+        """The key for this element; the returned type must implement comparison."""
+        raise NotImplementedError  # pragma: no cover
+
+
+ET = TypeVar("ET", bound=Element)  # the type of a value in a _KV
+
+
+def _MIN(t: int) -> int:
+    """The minimum number of keys in a non-root node for a BTree with the specified
+    ``t``
+    """
+    return t - 1
+
+
+def _MAX(t: int) -> int:
+    """The maximum number of keys in node for a BTree with the specified ``t``"""
+    return 2 * t - 1
+
+
+class _Creator:
+    """A _Creator class instance is used as a unique id for the BTree which created
+    a node.
+
+    We use a dedicated creator rather than just a BTree reference to avoid circularity
+    that would complicate GC.
+    """
+
+    def __str__(self):  # pragma: no cover
+        return f"{id(self):x}"
+
+
+class _Node(Generic[KT, ET]):
+    """A Node in the BTree.
+
+    A Node (leaf or internal) of the BTree.
+    """
+
+    __slots__ = ["t", "creator", "is_leaf", "elts", "children"]
+
+    def __init__(self, t: int, creator: _Creator, is_leaf: bool):
+        assert t >= 3
+        self.t = t
+        self.creator = creator
+        self.is_leaf = is_leaf
+        self.elts: list[ET] = []
+        self.children: list[_Node[KT, ET]] = []
+
+    def is_maximal(self) -> bool:
+        """Does this node have the maximal number of keys?"""
+        assert len(self.elts) <= _MAX(self.t)
+        return len(self.elts) == _MAX(self.t)
+
+    def is_minimal(self) -> bool:
+        """Does this node have the minimal number of keys?"""
+        assert len(self.elts) >= _MIN(self.t)
+        return len(self.elts) == _MIN(self.t)
+
+    def search_in_node(self, key: KT) -> tuple[int, bool]:
+        """Get the index of the ``Element`` matching ``key`` or the index of its
+        least successor.
+
+        Returns a tuple of the index and an ``equal`` boolean that is ``True`` iff.
+        the key was found.
+        """
+        l = len(self.elts)
+        if l > 0 and key > self.elts[l - 1].key():
+            # This is optimizing near in-order insertion.
+            return l, False
+        l = 0
+        i = len(self.elts)
+        r = i - 1
+        equal = False
+        while l <= r:
+            m = (l + r) // 2
+            k = self.elts[m].key()
+            if key == k:
+                i = m
+                equal = True
+                break
+            elif key < k:
+                i = m
+                r = m - 1
+            else:
+                l = m + 1
+        return i, equal
+
+    def maybe_cow_child(self, index: int) -> "_Node[KT, ET]":
+        assert not self.is_leaf
+        child = self.children[index]
+        cloned = child.maybe_cow(self.creator)
+        if cloned:
+            self.children[index] = cloned
+            return cloned
+        else:
+            return child
+
+    def _get_node(self, key: KT) -> Tuple[Optional["_Node[KT, ET]"], int]:
+        """Get the node associated with key and its index, doing
+        copy-on-write if we have to descend.
+
+        Returns a tuple of the node and the index, or the tuple ``(None, 0)``
+        if the key was not found.
+        """
+        i, equal = self.search_in_node(key)
+        if equal:
+            return (self, i)
+        elif self.is_leaf:
+            return (None, 0)
+        else:
+            child = self.maybe_cow_child(i)
+            return child._get_node(key)
+
+    def get(self, key: KT) -> Optional[ET]:
+        """Get the element associated with *key* or return ``None``"""
+        i, equal = self.search_in_node(key)
+        if equal:
+            return self.elts[i]
+        elif self.is_leaf:
+            return None
+        else:
+            return self.children[i].get(key)
+
+    def optimize_in_order_insertion(self, index: int) -> None:
+        """Try to minimize the number of Nodes in a BTree where the insertion
+        is done in-order or close to it, by stealing as much as we can from our
+        right sibling.
+
+        If we don't do this, then an in-order insertion will produce a BTree
+        where most of the nodes are minimal.
+        """
+        if index == 0:
+            return
+        left = self.children[index - 1]
+        if len(left.elts) == _MAX(self.t):
+            return
+        left = self.maybe_cow_child(index - 1)
+        while len(left.elts) < _MAX(self.t):
+            if not left.try_right_steal(self, index - 1):
+                break
+
+    def insert_nonfull(self, element: ET, in_order: bool) -> Optional[ET]:
+        assert not self.is_maximal()
+        while True:
+            key = element.key()
+            i, equal = self.search_in_node(key)
+            if equal:
+                # replace
+                old = self.elts[i]
+                self.elts[i] = element
+                return old
+            elif self.is_leaf:
+                self.elts.insert(i, element)
+                return None
+            else:
+                child = self.maybe_cow_child(i)
+                if child.is_maximal():
+                    self.adopt(*child.split())
+                    # Splitting might result in our target moving to us, so
+                    # search again.
+                    continue
+                oelt = child.insert_nonfull(element, in_order)
+                if in_order:
+                    self.optimize_in_order_insertion(i)
+                return oelt
+
+    def split(self) -> tuple["_Node[KT, ET]", ET, "_Node[KT, ET]"]:
+        """Split a maximal node into two minimal ones and a central element."""
+        assert self.is_maximal()
+        right = self.__class__(self.t, self.creator, self.is_leaf)
+        right.elts = list(self.elts[_MIN(self.t) + 1 :])
+        middle = self.elts[_MIN(self.t)]
+        self.elts = list(self.elts[: _MIN(self.t)])
+        if not self.is_leaf:
+            right.children = list(self.children[_MIN(self.t) + 1 :])
+            self.children = list(self.children[: _MIN(self.t) + 1])
+        return self, middle, right
+
+    def try_left_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
+        """Try to steal from this Node's left sibling for balancing purposes.
+
+        Returns ``True`` if the theft was successful, or ``False`` if not.
+        """
+        if index != 0:
+            left = parent.children[index - 1]
+            if not left.is_minimal():
+                left = parent.maybe_cow_child(index - 1)
+                elt = parent.elts[index - 1]
+                parent.elts[index - 1] = left.elts.pop()
+                self.elts.insert(0, elt)
+                if not left.is_leaf:
+                    assert not self.is_leaf
+                    child = left.children.pop()
+                    self.children.insert(0, child)
+                return True
+        return False
+
+    def try_right_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
+        """Try to steal from this Node's right sibling for balancing purposes.
+
+        Returns ``True`` if the theft was successful, or ``False`` if not.
+        """
+        if index + 1 < len(parent.children):
+            right = parent.children[index + 1]
+            if not right.is_minimal():
+                right = parent.maybe_cow_child(index + 1)
+                elt = parent.elts[index]
+                parent.elts[index] = right.elts.pop(0)
+                self.elts.append(elt)
+                if not right.is_leaf:
+                    assert not self.is_leaf
+                    child = right.children.pop(0)
+                    self.children.append(child)
+                return True
+        return False
+
+    def adopt(self, left: "_Node[KT, ET]", middle: ET, right: "_Node[KT, ET]") -> None:
+        """Adopt left, middle, and right into our Node (which must not be maximal,
+        and which must not be a leaf).  In the case were we are not the new root,
+        then the left child must already be in the Node."""
+        assert not self.is_maximal()
+        assert not self.is_leaf
+        key = middle.key()
+        i, equal = self.search_in_node(key)
+        assert not equal
+        self.elts.insert(i, middle)
+        if len(self.children) == 0:
+            # We are the new root
+            self.children = [left, right]
+        else:
+            assert self.children[i] == left
+            self.children.insert(i + 1, right)
+
+    def merge(self, parent: "_Node[KT, ET]", index: int) -> None:
+        """Merge this node's parent and its right sibling into this node."""
+        right = parent.children.pop(index + 1)
+        self.elts.append(parent.elts.pop(index))
+        self.elts.extend(right.elts)
+        if not self.is_leaf:
+            self.children.extend(right.children)
+
+    def minimum(self) -> ET:
+        """The least element in this subtree."""
+        if self.is_leaf:
+            return self.elts[0]
+        else:
+            return self.children[0].minimum()
+
+    def maximum(self) -> ET:
+        """The greatest element in this subtree."""
+        if self.is_leaf:
+            return self.elts[-1]
+        else:
+            return self.children[-1].maximum()
+
+    def balance(self, parent: "_Node[KT, ET]", index: int) -> None:
+        """This Node is minimal, and we want to make it non-minimal so we can delete.
+        We try to steal from our siblings, and if that doesn't work we will merge
+        with one of them."""
+        assert not parent.is_leaf
+        if self.try_left_steal(parent, index):
+            return
+        if self.try_right_steal(parent, index):
+            return
+        # Stealing didn't work, so both siblings must be minimal.
+        if index == 0:
+            # We are the left-most node so merge with our right sibling.
+            self.merge(parent, index)
+        else:
+            # Have our left sibling merge with us.  This lets us only have "merge right"
+            # code.
+            left = parent.maybe_cow_child(index - 1)
+            left.merge(parent, index - 1)
+
+    def delete(
+        self, key: KT, parent: Optional["_Node[KT, ET]"], exact: Optional[ET]
+    ) -> Optional[ET]:
+        """Delete an element matching *key* if it exists.  If *exact* is not ``None``
+        then it must be an exact match with that element.  The Node must not be
+        minimal unless it is the root."""
+        assert parent is None or not self.is_minimal()
+        i, equal = self.search_in_node(key)
+        original_key = None
+        if equal:
+            # Note we use "is" here as we meant "exactly this object".
+            if exact is not None and self.elts[i] is not exact:
+                raise ValueError("exact delete did not match existing elt")
+            if self.is_leaf:
+                return self.elts.pop(i)
+            # Note we need to ensure exact is None going forward as we've
+            # already checked exactness and are about to change our target key
+            # to the least successor.
+            exact = None
+            original_key = key
+            least_successor = self.children[i + 1].minimum()
+            key = least_successor.key()
+            i = i + 1
+        if self.is_leaf:
+            # No match
+            if exact is not None:
+                raise ValueError("exact delete had no match")
+            return None
+        # recursively delete in the appropriate child
+        child = self.maybe_cow_child(i)
+        if child.is_minimal():
+            child.balance(self, i)
+            # Things may have moved.
+            i, equal = self.search_in_node(key)
+            assert not equal
+            child = self.children[i]
+            assert not child.is_minimal()
+        elt = child.delete(key, self, exact)
+        if original_key is not None:
+            node, i = self._get_node(original_key)
+            assert node is not None
+            assert elt is not None
+            oelt = node.elts[i]
+            node.elts[i] = elt
+            elt = oelt
+        return elt
+
+    def visit_in_order(self, visit: Callable[[ET], None]) -> None:
+        """Call *visit* on all of the elements in order."""
+        for i, elt in enumerate(self.elts):
+            if not self.is_leaf:
+                self.children[i].visit_in_order(visit)
+            visit(elt)
+        if not self.is_leaf:
+            self.children[-1].visit_in_order(visit)
+
+    def _visit_preorder_by_node(self, visit: Callable[["_Node[KT, ET]"], None]) -> None:
+        """Visit nodes in preorder.  This method is only used for testing."""
+        visit(self)
+        if not self.is_leaf:
+            for child in self.children:
+                child._visit_preorder_by_node(visit)
+
+    def maybe_cow(self, creator: _Creator) -> Optional["_Node[KT, ET]"]:
+        """Return a clone of this Node if it was not created by *creator*, or ``None``
+        otherwise (i.e. copy for copy-on-write if we haven't already copied it)."""
+        if self.creator is not creator:
+            return self.clone(creator)
+        else:
+            return None
+
+    def clone(self, creator: _Creator) -> "_Node[KT, ET]":
+        """Make a shallow-copy duplicate of this node."""
+        cloned = self.__class__(self.t, creator, self.is_leaf)
+        cloned.elts.extend(self.elts)
+        if not self.is_leaf:
+            cloned.children.extend(self.children)
+        return cloned
+
+    def __str__(self):  # pragma: no cover
+        if not self.is_leaf:
+            children = " " + " ".join([f"{id(c):x}" for c in self.children])
+        else:
+            children = ""
+        return f"{id(self):x} {self.creator} {self.elts}{children}"
+
+
+class Cursor(Generic[KT, ET]):
+    """A seekable cursor for a BTree.
+
+    If you are going to use a cursor on a mutable BTree, you should use it
+    in a ``with`` block so that any mutations of the BTree automatically park
+    the cursor.
+    """
+
+    def __init__(self, btree: "BTree[KT, ET]"):
+        self.btree = btree
+        self.current_node: Optional[_Node] = None
+        # The current index is the element index within the current node, or
+        # if there is no current node then it is 0 on the left boundary and 1
+        # on the right boundary.
+        self.current_index: int = 0
+        self.recurse = False
+        self.increasing = True
+        self.parents: list[tuple[_Node, int]] = []
+        self.parked = False
+        self.parking_key: Optional[KT] = None
+        self.parking_key_read = False
+
+    def _seek_least(self) -> None:
+        # seek to the least value in the subtree beneath the current index of the
+        # current node
+        assert self.current_node is not None
+        while not self.current_node.is_leaf:
+            self.parents.append((self.current_node, self.current_index))
+            self.current_node = self.current_node.children[self.current_index]
+            assert self.current_node is not None
+            self.current_index = 0
+
+    def _seek_greatest(self) -> None:
+        # seek to the greatest value in the subtree beneath the current index of the
+        # current node
+        assert self.current_node is not None
+        while not self.current_node.is_leaf:
+            self.parents.append((self.current_node, self.current_index))
+            self.current_node = self.current_node.children[self.current_index]
+            assert self.current_node is not None
+            self.current_index = len(self.current_node.elts)
+
+    def park(self):
+        """Park the cursor.
+
+        A cursor must be "parked" before mutating the BTree to avoid undefined behavior.
+        Cursors created in a ``with`` block register with their BTree and will park
+        automatically.  Note that a parked cursor may not observe some changes made when
+        it is parked; for example a cursor being iterated with next() will not see items
+        inserted before its current position.
+        """
+        if not self.parked:
+            self.parked = True
+
+    def _maybe_unpark(self):
+        if self.parked:
+            if self.parking_key is not None:
+                # remember our increasing hint, as seeking might change it
+                increasing = self.increasing
+                if self.parking_key_read:
+                    # We've already returned the parking key, so we want to be before it
+                    # if decreasing and after it if increasing.
+                    before = not self.increasing
+                else:
+                    # We haven't returned the parking key, so we've parked right
+                    # after seeking or are on a boundary.  Either way, the before
+                    # hint we want is the value of self.increasing.
+                    before = self.increasing
+                self.seek(self.parking_key, before)
+                self.increasing = increasing  # might have been altered by seek()
+            self.parked = False
+            self.parking_key = None
+
+    def prev(self) -> Optional[ET]:
+        """Get the previous element, or return None if on the left boundary."""
+        self._maybe_unpark()
+        self.parking_key = None
+        if self.current_node is None:
+            # on a boundary
+            if self.current_index == 0:
+                # left boundary, there is no prev
+                return None
+            else:
+                assert self.current_index == 1
+                # right boundary; seek to the actual boundary
+                # so we can do a prev()
+                self.current_node = self.btree.root
+                self.current_index = len(self.btree.root.elts)
+                self._seek_greatest()
+        while True:
+            if self.recurse:
+                if not self.increasing:
+                    # We only want to recurse if we are continuing in the decreasing
+                    # direction.
+                    self._seek_greatest()
+                self.recurse = False
+            self.increasing = False
+            self.current_index -= 1
+            if self.current_index >= 0:
+                elt = self.current_node.elts[self.current_index]
+                if not self.current_node.is_leaf:
+                    self.recurse = True
+                self.parking_key = elt.key()
+                self.parking_key_read = True
+                return elt
+            else:
+                if len(self.parents) > 0:
+                    self.current_node, self.current_index = self.parents.pop()
+                else:
+                    self.current_node = None
+                    self.current_index = 0
+                    return None
+
+    def next(self) -> Optional[ET]:
+        """Get the next element, or return None if on the right boundary."""
+        self._maybe_unpark()
+        self.parking_key = None
+        if self.current_node is None:
+            # on a boundary
+            if self.current_index == 1:
+                # right boundary, there is no next
+                return None
+            else:
+                assert self.current_index == 0
+                # left boundary; seek to the actual boundary
+                # so we can do a next()
+                self.current_node = self.btree.root
+                self.current_index = 0
+                self._seek_least()
+        while True:
+            if self.recurse:
+                if self.increasing:
+                    # We only want to recurse if we are continuing in the increasing
+                    # direction.
+                    self._seek_least()
+                self.recurse = False
+            self.increasing = True
+            if self.current_index < len(self.current_node.elts):
+                elt = self.current_node.elts[self.current_index]
+                self.current_index += 1
+                if not self.current_node.is_leaf:
+                    self.recurse = True
+                self.parking_key = elt.key()
+                self.parking_key_read = True
+                return elt
+            else:
+                if len(self.parents) > 0:
+                    self.current_node, self.current_index = self.parents.pop()
+                else:
+                    self.current_node = None
+                    self.current_index = 1
+                    return None
+
+    def _adjust_for_before(self, before: bool, i: int) -> None:
+        if before:
+            self.current_index = i
+        else:
+            self.current_index = i + 1
+
+    def seek(self, key: KT, before: bool = True) -> None:
+        """Seek to the specified key.
+
+        If *before* is ``True`` (the default) then the cursor is positioned just
+        before *key* if it exists, or before its least successor if it doesn't.  A
+        subsequent next() will retrieve this value.  If *before* is ``False``, then
+        the cursor is positioned just after *key* if it exists, or its greatest
+        precessessor if it doesn't.  A subsequent prev() will return this value.
+        """
+        self.current_node = self.btree.root
+        assert self.current_node is not None
+        self.recurse = False
+        self.parents = []
+        self.increasing = before
+        self.parked = False
+        self.parking_key = key
+        self.parking_key_read = False
+        while not self.current_node.is_leaf:
+            i, equal = self.current_node.search_in_node(key)
+            if equal:
+                self._adjust_for_before(before, i)
+                if before:
+                    self._seek_greatest()
+                else:
+                    self._seek_least()
+                return
+            self.parents.append((self.current_node, i))
+            self.current_node = self.current_node.children[i]
+            assert self.current_node is not None
+        i, equal = self.current_node.search_in_node(key)
+        if equal:
+            self._adjust_for_before(before, i)
+        else:
+            self.current_index = i
+
+    def seek_first(self) -> None:
+        """Seek to the left boundary (i.e. just before the least element).
+
+        A subsequent next() will return the least element if the BTree isn't empty."""
+        self.current_node = None
+        self.current_index = 0
+        self.recurse = False
+        self.increasing = True
+        self.parents = []
+        self.parked = False
+        self.parking_key = None
+
+    def seek_last(self) -> None:
+        """Seek to the right boundary (i.e. just after the greatest element).
+
+        A subsequent prev() will return the greatest element if the BTree isn't empty.
+        """
+        self.current_node = None
+        self.current_index = 1
+        self.recurse = False
+        self.increasing = False
+        self.parents = []
+        self.parked = False
+        self.parking_key = None
+
+    def __enter__(self):
+        self.btree.register_cursor(self)
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.btree.deregister_cursor(self)
+        return False
+
+
+class Immutable(Exception):
+    """The BTree is immutable."""
+
+
+class BTree(Generic[KT, ET]):
+    """An in-memory BTree with copy-on-write and cursors."""
+
+    def __init__(self, *, t: int = DEFAULT_T, original: Optional["BTree"] = None):
+        """Create a BTree.
+
+        If *original* is not ``None``, then the BTree is shallow-cloned from
+        *original* using copy-on-write.  Otherwise a new BTree with the specified
+        *t* value is created.
+
+        The BTree is not thread-safe.
+        """
+        # We don't use a reference to ourselves as a creator as we don't want
+        # to prevent GC of old btrees.
+        self.creator = _Creator()
+        self._immutable = False
+        self.t: int
+        self.root: _Node
+        self.size: int
+        self.cursors: set[Cursor] = set()
+        if original is not None:
+            if not original._immutable:
+                raise ValueError("original BTree is not immutable")
+            self.t = original.t
+            self.root = original.root
+            self.size = original.size
+        else:
+            if t < 3:
+                raise ValueError("t must be >= 3")
+            self.t = t
+            self.root = _Node(self.t, self.creator, True)
+            self.size = 0
+
+    def make_immutable(self):
+        """Make the BTree immutable.
+
+        Attempts to alter the BTree after making it immutable will raise an
+        Immutable exception.  This operation cannot be undone.
+        """
+        if not self._immutable:
+            self._immutable = True
+
+    def _check_mutable_and_park(self) -> None:
+        if self._immutable:
+            raise Immutable
+        for cursor in self.cursors:
+            cursor.park()
+
+    # Note that we don't use insert() and delete() but rather insert_element() and
+    # delete_key() so that BTreeDict can be a proper MutableMapping and supply the
+    # rest of the standard mapping API.
+
+    def insert_element(self, elt: ET, in_order: bool = False) -> Optional[ET]:
+        """Insert the element into the BTree.
+
+        If *in_order* is ``True``, then extra work will be done to make left siblings
+        full, which optimizes storage space when the the elements are inserted in-order
+        or close to it.
+
+        Returns the previously existing element at the element's key or ``None``.
+        """
+        self._check_mutable_and_park()
+        cloned = self.root.maybe_cow(self.creator)
+        if cloned:
+            self.root = cloned
+        if self.root.is_maximal():
+            old_root = self.root
+            self.root = _Node(self.t, self.creator, False)
+            self.root.adopt(*old_root.split())
+        oelt = self.root.insert_nonfull(elt, in_order)
+        if oelt is None:
+            # We did not replace, so something was added.
+            self.size += 1
+        return oelt
+
+    def get_element(self, key: KT) -> Optional[ET]:
+        """Get the element matching *key* from the BTree, or return ``None`` if it
+        does not exist.
+        """
+        return self.root.get(key)
+
+    def _delete(self, key: KT, exact: Optional[ET]) -> Optional[ET]:
+        self._check_mutable_and_park()
+        cloned = self.root.maybe_cow(self.creator)
+        if cloned:
+            self.root = cloned
+        elt = self.root.delete(key, None, exact)
+        if elt is not None:
+            # We deleted something
+            self.size -= 1
+            if len(self.root.elts) == 0:
+                # The root is now empty.  If there is a child, then collapse this root
+                # level and make the child the new root.
+                if not self.root.is_leaf:
+                    assert len(self.root.children) == 1
+                    self.root = self.root.children[0]
+        return elt
+
+    def delete_key(self, key: KT) -> Optional[ET]:
+        """Delete the element matching *key* from the BTree.
+
+        Returns the matching element or ``None`` if it does not exist.
+        """
+        return self._delete(key, None)
+
+    def delete_exact(self, element: ET) -> Optional[ET]:
+        """Delete *element* from the BTree.
+
+        Returns the matching element or ``None`` if it was not in the BTree.
+        """
+        delt = self._delete(element.key(), element)
+        assert delt is element
+        return delt
+
+    def __len__(self):
+        return self.size
+
+    def visit_in_order(self, visit: Callable[[ET], None]) -> None:
+        """Call *visit*(element) on all elements in the tree in sorted order."""
+        self.root.visit_in_order(visit)
+
+    def _visit_preorder_by_node(self, visit: Callable[[_Node], None]) -> None:
+        self.root._visit_preorder_by_node(visit)
+
+    def cursor(self) -> Cursor[KT, ET]:
+        """Create a cursor."""
+        return Cursor(self)
+
+    def register_cursor(self, cursor: Cursor) -> None:
+        """Register a cursor for the automatic parking service."""
+        self.cursors.add(cursor)
+
+    def deregister_cursor(self, cursor: Cursor) -> None:
+        """Deregister a cursor from the automatic parking service."""
+        self.cursors.discard(cursor)
+
+    def __copy__(self):
+        return self.__class__(original=self)
+
+    def __iter__(self):
+        with self.cursor() as cursor:
+            while True:
+                elt = cursor.next()
+                if elt is None:
+                    break
+                yield elt.key()
+
+
+VT = TypeVar("VT")  # the type of a value in a BTreeDict
+
+
+class KV(Element, Generic[KT, VT]):
+    """The BTree element type used in a ``BTreeDict``."""
+
+    def __init__(self, key: KT, value: VT):
+        self._key = key
+        self._value = value
+
+    def key(self) -> KT:
+        return self._key
+
+    def value(self) -> VT:
+        return self._value
+
+    def __str__(self):  # pragma: no cover
+        return f"KV({self._key}, {self._value})"
+
+    def __repr__(self):  # pragma: no cover
+        return f"KV({self._key}, {self._value})"
+
+
+class BTreeDict(Generic[KT, VT], BTree[KT, KV[KT, VT]], MutableMapping[KT, VT]):
+    """A MutableMapping implemented with a BTree.
+
+    Unlike a normal Python dict, the BTreeDict may be mutated while iterating.
+    """
+
+    def __init__(
+        self,
+        *,
+        t: int = DEFAULT_T,
+        original: Optional[BTree] = None,
+        in_order: bool = False,
+    ):
+        super().__init__(t=t, original=original)
+        self.in_order = in_order
+
+    def __getitem__(self, key: KT) -> VT:
+        elt = self.get_element(key)
+        if elt is None:
+            raise KeyError
+        else:
+            return cast(KV, elt).value()
+
+    def __setitem__(self, key: KT, value: VT) -> None:
+        elt = KV(key, value)
+        self.insert_element(elt, self.in_order)
+
+    def __delitem__(self, key: KT) -> None:
+        if self.delete_key(key) is None:
+            raise KeyError
+
+
+class Member(Element, Generic[KT]):
+    """The BTree element type used in a ``BTreeSet``."""
+
+    def __init__(self, key: KT):
+        self._key = key
+
+    def key(self) -> KT:
+        return self._key
+
+
+class BTreeSet(BTree, Generic[KT], MutableSet[KT]):
+    """A MutableSet implemented with a BTree.
+
+    Unlike a normal Python set, the BTreeSet may be mutated while iterating.
+    """
+
+    def __init__(
+        self,
+        *,
+        t: int = DEFAULT_T,
+        original: Optional[BTree] = None,
+        in_order: bool = False,
+    ):
+        super().__init__(t=t, original=original)
+        self.in_order = in_order
+
+    def __contains__(self, key: Any) -> bool:
+        return self.get_element(key) is not None
+
+    def add(self, value: KT) -> None:
+        elt = Member(value)
+        self.insert_element(elt, self.in_order)
+
+    def discard(self, value: KT) -> None:
+        self.delete_key(value)
diff --git a/dns/btreezone.py b/dns/btreezone.py
new file mode 100644 (file)
index 0000000..c71bd5c
--- /dev/null
@@ -0,0 +1,369 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and
+# a separate per-version delegation index.  These additions let us
+#
+# 1) Do efficient CoW versioning (useful for future online updates).
+# 2) Maintain sort order
+# 3) Allow delegations to be found easily
+# 4) Handle glue
+# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant.  The ORIGIN
+#    flag is set at the origin node, the DELEGATION FLAG is set at delegation
+#    points, and the GLUE flag is set on nodes beneath delegation points.
+
+import enum
+from dataclasses import dataclass
+from typing import Callable, MutableMapping, Optional, Tuple, Union, cast
+
+import dns.btree
+import dns.immutable
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdataset
+import dns.rdatatype
+import dns.versioned
+import dns.zone
+
+
+class NodeFlags(enum.IntFlag):
+    ORIGIN = 0x01
+    DELEGATION = 0x02
+    GLUE = 0x04
+
+
+class Node(dns.node.Node):
+    __slots__ = ["flags", "id"]
+
+    def __init__(self, flags: Optional[NodeFlags] = None):
+        super().__init__()
+        if flags is None:
+            # We allow optional flags rather than a default
+            # as pyright doesn't like assigning a literal 0
+            # to flags.
+            flags = NodeFlags(0)
+        self.flags = flags
+        self.id = 0
+
+    def is_delegation(self):
+        return (self.flags & NodeFlags.DELEGATION) != 0
+
+    def is_glue(self):
+        return (self.flags & NodeFlags.GLUE) != 0
+
+    def is_origin(self):
+        return (self.flags & NodeFlags.ORIGIN) != 0
+
+    def is_origin_or_glue(self):
+        return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0
+
+
+@dns.immutable.immutable
+class ImmutableNode(Node):
+    def __init__(self, node: Node):
+        super().__init__()
+        self.id = node.id
+        self.rdatasets = tuple(  # type: ignore
+            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+        )
+        self.flags = node.flags
+
+    def find_rdataset(
+        self,
+        rdclass: dns.rdataclass.RdataClass,
+        rdtype: dns.rdatatype.RdataType,
+        covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+        create: bool = False,
+    ) -> dns.rdataset.Rdataset:
+        if create:
+            raise TypeError("immutable")
+        return super().find_rdataset(rdclass, rdtype, covers, False)
+
+    def get_rdataset(
+        self,
+        rdclass: dns.rdataclass.RdataClass,
+        rdtype: dns.rdatatype.RdataType,
+        covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+        create: bool = False,
+    ) -> Optional[dns.rdataset.Rdataset]:
+        if create:
+            raise TypeError("immutable")
+        return super().get_rdataset(rdclass, rdtype, covers, False)
+
+    def delete_rdataset(
+        self,
+        rdclass: dns.rdataclass.RdataClass,
+        rdtype: dns.rdatatype.RdataType,
+        covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+    ) -> None:
+        raise TypeError("immutable")
+
+    def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
+        raise TypeError("immutable")
+
+    def is_immutable(self) -> bool:
+        return True
+
+
+class Delegations(dns.btree.BTreeSet[dns.name.Name]):
+    def get_delegation(
+        self, name: dns.name.Name
+    ) -> Tuple[Optional[dns.name.Name], bool]:
+        """Get the delegation applicable to *name*, if it exists.
+
+        If there delegation, then return a tuple consisting of the name of
+        the delegation point, and a boolean which is `True` if the name is a proper
+        subdomain of the delegation point, and `False` if it is equal to the delegation
+        point.
+        """
+        cursor = self.cursor()
+        cursor.seek(name, before=False)
+        prev = cursor.prev()
+        if prev is None:
+            return None, False
+        cut = prev.key()
+        reln, _, _ = name.fullcompare(cut)
+        is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN
+        if is_subdomain or reln == dns.name.NameRelation.EQUAL:
+            return cut, is_subdomain
+        else:
+            return None, False
+
+    def is_glue(self, name: dns.name.Name) -> bool:
+        """Is *name* glue, i.e. is it beneath a delegation?"""
+        cursor = self.cursor()
+        cursor.seek(name, before=False)
+        cut, is_subdomain = self.get_delegation(name)
+        if cut is None:
+            return False
+        return is_subdomain
+
+
+class WritableVersion(dns.zone.WritableVersion):
+    def __init__(self, zone: dns.zone.Zone, replacement: bool = False):
+        super().__init__(zone, True)
+        if not replacement:
+            assert isinstance(zone, dns.versioned.Zone)
+            version = zone._versions[-1]
+            self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict(
+                original=version.nodes  # type: ignore
+            )
+            self.delegations = Delegations(original=version.delegations)  # type: ignore
+        else:
+            self.delegations = Delegations()
+
+    def _is_origin(self, name: dns.name.Name) -> bool:
+        # Assumes name has already been validated (and thus adjusted to the right
+        # relativity too)
+        if self.zone.relativize:
+            return name == dns.name.empty
+        else:
+            return name == self.zone.origin
+
+    def _maybe_cow_with_name(
+        self, name: dns.name.Name
+    ) -> Tuple[dns.node.Node, dns.name.Name]:
+        (node, name) = super()._maybe_cow_with_name(name)
+        node = cast(Node, node)
+        if self._is_origin(name):
+            node.flags |= NodeFlags.ORIGIN
+        elif self.delegations.is_glue(name):
+            node.flags |= NodeFlags.GLUE
+        return (node, name)
+
+    def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None:
+        cursor = self.nodes.cursor()  # type: ignore
+        cursor.seek(name, False)
+        updates = []
+        while True:
+            elt = cursor.next()
+            if elt is None:
+                break
+            ename = elt.key()
+            if not ename.is_subdomain(name):
+                break
+            node = cast(dns.node.Node, elt.value())
+            if ename not in self.changed:
+                new_node = self.zone.node_factory()
+                new_node.id = self.id  # type: ignore
+                new_node.rdatasets.extend(node.rdatasets)
+                self.changed.add(ename)
+                node = new_node
+            assert isinstance(node, Node)
+            if is_glue:
+                node.flags |= NodeFlags.GLUE
+            else:
+                node.flags &= ~NodeFlags.GLUE
+            # We don't update node here as any insertion could disturb the
+            # btree and invalidate our cursor.  We could use the cursor in a
+            # with block and avoid this, but it would do a lot of parking and
+            # unparking so the deferred update mode may still be better.
+            updates.append((ename, node))
+        for ename, node in updates:
+            self.nodes[ename] = node
+
+    def delete_node(self, name: dns.name.Name) -> None:
+        name = self._validate_name(name)
+        node = self.nodes.get(name)
+        if node is not None:
+            if node.is_delegation():  # type: ignore
+                self.delegations.discard(name)
+                self.update_glue_flag(name, False)
+            del self.nodes[name]
+            self.changed.add(name)
+
+    def put_rdataset(
+        self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
+    ) -> None:
+        (node, name) = self._maybe_cow_with_name(name)
+        if (
+            rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue()  # type: ignore
+        ):
+            node.flags |= NodeFlags.DELEGATION  # type: ignore
+            if name not in self.delegations:
+                self.delegations.add(name)
+                self.update_glue_flag(name, True)
+        node.replace_rdataset(rdataset)
+
+    def delete_rdataset(
+        self,
+        name: dns.name.Name,
+        rdtype: dns.rdatatype.RdataType,
+        covers: dns.rdatatype.RdataType,
+    ) -> None:
+        (node, name) = self._maybe_cow_with_name(name)
+        if rdtype == dns.rdatatype.NS and name in self.delegations:  # type: ignore
+            node.flags &= ~NodeFlags.DELEGATION  # type: ignore
+            self.delegations.discard(name)  # type: ignore
+            self.update_glue_flag(name, False)
+        node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+        if len(node) == 0:
+            del self.nodes[name]
+
+
+@dataclass(frozen=True)
+class Bounds:
+    name: dns.name.Name
+    left: dns.name.Name
+    right: Optional[dns.name.Name]
+    closest_encloser: dns.name.Name
+    is_equal: bool
+    is_delegation: bool
+
+    def __str__(self):
+        if self.is_equal:
+            op = "="
+        else:
+            op = "<"
+        if self.is_delegation:
+            zonecut = " zonecut"
+        else:
+            zonecut = ""
+        return (
+            f"{self.left} {op} {self.name} < {self.right}{zonecut}; "
+            f"{self.closest_encloser}"
+        )
+
+
+@dns.immutable.immutable
+class ImmutableVersion(dns.zone.Version):
+    def __init__(self, version: dns.zone.Version):
+        if not isinstance(version, WritableVersion):
+            raise ValueError(
+                "a dns.btreezone.ImmutableVersion requires a "
+                "dns.btreezone.WritableVersion"
+            )
+        super().__init__(version.zone, True)
+        self.id = version.id
+        self.origin = version.origin
+        for name in version.changed:
+            node = version.nodes.get(name)
+            if node:
+                version.nodes[name] = ImmutableNode(node)
+        # the cast below is for mypy
+        self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes)
+        self.nodes.make_immutable()  # type: ignore
+        self.delegations = version.delegations
+        self.delegations.make_immutable()
+
+    def bounds(self, name: Union[dns.name.Name, str]) -> Bounds:
+        """Return the 'bounds' of *name* in its zone.
+
+        The bounds information is useful when making an authoritative response, as
+        it can be used to determine whether the query name is at or beneath a delegation
+        point.  The other data in the ``Bounds`` object is useful for making on-the-fly
+        DNSSEC signatures.
+
+        The left bound of *name* is *name* itself if it is in the zone, or the greatest
+        predecessor which is in the zone.
+
+        The right bound of *name* is the least successor of *name*, or ``None`` if
+        no name in the zone is greater than *name*.
+
+        The closest encloser of *name* is *name* itself, if *name* is in the zone;
+        otherwise it is the name with the largest number of labels in common with
+        *name* that is in the zone, either explicitly or by the implied existence
+        of empty non-terminals.
+
+        The bounds *is_equal* field is ``True`` if and only if *name* is equal to
+        its left bound.
+
+        The bounds *is_delegation* field is ``True`` if and only if the left bound is a
+        delegation point.
+        """
+        assert self.origin is not None
+        # validate the origin because we may need to relativize
+        origin = self.zone._validate_name(self.origin)
+        name = self.zone._validate_name(name)
+        cut, _ = self.delegations.get_delegation(name)
+        if cut is not None:
+            target = cut
+            is_delegation = True
+        else:
+            target = name
+            is_delegation = False
+        c = cast(dns.btree.BTreeDict, self.nodes).cursor()
+        c.seek(target, False)
+        left = c.prev()
+        assert left is not None
+        c.next()  # skip over left
+        while True:
+            right = c.next()
+            if right is None or not right.value().is_glue():
+                break
+        left_comparison = left.key().fullcompare(name)
+        if right is not None:
+            right_key = right.key()
+            right_comparison = right_key.fullcompare(name)
+        else:
+            right_comparison = (
+                dns.name.NAMERELN_COMMONANCESTOR,
+                -1,
+                len(origin),
+            )
+            right_key = None
+        closest_encloser = dns.name.Name(
+            name[-max(left_comparison[2], right_comparison[2]) :]
+        )
+        return Bounds(
+            name,
+            left.key(),
+            right_key,
+            closest_encloser,
+            left_comparison[0] == dns.name.NameRelation.EQUAL,
+            is_delegation,
+        )
+
+
+class Zone(dns.versioned.Zone):
+    node_factory: Callable[[], dns.node.Node] = Node
+    map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast(
+        Callable[[], MutableMapping[dns.name.Name, dns.node.Node]],
+        dns.btree.BTreeDict[dns.name.Name, Node],
+    )
+    writable_version_factory: Optional[
+        Callable[[dns.zone.Zone, bool], dns.zone.Version]
+    ] = WritableVersion
+    immutable_version_factory: Optional[
+        Callable[[dns.zone.Version], dns.zone.Version]
+    ] = ImmutableVersion
index 6479ae47e0b137bcce48b055db6049c38172dc0f..260eea1b60be55f2ca85ce98ce785e858f085f52 100644 (file)
@@ -41,7 +41,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
         "_readers",
     ]
 
-    node_factory = Node
+    node_factory: Callable[[], dns.node.Node] = Node
 
     def __init__(
         self,
index b1a52f6325ab191b3f5ddd3922a2a1079b79ade0..05170fe8f02bd2dd6480745b34003614dcff802e 100644 (file)
@@ -131,8 +131,10 @@ class Zone(dns.transaction.TransactionManager):
 
     node_factory: Callable[[], dns.node.Node] = dns.node.Node
     map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
-    writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
-    immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
+    # We only require the version types as "Version" to allow for flexibility, as
+    # only the version protocol matters
+    writable_version_factory: Optional[Callable[["Zone", bool], "Version"]] = None
+    immutable_version_factory: Optional[Callable[["Version"], "Version"]] = None
 
     __slots__ = ["rdclass", "origin", "nodes", "relativize"]
 
@@ -1026,7 +1028,9 @@ class WritableVersion(Version):
         self.origin = zone.origin
         self.changed: Set[dns.name.Name] = set()
 
-    def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node:
+    def _maybe_cow_with_name(
+        self, name: dns.name.Name
+    ) -> Tuple[dns.node.Node, dns.name.Name]:
         name = self._validate_name(name)
         node = self.nodes.get(name)
         if node is None or name not in self.changed:
@@ -1044,9 +1048,12 @@ class WritableVersion(Version):
                 new_node.rdatasets.extend(node.rdatasets)
             self.nodes[name] = new_node
             self.changed.add(name)
-            return new_node
+            return (new_node, name)
         else:
-            return node
+            return (node, name)
+
+    def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node:
+        return self._maybe_cow_with_name(name)[0]
 
     def delete_node(self, name: dns.name.Name) -> None:
         name = self._validate_name(name)
@@ -1074,7 +1081,11 @@ class WritableVersion(Version):
 
 @dns.immutable.immutable
 class ImmutableVersion(Version):
-    def __init__(self, version: WritableVersion):
+    def __init__(self, version: Version):
+        if not isinstance(version, WritableVersion):
+            raise ValueError(
+                "a dns.zone.ImmutableVersion requires a dns.zone.WritableVersion"
+            )
         # We tell super() that it's a replacement as we don't want it
         # to copy the nodes, as we're about to do that with an
         # immutable Dict.
diff --git a/tests/test_btree.py b/tests/test_btree.py
new file mode 100644 (file)
index 0000000..e245353
--- /dev/null
@@ -0,0 +1,722 @@
+import copy
+
+import pytest
+
+import dns.btree as btree
+
+
+class BTreeDict(btree.BTreeDict):
+    # We mostly test with an in-order optimized BTreeDict with t=3 as that's how we
+    # generated the data.
+    def __init__(self, *, t=3, original=None):
+        super().__init__(t=t, original=original, in_order=True)
+
+
+def add_keys(b, keys):
+    if isinstance(keys, int):
+        keys = range(keys)
+    for key in keys:
+        b[key] = True
+
+
+def test_replace():
+    N = 8
+    b = BTreeDict()
+    add_keys(b, N)
+    b[0] = False
+    b[5] = False
+    b[7] = False
+    for key in range(N):
+        if key in {0, 5, 7}:
+            assert b[key] == False
+        else:
+            assert b[key] == True
+
+
+def test_min_max():
+    N = 8
+    b = BTreeDict()
+    add_keys(b, N)
+    assert b.root.minimum().key() == 0
+    assert b.root.maximum().key() == N - 1
+    del b[N - 1]
+    del b[0]
+    assert b.root.minimum().key() == 1
+    assert b.root.maximum().key() == N - 2
+
+
+def test_nonexistent():
+    N = 8
+    b = BTreeDict()
+    add_keys(b, N)
+    with pytest.raises(KeyError):
+        b[1.5] == False
+    assert b.delete_key(1.5) is None
+    with pytest.raises(KeyError):
+        del b[1.5]
+
+
+def test_in_order():
+    N = 100
+    b = BTreeDict()
+    add_keys(b, N)
+    expected = list(range(N))
+    assert list(b.keys()) == expected
+
+    keys = list(range(N - 1, -1, -1))
+    expected = N
+    for key in keys:
+        l = len(b)
+        assert l == expected
+        expected -= 1
+        del b[key]
+    assert len(b) == 0
+
+
+# Some key orderings generated randomly but hardcoded here for test stability.
+# The keys lead to 100% coverage in insert, find, and delete.
+
+random_keys_1 = [
+    36,
+    14,
+    89,
+    67,
+    80,
+    98,
+    71,
+    29,
+    92,
+    91,
+    79,
+    49,
+    63,
+    74,
+    19,
+    4,
+    23,
+    60,
+    10,
+    31,
+    94,
+    46,
+    18,
+    84,
+    61,
+    42,
+    77,
+    54,
+    76,
+    38,
+    26,
+    37,
+    24,
+    99,
+    45,
+    7,
+    97,
+    32,
+    53,
+    96,
+    82,
+    52,
+    8,
+    58,
+    11,
+    3,
+    15,
+    47,
+    17,
+    21,
+    28,
+    2,
+    20,
+    12,
+    95,
+    44,
+    16,
+    9,
+    51,
+    30,
+    33,
+    34,
+    88,
+    55,
+    43,
+    72,
+    57,
+    66,
+    22,
+    56,
+    68,
+    87,
+    73,
+    6,
+    25,
+    59,
+    0,
+    75,
+    90,
+    78,
+    50,
+    13,
+    83,
+    93,
+    39,
+    81,
+    41,
+    70,
+    48,
+    35,
+    65,
+    64,
+    62,
+    5,
+    27,
+    86,
+    40,
+    1,
+    85,
+    69,
+]
+
+random_keys_2 = [
+    49,
+    28,
+    0,
+    19,
+    14,
+    76,
+    65,
+    8,
+    12,
+    90,
+    71,
+    36,
+    31,
+    24,
+    83,
+    59,
+    98,
+    48,
+    26,
+    82,
+    46,
+    84,
+    80,
+    33,
+    74,
+    75,
+    60,
+    99,
+    20,
+    61,
+    88,
+    81,
+    41,
+    58,
+    85,
+    54,
+    96,
+    23,
+    72,
+    66,
+    1,
+    37,
+    57,
+    64,
+    27,
+    13,
+    40,
+    73,
+    69,
+    32,
+    55,
+    34,
+    5,
+    2,
+    39,
+    9,
+    93,
+    50,
+    47,
+    92,
+    79,
+    78,
+    63,
+    10,
+    30,
+    77,
+    87,
+    53,
+    7,
+    56,
+    21,
+    18,
+    62,
+    6,
+    11,
+    95,
+    70,
+    44,
+    42,
+    97,
+    35,
+    91,
+    43,
+    16,
+    89,
+    45,
+    67,
+    4,
+    22,
+    17,
+    25,
+    51,
+    94,
+    52,
+    68,
+    3,
+    15,
+    86,
+    38,
+    29,
+]
+
+
+def test_random_trees():
+    N = len(random_keys_1)
+    b = BTreeDict()
+    add_keys(b, random_keys_1)
+    expected = list(range(N))
+    assert list(b.keys()) == expected
+
+    for key in random_keys_1:
+        assert b[key]
+
+    keys = random_keys_2
+    expected_len = N
+    for key in keys:
+        l = len(b)
+        assert l == expected_len
+        expected_len -= 1
+        del b[key]
+    assert len(b) == 0
+
+
+def test_random_trees_no_in_order_optimization():
+    N = len(random_keys_1)
+    b = btree.BTreeDict(t=3)
+    add_keys(b, random_keys_1)
+    expected = list(range(N))
+    assert list(b.keys()) == expected
+
+    for key in random_keys_1:
+        assert b[key]
+
+    keys = random_keys_2
+    expected_len = N
+    for key in keys:
+        l = len(b)
+        assert l == expected_len
+        expected_len -= 1
+        del b[key]
+    assert len(b) == 0
+
+
+def node_set(b):
+    s = set()
+    b._visit_preorder_by_node(lambda n: s.add(n))
+    return s
+
+
+def test_cow():
+    N = len(random_keys_1)
+    b = BTreeDict()
+    add_keys(b, random_keys_1)
+    expected = list(range(N))
+    assert list(b.keys()) == expected
+    nsb = node_set(b)
+
+    with pytest.raises(ValueError):
+        d = BTreeDict(original=b)
+    b.make_immutable()
+    with pytest.raises(btree.Immutable):
+        b[100] = True
+    with pytest.raises(btree.Immutable):
+        del b[1]
+
+    b2 = BTreeDict(original=b)
+    keys = random_keys_2
+    expected_len = N
+    for key in keys:
+        l = len(b2)
+        assert l == expected_len
+        expected_len -= 1
+        del b2[key]
+    assert len(b2) == 0
+    b2[100] = True
+    b2[101] = True
+    assert list(b2.keys()) == [100, 101]
+
+    # and b is unchanged
+    assert list(b.keys()) == expected
+    nsb2 = node_set(b)
+    assert nsb == nsb2
+
+    # copy should be the same as b
+    b3 = copy.copy(b)
+    assert list(b.keys()) == expected
+    nsb3 = node_set(b3)
+    assert nsb == nsb3
+
+
+def test_cow_minimality():
+    b = BTreeDict()
+    add_keys(b, 8)
+    b.make_immutable()
+    b2 = BTreeDict(original=b)
+
+    assert b.root is b2.root
+    b2[7] = 100
+    assert b.root is not b2.root
+    assert b.root.children[0] is b2.root.children[0]
+    assert b.root.children[1] is not b2.root.children[1]
+    del b2[5]
+    assert b.root is not b2.root
+    assert b.root.children[0] is not b2.root.children[0]
+    assert b.root.children[1] is not b2.root.children[1]
+
+
+def test_cursor_seek():
+    N = len(random_keys_1)
+    b = BTreeDict()
+    add_keys(b, random_keys_1)
+
+    l = []
+    c = b.cursor()
+    while True:
+        elt = c.next()
+        if elt is None:
+            break
+        else:
+            l.append(elt.key())
+    expected = list(range(N))
+    assert l == expected
+    assert c.next() is None
+
+    # same as previous but with explicit seek_first()
+    l = []
+    c = b.cursor()
+    c.seek_first()
+    while True:
+        elt = c.next()
+        if elt is None:
+            break
+        else:
+            l.append(elt.key())
+    expected = list(range(N))
+    assert l == expected
+    assert c.next() is None
+
+    l = []
+    c = b.cursor()
+    c.seek_last()
+    while True:
+        elt = c.prev()
+        if elt is None:
+            break
+        else:
+            l.append(elt.key())
+    expected = list(range(N - 1, -1, -1))
+    assert l == expected
+    assert c.prev() is None
+
+
+def test_cursor_seek_before_and_after():
+    N = 8
+    b = BTreeDict()
+    add_keys(b, N)
+
+    c = b.cursor()
+
+    # Seek before, leaf
+    c.seek(2)
+    assert c.next().key() == 2
+    assert c.prev().key() == 2
+
+    # Seek before, parent
+    c.seek(5)
+    assert c.next().key() == 5
+    c.seek(5)
+    assert c.prev().key() == 4
+
+    # Seek after, leaf
+    c.seek(2, False)
+    assert c.next().key() == 3
+    c.seek(2, False)
+    assert c.prev().key() == 2
+
+    # Seek after, leaf
+    c.seek(2, False)
+    assert c.next().key() == 3
+    c.seek(2, False)
+    assert c.prev().key() == 2
+
+    # Seek after, parent
+    c.seek(5, False)
+    assert c.next().key() == 6
+    c.seek(5, False)
+    assert c.prev().key() == 5
+
+    # Nonexistent
+    c.seek(2.5)
+    assert c.next().key() == 3
+    c.seek(2.5)
+    assert c.prev().key() == 2
+    c.seek(4.5)
+    assert c.next().key() == 5
+    c.seek(5.5)
+    assert c.prev().key() == 5
+
+
+def test_cursor_reversing_in_parentnode():
+    N = 11
+    b = BTreeDict()
+    add_keys(b, N)
+    c = b.cursor()
+    c.seek(5)
+    assert c.next().key() == 5
+    assert c.prev().key() == 5
+    c.seek(5, False)
+    assert c.prev().key() == 5
+    assert c.next().key() == 5
+
+
+def test_cursor_empty_tree_seeks():
+    b = BTreeDict()
+    c = b.cursor()
+    c.seek(5)
+    assert c.next() == None
+    assert c.prev() == None
+
+
+def test_parking():
+    N = 11
+    b = BTreeDict()
+    add_keys(b, N)
+    c = b.cursor()
+
+    c.seek(5)
+    assert c.next().key() == 5
+    c.park()
+    assert c.next().key() == 6
+
+    c.seek(5)
+    assert c.prev().key() == 4
+    c.park()
+    assert c.prev().key() == 3
+
+    c.seek(5)
+    assert c.next().key() == 5
+    c.park()
+    assert c.prev().key() == 5
+
+    c.seek(5)
+    assert c.prev().key() == 4
+    c.park()
+    assert c.next().key() == 4
+
+    expected = list(range(11))
+
+    c.seek_first()
+    got = []
+    while True:
+        e = c.next()
+        if e is None:
+            break
+        got.append(e.key())
+        c.park()
+    assert got == expected
+
+    c.seek_last()
+    got = []
+    while True:
+        e = c.prev()
+        if e is None:
+            break
+        got.append(e.key())
+        c.park()
+    assert got == list(reversed(expected))
+
+    # parking on the boundary
+
+    c.seek_first()
+    c.park()
+    assert c.next().key() == 0
+
+    c.seek_last()
+    c.park()
+    assert c.prev().key() == 10
+
+    c.seek(0)
+    c.park()
+    assert c.next().key() == 0
+
+    c.seek(10, False)
+    c.park()
+    assert c.prev().key() == 10
+
+    # double parking (parking idempotency)
+
+    c.seek_first()
+    c.park()
+    c.park()
+    assert c.next().key() == 0
+
+    # mutation
+
+    c.seek(5)
+    c.park()
+    b[4.5] = True
+    assert c.next().key() == 5
+
+    c.seek(5)
+    c.park()
+    b[5.5] = True
+    assert c.next().key() == 5
+
+    c.seek(5)
+    c.park()
+    del b[5]
+    assert c.next().key() == 5.5
+
+    c.seek(5)
+    c.park()
+    b[4.49] = True
+    b[4.51] = True
+    assert c.prev().key() == 4.51
+
+    c.seek(5)
+    c.park()
+    b[5.49] = True
+    b[5.51] = True
+    assert c.next().key() == 5.49
+
+
+def test_automatic_parking():
+    N = 11
+    b = BTreeDict()
+    add_keys(b, N)
+
+    with b.cursor() as c:
+        assert c in b.cursors
+        c.seek(5)
+        assert c.next().key() == 5
+        b[5.5] = True
+        assert c.next().key() == 5.5
+    assert len(b.cursors) == 0
+
+
+def test_visit_in_order():
+    N = 100
+    b = BTreeDict()
+    add_keys(b, N)
+
+    l = []
+    b.visit_in_order(lambda elt: l.append(elt.key()))
+    assert l == list(range(N))
+
+
+def test_visit_preorder_by_node():
+    N = 100
+    b = BTreeDict()
+    add_keys(b, N)
+
+    kl = []
+    b._visit_preorder_by_node(lambda node: kl.append([elt.key() for elt in node.elts]))
+    expected = [
+        [35, 71],
+        [5, 11, 17, 23, 29],
+        [0, 1, 2, 3, 4],
+        [6, 7, 8, 9, 10],
+        [12, 13, 14, 15, 16],
+        [18, 19, 20, 21, 22],
+        [24, 25, 26, 27, 28],
+        [30, 31, 32, 33, 34],
+        [41, 47, 53, 59, 65],
+        [36, 37, 38, 39, 40],
+        [42, 43, 44, 45, 46],
+        [48, 49, 50, 51, 52],
+        [54, 55, 56, 57, 58],
+        [60, 61, 62, 63, 64],
+        [66, 67, 68, 69, 70],
+        [77, 83, 89, 95],
+        [72, 73, 74, 75, 76],
+        [78, 79, 80, 81, 82],
+        [84, 85, 86, 87, 88],
+        [90, 91, 92, 93, 94],
+        [96, 97, 98, 99],
+    ]
+    assert kl == expected
+
+
+def test_exact_delete():
+    N = 8
+    b = BTreeDict()
+    add_keys(b, N)
+    for key in [5, 0, 7]:
+        elt = b.get_element(key)
+        assert elt is not None
+        bogus_elt = btree.KV(key, False)
+        with pytest.raises(ValueError):
+            b.delete_exact(bogus_elt)
+        b.delete_exact(elt)
+        with pytest.raises(ValueError):
+            b.delete_exact(elt)
+
+
+def test_find_nonexistent_node():
+    # This is just for 100% coverage
+    b = BTreeDict()
+    b[0] = True
+    assert b.root._get_node(0) == (b.root, 0)
+    assert b.root._get_node(1) == (None, 0)
+
+
+def test_t_too_small():
+    with pytest.raises(ValueError):
+        BTreeDict(t=2)
+
+
+def test_immutable_idempotent():
+    # Again just for coverage.
+    b = BTreeDict()
+    b.make_immutable()
+    assert b._immutable
+    b.make_immutable()
+    assert b._immutable
+
+
+def test_btree_set():
+    b: btree.BTreeSet[int] = btree.BTreeSet()
+    b.add(1)
+    assert 1 in b
+    b.add(2)
+    assert 1 in b
+    assert 2 in b
+    assert len(b) == 2
+    assert list(b) == [1, 2]
+    b.discard(1)
+    assert 1 not in b
+    assert 2 in b
+    assert len(b) == 1
+    assert list(b) == [2]
+    b.discard(1)
+    assert 1 not in b
+    assert 2 in b
+    assert len(b) == 1
+    b.discard(2)
+    assert 1 not in b
+    assert 2 not in b
+    assert len(b) == 0
+    assert list(b) == []
diff --git a/tests/test_btreezone.py b/tests/test_btreezone.py
new file mode 100644 (file)
index 0000000..b8eb489
--- /dev/null
@@ -0,0 +1,236 @@
+from typing import cast
+
+import dns.btreezone
+import dns.rdataset
+import dns.zone
+
+Node = dns.btreezone.Node
+
+simple_zone = """
+$ORIGIN example.
+$TTL 300
+@ soa foo bar 1 2 3 4 5
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+a txt "a"
+c.b.a txt "cba"
+b txt "b"
+sub ns ns1.sub
+sub ns ns2.sub
+ns1.sub a 10.0.0.3
+ns2.sub a 10.0.0.4
+ns1.sub2 a 10.0.0.5
+ns2.sub2 a 10.0.0.6
+text txt "here to be after sub2"
+z txt "z"
+"""
+
+
+def make_example(text: str, relativize: bool = False) -> dns.btreezone.Zone:
+    z = dns.zone.from_text(
+        simple_zone, "example.", relativize=relativize, zone_factory=dns.btreezone.Zone
+    )
+    return cast(dns.btreezone.Zone, z)
+
+
+def do_test_node_flags(relativize: bool):
+    z = make_example(simple_zone, relativize)
+    n = cast(Node, z.get_node("@"))
+    assert not n.is_delegation()
+    assert not n.is_glue()
+    assert n.is_origin()
+    assert n.is_origin_or_glue()
+    assert n.is_immutable()
+    n = cast(Node, z.get_node("sub"))
+    assert n.is_delegation()
+    assert not n.is_glue()
+    assert not n.is_origin()
+    assert not n.is_origin_or_glue()
+    n = cast(Node, z.get_node("ns1.sub"))
+    assert not n.is_delegation()
+    assert n.is_glue()
+    assert not n.is_origin()
+    assert n.is_origin_or_glue()
+
+
+def test_node_flags_absolute():
+    do_test_node_flags(False)
+
+
+def test_node_flags_relative():
+    do_test_node_flags(True)
+
+
+def test_flags_in_constructor():
+    n = Node()
+    assert n.flags == 0
+    n = Node(dns.btreezone.NodeFlags.ORIGIN)
+    assert n.is_origin()
+
+
+def do_test_obscure_and_expose(relativize: bool):
+    z = make_example(simple_zone, relativize=relativize)
+    n = cast(Node, z.get_node("ns1.sub2"))
+    assert not n.is_delegation()
+    assert not n.is_glue()
+    assert not n.is_origin()
+    assert not n.is_origin_or_glue()
+    sub2_name = z._validate_name("sub2")
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        assert sub2_name not in version.delegations
+    rds = dns.rdataset.from_text("in", "ns", 300, "ns1.sub2", "ns2.sub2")
+    with z.writer() as txn:
+        txn.replace("sub2", rds)
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        assert sub2_name in version.delegations
+    n = cast(Node, z.get_node("ns1.sub2"))
+    assert not n.is_delegation()
+    assert n.is_glue()
+    assert not n.is_origin()
+    assert n.is_origin_or_glue()
+    with z.writer() as txn:
+        txn.delete("sub2")
+        txn.delete("ns2.sub2")  # for other coverage purposes!
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        assert sub2_name not in version.delegations
+    n = cast(Node, z.get_node("ns1.sub2"))
+    assert not n.is_delegation()
+    assert not n.is_glue()
+    assert not n.is_origin()
+    assert not n.is_origin_or_glue()
+    # repeat but delete just the rdataset
+    rds = dns.rdataset.from_text("in", "ns", 300, "ns1.sub2", "ns2.sub2")
+    with z.writer() as txn:
+        txn.replace("sub2", rds)
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        assert sub2_name in version.delegations
+    n = cast(Node, z.get_node("ns1.sub2"))
+    assert not n.is_delegation()
+    assert n.is_glue()
+    assert not n.is_origin()
+    assert n.is_origin_or_glue()
+    with z.writer() as txn:
+        txn.delete("sub2", "NS")
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        assert sub2_name not in version.delegations
+    n = cast(Node, z.get_node("ns1.sub2"))
+    assert not n.is_delegation()
+    assert not n.is_glue()
+    assert not n.is_origin()
+    assert not n.is_origin_or_glue()
+
+
+def test_obscure_and_expose_absolute():
+    do_test_obscure_and_expose(False)
+
+
+def test_obscure_and_expose_relative():
+    do_test_obscure_and_expose(True)
+
+
+def do_test_delegations(relativize: bool):
+    z = make_example(simple_zone, relativize=relativize)
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        name = z._validate_name("a.b.c.sub.example.")
+        delegation, is_glue = version.delegations.get_delegation(name)
+        assert delegation == z._validate_name("sub.example.")
+        assert is_glue
+        assert version.delegations.is_glue(name)
+        name = z._validate_name("sub.example.")
+        delegation, is_glue = version.delegations.get_delegation(name)
+        assert delegation == z._validate_name("sub.example.")
+        assert not is_glue
+        assert not version.delegations.is_glue(name)
+        name = z._validate_name("text.example.")
+        delegation, is_glue = version.delegations.get_delegation(name)
+        assert delegation is None
+        assert not is_glue
+        assert not version.delegations.is_glue(name)
+
+
+def test_delegations_absolute():
+    do_test_delegations(False)
+
+
+def test_delegations_relative():
+    do_test_delegations(True)
+
+
+def do_test_bounds(relativize: bool):
+    z = make_example(simple_zone, relativize=relativize)
+    with z.reader() as txn:
+        version = cast(dns.btreezone.ImmutableVersion, txn.version)
+        # tuple is (name, left, right, closest, is_equal, is_delegation)
+        tests = [
+            ("example.", "example.", "a.example.", "example.", True, False),
+            ("a.z.example.", "z.example.", None, "z.example.", False, False),
+            (
+                "a.b.a.example.",
+                "a.example.",
+                "c.b.a.example.",
+                "b.a.example.",
+                False,
+                False,
+            ),
+            (
+                "d.b.a.example.",
+                "c.b.a.example.",
+                "b.example.",
+                "b.a.example.",
+                False,
+                False,
+            ),
+            (
+                "d.c.b.a.example.",
+                "c.b.a.example.",
+                "b.example.",
+                "c.b.a.example.",
+                False,
+                False,
+            ),
+            (
+                "sub.example.",
+                "sub.example.",
+                "ns1.sub2.example.",
+                "sub.example.",
+                True,
+                True,
+            ),
+            (
+                "ns1.sub.example.",
+                "sub.example.",
+                "ns1.sub2.example.",
+                "sub.example.",
+                False,
+                True,
+            ),
+        ]
+        for name, left, right, closest, is_equal, is_delegation in tests:
+            name = z._validate_name(name)
+            left = z._validate_name(left)
+            if right is not None:
+                right = z._validate_name(right)
+            closest = z._validate_name(closest)
+            bounds = version.bounds(name)
+            print(bounds)
+            assert bounds.left == left
+            assert bounds.right == right
+            assert bounds.closest_encloser == closest
+            assert bounds.is_equal == is_equal
+            assert bounds.is_delegation == is_delegation
+
+
+def test_bounds_absolute():
+    do_test_bounds(False)
+
+
+def test_bounds_relative():
+    do_test_bounds(True)
index 35b5f42a8c14a1962310773c170fef41b3f21b48..dc0a225ec845c3f7758ca02358c7c568ba5c9ab9 100644 (file)
@@ -23,6 +23,7 @@ import unittest
 from io import BytesIO, StringIO
 from typing import cast
 
+import dns.btreezone
 import dns.exception
 import dns.message
 import dns.name
@@ -1172,9 +1173,11 @@ class ZoneTestCase(unittest.TestCase):
 
 
 class VersionedZoneTestCase(unittest.TestCase):
+    zone_factory = dns.versioned.Zone
+
     def testUseTransaction(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         with self.assertRaises(dns.versioned.UseTransaction):
             z.find_node("not_there", True)
@@ -1191,7 +1194,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testImmutableNodes(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         node = z.find_node("@")
         with self.assertRaises(TypeError):
@@ -1205,7 +1208,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testSelectDefaultPruningPolicy(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         z.set_pruning_policy(None)
         self.assertEqual(z._pruning_policy, z._default_pruning_policy)
@@ -1219,28 +1222,28 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testCannotSpecifyBothSerialAndVersionIdToReader(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         with self.assertRaises(ValueError):
             z.reader(1, 1)
 
     def testUnknownVersion(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         with self.assertRaises(KeyError):
             z.reader(99999)
 
     def testUnknownSerial(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         with self.assertRaises(KeyError):
             z.reader(serial=99999)
 
     def testNoRelativizeReader(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=False, zone_factory=self.zone_factory
         )
         with z.reader(serial=1) as txn:
             rds = txn.get("example.", "soa")
@@ -1248,7 +1251,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testNoRelativizeReaderOriginInText(self):
         z = dns.zone.from_text(
-            example_text, relativize=False, zone_factory=dns.versioned.Zone
+            example_text, relativize=False, zone_factory=self.zone_factory
         )
         with z.reader(serial=1) as txn:
             rds = txn.get("example.", "soa")
@@ -1256,7 +1259,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testNoRelativizeReaderAbsoluteGet(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=False, zone_factory=self.zone_factory
         )
         with z.reader(serial=1) as txn:
             rds = txn.get(dns.name.empty, "soa")
@@ -1264,7 +1267,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testCnameAndOtherDataAddOther(self):
         z = dns.zone.from_text(
-            example_cname, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_cname, "example.", relativize=True, zone_factory=self.zone_factory
         )
         rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1")
         with z.writer() as txn:
@@ -1290,7 +1293,7 @@ class VersionedZoneTestCase(unittest.TestCase):
             example_other_data,
             "example.",
             relativize=True,
-            zone_factory=dns.versioned.Zone,
+            zone_factory=self.zone_factory,
         )
         rds = dns.rdataset.from_text("in", "cname", 300, "www")
         with z.writer() as txn:
@@ -1305,7 +1308,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testGetSoa(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         soa = z.get_soa()
         self.assertTrue(soa.rdtype, dns.rdatatype.SOA)
@@ -1313,7 +1316,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testGetSoaTxn(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         with z.reader(serial=1) as txn:
             soa = z.get_soa(txn)
@@ -1327,7 +1330,7 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testGetRdataset1(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         rds = z.get_rdataset("@", "soa")
         exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5")
@@ -1335,11 +1338,15 @@ class VersionedZoneTestCase(unittest.TestCase):
 
     def testGetRdataset2(self):
         z = dns.zone.from_text(
-            example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+            example_text, "example.", relativize=True, zone_factory=self.zone_factory
         )
         rds = z.get_rdataset("@", "loc")
         self.assertTrue(rds is None)
 
 
+class BTreeZoneTestCase(VersionedZoneTestCase):
+    zone_factory = dns.btreezone.Zone
+
+
 if __name__ == "__main__":
     unittest.main()