]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add support for RFC 4471 predecessor() and successor() methods. (#1002)
authorBob Halley <halley@dnspython.org>
Sat, 4 Nov 2023 23:22:34 +0000 (16:22 -0700)
committerGitHub <noreply@github.com>
Sat, 4 Nov 2023 23:22:34 +0000 (16:22 -0700)
dns/name.py
doc/rfc.rst
tests/test_name.py

index f452bfed7f636724dc551a7c323ca400a1f4e2b7..2e44763ca8144a9b8caf626b01776cb7ea75e38e 100644 (file)
@@ -20,8 +20,9 @@
 
 import copy
 import encodings.idna  # type: ignore
+import functools
 import struct
-from typing import Any, Dict, Iterable, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
 
 try:
     import idna  # type: ignore
@@ -128,6 +129,10 @@ class IDNAException(dns.exception.DNSException):
         super().__init__(*args, **kwargs)
 
 
+class NeedSubdomainOfOrigin(dns.exception.DNSException):
+    """An absolute name was provided that is not a subdomain of the specified origin."""
+
+
 _escaped = b'"().;\\@$'
 _escaped_text = '"().;\\@$'
 
@@ -843,6 +848,42 @@ class Name:
             raise NoParent
         return Name(self.labels[1:])
 
+    def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
+        """Return the maximal predecessor of *name* in the DNSSEC ordering in the zone
+        whose origin is *origin*, or return the longest name under *origin* if the
+        name is origin (i.e. wrap around to the longest name, which may still be
+        *origin* due to length considerations.
+
+        The relativity of the name is preserved, so if this name is relative
+        then the method will return a relative name, and likewise if this name
+        is absolute then the predecessor will be absolute.
+
+        *prefix_ok* indicates if prefixing labels is allowed, and
+        defaults to ``True``.  Normally it is good to allow this, but if computing
+        a maximal predecessor at a zone cut point then ``False`` must be specified.
+        """
+        return _handle_relativity_and_call(
+            _absolute_predecessor, self, origin, prefix_ok
+        )
+
+    def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
+        """Return the minimal successor of *name* in the DNSSEC ordering in the zone
+        whose origin is *origin*, or return *origin* if the successor cannot be
+        computed due to name length limitations.
+
+        Note that *origin* is returned in the "too long" cases because wrapping
+        around to the origin is how NSEC records express "end of the zone".
+
+        The relativity of the name is preserved, so if this name is relative
+        then the method will return a relative name, and likewise if this name
+        is absolute then the successor will be absolute.
+
+        *prefix_ok* indicates if prefixing a new minimal label is allowed, and
+        defaults to ``True``.  Normally it is good to allow this, but if computing
+        a minimal successor at a zone cut point then ``False`` must be specified.
+        """
+        return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok)
+
 
 #: The root name, '.'
 root = Name([b""])
@@ -1082,3 +1123,161 @@ def from_wire(message: bytes, current: int) -> Tuple[Name, int]:
     parser = dns.wire.Parser(message, current)
     name = from_wire_parser(parser)
     return (name, parser.current - current)
