From: Bob Halley Date: Sun, 10 Aug 2025 21:16:34 +0000 (-0700) Subject: Btree Zones (#1215) X-Git-Tag: v2.8.0rc1~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a650a406c33045372b52dde70fba06bbb6d682da;p=thirdparty%2Fdnspython.git Btree Zones (#1215) * Add BTree zone --- diff --git a/dns/btree.py b/dns/btree.py new file mode 100644 index 00000000..f544edd8 --- /dev/null +++ b/dns/btree.py @@ -0,0 +1,850 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +""" +A BTree in the style of Cormen, Leiserson, and Rivest's "Algorithms" book, with +copy-on-write node updates, cursors, and optional space optimization for mostly-in-order +insertion. +""" + +from collections.abc import MutableMapping, MutableSet +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, cast + +DEFAULT_T = 127 + +KT = TypeVar("KT") # the type of a key in Element + + +class Element(Generic[KT]): + """All items stored in the BTree are Elements.""" + + def key(self) -> KT: + """The key for this element; the returned type must implement comparison.""" + raise NotImplementedError # pragma: no cover + + +ET = TypeVar("ET", bound=Element) # the type of a value in a _KV + + +def _MIN(t: int) -> int: + """The minimum number of keys in a non-root node for a BTree with the specified + ``t`` + """ + return t - 1 + + +def _MAX(t: int) -> int: + """The maximum number of keys in node for a BTree with the specified ``t``""" + return 2 * t - 1 + + +class _Creator: + """A _Creator class instance is used as a unique id for the BTree which created + a node. + + We use a dedicated creator rather than just a BTree reference to avoid circularity + that would complicate GC. + """ + + def __str__(self): # pragma: no cover + return f"{id(self):x}" + + +class _Node(Generic[KT, ET]): + """A Node in the BTree. + + A Node (leaf or internal) of the BTree. + """ + + __slots__ = ["t", "creator", "is_leaf", "elts", "children"] + + def __init__(self, t: int, creator: _Creator, is_leaf: bool): + assert t >= 3 + self.t = t + self.creator = creator + self.is_leaf = is_leaf + self.elts: list[ET] = [] + self.children: list[_Node[KT, ET]] = [] + + def is_maximal(self) -> bool: + """Does this node have the maximal number of keys?""" + assert len(self.elts) <= _MAX(self.t) + return len(self.elts) == _MAX(self.t) + + def is_minimal(self) -> bool: + """Does this node have the minimal number of keys?""" + assert len(self.elts) >= _MIN(self.t) + return len(self.elts) == _MIN(self.t) + + def search_in_node(self, key: KT) -> tuple[int, bool]: + """Get the index of the ``Element`` matching ``key`` or the index of its + least successor. + + Returns a tuple of the index and an ``equal`` boolean that is ``True`` iff. + the key was found. + """ + l = len(self.elts) + if l > 0 and key > self.elts[l - 1].key(): + # This is optimizing near in-order insertion. + return l, False + l = 0 + i = len(self.elts) + r = i - 1 + equal = False + while l <= r: + m = (l + r) // 2 + k = self.elts[m].key() + if key == k: + i = m + equal = True + break + elif key < k: + i = m + r = m - 1 + else: + l = m + 1 + return i, equal + + def maybe_cow_child(self, index: int) -> "_Node[KT, ET]": + assert not self.is_leaf + child = self.children[index] + cloned = child.maybe_cow(self.creator) + if cloned: + self.children[index] = cloned + return cloned + else: + return child + + def _get_node(self, key: KT) -> Tuple[Optional["_Node[KT, ET]"], int]: + """Get the node associated with key and its index, doing + copy-on-write if we have to descend. + + Returns a tuple of the node and the index, or the tuple ``(None, 0)`` + if the key was not found. + """ + i, equal = self.search_in_node(key) + if equal: + return (self, i) + elif self.is_leaf: + return (None, 0) + else: + child = self.maybe_cow_child(i) + return child._get_node(key) + + def get(self, key: KT) -> Optional[ET]: + """Get the element associated with *key* or return ``None``""" + i, equal = self.search_in_node(key) + if equal: + return self.elts[i] + elif self.is_leaf: + return None + else: + return self.children[i].get(key) + + def optimize_in_order_insertion(self, index: int) -> None: + """Try to minimize the number of Nodes in a BTree where the insertion + is done in-order or close to it, by stealing as much as we can from our + right sibling. + + If we don't do this, then an in-order insertion will produce a BTree + where most of the nodes are minimal. + """ + if index == 0: + return + left = self.children[index - 1] + if len(left.elts) == _MAX(self.t): + return + left = self.maybe_cow_child(index - 1) + while len(left.elts) < _MAX(self.t): + if not left.try_right_steal(self, index - 1): + break + + def insert_nonfull(self, element: ET, in_order: bool) -> Optional[ET]: + assert not self.is_maximal() + while True: + key = element.key() + i, equal = self.search_in_node(key) + if equal: + # replace + old = self.elts[i] + self.elts[i] = element + return old + elif self.is_leaf: + self.elts.insert(i, element) + return None + else: + child = self.maybe_cow_child(i) + if child.is_maximal(): + self.adopt(*child.split()) + # Splitting might result in our target moving to us, so + # search again. + continue + oelt = child.insert_nonfull(element, in_order) + if in_order: + self.optimize_in_order_insertion(i) + return oelt + + def split(self) -> tuple["_Node[KT, ET]", ET, "_Node[KT, ET]"]: + """Split a maximal node into two minimal ones and a central element.""" + assert self.is_maximal() + right = self.__class__(self.t, self.creator, self.is_leaf) + right.elts = list(self.elts[_MIN(self.t) + 1 :]) + middle = self.elts[_MIN(self.t)] + self.elts = list(self.elts[: _MIN(self.t)]) + if not self.is_leaf: + right.children = list(self.children[_MIN(self.t) + 1 :]) + self.children = list(self.children[: _MIN(self.t) + 1]) + return self, middle, right + + def try_left_steal(self, parent: "_Node[KT, ET]", index: int) -> bool: + """Try to steal from this Node's left sibling for balancing purposes. + + Returns ``True`` if the theft was successful, or ``False`` if not. + """ + if index != 0: + left = parent.children[index - 1] + if not left.is_minimal(): + left = parent.maybe_cow_child(index - 1) + elt = parent.elts[index - 1] + parent.elts[index - 1] = left.elts.pop() + self.elts.insert(0, elt) + if not left.is_leaf: + assert not self.is_leaf + child = left.children.pop() + self.children.insert(0, child) + return True + return False + + def try_right_steal(self, parent: "_Node[KT, ET]", index: int) -> bool: + """Try to steal from this Node's right sibling for balancing purposes. + + Returns ``True`` if the theft was successful, or ``False`` if not. + """ + if index + 1 < len(parent.children): + right = parent.children[index + 1] + if not right.is_minimal(): + right = parent.maybe_cow_child(index + 1) + elt = parent.elts[index] + parent.elts[index] = right.elts.pop(0) + self.elts.append(elt) + if not right.is_leaf: + assert not self.is_leaf + child = right.children.pop(0) + self.children.append(child) + return True + return False + + def adopt(self, left: "_Node[KT, ET]", middle: ET, right: "_Node[KT, ET]") -> None: + """Adopt left, middle, and right into our Node (which must not be maximal, + and which must not be a leaf). In the case were we are not the new root, + then the left child must already be in the Node.""" + assert not self.is_maximal() + assert not self.is_leaf + key = middle.key() + i, equal = self.search_in_node(key) + assert not equal + self.elts.insert(i, middle) + if len(self.children) == 0: + # We are the new root + self.children = [left, right] + else: + assert self.children[i] == left + self.children.insert(i + 1, right) + + def merge(self, parent: "_Node[KT, ET]", index: int) -> None: + """Merge this node's parent and its right sibling into this node.""" + right = parent.children.pop(index + 1) + self.elts.append(parent.elts.pop(index)) + self.elts.extend(right.elts) + if not self.is_leaf: + self.children.extend(right.children) + + def minimum(self) -> ET: + """The least element in this subtree.""" + if self.is_leaf: + return self.elts[0] + else: + return self.children[0].minimum() + + def maximum(self) -> ET: + """The greatest element in this subtree.""" + if self.is_leaf: + return self.elts[-1] + else: + return self.children[-1].maximum() + + def balance(self, parent: "_Node[KT, ET]", index: int) -> None: + """This Node is minimal, and we want to make it non-minimal so we can delete. + We try to steal from our siblings, and if that doesn't work we will merge + with one of them.""" + assert not parent.is_leaf + if self.try_left_steal(parent, index): + return + if self.try_right_steal(parent, index): + return + # Stealing didn't work, so both siblings must be minimal. + if index == 0: + # We are the left-most node so merge with our right sibling. + self.merge(parent, index) + else: + # Have our left sibling merge with us. This lets us only have "merge right" + # code. + left = parent.maybe_cow_child(index - 1) + left.merge(parent, index - 1) + + def delete( + self, key: KT, parent: Optional["_Node[KT, ET]"], exact: Optional[ET] + ) -> Optional[ET]: + """Delete an element matching *key* if it exists. If *exact* is not ``None`` + then it must be an exact match with that element. The Node must not be + minimal unless it is the root.""" + assert parent is None or not self.is_minimal() + i, equal = self.search_in_node(key) + original_key = None + if equal: + # Note we use "is" here as we meant "exactly this object". + if exact is not None and self.elts[i] is not exact: + raise ValueError("exact delete did not match existing elt") + if self.is_leaf: + return self.elts.pop(i) + # Note we need to ensure exact is None going forward as we've + # already checked exactness and are about to change our target key + # to the least successor. + exact = None + original_key = key + least_successor = self.children[i + 1].minimum() + key = least_successor.key() + i = i + 1 + if self.is_leaf: + # No match + if exact is not None: + raise ValueError("exact delete had no match") + return None + # recursively delete in the appropriate child + child = self.maybe_cow_child(i) + if child.is_minimal(): + child.balance(self, i) + # Things may have moved. + i, equal = self.search_in_node(key) + assert not equal + child = self.children[i] + assert not child.is_minimal() + elt = child.delete(key, self, exact) + if original_key is not None: + node, i = self._get_node(original_key) + assert node is not None + assert elt is not None + oelt = node.elts[i] + node.elts[i] = elt + elt = oelt + return elt + + def visit_in_order(self, visit: Callable[[ET], None]) -> None: + """Call *visit* on all of the elements in order.""" + for i, elt in enumerate(self.elts): + if not self.is_leaf: + self.children[i].visit_in_order(visit) + visit(elt) + if not self.is_leaf: + self.children[-1].visit_in_order(visit) + + def _visit_preorder_by_node(self, visit: Callable[["_Node[KT, ET]"], None]) -> None: + """Visit nodes in preorder. This method is only used for testing.""" + visit(self) + if not self.is_leaf: + for child in self.children: + child._visit_preorder_by_node(visit) + + def maybe_cow(self, creator: _Creator) -> Optional["_Node[KT, ET]"]: + """Return a clone of this Node if it was not created by *creator*, or ``None`` + otherwise (i.e. copy for copy-on-write if we haven't already copied it).""" + if self.creator is not creator: + return self.clone(creator) + else: + return None + + def clone(self, creator: _Creator) -> "_Node[KT, ET]": + """Make a shallow-copy duplicate of this node.""" + cloned = self.__class__(self.t, creator, self.is_leaf) + cloned.elts.extend(self.elts) + if not self.is_leaf: + cloned.children.extend(self.children) + return cloned + + def __str__(self): # pragma: no cover + if not self.is_leaf: + children = " " + " ".join([f"{id(c):x}" for c in self.children]) + else: + children = "" + return f"{id(self):x} {self.creator} {self.elts}{children}" + + +class Cursor(Generic[KT, ET]): + """A seekable cursor for a BTree. + + If you are going to use a cursor on a mutable BTree, you should use it + in a ``with`` block so that any mutations of the BTree automatically park + the cursor. + """ + + def __init__(self, btree: "BTree[KT, ET]"): + self.btree = btree + self.current_node: Optional[_Node] = None + # The current index is the element index within the current node, or + # if there is no current node then it is 0 on the left boundary and 1 + # on the right boundary. + self.current_index: int = 0 + self.recurse = False + self.increasing = True + self.parents: list[tuple[_Node, int]] = [] + self.parked = False + self.parking_key: Optional[KT] = None + self.parking_key_read = False + + def _seek_least(self) -> None: + # seek to the least value in the subtree beneath the current index of the + # current node + assert self.current_node is not None + while not self.current_node.is_leaf: + self.parents.append((self.current_node, self.current_index)) + self.current_node = self.current_node.children[self.current_index] + assert self.current_node is not None + self.current_index = 0 + + def _seek_greatest(self) -> None: + # seek to the greatest value in the subtree beneath the current index of the + # current node + assert self.current_node is not None + while not self.current_node.is_leaf: + self.parents.append((self.current_node, self.current_index)) + self.current_node = self.current_node.children[self.current_index] + assert self.current_node is not None + self.current_index = len(self.current_node.elts) + + def park(self): + """Park the cursor. + + A cursor must be "parked" before mutating the BTree to avoid undefined behavior. + Cursors created in a ``with`` block register with their BTree and will park + automatically. Note that a parked cursor may not observe some changes made when + it is parked; for example a cursor being iterated with next() will not see items + inserted before its current position. + """ + if not self.parked: + self.parked = True + + def _maybe_unpark(self): + if self.parked: + if self.parking_key is not None: + # remember our increasing hint, as seeking might change it + increasing = self.increasing + if self.parking_key_read: + # We've already returned the parking key, so we want to be before it + # if decreasing and after it if increasing. + before = not self.increasing + else: + # We haven't returned the parking key, so we've parked right + # after seeking or are on a boundary. Either way, the before + # hint we want is the value of self.increasing. + before = self.increasing + self.seek(self.parking_key, before) + self.increasing = increasing # might have been altered by seek() + self.parked = False + self.parking_key = None + + def prev(self) -> Optional[ET]: + """Get the previous element, or return None if on the left boundary.""" + self._maybe_unpark() + self.parking_key = None + if self.current_node is None: + # on a boundary + if self.current_index == 0: + # left boundary, there is no prev + return None + else: + assert self.current_index == 1 + # right boundary; seek to the actual boundary + # so we can do a prev() + self.current_node = self.btree.root + self.current_index = len(self.btree.root.elts) + self._seek_greatest() + while True: + if self.recurse: + if not self.increasing: + # We only want to recurse if we are continuing in the decreasing + # direction. + self._seek_greatest() + self.recurse = False + self.increasing = False + self.current_index -= 1 + if self.current_index >= 0: + elt = self.current_node.elts[self.current_index] + if not self.current_node.is_leaf: + self.recurse = True + self.parking_key = elt.key() + self.parking_key_read = True + return elt + else: + if len(self.parents) > 0: + self.current_node, self.current_index = self.parents.pop() + else: + self.current_node = None + self.current_index = 0 + return None + + def next(self) -> Optional[ET]: + """Get the next element, or return None if on the right boundary.""" + self._maybe_unpark() + self.parking_key = None + if self.current_node is None: + # on a boundary + if self.current_index == 1: + # right boundary, there is no next + return None + else: + assert self.current_index == 0 + # left boundary; seek to the actual boundary + # so we can do a next() + self.current_node = self.btree.root + self.current_index = 0 + self._seek_least() + while True: + if self.recurse: + if self.increasing: + # We only want to recurse if we are continuing in the increasing + # direction. + self._seek_least() + self.recurse = False + self.increasing = True + if self.current_index < len(self.current_node.elts): + elt = self.current_node.elts[self.current_index] + self.current_index += 1 + if not self.current_node.is_leaf: + self.recurse = True + self.parking_key = elt.key() + self.parking_key_read = True + return elt + else: + if len(self.parents) > 0: + self.current_node, self.current_index = self.parents.pop() + else: + self.current_node = None + self.current_index = 1 + return None + + def _adjust_for_before(self, before: bool, i: int) -> None: + if before: + self.current_index = i + else: + self.current_index = i + 1 + + def seek(self, key: KT, before: bool = True) -> None: + """Seek to the specified key. + + If *before* is ``True`` (the default) then the cursor is positioned just + before *key* if it exists, or before its least successor if it doesn't. A + subsequent next() will retrieve this value. If *before* is ``False``, then + the cursor is positioned just after *key* if it exists, or its greatest + precessessor if it doesn't. A subsequent prev() will return this value. + """ + self.current_node = self.btree.root + assert self.current_node is not None + self.recurse = False + self.parents = [] + self.increasing = before + self.parked = False + self.parking_key = key + self.parking_key_read = False + while not self.current_node.is_leaf: + i, equal = self.current_node.search_in_node(key) + if equal: + self._adjust_for_before(before, i) + if before: + self._seek_greatest() + else: + self._seek_least() + return + self.parents.append((self.current_node, i)) + self.current_node = self.current_node.children[i] + assert self.current_node is not None + i, equal = self.current_node.search_in_node(key) + if equal: + self._adjust_for_before(before, i) + else: + self.current_index = i + + def seek_first(self) -> None: + """Seek to the left boundary (i.e. just before the least element). + + A subsequent next() will return the least element if the BTree isn't empty.""" + self.current_node = None + self.current_index = 0 + self.recurse = False + self.increasing = True + self.parents = [] + self.parked = False + self.parking_key = None + + def seek_last(self) -> None: + """Seek to the right boundary (i.e. just after the greatest element). + + A subsequent prev() will return the greatest element if the BTree isn't empty. + """ + self.current_node = None + self.current_index = 1 + self.recurse = False + self.increasing = False + self.parents = [] + self.parked = False + self.parking_key = None + + def __enter__(self): + self.btree.register_cursor(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.btree.deregister_cursor(self) + return False + + +class Immutable(Exception): + """The BTree is immutable.""" + + +class BTree(Generic[KT, ET]): + """An in-memory BTree with copy-on-write and cursors.""" + + def __init__(self, *, t: int = DEFAULT_T, original: Optional["BTree"] = None): + """Create a BTree. + + If *original* is not ``None``, then the BTree is shallow-cloned from + *original* using copy-on-write. Otherwise a new BTree with the specified + *t* value is created. + + The BTree is not thread-safe. + """ + # We don't use a reference to ourselves as a creator as we don't want + # to prevent GC of old btrees. + self.creator = _Creator() + self._immutable = False + self.t: int + self.root: _Node + self.size: int + self.cursors: set[Cursor] = set() + if original is not None: + if not original._immutable: + raise ValueError("original BTree is not immutable") + self.t = original.t + self.root = original.root + self.size = original.size + else: + if t < 3: + raise ValueError("t must be >= 3") + self.t = t + self.root = _Node(self.t, self.creator, True) + self.size = 0 + + def make_immutable(self): + """Make the BTree immutable. + + Attempts to alter the BTree after making it immutable will raise an + Immutable exception. This operation cannot be undone. + """ + if not self._immutable: + self._immutable = True + + def _check_mutable_and_park(self) -> None: + if self._immutable: + raise Immutable + for cursor in self.cursors: + cursor.park() + + # Note that we don't use insert() and delete() but rather insert_element() and + # delete_key() so that BTreeDict can be a proper MutableMapping and supply the + # rest of the standard mapping API. + + def insert_element(self, elt: ET, in_order: bool = False) -> Optional[ET]: + """Insert the element into the BTree. + + If *in_order* is ``True``, then extra work will be done to make left siblings + full, which optimizes storage space when the the elements are inserted in-order + or close to it. + + Returns the previously existing element at the element's key or ``None``. + """ + self._check_mutable_and_park() + cloned = self.root.maybe_cow(self.creator) + if cloned: + self.root = cloned + if self.root.is_maximal(): + old_root = self.root + self.root = _Node(self.t, self.creator, False) + self.root.adopt(*old_root.split()) + oelt = self.root.insert_nonfull(elt, in_order) + if oelt is None: + # We did not replace, so something was added. + self.size += 1 + return oelt + + def get_element(self, key: KT) -> Optional[ET]: + """Get the element matching *key* from the BTree, or return ``None`` if it + does not exist. + """ + return self.root.get(key) + + def _delete(self, key: KT, exact: Optional[ET]) -> Optional[ET]: + self._check_mutable_and_park() + cloned = self.root.maybe_cow(self.creator) + if cloned: + self.root = cloned + elt = self.root.delete(key, None, exact) + if elt is not None: + # We deleted something + self.size -= 1 + if len(self.root.elts) == 0: + # The root is now empty. If there is a child, then collapse this root + # level and make the child the new root. + if not self.root.is_leaf: + assert len(self.root.children) == 1 + self.root = self.root.children[0] + return elt + + def delete_key(self, key: KT) -> Optional[ET]: + """Delete the element matching *key* from the BTree. + + Returns the matching element or ``None`` if it does not exist. + """ + return self._delete(key, None) + + def delete_exact(self, element: ET) -> Optional[ET]: + """Delete *element* from the BTree. + + Returns the matching element or ``None`` if it was not in the BTree. + """ + delt = self._delete(element.key(), element) + assert delt is element + return delt + + def __len__(self): + return self.size + + def visit_in_order(self, visit: Callable[[ET], None]) -> None: + """Call *visit*(element) on all elements in the tree in sorted order.""" + self.root.visit_in_order(visit) + + def _visit_preorder_by_node(self, visit: Callable[[_Node], None]) -> None: + self.root._visit_preorder_by_node(visit) + + def cursor(self) -> Cursor[KT, ET]: + """Create a cursor.""" + return Cursor(self) + + def register_cursor(self, cursor: Cursor) -> None: + """Register a cursor for the automatic parking service.""" + self.cursors.add(cursor) + + def deregister_cursor(self, cursor: Cursor) -> None: + """Deregister a cursor from the automatic parking service.""" + self.cursors.discard(cursor) + + def __copy__(self): + return self.__class__(original=self) + + def __iter__(self): + with self.cursor() as cursor: + while True: + elt = cursor.next() + if elt is None: + break + yield elt.key() + + +VT = TypeVar("VT") # the type of a value in a BTreeDict + + +class KV(Element, Generic[KT, VT]): + """The BTree element type used in a ``BTreeDict``.""" + + def __init__(self, key: KT, value: VT): + self._key = key + self._value = value + + def key(self) -> KT: + return self._key + + def value(self) -> VT: + return self._value + + def __str__(self): # pragma: no cover + return f"KV({self._key}, {self._value})" + + def __repr__(self): # pragma: no cover + return f"KV({self._key}, {self._value})" + + +class BTreeDict(Generic[KT, VT], BTree[KT, KV[KT, VT]], MutableMapping[KT, VT]): + """A MutableMapping implemented with a BTree. + + Unlike a normal Python dict, the BTreeDict may be mutated while iterating. + """ + + def __init__( + self, + *, + t: int = DEFAULT_T, + original: Optional[BTree] = None, + in_order: bool = False, + ): + super().__init__(t=t, original=original) + self.in_order = in_order + + def __getitem__(self, key: KT) -> VT: + elt = self.get_element(key) + if elt is None: + raise KeyError + else: + return cast(KV, elt).value() + + def __setitem__(self, key: KT, value: VT) -> None: + elt = KV(key, value) + self.insert_element(elt, self.in_order) + + def __delitem__(self, key: KT) -> None: + if self.delete_key(key) is None: + raise KeyError + + +class Member(Element, Generic[KT]): + """The BTree element type used in a ``BTreeSet``.""" + + def __init__(self, key: KT): + self._key = key + + def key(self) -> KT: + return self._key + + +class BTreeSet(BTree, Generic[KT], MutableSet[KT]): + """A MutableSet implemented with a BTree. + + Unlike a normal Python set, the BTreeSet may be mutated while iterating. + """ + + def __init__( + self, + *, + t: int = DEFAULT_T, + original: Optional[BTree] = None, + in_order: bool = False, + ): + super().__init__(t=t, original=original) + self.in_order = in_order + + def __contains__(self, key: Any) -> bool: + return self.get_element(key) is not None + + def add(self, value: KT) -> None: + elt = Member(value) + self.insert_element(elt, self.in_order) + + def discard(self, value: KT) -> None: + self.delete_key(value) diff --git a/dns/btreezone.py b/dns/btreezone.py new file mode 100644 index 00000000..c71bd5c6 --- /dev/null +++ b/dns/btreezone.py @@ -0,0 +1,369 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and +# a separate per-version delegation index. These additions let us +# +# 1) Do efficient CoW versioning (useful for future online updates). +# 2) Maintain sort order +# 3) Allow delegations to be found easily +# 4) Handle glue +# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant. The ORIGIN +# flag is set at the origin node, the DELEGATION FLAG is set at delegation +# points, and the GLUE flag is set on nodes beneath delegation points. + +import enum +from dataclasses import dataclass +from typing import Callable, MutableMapping, Optional, Tuple, Union, cast + +import dns.btree +import dns.immutable +import dns.name +import dns.node +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.versioned +import dns.zone + + +class NodeFlags(enum.IntFlag): + ORIGIN = 0x01 + DELEGATION = 0x02 + GLUE = 0x04 + + +class Node(dns.node.Node): + __slots__ = ["flags", "id"] + + def __init__(self, flags: Optional[NodeFlags] = None): + super().__init__() + if flags is None: + # We allow optional flags rather than a default + # as pyright doesn't like assigning a literal 0 + # to flags. + flags = NodeFlags(0) + self.flags = flags + self.id = 0 + + def is_delegation(self): + return (self.flags & NodeFlags.DELEGATION) != 0 + + def is_glue(self): + return (self.flags & NodeFlags.GLUE) != 0 + + def is_origin(self): + return (self.flags & NodeFlags.ORIGIN) != 0 + + def is_origin_or_glue(self): + return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0 + + +@dns.immutable.immutable +class ImmutableNode(Node): + def __init__(self, node: Node): + super().__init__() + self.id = node.id + self.rdatasets = tuple( # type: ignore + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + self.flags = node.flags + + def find_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> dns.rdataset.Rdataset: + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + create: bool = False, + ) -> Optional[dns.rdataset.Rdataset]: + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset( + self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType = dns.rdatatype.NONE, + ) -> None: + raise TypeError("immutable") + + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: + raise TypeError("immutable") + + def is_immutable(self) -> bool: + return True + + +class Delegations(dns.btree.BTreeSet[dns.name.Name]): + def get_delegation( + self, name: dns.name.Name + ) -> Tuple[Optional[dns.name.Name], bool]: + """Get the delegation applicable to *name*, if it exists. + + If there delegation, then return a tuple consisting of the name of + the delegation point, and a boolean which is `True` if the name is a proper + subdomain of the delegation point, and `False` if it is equal to the delegation + point. + """ + cursor = self.cursor() + cursor.seek(name, before=False) + prev = cursor.prev() + if prev is None: + return None, False + cut = prev.key() + reln, _, _ = name.fullcompare(cut) + is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN + if is_subdomain or reln == dns.name.NameRelation.EQUAL: + return cut, is_subdomain + else: + return None, False + + def is_glue(self, name: dns.name.Name) -> bool: + """Is *name* glue, i.e. is it beneath a delegation?""" + cursor = self.cursor() + cursor.seek(name, before=False) + cut, is_subdomain = self.get_delegation(name) + if cut is None: + return False + return is_subdomain + + +class WritableVersion(dns.zone.WritableVersion): + def __init__(self, zone: dns.zone.Zone, replacement: bool = False): + super().__init__(zone, True) + if not replacement: + assert isinstance(zone, dns.versioned.Zone) + version = zone._versions[-1] + self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict( + original=version.nodes # type: ignore + ) + self.delegations = Delegations(original=version.delegations) # type: ignore + else: + self.delegations = Delegations() + + def _is_origin(self, name: dns.name.Name) -> bool: + # Assumes name has already been validated (and thus adjusted to the right + # relativity too) + if self.zone.relativize: + return name == dns.name.empty + else: + return name == self.zone.origin + + def _maybe_cow_with_name( + self, name: dns.name.Name + ) -> Tuple[dns.node.Node, dns.name.Name]: + (node, name) = super()._maybe_cow_with_name(name) + node = cast(Node, node) + if self._is_origin(name): + node.flags |= NodeFlags.ORIGIN + elif self.delegations.is_glue(name): + node.flags |= NodeFlags.GLUE + return (node, name) + + def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None: + cursor = self.nodes.cursor() # type: ignore + cursor.seek(name, False) + updates = [] + while True: + elt = cursor.next() + if elt is None: + break + ename = elt.key() + if not ename.is_subdomain(name): + break + node = cast(dns.node.Node, elt.value()) + if ename not in self.changed: + new_node = self.zone.node_factory() + new_node.id = self.id # type: ignore + new_node.rdatasets.extend(node.rdatasets) + self.changed.add(ename) + node = new_node + assert isinstance(node, Node) + if is_glue: + node.flags |= NodeFlags.GLUE + else: + node.flags &= ~NodeFlags.GLUE + # We don't update node here as any insertion could disturb the + # btree and invalidate our cursor. We could use the cursor in a + # with block and avoid this, but it would do a lot of parking and + # unparking so the deferred update mode may still be better. + updates.append((ename, node)) + for ename, node in updates: + self.nodes[ename] = node + + def delete_node(self, name: dns.name.Name) -> None: + name = self._validate_name(name) + node = self.nodes.get(name) + if node is not None: + if node.is_delegation(): # type: ignore + self.delegations.discard(name) + self.update_glue_flag(name, False) + del self.nodes[name] + self.changed.add(name) + + def put_rdataset( + self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset + ) -> None: + (node, name) = self._maybe_cow_with_name(name) + if ( + rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue() # type: ignore + ): + node.flags |= NodeFlags.DELEGATION # type: ignore + if name not in self.delegations: + self.delegations.add(name) + self.update_glue_flag(name, True) + node.replace_rdataset(rdataset) + + def delete_rdataset( + self, + name: dns.name.Name, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType, + ) -> None: + (node, name) = self._maybe_cow_with_name(name) + if rdtype == dns.rdatatype.NS and name in self.delegations: # type: ignore + node.flags &= ~NodeFlags.DELEGATION # type: ignore + self.delegations.discard(name) # type: ignore + self.update_glue_flag(name, False) + node.delete_rdataset(self.zone.rdclass, rdtype, covers) + if len(node) == 0: + del self.nodes[name] + + +@dataclass(frozen=True) +class Bounds: + name: dns.name.Name + left: dns.name.Name + right: Optional[dns.name.Name] + closest_encloser: dns.name.Name + is_equal: bool + is_delegation: bool + + def __str__(self): + if self.is_equal: + op = "=" + else: + op = "<" + if self.is_delegation: + zonecut = " zonecut" + else: + zonecut = "" + return ( + f"{self.left} {op} {self.name} < {self.right}{zonecut}; " + f"{self.closest_encloser}" + ) + + +@dns.immutable.immutable +class ImmutableVersion(dns.zone.Version): + def __init__(self, version: dns.zone.Version): + if not isinstance(version, WritableVersion): + raise ValueError( + "a dns.btreezone.ImmutableVersion requires a " + "dns.btreezone.WritableVersion" + ) + super().__init__(version.zone, True) + self.id = version.id + self.origin = version.origin + for name in version.changed: + node = version.nodes.get(name) + if node: + version.nodes[name] = ImmutableNode(node) + # the cast below is for mypy + self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes) + self.nodes.make_immutable() # type: ignore + self.delegations = version.delegations + self.delegations.make_immutable() + + def bounds(self, name: Union[dns.name.Name, str]) -> Bounds: + """Return the 'bounds' of *name* in its zone. + + The bounds information is useful when making an authoritative response, as + it can be used to determine whether the query name is at or beneath a delegation + point. The other data in the ``Bounds`` object is useful for making on-the-fly + DNSSEC signatures. + + The left bound of *name* is *name* itself if it is in the zone, or the greatest + predecessor which is in the zone. + + The right bound of *name* is the least successor of *name*, or ``None`` if + no name in the zone is greater than *name*. + + The closest encloser of *name* is *name* itself, if *name* is in the zone; + otherwise it is the name with the largest number of labels in common with + *name* that is in the zone, either explicitly or by the implied existence + of empty non-terminals. + + The bounds *is_equal* field is ``True`` if and only if *name* is equal to + its left bound. + + The bounds *is_delegation* field is ``True`` if and only if the left bound is a + delegation point. + """ + assert self.origin is not None + # validate the origin because we may need to relativize + origin = self.zone._validate_name(self.origin) + name = self.zone._validate_name(name) + cut, _ = self.delegations.get_delegation(name) + if cut is not None: + target = cut + is_delegation = True + else: + target = name + is_delegation = False + c = cast(dns.btree.BTreeDict, self.nodes).cursor() + c.seek(target, False) + left = c.prev() + assert left is not None + c.next() # skip over left + while True: + right = c.next() + if right is None or not right.value().is_glue(): + break + left_comparison = left.key().fullcompare(name) + if right is not None: + right_key = right.key() + right_comparison = right_key.fullcompare(name) + else: + right_comparison = ( + dns.name.NAMERELN_COMMONANCESTOR, + -1, + len(origin), + ) + right_key = None + closest_encloser = dns.name.Name( + name[-max(left_comparison[2], right_comparison[2]) :] + ) + return Bounds( + name, + left.key(), + right_key, + closest_encloser, + left_comparison[0] == dns.name.NameRelation.EQUAL, + is_delegation, + ) + + +class Zone(dns.versioned.Zone): + node_factory: Callable[[], dns.node.Node] = Node + map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast( + Callable[[], MutableMapping[dns.name.Name, dns.node.Node]], + dns.btree.BTreeDict[dns.name.Name, Node], + ) + writable_version_factory: Optional[ + Callable[[dns.zone.Zone, bool], dns.zone.Version] + ] = WritableVersion + immutable_version_factory: Optional[ + Callable[[dns.zone.Version], dns.zone.Version] + ] = ImmutableVersion diff --git a/dns/versioned.py b/dns/versioned.py index 6479ae47..260eea1b 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -41,7 +41,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] "_readers", ] - node_factory = Node + node_factory: Callable[[], dns.node.Node] = Node def __init__( self, diff --git a/dns/zone.py b/dns/zone.py index b1a52f63..05170fe8 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -131,8 +131,10 @@ class Zone(dns.transaction.TransactionManager): node_factory: Callable[[], dns.node.Node] = dns.node.Node map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict - writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None - immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None + # We only require the version types as "Version" to allow for flexibility, as + # only the version protocol matters + writable_version_factory: Optional[Callable[["Zone", bool], "Version"]] = None + immutable_version_factory: Optional[Callable[["Version"], "Version"]] = None __slots__ = ["rdclass", "origin", "nodes", "relativize"] @@ -1026,7 +1028,9 @@ class WritableVersion(Version): self.origin = zone.origin self.changed: Set[dns.name.Name] = set() - def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node: + def _maybe_cow_with_name( + self, name: dns.name.Name + ) -> Tuple[dns.node.Node, dns.name.Name]: name = self._validate_name(name) node = self.nodes.get(name) if node is None or name not in self.changed: @@ -1044,9 +1048,12 @@ class WritableVersion(Version): new_node.rdatasets.extend(node.rdatasets) self.nodes[name] = new_node self.changed.add(name) - return new_node + return (new_node, name) else: - return node + return (node, name) + + def _maybe_cow(self, name: dns.name.Name) -> dns.node.Node: + return self._maybe_cow_with_name(name)[0] def delete_node(self, name: dns.name.Name) -> None: name = self._validate_name(name) @@ -1074,7 +1081,11 @@ class WritableVersion(Version): @dns.immutable.immutable class ImmutableVersion(Version): - def __init__(self, version: WritableVersion): + def __init__(self, version: Version): + if not isinstance(version, WritableVersion): + raise ValueError( + "a dns.zone.ImmutableVersion requires a dns.zone.WritableVersion" + ) # We tell super() that it's a replacement as we don't want it # to copy the nodes, as we're about to do that with an # immutable Dict. diff --git a/tests/test_btree.py b/tests/test_btree.py new file mode 100644 index 00000000..e2453531 --- /dev/null +++ b/tests/test_btree.py @@ -0,0 +1,722 @@ +import copy + +import pytest + +import dns.btree as btree + + +class BTreeDict(btree.BTreeDict): + # We mostly test with an in-order optimized BTreeDict with t=3 as that's how we + # generated the data. + def __init__(self, *, t=3, original=None): + super().__init__(t=t, original=original, in_order=True) + + +def add_keys(b, keys): + if isinstance(keys, int): + keys = range(keys) + for key in keys: + b[key] = True + + +def test_replace(): + N = 8 + b = BTreeDict() + add_keys(b, N) + b[0] = False + b[5] = False + b[7] = False + for key in range(N): + if key in {0, 5, 7}: + assert b[key] == False + else: + assert b[key] == True + + +def test_min_max(): + N = 8 + b = BTreeDict() + add_keys(b, N) + assert b.root.minimum().key() == 0 + assert b.root.maximum().key() == N - 1 + del b[N - 1] + del b[0] + assert b.root.minimum().key() == 1 + assert b.root.maximum().key() == N - 2 + + +def test_nonexistent(): + N = 8 + b = BTreeDict() + add_keys(b, N) + with pytest.raises(KeyError): + b[1.5] == False + assert b.delete_key(1.5) is None + with pytest.raises(KeyError): + del b[1.5] + + +def test_in_order(): + N = 100 + b = BTreeDict() + add_keys(b, N) + expected = list(range(N)) + assert list(b.keys()) == expected + + keys = list(range(N - 1, -1, -1)) + expected = N + for key in keys: + l = len(b) + assert l == expected + expected -= 1 + del b[key] + assert len(b) == 0 + + +# Some key orderings generated randomly but hardcoded here for test stability. +# The keys lead to 100% coverage in insert, find, and delete. + +random_keys_1 = [ + 36, + 14, + 89, + 67, + 80, + 98, + 71, + 29, + 92, + 91, + 79, + 49, + 63, + 74, + 19, + 4, + 23, + 60, + 10, + 31, + 94, + 46, + 18, + 84, + 61, + 42, + 77, + 54, + 76, + 38, + 26, + 37, + 24, + 99, + 45, + 7, + 97, + 32, + 53, + 96, + 82, + 52, + 8, + 58, + 11, + 3, + 15, + 47, + 17, + 21, + 28, + 2, + 20, + 12, + 95, + 44, + 16, + 9, + 51, + 30, + 33, + 34, + 88, + 55, + 43, + 72, + 57, + 66, + 22, + 56, + 68, + 87, + 73, + 6, + 25, + 59, + 0, + 75, + 90, + 78, + 50, + 13, + 83, + 93, + 39, + 81, + 41, + 70, + 48, + 35, + 65, + 64, + 62, + 5, + 27, + 86, + 40, + 1, + 85, + 69, +] + +random_keys_2 = [ + 49, + 28, + 0, + 19, + 14, + 76, + 65, + 8, + 12, + 90, + 71, + 36, + 31, + 24, + 83, + 59, + 98, + 48, + 26, + 82, + 46, + 84, + 80, + 33, + 74, + 75, + 60, + 99, + 20, + 61, + 88, + 81, + 41, + 58, + 85, + 54, + 96, + 23, + 72, + 66, + 1, + 37, + 57, + 64, + 27, + 13, + 40, + 73, + 69, + 32, + 55, + 34, + 5, + 2, + 39, + 9, + 93, + 50, + 47, + 92, + 79, + 78, + 63, + 10, + 30, + 77, + 87, + 53, + 7, + 56, + 21, + 18, + 62, + 6, + 11, + 95, + 70, + 44, + 42, + 97, + 35, + 91, + 43, + 16, + 89, + 45, + 67, + 4, + 22, + 17, + 25, + 51, + 94, + 52, + 68, + 3, + 15, + 86, + 38, + 29, +] + + +def test_random_trees(): + N = len(random_keys_1) + b = BTreeDict() + add_keys(b, random_keys_1) + expected = list(range(N)) + assert list(b.keys()) == expected + + for key in random_keys_1: + assert b[key] + + keys = random_keys_2 + expected_len = N + for key in keys: + l = len(b) + assert l == expected_len + expected_len -= 1 + del b[key] + assert len(b) == 0 + + +def test_random_trees_no_in_order_optimization(): + N = len(random_keys_1) + b = btree.BTreeDict(t=3) + add_keys(b, random_keys_1) + expected = list(range(N)) + assert list(b.keys()) == expected + + for key in random_keys_1: + assert b[key] + + keys = random_keys_2 + expected_len = N + for key in keys: + l = len(b) + assert l == expected_len + expected_len -= 1 + del b[key] + assert len(b) == 0 + + +def node_set(b): + s = set() + b._visit_preorder_by_node(lambda n: s.add(n)) + return s + + +def test_cow(): + N = len(random_keys_1) + b = BTreeDict() + add_keys(b, random_keys_1) + expected = list(range(N)) + assert list(b.keys()) == expected + nsb = node_set(b) + + with pytest.raises(ValueError): + d = BTreeDict(original=b) + b.make_immutable() + with pytest.raises(btree.Immutable): + b[100] = True + with pytest.raises(btree.Immutable): + del b[1] + + b2 = BTreeDict(original=b) + keys = random_keys_2 + expected_len = N + for key in keys: + l = len(b2) + assert l == expected_len + expected_len -= 1 + del b2[key] + assert len(b2) == 0 + b2[100] = True + b2[101] = True + assert list(b2.keys()) == [100, 101] + + # and b is unchanged + assert list(b.keys()) == expected + nsb2 = node_set(b) + assert nsb == nsb2 + + # copy should be the same as b + b3 = copy.copy(b) + assert list(b.keys()) == expected + nsb3 = node_set(b3) + assert nsb == nsb3 + + +def test_cow_minimality(): + b = BTreeDict() + add_keys(b, 8) + b.make_immutable() + b2 = BTreeDict(original=b) + + assert b.root is b2.root + b2[7] = 100 + assert b.root is not b2.root + assert b.root.children[0] is b2.root.children[0] + assert b.root.children[1] is not b2.root.children[1] + del b2[5] + assert b.root is not b2.root + assert b.root.children[0] is not b2.root.children[0] + assert b.root.children[1] is not b2.root.children[1] + + +def test_cursor_seek(): + N = len(random_keys_1) + b = BTreeDict() + add_keys(b, random_keys_1) + + l = [] + c = b.cursor() + while True: + elt = c.next() + if elt is None: + break + else: + l.append(elt.key()) + expected = list(range(N)) + assert l == expected + assert c.next() is None + + # same as previous but with explicit seek_first() + l = [] + c = b.cursor() + c.seek_first() + while True: + elt = c.next() + if elt is None: + break + else: + l.append(elt.key()) + expected = list(range(N)) + assert l == expected + assert c.next() is None + + l = [] + c = b.cursor() + c.seek_last() + while True: + elt = c.prev() + if elt is None: + break + else: + l.append(elt.key()) + expected = list(range(N - 1, -1, -1)) + assert l == expected + assert c.prev() is None + + +def test_cursor_seek_before_and_after(): + N = 8 + b = BTreeDict() + add_keys(b, N) + + c = b.cursor() + + # Seek before, leaf + c.seek(2) + assert c.next().key() == 2 + assert c.prev().key() == 2 + + # Seek before, parent + c.seek(5) + assert c.next().key() == 5 + c.seek(5) + assert c.prev().key() == 4 + + # Seek after, leaf + c.seek(2, False) + assert c.next().key() == 3 + c.seek(2, False) + assert c.prev().key() == 2 + + # Seek after, leaf + c.seek(2, False) + assert c.next().key() == 3 + c.seek(2, False) + assert c.prev().key() == 2 + + # Seek after, parent + c.seek(5, False) + assert c.next().key() == 6 + c.seek(5, False) + assert c.prev().key() == 5 + + # Nonexistent + c.seek(2.5) + assert c.next().key() == 3 + c.seek(2.5) + assert c.prev().key() == 2 + c.seek(4.5) + assert c.next().key() == 5 + c.seek(5.5) + assert c.prev().key() == 5 + + +def test_cursor_reversing_in_parentnode(): + N = 11 + b = BTreeDict() + add_keys(b, N) + c = b.cursor() + c.seek(5) + assert c.next().key() == 5 + assert c.prev().key() == 5 + c.seek(5, False) + assert c.prev().key() == 5 + assert c.next().key() == 5 + + +def test_cursor_empty_tree_seeks(): + b = BTreeDict() + c = b.cursor() + c.seek(5) + assert c.next() == None + assert c.prev() == None + + +def test_parking(): + N = 11 + b = BTreeDict() + add_keys(b, N) + c = b.cursor() + + c.seek(5) + assert c.next().key() == 5 + c.park() + assert c.next().key() == 6 + + c.seek(5) + assert c.prev().key() == 4 + c.park() + assert c.prev().key() == 3 + + c.seek(5) + assert c.next().key() == 5 + c.park() + assert c.prev().key() == 5 + + c.seek(5) + assert c.prev().key() == 4 + c.park() + assert c.next().key() == 4 + + expected = list(range(11)) + + c.seek_first() + got = [] + while True: + e = c.next() + if e is None: + break + got.append(e.key()) + c.park() + assert got == expected + + c.seek_last() + got = [] + while True: + e = c.prev() + if e is None: + break + got.append(e.key()) + c.park() + assert got == list(reversed(expected)) + + # parking on the boundary + + c.seek_first() + c.park() + assert c.next().key() == 0 + + c.seek_last() + c.park() + assert c.prev().key() == 10 + + c.seek(0) + c.park() + assert c.next().key() == 0 + + c.seek(10, False) + c.park() + assert c.prev().key() == 10 + + # double parking (parking idempotency) + + c.seek_first() + c.park() + c.park() + assert c.next().key() == 0 + + # mutation + + c.seek(5) + c.park() + b[4.5] = True + assert c.next().key() == 5 + + c.seek(5) + c.park() + b[5.5] = True + assert c.next().key() == 5 + + c.seek(5) + c.park() + del b[5] + assert c.next().key() == 5.5 + + c.seek(5) + c.park() + b[4.49] = True + b[4.51] = True + assert c.prev().key() == 4.51 + + c.seek(5) + c.park() + b[5.49] = True + b[5.51] = True + assert c.next().key() == 5.49 + + +def test_automatic_parking(): + N = 11 + b = BTreeDict() + add_keys(b, N) + + with b.cursor() as c: + assert c in b.cursors + c.seek(5) + assert c.next().key() == 5 + b[5.5] = True + assert c.next().key() == 5.5 + assert len(b.cursors) == 0 + + +def test_visit_in_order(): + N = 100 + b = BTreeDict() + add_keys(b, N) + + l = [] + b.visit_in_order(lambda elt: l.append(elt.key())) + assert l == list(range(N)) + + +def test_visit_preorder_by_node(): + N = 100 + b = BTreeDict() + add_keys(b, N) + + kl = [] + b._visit_preorder_by_node(lambda node: kl.append([elt.key() for elt in node.elts])) + expected = [ + [35, 71], + [5, 11, 17, 23, 29], + [0, 1, 2, 3, 4], + [6, 7, 8, 9, 10], + [12, 13, 14, 15, 16], + [18, 19, 20, 21, 22], + [24, 25, 26, 27, 28], + [30, 31, 32, 33, 34], + [41, 47, 53, 59, 65], + [36, 37, 38, 39, 40], + [42, 43, 44, 45, 46], + [48, 49, 50, 51, 52], + [54, 55, 56, 57, 58], + [60, 61, 62, 63, 64], + [66, 67, 68, 69, 70], + [77, 83, 89, 95], + [72, 73, 74, 75, 76], + [78, 79, 80, 81, 82], + [84, 85, 86, 87, 88], + [90, 91, 92, 93, 94], + [96, 97, 98, 99], + ] + assert kl == expected + + +def test_exact_delete(): + N = 8 + b = BTreeDict() + add_keys(b, N) + for key in [5, 0, 7]: + elt = b.get_element(key) + assert elt is not None + bogus_elt = btree.KV(key, False) + with pytest.raises(ValueError): + b.delete_exact(bogus_elt) + b.delete_exact(elt) + with pytest.raises(ValueError): + b.delete_exact(elt) + + +def test_find_nonexistent_node(): + # This is just for 100% coverage + b = BTreeDict() + b[0] = True + assert b.root._get_node(0) == (b.root, 0) + assert b.root._get_node(1) == (None, 0) + + +def test_t_too_small(): + with pytest.raises(ValueError): + BTreeDict(t=2) + + +def test_immutable_idempotent(): + # Again just for coverage. + b = BTreeDict() + b.make_immutable() + assert b._immutable + b.make_immutable() + assert b._immutable + + +def test_btree_set(): + b: btree.BTreeSet[int] = btree.BTreeSet() + b.add(1) + assert 1 in b + b.add(2) + assert 1 in b + assert 2 in b + assert len(b) == 2 + assert list(b) == [1, 2] + b.discard(1) + assert 1 not in b + assert 2 in b + assert len(b) == 1 + assert list(b) == [2] + b.discard(1) + assert 1 not in b + assert 2 in b + assert len(b) == 1 + b.discard(2) + assert 1 not in b + assert 2 not in b + assert len(b) == 0 + assert list(b) == [] diff --git a/tests/test_btreezone.py b/tests/test_btreezone.py new file mode 100644 index 00000000..b8eb4892 --- /dev/null +++ b/tests/test_btreezone.py @@ -0,0 +1,236 @@ +from typing import cast + +import dns.btreezone +import dns.rdataset +import dns.zone + +Node = dns.btreezone.Node + +simple_zone = """ +$ORIGIN example. +$TTL 300 +@ soa foo bar 1 2 3 4 5 +@ ns ns1 +@ ns ns2 +ns1 a 10.0.0.1 +ns2 a 10.0.0.2 +a txt "a" +c.b.a txt "cba" +b txt "b" +sub ns ns1.sub +sub ns ns2.sub +ns1.sub a 10.0.0.3 +ns2.sub a 10.0.0.4 +ns1.sub2 a 10.0.0.5 +ns2.sub2 a 10.0.0.6 +text txt "here to be after sub2" +z txt "z" +""" + + +def make_example(text: str, relativize: bool = False) -> dns.btreezone.Zone: + z = dns.zone.from_text( + simple_zone, "example.", relativize=relativize, zone_factory=dns.btreezone.Zone + ) + return cast(dns.btreezone.Zone, z) + + +def do_test_node_flags(relativize: bool): + z = make_example(simple_zone, relativize) + n = cast(Node, z.get_node("@")) + assert not n.is_delegation() + assert not n.is_glue() + assert n.is_origin() + assert n.is_origin_or_glue() + assert n.is_immutable() + n = cast(Node, z.get_node("sub")) + assert n.is_delegation() + assert not n.is_glue() + assert not n.is_origin() + assert not n.is_origin_or_glue() + n = cast(Node, z.get_node("ns1.sub")) + assert not n.is_delegation() + assert n.is_glue() + assert not n.is_origin() + assert n.is_origin_or_glue() + + +def test_node_flags_absolute(): + do_test_node_flags(False) + + +def test_node_flags_relative(): + do_test_node_flags(True) + + +def test_flags_in_constructor(): + n = Node() + assert n.flags == 0 + n = Node(dns.btreezone.NodeFlags.ORIGIN) + assert n.is_origin() + + +def do_test_obscure_and_expose(relativize: bool): + z = make_example(simple_zone, relativize=relativize) + n = cast(Node, z.get_node("ns1.sub2")) + assert not n.is_delegation() + assert not n.is_glue() + assert not n.is_origin() + assert not n.is_origin_or_glue() + sub2_name = z._validate_name("sub2") + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + assert sub2_name not in version.delegations + rds = dns.rdataset.from_text("in", "ns", 300, "ns1.sub2", "ns2.sub2") + with z.writer() as txn: + txn.replace("sub2", rds) + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + assert sub2_name in version.delegations + n = cast(Node, z.get_node("ns1.sub2")) + assert not n.is_delegation() + assert n.is_glue() + assert not n.is_origin() + assert n.is_origin_or_glue() + with z.writer() as txn: + txn.delete("sub2") + txn.delete("ns2.sub2") # for other coverage purposes! + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + assert sub2_name not in version.delegations + n = cast(Node, z.get_node("ns1.sub2")) + assert not n.is_delegation() + assert not n.is_glue() + assert not n.is_origin() + assert not n.is_origin_or_glue() + # repeat but delete just the rdataset + rds = dns.rdataset.from_text("in", "ns", 300, "ns1.sub2", "ns2.sub2") + with z.writer() as txn: + txn.replace("sub2", rds) + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + assert sub2_name in version.delegations + n = cast(Node, z.get_node("ns1.sub2")) + assert not n.is_delegation() + assert n.is_glue() + assert not n.is_origin() + assert n.is_origin_or_glue() + with z.writer() as txn: + txn.delete("sub2", "NS") + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + assert sub2_name not in version.delegations + n = cast(Node, z.get_node("ns1.sub2")) + assert not n.is_delegation() + assert not n.is_glue() + assert not n.is_origin() + assert not n.is_origin_or_glue() + + +def test_obscure_and_expose_absolute(): + do_test_obscure_and_expose(False) + + +def test_obscure_and_expose_relative(): + do_test_obscure_and_expose(True) + + +def do_test_delegations(relativize: bool): + z = make_example(simple_zone, relativize=relativize) + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + name = z._validate_name("a.b.c.sub.example.") + delegation, is_glue = version.delegations.get_delegation(name) + assert delegation == z._validate_name("sub.example.") + assert is_glue + assert version.delegations.is_glue(name) + name = z._validate_name("sub.example.") + delegation, is_glue = version.delegations.get_delegation(name) + assert delegation == z._validate_name("sub.example.") + assert not is_glue + assert not version.delegations.is_glue(name) + name = z._validate_name("text.example.") + delegation, is_glue = version.delegations.get_delegation(name) + assert delegation is None + assert not is_glue + assert not version.delegations.is_glue(name) + + +def test_delegations_absolute(): + do_test_delegations(False) + + +def test_delegations_relative(): + do_test_delegations(True) + + +def do_test_bounds(relativize: bool): + z = make_example(simple_zone, relativize=relativize) + with z.reader() as txn: + version = cast(dns.btreezone.ImmutableVersion, txn.version) + # tuple is (name, left, right, closest, is_equal, is_delegation) + tests = [ + ("example.", "example.", "a.example.", "example.", True, False), + ("a.z.example.", "z.example.", None, "z.example.", False, False), + ( + "a.b.a.example.", + "a.example.", + "c.b.a.example.", + "b.a.example.", + False, + False, + ), + ( + "d.b.a.example.", + "c.b.a.example.", + "b.example.", + "b.a.example.", + False, + False, + ), + ( + "d.c.b.a.example.", + "c.b.a.example.", + "b.example.", + "c.b.a.example.", + False, + False, + ), + ( + "sub.example.", + "sub.example.", + "ns1.sub2.example.", + "sub.example.", + True, + True, + ), + ( + "ns1.sub.example.", + "sub.example.", + "ns1.sub2.example.", + "sub.example.", + False, + True, + ), + ] + for name, left, right, closest, is_equal, is_delegation in tests: + name = z._validate_name(name) + left = z._validate_name(left) + if right is not None: + right = z._validate_name(right) + closest = z._validate_name(closest) + bounds = version.bounds(name) + print(bounds) + assert bounds.left == left + assert bounds.right == right + assert bounds.closest_encloser == closest + assert bounds.is_equal == is_equal + assert bounds.is_delegation == is_delegation + + +def test_bounds_absolute(): + do_test_bounds(False) + + +def test_bounds_relative(): + do_test_bounds(True) diff --git a/tests/test_zone.py b/tests/test_zone.py index 35b5f42a..dc0a225e 100644 --- a/tests/test_zone.py +++ b/tests/test_zone.py @@ -23,6 +23,7 @@ import unittest from io import BytesIO, StringIO from typing import cast +import dns.btreezone import dns.exception import dns.message import dns.name @@ -1172,9 +1173,11 @@ class ZoneTestCase(unittest.TestCase): class VersionedZoneTestCase(unittest.TestCase): + zone_factory = dns.versioned.Zone + def testUseTransaction(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) with self.assertRaises(dns.versioned.UseTransaction): z.find_node("not_there", True) @@ -1191,7 +1194,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testImmutableNodes(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) node = z.find_node("@") with self.assertRaises(TypeError): @@ -1205,7 +1208,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testSelectDefaultPruningPolicy(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) z.set_pruning_policy(None) self.assertEqual(z._pruning_policy, z._default_pruning_policy) @@ -1219,28 +1222,28 @@ class VersionedZoneTestCase(unittest.TestCase): def testCannotSpecifyBothSerialAndVersionIdToReader(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) with self.assertRaises(ValueError): z.reader(1, 1) def testUnknownVersion(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) with self.assertRaises(KeyError): z.reader(99999) def testUnknownSerial(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) with self.assertRaises(KeyError): z.reader(serial=99999) def testNoRelativizeReader(self): z = dns.zone.from_text( - example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=False, zone_factory=self.zone_factory ) with z.reader(serial=1) as txn: rds = txn.get("example.", "soa") @@ -1248,7 +1251,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testNoRelativizeReaderOriginInText(self): z = dns.zone.from_text( - example_text, relativize=False, zone_factory=dns.versioned.Zone + example_text, relativize=False, zone_factory=self.zone_factory ) with z.reader(serial=1) as txn: rds = txn.get("example.", "soa") @@ -1256,7 +1259,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testNoRelativizeReaderAbsoluteGet(self): z = dns.zone.from_text( - example_text, "example.", relativize=False, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=False, zone_factory=self.zone_factory ) with z.reader(serial=1) as txn: rds = txn.get(dns.name.empty, "soa") @@ -1264,7 +1267,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testCnameAndOtherDataAddOther(self): z = dns.zone.from_text( - example_cname, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_cname, "example.", relativize=True, zone_factory=self.zone_factory ) rds = dns.rdataset.from_text("in", "a", 300, "10.0.0.1") with z.writer() as txn: @@ -1290,7 +1293,7 @@ class VersionedZoneTestCase(unittest.TestCase): example_other_data, "example.", relativize=True, - zone_factory=dns.versioned.Zone, + zone_factory=self.zone_factory, ) rds = dns.rdataset.from_text("in", "cname", 300, "www") with z.writer() as txn: @@ -1305,7 +1308,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testGetSoa(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) soa = z.get_soa() self.assertTrue(soa.rdtype, dns.rdatatype.SOA) @@ -1313,7 +1316,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testGetSoaTxn(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) with z.reader(serial=1) as txn: soa = z.get_soa(txn) @@ -1327,7 +1330,7 @@ class VersionedZoneTestCase(unittest.TestCase): def testGetRdataset1(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) rds = z.get_rdataset("@", "soa") exrds = dns.rdataset.from_text("IN", "SOA", 300, "foo bar 1 2 3 4 5") @@ -1335,11 +1338,15 @@ class VersionedZoneTestCase(unittest.TestCase): def testGetRdataset2(self): z = dns.zone.from_text( - example_text, "example.", relativize=True, zone_factory=dns.versioned.Zone + example_text, "example.", relativize=True, zone_factory=self.zone_factory ) rds = z.get_rdataset("@", "loc") self.assertTrue(rds is None) +class BTreeZoneTestCase(VersionedZoneTestCase): + zone_factory = dns.btreezone.Zone + + if __name__ == "__main__": unittest.main()