--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""
+A BTree in the style of Cormen, Leiserson, and Rivest's "Algorithms" book, with
+copy-on-write node updates, cursors, and optional space optimization for mostly-in-order
+insertion.
+"""
+
+from collections.abc import MutableMapping, MutableSet
+from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, cast
+
+DEFAULT_T = 127
+
+KT = TypeVar("KT") # the type of a key in Element
+
+
+class Element(Generic[KT]):
+ """All items stored in the BTree are Elements."""
+
+ def key(self) -> KT:
+ """The key for this element; the returned type must implement comparison."""
+ raise NotImplementedError # pragma: no cover
+
+
+ET = TypeVar("ET", bound=Element) # the type of a value in a _KV
+
+
+def _MIN(t: int) -> int:
+ """The minimum number of keys in a non-root node for a BTree with the specified
+ ``t``
+ """
+ return t - 1
+
+
+def _MAX(t: int) -> int:
+ """The maximum number of keys in node for a BTree with the specified ``t``"""
+ return 2 * t - 1
+
+
+class _Creator:
+ """A _Creator class instance is used as a unique id for the BTree which created
+ a node.
+
+ We use a dedicated creator rather than just a BTree reference to avoid circularity
+ that would complicate GC.
+ """
+
+ def __str__(self): # pragma: no cover
+ return f"{id(self):x}"
+
+
+class _Node(Generic[KT, ET]):
+ """A Node in the BTree.
+
+ A Node (leaf or internal) of the BTree.
+ """
+
+ __slots__ = ["t", "creator", "is_leaf", "elts", "children"]
+
+ def __init__(self, t: int, creator: _Creator, is_leaf: bool):
+ assert t >= 3
+ self.t = t
+ self.creator = creator
+ self.is_leaf = is_leaf
+ self.elts: list[ET] = []
+ self.children: list[_Node[KT, ET]] = []
+
+ def is_maximal(self) -> bool:
+ """Does this node have the maximal number of keys?"""
+ assert len(self.elts) <= _MAX(self.t)
+ return len(self.elts) == _MAX(self.t)
+
+ def is_minimal(self) -> bool:
+ """Does this node have the minimal number of keys?"""
+ assert len(self.elts) >= _MIN(self.t)
+ return len(self.elts) == _MIN(self.t)
+
+ def search_in_node(self, key: KT) -> tuple[int, bool]:
+ """Get the index of the ``Element`` matching ``key`` or the index of its
+ least successor.
+
+ Returns a tuple of the index and an ``equal`` boolean that is ``True`` iff.
+ the key was found.
+ """
+ l = len(self.elts)
+ if l > 0 and key > self.elts[l - 1].key():
+ # This is optimizing near in-order insertion.
+ return l, False
+ l = 0
+ i = len(self.elts)
+ r = i - 1
+ equal = False
+ while l <= r:
+ m = (l + r) // 2
+ k = self.elts[m].key()
+ if key == k:
+ i = m
+ equal = True
+ break
+ elif key < k:
+ i = m
+ r = m - 1
+ else:
+ l = m + 1
+ return i, equal
+
+ def maybe_cow_child(self, index: int) -> "_Node[KT, ET]":
+ assert not self.is_leaf
+ child = self.children[index]
+ cloned = child.maybe_cow(self.creator)
+ if cloned:
+ self.children[index] = cloned
+ return cloned
+ else:
+ return child
+
+ def _get_node(self, key: KT) -> Tuple[Optional["_Node[KT, ET]"], int]:
+ """Get the node associated with key and its index, doing
+ copy-on-write if we have to descend.
+
+ Returns a tuple of the node and the index, or the tuple ``(None, 0)``
+ if the key was not found.
+ """
+ i, equal = self.search_in_node(key)
+ if equal:
+ return (self, i)
+ elif self.is_leaf:
+ return (None, 0)
+ else:
+ child = self.maybe_cow_child(i)
+ return child._get_node(key)
+
+ def get(self, key: KT) -> Optional[ET]:
+ """Get the element associated with *key* or return ``None``"""
+ i, equal = self.search_in_node(key)
+ if equal:
+ return self.elts[i]
+ elif self.is_leaf:
+ return None
+ else:
+ return self.children[i].get(key)
+
+ def optimize_in_order_insertion(self, index: int) -> None:
+ """Try to minimize the number of Nodes in a BTree where the insertion
+ is done in-order or close to it, by stealing as much as we can from our
+ right sibling.
+
+ If we don't do this, then an in-order insertion will produce a BTree
+ where most of the nodes are minimal.
+ """
+ if index == 0:
+ return
+ left = self.children[index - 1]
+ if len(left.elts) == _MAX(self.t):
+ return
+ left = self.maybe_cow_child(index - 1)
+ while len(left.elts) < _MAX(self.t):
+ if not left.try_right_steal(self, index - 1):
+ break
+
+ def insert_nonfull(self, element: ET, in_order: bool) -> Optional[ET]:
+ assert not self.is_maximal()
+ while True:
+ key = element.key()
+ i, equal = self.search_in_node(key)
+ if equal:
+ # replace
+ old = self.elts[i]
+ self.elts[i] = element
+ return old
+ elif self.is_leaf:
+ self.elts.insert(i, element)
+ return None
+ else:
+ child = self.maybe_cow_child(i)
+ if child.is_maximal():
+ self.adopt(*child.split())
+ # Splitting might result in our target moving to us, so
+ # search again.
+ continue
+ oelt = child.insert_nonfull(element, in_order)
+ if in_order:
+ self.optimize_in_order_insertion(i)
+ return oelt
+
+ def split(self) -> tuple["_Node[KT, ET]", ET, "_Node[KT, ET]"]:
+ """Split a maximal node into two minimal ones and a central element."""
+ assert self.is_maximal()
+ right = self.__class__(self.t, self.creator, self.is_leaf)
+ right.elts = list(self.elts[_MIN(self.t) + 1 :])
+ middle = self.elts[_MIN(self.t)]
+ self.elts = list(self.elts[: _MIN(self.t)])
+ if not self.is_leaf:
+ right.children = list(self.children[_MIN(self.t) + 1 :])
+ self.children = list(self.children[: _MIN(self.t) + 1])
+ return self, middle, right
+
+ def try_left_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
+ """Try to steal from this Node's left sibling for balancing purposes.
+
+ Returns ``True`` if the theft was successful, or ``False`` if not.
+ """
+ if index != 0:
+ left = parent.children[index - 1]
+ if not left.is_minimal():
+ left = parent.maybe_cow_child(index - 1)
+ elt = parent.elts[index - 1]
+ parent.elts[index - 1] = left.elts.pop()
+ self.elts.insert(0, elt)
+ if not left.is_leaf:
+ assert not self.is_leaf
+ child = left.children.pop()
+ self.children.insert(0, child)
+ return True
+ return False
+
+ def try_right_steal(self, parent: "_Node[KT, ET]", index: int) -> bool:
+ """Try to steal from this Node's right sibling for balancing purposes.
+
+ Returns ``True`` if the theft was successful, or ``False`` if not.
+ """
+ if index + 1 < len(parent.children):
+ right = parent.children[index + 1]
+ if not right.is_minimal():
+ right = parent.maybe_cow_child(index + 1)
+ elt = parent.elts[index]
+ parent.elts[index] = right.elts.pop(0)
+ self.elts.append(elt)
+ if not right.is_leaf:
+ assert not self.is_leaf
+ child = right.children.pop(0)
+ self.children.append(child)
+ return True
+ return False
+
+ def adopt(self, left: "_Node[KT, ET]", middle: ET, right: "_Node[KT, ET]") -> None:
+ """Adopt left, middle, and right into our Node (which must not be maximal,
+ and which must not be a leaf). In the case were we are not the new root,
+ then the left child must already be in the Node."""
+ assert not self.is_maximal()
+ assert not self.is_leaf
+ key = middle.key()
+ i, equal = self.search_in_node(key)
+ assert not equal
+ self.elts.insert(i, middle)
+ if len(self.children) == 0:
+ # We are the new root
+ self.children = [left, right]
+ else:
+ assert self.children[i] == left
+ self.children.insert(i + 1, right)
+
+ def merge(self, parent: "_Node[KT, ET]", index: int) -> None:
+ """Merge this node's parent and its right sibling into this node."""
+ right = parent.children.pop(index + 1)
+ self.elts.append(parent.elts.pop(index))
+ self.elts.extend(right.elts)
+ if not self.is_leaf:
+ self.children.extend(right.children)
+
+ def minimum(self) -> ET:
+ """The least element in this subtree."""
+ if self.is_leaf:
+ return self.elts[0]
+ else:
+ return self.children[0].minimum()
+
+ def maximum(self) -> ET:
+ """The greatest element in this subtree."""
+ if self.is_leaf:
+ return self.elts[-1]
+ else:
+ return self.children[-1].maximum()
+
+ def balance(self, parent: "_Node[KT, ET]", index: int) -> None:
+ """This Node is minimal, and we want to make it non-minimal so we can delete.
+ We try to steal from our siblings, and if that doesn't work we will merge
+ with one of them."""
+ assert not parent.is_leaf
+ if self.try_left_steal(parent, index):
+ return
+ if self.try_right_steal(parent, index):
+ return
+ # Stealing didn't work, so both siblings must be minimal.
+ if index == 0:
+ # We are the left-most node so merge with our right sibling.
+ self.merge(parent, index)
+ else:
+ # Have our left sibling merge with us. This lets us only have "merge right"
+ # code.
+ left = parent.maybe_cow_child(index - 1)
+ left.merge(parent, index - 1)
+
+ def delete(
+ self, key: KT, parent: Optional["_Node[KT, ET]"], exact: Optional[ET]
+ ) -> Optional[ET]:
+ """Delete an element matching *key* if it exists. If *exact* is not ``None``
+ then it must be an exact match with that element. The Node must not be
+ minimal unless it is the root."""
+ assert parent is None or not self.is_minimal()
+ i, equal = self.search_in_node(key)
+ original_key = None
+ if equal:
+ # Note we use "is" here as we meant "exactly this object".
+ if exact is not None and self.elts[i] is not exact:
+ raise ValueError("exact delete did not match existing elt")
+ if self.is_leaf:
+ return self.elts.pop(i)
+ # Note we need to ensure exact is None going forward as we've
+ # already checked exactness and are about to change our target key
+ # to the least successor.
+ exact = None
+ original_key = key
+ least_successor = self.children[i + 1].minimum()
+ key = least_successor.key()
+ i = i + 1
+ if self.is_leaf:
+ # No match
+ if exact is not None:
+ raise ValueError("exact delete had no match")
+ return None
+ # recursively delete in the appropriate child
+ child = self.maybe_cow_child(i)
+ if child.is_minimal():
+ child.balance(self, i)
+ # Things may have moved.
+ i, equal = self.search_in_node(key)
+ assert not equal
+ child = self.children[i]
+ assert not child.is_minimal()
+ elt = child.delete(key, self, exact)
+ if original_key is not None:
+ node, i = self._get_node(original_key)
+ assert node is not None
+ assert elt is not None
+ oelt = node.elts[i]
+ node.elts[i] = elt
+ elt = oelt
+ return elt
+
+ def visit_in_order(self, visit: Callable[[ET], None]) -> None:
+ """Call *visit* on all of the elements in order."""
+ for i, elt in enumerate(self.elts):
+ if not self.is_leaf:
+ self.children[i].visit_in_order(visit)
+ visit(elt)
+ if not self.is_leaf:
+ self.children[-1].visit_in_order(visit)
+
+ def _visit_preorder_by_node(self, visit: Callable[["_Node[KT, ET]"], None]) -> None:
+ """Visit nodes in preorder. This method is only used for testing."""
+ visit(self)
+ if not self.is_leaf:
+ for child in self.children:
+ child._visit_preorder_by_node(visit)
+
+ def maybe_cow(self, creator: _Creator) -> Optional["_Node[KT, ET]"]:
+ """Return a clone of this Node if it was not created by *creator*, or ``None``
+ otherwise (i.e. copy for copy-on-write if we haven't already copied it)."""
+ if self.creator is not creator:
+ return self.clone(creator)
+ else:
+ return None
+
+ def clone(self, creator: _Creator) -> "_Node[KT, ET]":
+ """Make a shallow-copy duplicate of this node."""
+ cloned = self.__class__(self.t, creator, self.is_leaf)
+ cloned.elts.extend(self.elts)
+ if not self.is_leaf:
+ cloned.children.extend(self.children)
+ return cloned
+
+ def __str__(self): # pragma: no cover
+ if not self.is_leaf:
+ children = " " + " ".join([f"{id(c):x}" for c in self.children])
+ else:
+ children = ""
+ return f"{id(self):x} {self.creator} {self.elts}{children}"
+
+
+class Cursor(Generic[KT, ET]):
+ """A seekable cursor for a BTree.
+
+ If you are going to use a cursor on a mutable BTree, you should use it
+ in a ``with`` block so that any mutations of the BTree automatically park
+ the cursor.
+ """
+
+ def __init__(self, btree: "BTree[KT, ET]"):
+ self.btree = btree
+ self.current_node: Optional[_Node] = None
+ # The current index is the element index within the current node, or
+ # if there is no current node then it is 0 on the left boundary and 1
+ # on the right boundary.
+ self.current_index: int = 0
+ self.recurse = False
+ self.increasing = True
+ self.parents: list[tuple[_Node, int]] = []
+ self.parked = False
+ self.parking_key: Optional[KT] = None
+ self.parking_key_read = False
+
+ def _seek_least(self) -> None:
+ # seek to the least value in the subtree beneath the current index of the
+ # current node
+ assert self.current_node is not None
+ while not self.current_node.is_leaf:
+ self.parents.append((self.current_node, self.current_index))
+ self.current_node = self.current_node.children[self.current_index]
+ assert self.current_node is not None
+ self.current_index = 0
+
+ def _seek_greatest(self) -> None:
+ # seek to the greatest value in the subtree beneath the current index of the
+ # current node
+ assert self.current_node is not None
+ while not self.current_node.is_leaf:
+ self.parents.append((self.current_node, self.current_index))
+ self.current_node = self.current_node.children[self.current_index]
+ assert self.current_node is not None
+ self.current_index = len(self.current_node.elts)
+
+ def park(self):
+ """Park the cursor.
+
+ A cursor must be "parked" before mutating the BTree to avoid undefined behavior.
+ Cursors created in a ``with`` block register with their BTree and will park
+ automatically. Note that a parked cursor may not observe some changes made when
+ it is parked; for example a cursor being iterated with next() will not see items
+ inserted before its current position.
+ """
+ if not self.parked:
+ self.parked = True
+
+ def _maybe_unpark(self):
+ if self.parked:
+ if self.parking_key is not None:
+ # remember our increasing hint, as seeking might change it
+ increasing = self.increasing
+ if self.parking_key_read:
+ # We've already returned the parking key, so we want to be before it
+ # if decreasing and after it if increasing.
+ before = not self.increasing
+ else:
+ # We haven't returned the parking key, so we've parked right
+ # after seeking or are on a boundary. Either way, the before
+ # hint we want is the value of self.increasing.
+ before = self.increasing
+ self.seek(self.parking_key, before)
+ self.increasing = increasing # might have been altered by seek()
+ self.parked = False
+ self.parking_key = None
+
+ def prev(self) -> Optional[ET]:
+ """Get the previous element, or return None if on the left boundary."""
+ self._maybe_unpark()
+ self.parking_key = None
+ if self.current_node is None:
+ # on a boundary
+ if self.current_index == 0:
+ # left boundary, there is no prev
+ return None
+ else:
+ assert self.current_index == 1
+ # right boundary; seek to the actual boundary
+ # so we can do a prev()
+ self.current_node = self.btree.root
+ self.current_index = len(self.btree.root.elts)
+ self._seek_greatest()
+ while True:
+ if self.recurse:
+ if not self.increasing:
+ # We only want to recurse if we are continuing in the decreasing
+ # direction.
+ self._seek_greatest()
+ self.recurse = False
+ self.increasing = False
+ self.current_index -= 1
+ if self.current_index >= 0:
+ elt = self.current_node.elts[self.current_index]
+ if not self.current_node.is_leaf:
+ self.recurse = True
+ self.parking_key = elt.key()
+ self.parking_key_read = True
+ return elt
+ else:
+ if len(self.parents) > 0:
+ self.current_node, self.current_index = self.parents.pop()
+ else:
+ self.current_node = None
+ self.current_index = 0
+ return None
+
+ def next(self) -> Optional[ET]:
+ """Get the next element, or return None if on the right boundary."""
+ self._maybe_unpark()
+ self.parking_key = None
+ if self.current_node is None:
+ # on a boundary
+ if self.current_index == 1:
+ # right boundary, there is no next
+ return None
+ else:
+ assert self.current_index == 0
+ # left boundary; seek to the actual boundary
+ # so we can do a next()
+ self.current_node = self.btree.root
+ self.current_index = 0
+ self._seek_least()
+ while True:
+ if self.recurse:
+ if self.increasing:
+ # We only want to recurse if we are continuing in the increasing
+ # direction.
+ self._seek_least()
+ self.recurse = False
+ self.increasing = True
+ if self.current_index < len(self.current_node.elts):
+ elt = self.current_node.elts[self.current_index]
+ self.current_index += 1
+ if not self.current_node.is_leaf:
+ self.recurse = True
+ self.parking_key = elt.key()
+ self.parking_key_read = True
+ return elt
+ else:
+ if len(self.parents) > 0:
+ self.current_node, self.current_index = self.parents.pop()
+ else:
+ self.current_node = None
+ self.current_index = 1
+ return None
+
+ def _adjust_for_before(self, before: bool, i: int) -> None:
+ if before:
+ self.current_index = i
+ else:
+ self.current_index = i + 1
+
+ def seek(self, key: KT, before: bool = True) -> None:
+ """Seek to the specified key.
+
+ If *before* is ``True`` (the default) then the cursor is positioned just
+ before *key* if it exists, or before its least successor if it doesn't. A
+ subsequent next() will retrieve this value. If *before* is ``False``, then
+ the cursor is positioned just after *key* if it exists, or its greatest
+ precessessor if it doesn't. A subsequent prev() will return this value.
+ """
+ self.current_node = self.btree.root
+ assert self.current_node is not None
+ self.recurse = False
+ self.parents = []
+ self.increasing = before
+ self.parked = False
+ self.parking_key = key
+ self.parking_key_read = False
+ while not self.current_node.is_leaf:
+ i, equal = self.current_node.search_in_node(key)
+ if equal:
+ self._adjust_for_before(before, i)
+ if before:
+ self._seek_greatest()
+ else:
+ self._seek_least()
+ return
+ self.parents.append((self.current_node, i))
+ self.current_node = self.current_node.children[i]
+ assert self.current_node is not None
+ i, equal = self.current_node.search_in_node(key)
+ if equal:
+ self._adjust_for_before(before, i)
+ else:
+ self.current_index = i
+
+ def seek_first(self) -> None:
+ """Seek to the left boundary (i.e. just before the least element).
+
+ A subsequent next() will return the least element if the BTree isn't empty."""
+ self.current_node = None
+ self.current_index = 0
+ self.recurse = False
+ self.increasing = True
+ self.parents = []
+ self.parked = False
+ self.parking_key = None
+
+ def seek_last(self) -> None:
+ """Seek to the right boundary (i.e. just after the greatest element).
+
+ A subsequent prev() will return the greatest element if the BTree isn't empty.
+ """
+ self.current_node = None
+ self.current_index = 1
+ self.recurse = False
+ self.increasing = False
+ self.parents = []
+ self.parked = False
+ self.parking_key = None
+
+ def __enter__(self):
+ self.btree.register_cursor(self)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.btree.deregister_cursor(self)
+ return False
+
+
+class Immutable(Exception):
+ """The BTree is immutable."""
+
+
+class BTree(Generic[KT, ET]):
+ """An in-memory BTree with copy-on-write and cursors."""
+
+ def __init__(self, *, t: int = DEFAULT_T, original: Optional["BTree"] = None):
+ """Create a BTree.
+
+ If *original* is not ``None``, then the BTree is shallow-cloned from
+ *original* using copy-on-write. Otherwise a new BTree with the specified
+ *t* value is created.
+
+ The BTree is not thread-safe.
+ """
+ # We don't use a reference to ourselves as a creator as we don't want
+ # to prevent GC of old btrees.
+ self.creator = _Creator()
+ self._immutable = False
+ self.t: int
+ self.root: _Node
+ self.size: int
+ self.cursors: set[Cursor] = set()
+ if original is not None:
+ if not original._immutable:
+ raise ValueError("original BTree is not immutable")
+ self.t = original.t
+ self.root = original.root
+ self.size = original.size
+ else:
+ if t < 3:
+ raise ValueError("t must be >= 3")
+ self.t = t
+ self.root = _Node(self.t, self.creator, True)
+ self.size = 0
+
+ def make_immutable(self):
+ """Make the BTree immutable.
+
+ Attempts to alter the BTree after making it immutable will raise an
+ Immutable exception. This operation cannot be undone.
+ """
+ if not self._immutable:
+ self._immutable = True
+
+ def _check_mutable_and_park(self) -> None:
+ if self._immutable:
+ raise Immutable
+ for cursor in self.cursors:
+ cursor.park()
+
+ # Note that we don't use insert() and delete() but rather insert_element() and
+ # delete_key() so that BTreeDict can be a proper MutableMapping and supply the
+ # rest of the standard mapping API.
+
+ def insert_element(self, elt: ET, in_order: bool = False) -> Optional[ET]:
+ """Insert the element into the BTree.
+
+ If *in_order* is ``True``, then extra work will be done to make left siblings
+ full, which optimizes storage space when the the elements are inserted in-order
+ or close to it.
+
+ Returns the previously existing element at the element's key or ``None``.
+ """
+ self._check_mutable_and_park()
+ cloned = self.root.maybe_cow(self.creator)
+ if cloned:
+ self.root = cloned
+ if self.root.is_maximal():
+ old_root = self.root
+ self.root = _Node(self.t, self.creator, False)
+ self.root.adopt(*old_root.split())
+ oelt = self.root.insert_nonfull(elt, in_order)
+ if oelt is None:
+ # We did not replace, so something was added.
+ self.size += 1
+ return oelt
+
+ def get_element(self, key: KT) -> Optional[ET]:
+ """Get the element matching *key* from the BTree, or return ``None`` if it
+ does not exist.
+ """
+ return self.root.get(key)
+
+ def _delete(self, key: KT, exact: Optional[ET]) -> Optional[ET]:
+ self._check_mutable_and_park()
+ cloned = self.root.maybe_cow(self.creator)
+ if cloned:
+ self.root = cloned
+ elt = self.root.delete(key, None, exact)
+ if elt is not None:
+ # We deleted something
+ self.size -= 1
+ if len(self.root.elts) == 0:
+ # The root is now empty. If there is a child, then collapse this root
+ # level and make the child the new root.
+ if not self.root.is_leaf:
+ assert len(self.root.children) == 1
+ self.root = self.root.children[0]
+ return elt
+
+ def delete_key(self, key: KT) -> Optional[ET]:
+ """Delete the element matching *key* from the BTree.
+
+ Returns the matching element or ``None`` if it does not exist.
+ """
+ return self._delete(key, None)
+
+ def delete_exact(self, element: ET) -> Optional[ET]:
+ """Delete *element* from the BTree.
+
+ Returns the matching element or ``None`` if it was not in the BTree.
+ """
+ delt = self._delete(element.key(), element)
+ assert delt is element
+ return delt
+
+ def __len__(self):
+ return self.size
+
+ def visit_in_order(self, visit: Callable[[ET], None]) -> None:
+ """Call *visit*(element) on all elements in the tree in sorted order."""
+ self.root.visit_in_order(visit)
+
+ def _visit_preorder_by_node(self, visit: Callable[[_Node], None]) -> None:
+ self.root._visit_preorder_by_node(visit)
+
+ def cursor(self) -> Cursor[KT, ET]:
+ """Create a cursor."""
+ return Cursor(self)
+
+ def register_cursor(self, cursor: Cursor) -> None:
+ """Register a cursor for the automatic parking service."""
+ self.cursors.add(cursor)
+
+ def deregister_cursor(self, cursor: Cursor) -> None:
+ """Deregister a cursor from the automatic parking service."""
+ self.cursors.discard(cursor)
+
+ def __copy__(self):
+ return self.__class__(original=self)
+
+ def __iter__(self):
+ with self.cursor() as cursor:
+ while True:
+ elt = cursor.next()
+ if elt is None:
+ break
+ yield elt.key()
+
+
+VT = TypeVar("VT") # the type of a value in a BTreeDict
+
+
+class KV(Element, Generic[KT, VT]):
+ """The BTree element type used in a ``BTreeDict``."""
+
+ def __init__(self, key: KT, value: VT):
+ self._key = key
+ self._value = value
+
+ def key(self) -> KT:
+ return self._key
+
+ def value(self) -> VT:
+ return self._value
+
+ def __str__(self): # pragma: no cover
+ return f"KV({self._key}, {self._value})"
+
+ def __repr__(self): # pragma: no cover
+ return f"KV({self._key}, {self._value})"
+
+
+class BTreeDict(Generic[KT, VT], BTree[KT, KV[KT, VT]], MutableMapping[KT, VT]):
+ """A MutableMapping implemented with a BTree.
+
+ Unlike a normal Python dict, the BTreeDict may be mutated while iterating.
+ """
+
+ def __init__(
+ self,
+ *,
+ t: int = DEFAULT_T,
+ original: Optional[BTree] = None,
+ in_order: bool = False,
+ ):
+ super().__init__(t=t, original=original)
+ self.in_order = in_order
+
+ def __getitem__(self, key: KT) -> VT:
+ elt = self.get_element(key)
+ if elt is None:
+ raise KeyError
+ else:
+ return cast(KV, elt).value()
+
+ def __setitem__(self, key: KT, value: VT) -> None:
+ elt = KV(key, value)
+ self.insert_element(elt, self.in_order)
+
+ def __delitem__(self, key: KT) -> None:
+ if self.delete_key(key) is None:
+ raise KeyError
+
+
+class Member(Element, Generic[KT]):
+ """The BTree element type used in a ``BTreeSet``."""
+
+ def __init__(self, key: KT):
+ self._key = key
+
+ def key(self) -> KT:
+ return self._key
+
+
+class BTreeSet(BTree, Generic[KT], MutableSet[KT]):
+ """A MutableSet implemented with a BTree.
+
+ Unlike a normal Python set, the BTreeSet may be mutated while iterating.
+ """
+
+ def __init__(
+ self,
+ *,
+ t: int = DEFAULT_T,
+ original: Optional[BTree] = None,
+ in_order: bool = False,
+ ):
+ super().__init__(t=t, original=original)
+ self.in_order = in_order
+
+ def __contains__(self, key: Any) -> bool:
+ return self.get_element(key) is not None
+
+ def add(self, value: KT) -> None:
+ elt = Member(value)
+ self.insert_element(elt, self.in_order)
+
+ def discard(self, value: KT) -> None:
+ self.delete_key(value)
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and
+# a separate per-version delegation index. These additions let us
+#
+# 1) Do efficient CoW versioning (useful for future online updates).
+# 2) Maintain sort order
+# 3) Allow delegations to be found easily
+# 4) Handle glue
+# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant. The ORIGIN
+# flag is set at the origin node, the DELEGATION FLAG is set at delegation
+# points, and the GLUE flag is set on nodes beneath delegation points.
+
+import enum
+from dataclasses import dataclass
+from typing import Callable, MutableMapping, Optional, Tuple, Union, cast
+
+import dns.btree
+import dns.immutable
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdataset
+import dns.rdatatype
+import dns.versioned
+import dns.zone
+
+
+class NodeFlags(enum.IntFlag):
+ ORIGIN = 0x01
+ DELEGATION = 0x02
+ GLUE = 0x04
+
+
+class Node(dns.node.Node):
+ __slots__ = ["flags", "id"]
+
+ def __init__(self, flags: Optional[NodeFlags] = None):
+ super().__init__()
+ if flags is None:
+ # We allow optional flags rather than a default
+ # as pyright doesn't like assigning a literal 0
+ # to flags.
+ flags = NodeFlags(0)
+ self.flags = flags
+ self.id = 0
+
+ def is_delegation(self):
+ return (self.flags & NodeFlags.DELEGATION) != 0
+
+ def is_glue(self):
+ return (self.flags & NodeFlags.GLUE) != 0
+
+ def is_origin(self):
+ return (self.flags & NodeFlags.ORIGIN) != 0
+
+ def is_origin_or_glue(self):
+ return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0
+
+
+@dns.immutable.immutable
+class ImmutableNode(Node):
+ def __init__(self, node: Node):
+ super().__init__()
+ self.id = node.id
+ self.rdatasets = tuple( # type: ignore
+ [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+ )
+ self.flags = node.flags
+
+ def find_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> dns.rdataset.Rdataset:
+ if create:
+ raise TypeError("immutable")
+ return super().find_rdataset(rdclass, rdtype, covers, False)
+
+ def get_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ create: bool = False,
+ ) -> Optional[dns.rdataset.Rdataset]:
+ if create:
+ raise TypeError("immutable")
+ return super().get_rdataset(rdclass, rdtype, covers, False)
+
+ def delete_rdataset(
+ self,
+ rdclass: dns.rdataclass.RdataClass,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
+ ) -> None:
+ raise TypeError("immutable")
+
+ def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
+ raise TypeError("immutable")
+
+ def is_immutable(self) -> bool:
+ return True
+
+
+class Delegations(dns.btree.BTreeSet[dns.name.Name]):
+ def get_delegation(
+ self, name: dns.name.Name
+ ) -> Tuple[Optional[dns.name.Name], bool]:
+ """Get the delegation applicable to *name*, if it exists.
+
+ If there delegation, then return a tuple consisting of the name of
+ the delegation point, and a boolean which is `True` if the name is a proper
+ subdomain of the delegation point, and `False` if it is equal to the delegation
+ point.
+ """
+ cursor = self.cursor()
+ cursor.seek(name, before=False)
+ prev = cursor.prev()
+ if prev is None:
+ return None, False
+ cut = prev.key()
+ reln, _, _ = name.fullcompare(cut)
+ is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN
+ if is_subdomain or reln == dns.name.NameRelation.EQUAL:
+ return cut, is_subdomain
+ else:
+ return None, False
+
+ def is_glue(self, name: dns.name.Name) -> bool:
+ """Is *name* glue, i.e. is it beneath a delegation?"""
+ cursor = self.cursor()
+ cursor.seek(name, before=False)
+ cut, is_subdomain = self.get_delegation(name)
+ if cut is None:
+ return False
+ return is_subdomain
+
+
+class WritableVersion(dns.zone.WritableVersion):
+ def __init__(self, zone: dns.zone.Zone, replacement: bool = False):
+ super().__init__(zone, True)
+ if not replacement:
+ assert isinstance(zone, dns.versioned.Zone)
+ version = zone._versions[-1]
+ self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict(
+ original=version.nodes # type: ignore
+ )
+ self.delegations = Delegations(original=version.delegations) # type: ignore
+ else:
+ self.delegations = Delegations()
+
+ def _is_origin(self, name: dns.name.Name) -> bool:
+ # Assumes name has already been validated (and thus adjusted to the right
+ # relativity too)
+ if self.zone.relativize:
+ return name == dns.name.empty
+ else:
+ return name == self.zone.origin
+
+ def _maybe_cow_with_name(
+ self, name: dns.name.Name
+ ) -> Tuple[dns.node.Node, dns.name.Name]:
+ (node, name) = super()._maybe_cow_with_name(name)
+ node = cast(Node, node)
+ if self._is_origin(name):
+ node.flags |= NodeFlags.ORIGIN
+ elif self.delegations.is_glue(name):
+ node.flags |= NodeFlags.GLUE
+ return (node, name)
+
+ def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None:
+ cursor = self.nodes.cursor() # type: ignore
+ cursor.seek(name, False)
+ updates = []
+ while True:
+ elt = cursor.next()
+ if elt is None:
+ break
+ ename = elt.key()
+ if not ename.is_subdomain(name):
+ break
+ node = cast(dns.node.Node, elt.value())
+ if ename not in self.changed:
+ new_node = self.zone.node_factory()
+ new_node.id = self.id # type: ignore
+ new_node.rdatasets.extend(node.rdatasets)
+ self.changed.add(ename)
+ node = new_node
+ assert isinstance(node, Node)
+ if is_glue:
+ node.flags |= NodeFlags.GLUE
+ else:
+ node.flags &= ~NodeFlags.GLUE
+ # We don't update node here as any insertion could disturb the
+ # btree and invalidate our cursor. We could use the cursor in a
+ # with block and avoid this, but it would do a lot of parking and
+ # unparking so the deferred update mode may still be better.
+ updates.append((ename, node))
+ for ename, node in updates:
+ self.nodes[ename] = node
+
+ def delete_node(self, name: dns.name.Name) -> None:
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is not None:
+ if node.is_delegation(): # type: ignore
+ self.delegations.discard(name)
+ self.update_glue_flag(name, False)
+ del self.nodes[name]
+ self.changed.add(name)
+
+ def put_rdataset(
+ self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
+ ) -> None:
+ (node, name) = self._maybe_cow_with_name(name)
+ if (
+ rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue() # type: ignore
+ ):
+ node.flags |= NodeFlags.DELEGATION # type: ignore
+ if name not in self.delegations:
+ self.delegations.add(name)
+ self.update_glue_flag(name, True)
+ node.replace_rdataset(rdataset)
+
+ def delete_rdataset(
+ self,
+ name: dns.name.Name,
+ rdtype: dns.rdatatype.RdataType,
+ covers: dns.rdatatype.RdataType,
+ ) -> None:
+ (node, name) = self._maybe_cow_with_name(name)
+ if rdtype == dns.rdatatype.NS and name in self.delegations: # type: ignore
+ node.flags &= ~NodeFlags.DELEGATION # type: ignore
+ self.delegations.discard(name) # type: ignore
+ self.update_glue_flag(name, False)
+ node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+ if len(node) == 0:
+ del self.nodes[name]
+
+
+@dataclass(frozen=True)
+class Bounds:
+ name: dns.name.Name
+ left: dns.name.Name
+ right: Optional[dns.name.Name]
+ closest_encloser: dns.name.Name
+ is_equal: bool
+ is_delegation: bool
+
+ def __str__(self):
+ if self.is_equal:
+ op = "="
+ else:
+ op = "<"
+ if self.is_delegation:
+ zonecut = " zonecut"
+ else:
+ zonecut = ""
+ return (
+ f"{self.left} {op} {self.name} < {self.right}{zonecut}; "
+ f"{self.closest_encloser}"
+ )
+
+
+@dns.immutable.immutable
+class ImmutableVersion(dns.zone.Version):
+ def __init__(self, version: dns.zone.Version):
+ if not isinstance(version, WritableVersion):
+ raise ValueError(
+ "a dns.btreezone.ImmutableVersion requires a "
+ "dns.btreezone.WritableVersion"
+ )
+ super().__init__(version.zone, True)
+ self.id = version.id
+ self.origin = version.origin
+ for name in version.changed:
+ node = version.nodes.get(name)
+ if node:
+ version.nodes[name] = ImmutableNode(node)
+ # the cast below is for mypy
+ self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes)
+ self.nodes.make_immutable() # type: ignore
+ self.delegations = version.delegations
+ self.delegations.make_immutable()
+
+ def bounds(self, name: Union[dns.name.Name, str]) -> Bounds:
+ """Return the 'bounds' of *name* in its zone.
+
+ The bounds information is useful when making an authoritative response, as
+ it can be used to determine whether the query name is at or beneath a delegation
+ point. The other data in the ``Bounds`` object is useful for making on-the-fly
+ DNSSEC signatures.
+
+ The left bound of *name* is *name* itself if it is in the zone, or the greatest
+ predecessor which is in the zone.
+
+ The right bound of *name* is the least successor of *name*, or ``None`` if
+ no name in the zone is greater than *name*.
+
+ The closest encloser of *name* is *name* itself, if *name* is in the zone;
+ otherwise it is the name with the largest number of labels in common with
+ *name* that is in the zone, either explicitly or by the implied existence
+ of empty non-terminals.
+
+ The bounds *is_equal* field is ``True`` if and only if *name* is equal to
+ its left bound.
+
+ The bounds *is_delegation* field is ``True`` if and only if the left bound is a
+ delegation point.
+ """
+ assert self.origin is not None
+ # validate the origin because we may need to relativize
+ origin = self.zone._validate_name(self.origin)
+ name = self.zone._validate_name(name)
+ cut, _ = self.delegations.get_delegation(name)
+ if cut is not None:
+ target = cut
+ is_delegation = True
+ else:
+ target = name
+ is_delegation = False
+ c = cast(dns.btree.BTreeDict, self.nodes).cursor()
+ c.seek(target, False)
+ left = c.prev()
+ assert left is not None
+ c.next() # skip over left
+ while True:
+ right = c.next()
+ if right is None or not right.value().is_glue():
+ break
+ left_comparison = left.key().fullcompare(name)
+ if right is not None:
+ right_key = right.key()
+ right_comparison = right_key.fullcompare(name)
+ else:
+ right_comparison = (
+ dns.name.NAMERELN_COMMONANCESTOR,
+ -1,
+ len(origin),
+ )
+ right_key = None
+ closest_encloser = dns.name.Name(
+ name[-max(left_comparison[2], right_comparison[2]) :]
+ )
+ return Bounds(
+ name,
+ left.key(),
+ right_key,
+ closest_encloser,
+ left_comparison[0] == dns.name.NameRelation.EQUAL,
+ is_delegation,
+ )
+
+
+class Zone(dns.versioned.Zone):
+ node_factory: Callable[[], dns.node.Node] = Node
+ map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast(
+ Callable[[], MutableMapping[dns.name.Name, dns.node.Node]],
+ dns.btree.BTreeDict[dns.name.Name, Node],
+ )
+ writable_version_factory: Optional[
+ Callable[[dns.zone.Zone, bool], dns.zone.Version]
+ ] = WritableVersion
+ immutable_version_factory: Optional[
+ Callable[[dns.zone.Version], dns.zone.Version]
+ ] = ImmutableVersion
"_readers",
]
- node_factory = Node
+ node_factory: Callable[[], dns.node.Node] = Node
def __init__(
self,
node_factory: Callable[[], dns.node.Node] = dns.node.Node
map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict
- writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None
- immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None
+ # We only require the version types as "Version" to allow for flexibility, as
+ # only the version protocol matters
+ writable_version_factory: Optional[Callable[["Zone", bool], "Version"]] = None
+ immutable_version_factory: Optional[Callable[["Version"], "Version"]] = None
__slots__ = ["rdclass", "origin", "nodes", "relativize"]
self.origin = zone.origin
self.changed: Set[dns.name.Name] = set()
- def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node:
+ def _maybe_cow_with_name(
+ self, name: dns.name.Name
+ ) -> Tuple[dns.node.Node, dns.name.Name]:
name = self._validate_name(name)
node = self.nodes.get(name)
if node is None or name not in self.changed:
new_node.rdatasets.extend(node.rdatasets)
self.nodes[name] = new_node
self.changed.add(name)
- return new_node
+ return (new_node, name)
else:
- return node
+ return (node, name)
+
+ def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node:
+ return self._maybe_cow_with_name(name)[0]
def delete_node(self, name: dns.name.Name) -> None:
name = self._validate_name(name)
@dns.immutable.immutable
class ImmutableVersion(Version):
- def __init__(self, version: WritableVersion):
+ def __init__(self, version: Version):
+ if not isinstance(version, WritableVersion):
+ raise ValueError(
+ "a dns.zone.ImmutableVersion requires a dns.zone.WritableVersion"
+ )
# We tell super() that it's a replacement as we don't want it
# to copy the nodes, as we're about to do that with an
# immutable Dict.
--- /dev/null
+import copy
+
+import pytest
+
+import dns.btree as btree
+
+
+class BTreeDict(btree.BTreeDict):
+ # We mostly test with an in-order optimized BTreeDict with t=3 as that's how we
+ # generated the data.
+ def __init__(self, *, t=3, original=None):
+ super().__init__(t=t, original=original, in_order=True)
+
+
+def add_keys(b, keys):
+ if isinstance(keys, int):
+ keys = range(keys)
+ for key in keys:
+ b[key] = True
+
+
+def test_replace():
+ N = 8
+ b = BTreeDict()
+ add_keys(b, N)
+ b[0] = False
+ b[5] = False
+ b[7] = False
+ for key in range(N):
+ if key in {0, 5, 7}:
+ assert b[key] == False
+ else:
+ assert b[key] == True
+
+
+def test_min_max():
+ N = 8
+ b = BTreeDict()
+ add_keys(b, N)
+ assert b.root.minimum().key() == 0
+ assert b.root.maximum().key() == N - 1
+ del b[N - 1]
+ del b[0]
+ assert b.root.minimum().key() == 1
+ assert b.root.maximum().key() == N - 2
+
+
+def test_nonexistent():
+ N = 8
+ b = BTreeDict()
+ add_keys(b, N)
+ with pytest.raises(KeyError):
+ b[1.5] == False
+ assert b.delete_key(1.5) is None
+ with pytest.raises(KeyError):
+ del b[1.5]
+
+
+def test_in_order():
+ N = 100
+ b = BTreeDict()
+ add_keys(b, N)
+ expected = list(range(N))
+ assert list(b.keys()) == expected
+
+ keys = list(range(N - 1, -1, -1))
+ expected = N
+ for key in keys:
+ l = len(b)
+ assert l == expected
+ expected -= 1
+ del b[key]
+ assert len(b) == 0
+
+
+# Some key orderings generated randomly but hardcoded here for test stability.
+# The keys lead to 100% coverage in insert, find, and delete.
+
+random_keys_1 = [
+ 36,
+ 14,
+ 89,
+ 67,
+ 80,
+ 98,
+ 71,
+ 29,
+ 92,
+ 91,
+ 79,
+ 49,
+ 63,
+ 74,
+ 19,
+ 4,
+ 23,
+ 60,
+ 10,
+ 31,
+ 94,
+ 46,
+ 18,
+ 84,
+ 61,
+ 42,
+ 77,
+ 54,
+ 76,
+ 38,
+ 26,
+ 37,
+ 24,
+ 99,
+ 45,
+ 7,
+ 97,
+ 32,
+ 53,
+ 96,
+ 82,
+ 52,
+ 8,
+ 58,
+ 11,
+ 3,
+ 15,
+ 47,
+ 17,
+ 21,
+ 28,
+ 2,
+ 20,
+ 12,
+ 95,
+ 44,
+ 16,
+ 9,
+ 51,
+ 30,
+ 33,
+ 34,
+ 88,
+ 55,
+ 43,
+ 72,
+ 57,
+ 66,
+ 22,
+ 56,
+ 68,
+ 87,
+ 73,
+ 6,
+ 25,
+ 59,
+ 0,
+ 75,
+ 90,
+ 78,
+ 50,
+ 13,
+ 83,
+ 93,
+ 39,
+ 81,
+ 41,
+ 70,
+ 48,
+ 35,
+ 65,
+ 64,
+ 62,
+ 5,
+ 27,
+ 86,
+ 40,
+ 1,
+ 85,
+ 69,
+]
+
+random_keys_2 = [
+ 49,
+ 28,
+ 0,
+ 19,
+ 14,
+ 76,
+ 65,
+ 8,
+ 12,
+ 90,
+ 71,
+ 36,
+ 31,
+ 24,
+ 83,
+ 59,
+ 98,
+ 48,
+ 26,
+ 82,
+ 46,
+ 84,
+ 80,
+ 33,
+ 74,
+ 75,
+ 60,
+ 99,
+ 20,
+ 61,
+ 88,
+ 81,
+ 41,
+ 58,
+ 85,
+ 54,
+ 96,
+ 23,
+ 72,
+ 66,
+ 1,
+ 37,
+ 57,
+ 64,
+ 27,
+ 13,
+ 40,
+ 73,
+ 69,
+ 32,
+ 55,
+ 34,
+ 5,
+ 2,
+ 39,
+ 9,
+ 93,
+ 50,
+ 47,
+ 92,
+ 79,
+ 78,
+ 63,
+ 10,
+ 30,
+ 77,
+ 87,
+ 53,
+ 7,
+ 56,
+ 21,
+ 18,
+ 62,
+ 6,
+ 11,
+ 95,
+ 70,
+ 44,
+ 42,
+ 97,
+ 35,
+ 91,
+ 43,
+ 16,
+ 89,
+ 45,
+ 67,
+ 4,
+ 22,
+ 17,
+ 25,
+ 51,
+ 94,
+ 52,
+ 68,
+ 3,
+ 15,
+ 86,
+ 38,
+ 29,
+]
+
+
+def test_random_trees():
+ N = len(random_keys_1)
+ b = BTreeDict()
+ add_keys(b, random_keys_1)
+ expected = list(range(N))
+ assert list(b.keys()) == expected
+
+ for key in random_keys_1:
+ assert b[key]
+
+ keys = random_keys_2
+ expected_len = N
+ for key in keys:
+ l = len(b)
+ assert l == expected_len
+ expected_len -= 1
+ del b[key]
+ assert len(b) == 0
+
+
+def test_random_trees_no_in_order_optimization():
+ N = len(random_keys_1)
+ b = btree.BTreeDict(t=3)
+ add_keys(b, random_keys_1)
+ expected = list(range(N))
+ assert list(b.keys()) == expected
+
+ for key in random_keys_1:
+ assert b[key]
+
+ keys = random_keys_2
+ expected_len = N
+ for key in keys:
+ l = len(b)
+ assert l == expected_len
+ expected_len -= 1
+ del b[key]
+ assert len(b) == 0
+
+
+def node_set(b):
+ s = set()
+ b._visit_preorder_by_node(lambda n: s.add(n))
+ return s
+
+
+def test_cow():
+ N = len(random_keys_1)
+ b = BTreeDict()
+ add_keys(b, random_keys_1)
+ expected = list(range(N))
+ assert list(b.keys()) == expected
+ nsb = node_set(b)
+
+ with pytest.raises(ValueError):
+ d = BTreeDict(original=b)
+ b.make_immutable()
+ with pytest.raises(btree.Immutable):
+ b[100] = True
+ with pytest.raises(btree.Immutable):
+ del b[1]
+
+ b2 = BTreeDict(original=b)
+ keys = random_keys_2
+ expected_len = N
+ for key in keys:
+ l = len(b2)
+ assert l == expected_len
+ expected_len -= 1
+ del b2[key]
+ assert len(b2) == 0
+ b2[100] = True
+ b2[101] = True
+ assert list(b2.keys()) == [100, 101]
+
+ # and b is unchanged
+ assert list(b.keys()) == expected
+ nsb2 = node_set(b)
+ assert nsb == nsb2
+
+ # copy should be the same as b
+ b3 = copy.copy(b)
+ assert list(b.keys()) == expected
+ nsb3 = node_set(b3)
+ assert nsb == nsb3
+
+
+def test_cow_minimality():
+ b = BTreeDict()
+ add_keys(b, 8)
+ b.make_immutable()
+ b2 = BTreeDict(original=b)
+
+ assert b.root is b2.root
+ b2[7] = 100
+ assert b.root is not b2.root
+ assert b.root.children[0] is b2.root.children[0]
+ assert b.root.children[1] is not b2.root.children[1]
+ del b2[5]
+ assert b.root is not b2.root
+ assert b.root.children[0] is not b2.root.children[0]
+ assert b.root.children[1] is not b2.root.children[1]
+
+
+def test_cursor_seek():
+ N = len(random_keys_1)
+ b = BTreeDict()
+ add_keys(b, random_keys_1)
+
+ l = []
+ c = b.cursor()
+ while True:
+ elt = c.next()
+ if elt is None:
+ break
+ else:
+ l.append(elt.key())
+ expected = list(range(N))
+ assert l == expected
+ assert c.next() is None
+
+ # same as previous but with explicit seek_first()
+ l = []
+ c = b.cursor()
+ c.seek_first()
+ while True:
+ elt = c.next()
+ if elt is None:
+ break
+ else:
+ l.append(elt.key())
+ expected = list(range(N))
+ assert l == expected
+ assert c.next() is None
+
+ l = []
+ c = b.cursor()
+ c.seek_last()
+ while True:
+ elt = c.prev()
+ if elt is None:
+ break
+ else:
+ l.append(elt.key())
+ expected = list(range(N - 1, -1, -1))
+ assert l == expected
+ assert c.prev() is None
+
+
+def test_cursor_seek_before_and_after():
+ N = 8
+ b = BTreeDict()
+ add_keys(b, N)
+
+ c = b.cursor()
+
+ # Seek before, leaf
+ c.seek(2)
+ assert c.next().key() == 2
+ assert c.prev().key() == 2
+
+ # Seek before, parent
+ c.seek(5)
+ assert c.next().key() == 5
+ c.seek(5)
+ assert c.prev().key() == 4
+
+ # Seek after, leaf
+ c.seek(2, False)
+ assert c.next().key() == 3
+ c.seek(2, False)
+ assert c.prev().key() == 2
+
+ # Seek after, leaf
+ c.seek(2, False)
+ assert c.next().key() == 3
+ c.seek(2, False)
+ assert c.prev().key() == 2
+
+ # Seek after, parent
+ c.seek(5, False)
+ assert c.next().key() == 6
+ c.seek(5, False)
+ assert c.prev().key() == 5
+
+ # Nonexistent
+ c.seek(2.5)
+ assert c.next().key() == 3
+ c.seek(2.5)
+ assert c.prev().key() == 2
+ c.seek(4.5)
+ assert c.next().key() == 5
+ c.seek(5.5)
+ assert c.prev().key() == 5
+
+
+def test_cursor_reversing_in_parentnode():
+ N = 11
+ b = BTreeDict()
+ add_keys(b, N)
+ c = b.cursor()
+ c.seek(5)
+ assert c.next().key() == 5
+ assert c.prev().key() == 5
+ c.seek(5, False)
+ assert c.prev().key() == 5
+ assert c.next().key() == 5
+
+
+def test_cursor_empty_tree_seeks():
+ b = BTreeDict()
+ c = b.cursor()
+ c.seek(5)
+ assert c.next() == None
+ assert c.prev() == None
+
+
+def test_parking():
+ N = 11
+ b = BTreeDict()
+ add_keys(b, N)
+ c = b.cursor()
+
+ c.seek(5)
+ assert c.next().key() == 5
+ c.park()
+ assert c.next().key() == 6
+
+ c.seek(5)
+ assert c.prev().key() == 4
+ c.park()
+ assert c.prev().key() == 3
+
+ c.seek(5)
+ assert c.next().key() == 5
+ c.park()
+ assert c.prev().key() == 5
+
+ c.seek(5)
+ assert c.prev().key() == 4
+ c.park()
+ assert c.next().key() == 4
+
+ expected = list(range(11))
+
+ c.seek_first()
+ got = []
+ while True:
+ e = c.next()
+ if e is None:
+ break
+ got.append(e.key())
+ c.park()
+ assert got == expected
+
+ c.seek_last()
+ got = []
+ while True:
+ e = c.prev()
+ if e is None:
+ break
+ got.append(e.key())
+ c.park()
+ assert got == list(reversed(expected))
+
+ # parking on the boundary
+
+ c.seek_first()
+ c.park()
+ assert c.next().key() == 0
+
+ c.seek_last()
+ c.park()
+ assert c.prev().key() == 10
+
+ c.seek(0)
+ c.park()
+ assert c.next().key() == 0
+
+ c.seek(10, False)
+ c.park()
+ assert c.prev().key() == 10
+
+ # double parking (parking idempotency)
+
+ c.seek_first()
+ c.park()
+ c.park()
+ assert c.next().key() == 0
+
+ # mutation
+
+ c.seek(5)
+ c.park()
+ b[4.5] = True
+ assert c.next().key() == 5
+
+ c.seek(5)
+ c.park()
+ b[5.5] = True
+ assert c.next().key() == 5
+
+ c.seek(5)
+ c.park()
+ del b[5]
+ assert c.next().key() == 5.5
+
+ c.seek(5)
+ c.park()
+ b[4.49] = True
+ b[4.51] = True
+ assert c.prev().key() == 4.51
+
+ c.seek(5)
+ c.park()
+ b[5.49] = True
+ b[5.51] = True
+ assert c.next().key() == 5.49
+
+
+def test_automatic_parking():
+ N = 11
+ b = BTreeDict()
+ add_keys(b, N)
+
+ with b.cursor() as c:
+ assert c in b.cursors
+ c.seek(5)
+ assert c.next().key() == 5
+ b[5.5] = True
+ assert c.next().key() == 5.5
+ assert len(b.cursors) == 0
+
+
+def test_visit_in_order():
+ N = 100
+ b = BTreeDict()
+ add_keys(b, N)
+
+ l = []
+ b.visit_in_order(lambda elt: l.append(elt.key()))
+ assert l == list(range(N))
+
+
+def test_visit_preorder_by_node():
+ N = 100
+ b = BTreeDict()
+ add_keys(b, N)
+
+ kl = []
+ b._visit_preorder_by_node(lambda node: kl.append([elt.key() for elt in node.elts]))
+ expected = [
+ [35, 71],
+ [5, 11, 17, 23, 29],
+ [0, 1, 2, 3, 4],
+ [6, 7, 8, 9, 10],
+ [12, 13, 14, 15, 16],
+ [18, 19, 20, 21, 22],
+ [24, 25, 26, 27, 28],
+ [30, 31, 32, 33, 34],
+ [41, 47, 53, 59, 65],
+ [36, 37, 38, 39, 40],
+ [42, 43, 44, 45, 46],
+ [48, 49, 50, 51, 52],
+ [54, 55, 56, 57, 58],
+ [60, 61, 62, 63, 64],
+ [66, 67, 68, 69, 70],
+ [77, 83, 89, 95],
+ [72, 73, 74, 75, 76],
+ [78, 79, 80, 81, 82],
+ [84, 85, 86, 87, 88],
+ [90, 91, 92, 93, 94],
+ [96, 97, 98, 99],
+ ]
+ assert kl == expected
+
+
+def test_exact_delete():
+ N = 8
+ b = BTreeDict()
+ add_keys(b, N)
+ for key in [5, 0, 7]:
+ elt = b.get_element(key)
+ assert elt is not None
+ bogus_elt = btree.KV(key, False)
+ with pytest.raises(ValueError):
+ b.delete_exact(bogus_elt)
+ b.delete_exact(elt)
+ with pytest.raises(ValueError):
+ b.delete_exact(elt)
+
+
+def test_find_nonexistent_node():
+ # This is just for 100% coverage
+ b = BTreeDict()
+ b[0] = True
+ assert b.root._get_node(0) == (b.root, 0)
+ assert b.root._get_node(1) == (None, 0)
+
+
+def test_t_too_small():
+ with pytest.raises(ValueError):
+ BTreeDict(t=2)
+
+
+def test_immutable_idempotent():
+ # Again just for coverage.
+ b = BTreeDict()
+ b.make_immutable()
+ assert b._immutable
+ b.make_immutable()
+ assert b._immutable
+
+
+def test_btree_set():
+ b: btree.BTreeSet[int] = btree.BTreeSet()
+ b.add(1)
+ assert 1 in b
+ b.add(2)
+ assert 1 in b
+ assert 2 in b
+ assert len(b) == 2
+ assert list(b) == [1, 2]
+ b.discard(1)
+ assert 1 not in b
+ assert 2 in b
+ assert len(b) == 1
+ assert list(b) == [2]
+ b.discard(1)
+ assert 1 not in b
+ assert 2 in b
+ assert len(b) == 1
+ b.discard(2)
+ assert 1 not in b
+ assert 2 not in b
+ assert len(b) == 0
+ assert list(b) == []
--- /dev/null
+from typing import cast
+
+import dns.btreezone
+import dns.rdataset
+import dns.zone
+
+Node = dns.btreezone.Node
+
+simple_zone = """
+$ORIGIN example.
+$TTL 300
+@ soa foo bar 1 2 3 4 5
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+a txt "a"
+c.b.a txt "cba"
+b txt "b"
+sub ns ns1.sub
+sub ns ns2.sub
+ns1.sub a 10.0.0.3
+ns2.sub a 10.0.0.4
+ns1.sub2 a 10.0.0.5
+ns2.sub2 a 10.0.0.6
+text txt "here to be after sub2"
+z txt "z"
+"""
+
+
+def make_example(text: str, relativize: bool = False) -> dns.btreezone.Zone:
+ z = dns.zone.from_text(
+ simple_zone, "example.", relativize=relativize, zone_factory=dns.btreezone.Zone
+ )
+ return cast(dns.btreezone.Zone, z)
+
+
+def do_test_node_flags(relativize: bool):
+ z = make_example(simple_zone, relativize)
+ n = cast(Node, z.get_node("@"))
+ assert not n.is_delegation()
+ assert not n.is_glue()
+ assert n.is_origin()
+ assert n.is_origin_or_glue()
+ assert n.is_immutable()
+ n = cast(Node, z.get_node("sub"))
+ assert n.is_delegation()
+ assert not n.is_glue()
+ assert not n.is_origin()
+ assert not n.is_origin_or_glue()
+ n = cast(Node, z.get_node("ns1.sub"))
+ assert not n.is_delegation()
+ assert n.is_glue()
+ assert not n.is_origin()
+ assert n.is_origin_or_glue()
+
+
+def test_node_flags_absolute():
+ do_test_node_flags(False)
+
+
+def test_node_flags_relative():
+ do_test_node_flags(True)
+
+
+def test_flags_in_constructor():
+ n = Node()
+ assert n.flags == 0
+ n = Node(dns.btreezone.NodeFlags.ORIGIN)
+ assert n.is_origin()
+
+
+def do_test_obscure_and_expose(relativize: bool):
+ z = make_example(simple_zone, relativize=relativize)
+ n = cast(Node, z.get_node("ns1.sub2"))
+ assert not n.is_delegation()
+ assert not n.is_glue()
+ assert not n.is_origin()
+ assert not n.is_origin_or_glue()
+ sub2_name = z._validate_name("sub2")
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ assert sub2_name not in version.delegations
+ rds = dns.rdataset.from_text("in", "ns", 300, "ns1.sub2", "ns2.sub2")
+ with z.writer() as txn:
+ txn.replace("sub2", rds)
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ assert sub2_name in version.delegations
+ n = cast(Node, z.get_node("ns1.sub2"))
+ assert not n.is_delegation()
+ assert n.is_glue()
+ assert not n.is_origin()
+ assert n.is_origin_or_glue()
+ with z.writer() as txn:
+ txn.delete("sub2")
+ txn.delete("ns2.sub2") # for other coverage purposes!
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ assert sub2_name not in version.delegations
+ n = cast(Node, z.get_node("ns1.sub2"))
+ assert not n.is_delegation()
+ assert not n.is_glue()
+ assert not n.is_origin()
+ assert not n.is_origin_or_glue()
+ # repeat but delete just the rdataset
+ rds = dns.rdataset.from_text("in", "ns", 300, "ns1.sub2", "ns2.sub2")
+ with z.writer() as txn:
+ txn.replace("sub2", rds)
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ assert sub2_name in version.delegations
+ n = cast(Node, z.get_node("ns1.sub2"))
+ assert not n.is_delegation()
+ assert n.is_glue()
+ assert not n.is_origin()
+ assert n.is_origin_or_glue()
+ with z.writer() as txn:
+ txn.delete("sub2", "NS")
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ assert sub2_name not in version.delegations
+ n = cast(Node, z.get_node("ns1.sub2"))
+ assert not n.is_delegation()
+ assert not n.is_glue()
+ assert not n.is_origin()
+ assert not n.is_origin_or_glue()
+
+
+def test_obscure_and_expose_absolute():
+ do_test_obscure_and_expose(False)
+
+
+def test_obscure_and_expose_relative():
+ do_test_obscure_and_expose(True)
+
+
+def do_test_delegations(relativize: bool):
+ z = make_example(simple_zone, relativize=relativize)
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ name = z._validate_name("a.b.c.sub.example.")
+ delegation, is_glue = version.delegations.get_delegation(name)
+ assert delegation == z._validate_name("sub.example.")
+ assert is_glue
+ assert version.delegations.is_glue(name)
+ name = z._validate_name("sub.example.")
+ delegation, is_glue = version.delegations.get_delegation(name)
+ assert delegation == z._validate_name("sub.example.")
+ assert not is_glue
+ assert not version.delegations.is_glue(name)
+ name = z._validate_name("text.example.")
+ delegation, is_glue = version.delegations.get_delegation(name)
+ assert delegation is None
+ assert not is_glue
+ assert not version.delegations.is_glue(name)
+
+
+def test_delegations_absolute():
+ do_test_delegations(False)
+
+
+def test_delegations_relative():
+ do_test_delegations(True)
+
+
+def do_test_bounds(relativize: bool):
+ z = make_example(simple_zone, relativize=relativize)
+ with z.reader() as txn:
+ version = cast(dns.btreezone.ImmutableVersion, txn.version)
+ # tuple is (name, left, right, closest, is_equal, is_delegation)
+ tests = [
+ ("example.", "example.", "a.example.", "example.", True, False),
+ ("a.z.example.", "z.example.", None, "z.example.", False, False),
+ (
+ "a.b.a.example.",
+ "a.example.",
+ "c.b.a.example.",
+ "b.a.example.",
+ False,
+ False,
+ ),
+ (
+ "d.b.a.example.",
+ "c.b.a.example.",
+ "b.example.",
+ "b.a.example.",
+ False,
+ False,
+ ),
+ (
+ "d.c.b.a.example.",
+ "c.b.a.example.",
+ "b.example.",
+ "c.b.a.example.",
+ False,
+ False,
+ ),
+ (
+ "sub.example.",
+ "sub.example.",
+ "ns1.sub2.example.",
+ "sub.example.",
+ True,
+ True,
+ ),
+ (
+ "ns1.sub.example.",
+ "sub.example.",
+ "ns1.sub2.example.",
+ "sub.example.",
+ False,
+ True,
+ ),
+ ]
+ for name, left, right, closest, is_equal, is_delegation in tests:
+ name = z._validate_name(name)
+ left = z._validate_name(left)
+ if right is not None:
+ right = z._validate_name(right)
+ closest = z._validate_name(closest)
+ bounds = version.bounds(name)
+ print(bounds)
+ assert bounds.left == left
+ assert bounds.right == right
+ assert bounds.closest_encloser == closest
+ assert bounds.is_equal == is_equal
+ assert bounds.is_delegation == is_delegation
+
+
+def test_bounds_absolute():
+ do_test_bounds(False)
+
+
+def test_bounds_relative():
+ do_test_bounds(True)
from io import BytesIO, StringIO
from typing import cast
+import dns.btreezone
import dns.exception
import dns.message
import dns.name
class VersionedZoneTestCase(unittest.TestCase):
+ zone_factory = dns.versioned.Zone
+
def testUseTransaction(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
with self.assertRaises(dns.versioned.UseTransaction):
z.find_node("not_there", True)
def testImmutableNodes(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
node = z.find_node("@")
with self.assertRaises(TypeError):
def testSelectDefaultPruningPolicy(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
z.set_pruning_policy(None)
self.assertEqual(z._pruning_policy, z._default_pruning_policy)
def testCannotSpecifyBothSerialAndVersionIdToReader(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
with self.assertRaises(ValueError):
z.reader(1, 1)
def testUnknownVersion(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
with self.assertRaises(KeyError):
z.reader(99999)
def testUnknownSerial(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
with self.assertRaises(KeyError):
z.reader(serial=99999)
def testNoRelativizeReader(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=False, zone_factory=self.zone_factory
)
with z.reader(serial=1) as txn:
rds = txn.get("example.", "soa")
def testNoRelativizeReaderOriginInText(self):
z = dns.zone.from_text(
- example_text, relativize=False, zone_factory=dns.versioned.Zone
+ example_text, relativize=False, zone_factory=self.zone_factory
)
with z.reader(serial=1) as txn:
rds = txn.get("example.", "soa")
def testNoRelativizeReaderAbsoluteGet(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=False, zone_factory=self.zone_factory
)
with z.reader(serial=1) as txn:
rds = txn.get(dns.name.empty, "soa")
def testCnameAndOtherDataAddOther(self):
z = dns.zone.from_text(
- example_cname, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_cname, "example.", relativize=True, zone_factory=self.zone_factory
)
rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1")
with z.writer() as txn:
example_other_data,
"example.",
relativize=True,
- zone_factory=dns.versioned.Zone,
+ zone_factory=self.zone_factory,
)
rds = dns.rdataset.from_text("in", "cname", 300, "www")
with z.writer() as txn:
def testGetSoa(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
soa = z.get_soa()
self.assertTrue(soa.rdtype, dns.rdatatype.SOA)
def testGetSoaTxn(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
with z.reader(serial=1) as txn:
soa = z.get_soa(txn)
def testGetRdataset1(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
rds = z.get_rdataset("@", "soa")
exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5")
def testGetRdataset2(self):
z = dns.zone.from_text(
- example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone
+ example_text, "example.", relativize=True, zone_factory=self.zone_factory
)
rds = z.get_rdataset("@", "loc")
self.assertTrue(rds is None)
+class BTreeZoneTestCase(VersionedZoneTestCase):
+ zone_factory = dns.btreezone.Zone
+
+
if __name__ == "__main__":
unittest.main()