+
+
+# RFC 4471 Support
+
+_MINIMAL_OCTET = b"\x00"
+_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET)
+_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET])
+_MAXIMAL_OCTET = b"\xff"
+_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET)
+_AT_SIGN_VALUE = ord("@")
+_LEFT_SQUARE_BRACKET_VALUE = ord("[")
+
+
+def _wire_length(labels):
+    return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0)
+
+
+def _pad_to_max_name(name):
+    needed = 255 - _wire_length(name.labels)
+    new_labels = []
+    while needed > 64:
+        new_labels.append(_MAXIMAL_OCTET * 63)
+        needed -= 64
+    if needed >= 2:
+        new_labels.append(_MAXIMAL_OCTET * (needed - 1))
+    # Note we're already maximal in the needed == 1 case as while we'd like
+    # to add one more byte as a new label, we can't, as adding a new non-empty
+    # label requires at least 2 bytes.
+    new_labels = list(reversed(new_labels))
+    new_labels.extend(name.labels)
+    return Name(new_labels)
+
+
+def _pad_to_max_label(label, suffix_labels):
+    length = len(label)
+    # We have to subtract one here to account for the length byte of label.
+    remaining = 255 - _wire_length(suffix_labels) - length - 1
+    if remaining <= 0:
+        # Shouldn't happen!
+        return label
+    needed = min(63 - length, remaining)
+    return label + _MAXIMAL_OCTET * needed
+
+
+def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name:
+    # This is the RFC 4471 predecessor algorithm using the "absolute method" of section
+    # 3.1.1.
+    #
+    # Our caller must ensure that the name and origin are absolute, and that name is a
+    # subdomain of origin.
+    if name == origin:
+        return _pad_to_max_name(name)
+    least_significant_label = name[0]
+    if least_significant_label == _MINIMAL_OCTET:
+        return name.parent()
+    least_octet = least_significant_label[-1]
+    suffix_labels = name.labels[1:]
+    if least_octet == _MINIMAL_OCTET_VALUE:
+        new_labels = [least_significant_label[:-1]]
+    else:
+        octets = bytearray(least_significant_label)
+        octet = octets[-1]
+        if octet == _LEFT_SQUARE_BRACKET_VALUE:
+            octet = _AT_SIGN_VALUE
+        else:
+            octet -= 1
+        octets[-1] = octet
+        least_significant_label = bytes(octets)
+        new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)]
+    new_labels.extend(suffix_labels)
+    name = Name(new_labels)
+    if prefix_ok:
+        return _pad_to_max_name(name)
+    else:
+        return name
+
+
+def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name:
+    # This is the RFC 4471 successor algorithm using the "absolute method" of section
+    # 3.1.2.
+    #
+    # Our caller must ensure that the name and origin are absolute, and that name is a
+    # subdomain of origin.
+    if prefix_ok:
+        # Try prefixing \000 as new label
+        try:
+            return _SUCCESSOR_PREFIX.concatenate(name)
+        except NameTooLong:
+            pass
+    while name != origin:
+        # Try extending the least significant label.
+        least_significant_label = name[0]
+        if len(least_significant_label) < 63:
+            # We may be able to extend the least label with a minimal additional byte.
+            # This is only "may" because we could have a maximal length name even though
+            # the least significant label isn't maximally long.
+            new_labels = [least_significant_label + _MINIMAL_OCTET]
+            new_labels.extend(name.labels[1:])
+            try:
+                return dns.name.Name(new_labels)
+            except dns.name.NameTooLong:
+                pass
+        # We can't extend the label either, so we'll try to increment the least
+        # signficant non-maximal byte in it.
+        octets = bytearray(least_significant_label)
+        # We do this reversed iteration with an explicit indexing variable because
+        # if we find something to increment, we're going to want to truncate everything
+        # to the right of it.
+        for i in range(len(octets) - 1, -1, -1):
+            octet = octets[i]
+            if octet == _MAXIMAL_OCTET_VALUE:
+                # We can't increment this, so keep looking.
+                continue
+            # Finally, something we can increment.  We have to apply a special rule for
+            # incrementing "@", sending it to "[", because RFC 4034 6.1 says that when
+            # comparing names, uppercase letters compare as if they were their
+            # lower-case equivalents. If we increment "@" to "A", then it would compare
+            # as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have
+            # skipped the most minimal successor, namely "[".
+            if octet == _AT_SIGN_VALUE:
+                octet = _LEFT_SQUARE_BRACKET_VALUE
+            else:
+                octet += 1
+            octets[i] = octet
+            # We can now truncate all of the maximal values we skipped (if any)
+            new_labels = [bytes(octets[: i + 1])]
+            new_labels.extend(name.labels[1:])
+            # We haven't changed the length of the name, so the Name constructor will
+            # always work.
+            return Name(new_labels)
+        # We couldn't increment, so chop off the least significant label and try
+        # again.
+        name = name.parent()
+
+    # We couldn't increment at all, so return the origin, as wrapping around is the
+    # DNSSEC way.
+    return origin
+
+
+def _handle_relativity_and_call(
+    function: Callable[[Name, Name, bool], Name],
+    name: Name,
+    origin: Name,
+    prefix_ok: bool,
+) -> Name:
+    # Make "name" absolute if needed, ensure that the origin is absolute,
+    # call function(), and then relativize the result if needed.
+    if not origin.is_absolute():
+        raise NeedAbsoluteNameOrOrigin
+    relative = not name.is_absolute()
+    if relative:
+        name = name.derelativize(origin)
+    elif not name.is_subdomain(origin):
+        raise NeedSubdomainOfOrigin
+    result_name = function(name, origin, prefix_ok)
+    if relative:
+        result_name = result_name.relativize(origin)
+    return result_name
index 23212d8226e1ee7f43d0a1a29ebeeeb1951373ea..9e98bfa640d3240dd2dd10b1e4b0fb0e3f6e047d 100644 (file)
@@ -81,6 +81,9 @@ DNSSEC RFCs
 `RFC 4470 <https://tools.ietf.org/html/rfc4470>`_
     Minimally covering NSEC records and On-line Signing.
 
+`RFC 4471 <https://tools.ietf.org/html/rfc4471>`_
+    Derivation of DNS Name Predecessor and Successor.
+
 `RFC 5155 <https://tools.ietf.org/html/rfc5155>`_
     DNS Security (DNSSEC) Hashed Authenticated Denial of Existence.  [NSEC3]
 
index 1d18ce3e3ae4ec725e0ff987b71da36d4edba068..96c0345960225ad078167a36e43e74e435d00868 100644 (file)
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-from typing import Dict  # pylint: disable=unused-import
 import copy
 import operator
 import pickle
 import unittest
-
 from io import BytesIO
+from typing import Dict  # pylint: disable=unused-import
 
+import dns.e164
 import dns.name
 import dns.reversename
-import dns.e164
 
 # pylint: disable=line-too-long,unsupported-assignment-operation
 
 
+def expand(text):
+    # This is a helper routine to expand name patterns for RFC 4471 tests.
+    #
+    # Basically it turns <character>{<n>} into <n> instances of the character.
+    # For example:
+    #
+    #       r"fo{2}.example." => r"foo.example.".
+    #
+    # Two characters get special treatment:
+    # "-" is mapped to r"\000" and "+" is mapped to r"\255".  For example
+    #
+    #       r"+{3}-.example." -> r"\255\255\255\000.example."
+    #
+    # We do this just to make parsing simpler, so we don't have to process escapes
+    # ourselves.
+    i = 0
+    l = len(text)
+    previous = ""
+    reading_count = False
+    count = 0
+    expanded = []
+    for c in text:
+        if c == "-":
+            c = r"\000"
+        elif c == "+":
+            c = r"\255"
+        if reading_count:
+            assert len(c) == 1
+            if c >= "0" and c <= "9":
+                count *= 10
+                count += ord(c) - ord("0")
+            elif c == "}":
+                expanded.append(previous * count)
+                previous = ""
+                reading_count = False
+                count = 0
+        elif c == "{":
+            reading_count = True
+        else:
+            expanded.append(previous)
+            previous = c
+    # don't forget the last char (if there is one)
+    expanded.append(previous)
+    x = "".join(expanded)
+    return x
+
+
 class NameTestCase(unittest.TestCase):
     def setUp(self):
         self.origin = dns.name.from_text("example.")
@@ -1139,6 +1185,104 @@ class NameTestCase(unittest.TestCase):
         n2 = pickle.loads(p)
         self.assertEqual(n1, n2)
 
+    def test_pad_to_max_name(self):
+        # Test edge cases in our padding helper.
+        tests = [
+            ("o{61}.o{63}.o{63}.o{63}.", "o{61}.o{63}.o{63}.o{63}."),
+            ("o{60}.o{63}.o{63}.o{63}.", "o{60}.o{63}.o{63}.o{63}."),
+            ("o{59}.o{63}.o{63}.o{63}.", "+.o{59}.o{63}.o{63}.o{63}."),
+            ("o{63}.o{63}.o{63}.", "+{61}.o{63}.o{63}.o{63}."),
+            ("o{63}.o{63}.", "+{61}.+{63}.o{63}.o{63}."),
+        ]
+        for name_text, expected_text in tests:
+            name = dns.name.from_text(expand(name_text))
+            expected = dns.name.from_text(expand(expected_text))
+            self.assertEqual(dns.name._pad_to_max_name(name), expected)
+
+    def test_predecessors_and_successors(self):
+        # Test RFC 4471 predecessor and successor methods.
+
+        # Here we're actually testing the test suite, but expand() is complicated enough
+        # to deserve a little testing.
+        self.assertEqual(expand("f+{3}-o.example."), r"f\255\255\255\000o.example.")
+
+        # Ok, now test successors!
+        origin = dns.name.from_text("example.com.")
+        tests = [
+            # Examples from the RFC.
+            ("foo", True, "\\000.foo"),
+            ("foo", False, "foo\\000"),
+            # The syntax here is almost the RFC's "alternate syntax"  except that to
+            # make expand simpler, i.e. not have to understand that \000 was one octet,
+            # I made the convention that "-" means 0 and "+" means 255.  We use raw
+            # string constants where needed to avoid escaping backslash.
+            (
+                "fo{47}.o{63}.o{63}.o{63}",
+                True,
+                "fo{47}-.o{63}.o{63}.o{63}",
+            ),
+            (
+                "fo{48}.o{63}.o{63}.o{63}",
+                True,
+                "fo{47}p.o{63}.o{63}.o{63}",
+            ),
+            ("+{49}.o{63}.o{63}.o{63}", True, "o{62}p.o{63}.o{63}"),
+            (
+                "fo{40}+{8}.o{63}.o{63}.o{63}",
+                True,
+                "fo{39}p.o{63}.o{63}.o{63}",
+            ),
+            (
+                r"fo{47}\@.o{63}.o{63}.o{63}",
+                True,
+                r"fo{47}\[.o{63}.o{63}.o{63}",
+            ),
+            ("+{49}.+{63}.+{63}.+{63}", True, ""),
+            # Some more tests not in the RFC
+            ("+{49}.+{63}.o{63}.o{63}", True, "o{62}p.o{63}"),
+            (
+                "+{48}.o{63}.o{63}.o{63}",
+                True,
+                "+{48}-.o{63}.o{63}.o{63}",
+            ),
+        ]
+        for test_origin in [origin, None]:
+            for name_text, prefix_ok, expected_successor_text in tests:
+                name = dns.name.from_text(expand(name_text), test_origin)
+                expected_successor = dns.name.from_text(
+                    expand(expected_successor_text), test_origin
+                )
+                successor = name.successor(origin, prefix_ok)
+                self.assertEqual(successor, expected_successor)
+                self.assertTrue(
+                    successor > name
+                    or successor == origin
+                    or successor == dns.name.empty
+                )
+                # Now test the predecessor
+                predecessor = successor.predecessor(origin, prefix_ok)
+                self.assertEqual(predecessor, name)
+
+        # Finally, test that a maximal length origin is its own predecessor and
+        # successor.
+        origin = dns.name.from_text(expand("+{49}.+{63}.o{63}.o{63}.example.com."))
+        assert origin.successor(origin, True) == origin
+        assert origin.predecessor(origin, True) == origin
+
+    def test_predecessor_and_successor_errors(self):
+        name = dns.name.from_text("name", None)
+        origin = dns.name.from_text("origin", None)  # note Relative!
+        with self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin):
+            name.successor(origin, True)
+        with self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin):
+            name.predecessor(origin, True)
+        name = dns.name.from_text("name.")  # Note absolute
+        origin = dns.name.from_text("origin.")  # 'name' is not a subdomain of 'origin'
+        with self.assertRaises(dns.name.NeedSubdomainOfOrigin):
+            name.successor(origin, True)
+        with self.assertRaises(dns.name.NeedSubdomainOfOrigin):
+            name.predecessor(origin, True)
+
 
 if __name__ == "__main__":
     unittest.main()