From: Bob Halley Date: Sat, 4 Nov 2023 23:22:34 +0000 (-0700) Subject: Add support for RFC 4471 predecessor() and successor() methods. (#1002) X-Git-Tag: v2.5.0rc1~31 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c007d40ec82b29d4ff6288770fe3d3a88b4094bd;p=thirdparty%2Fdnspython.git Add support for RFC 4471 predecessor() and successor() methods. (#1002) --- diff --git a/dns/name.py b/dns/name.py index f452bfed..2e44763c 100644 --- a/dns/name.py +++ b/dns/name.py @@ -20,8 +20,9 @@ import copy import encodings.idna # type: ignore +import functools import struct -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union try: import idna # type: ignore @@ -128,6 +129,10 @@ class IDNAException(dns.exception.DNSException): super().__init__(*args, **kwargs) +class NeedSubdomainOfOrigin(dns.exception.DNSException): + """An absolute name was provided that is not a subdomain of the specified origin.""" + + _escaped = b'"().;\\@$' _escaped_text = '"().;\\@$' @@ -843,6 +848,42 @@ class Name: raise NoParent return Name(self.labels[1:]) + def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name": + """Return the maximal predecessor of *name* in the DNSSEC ordering in the zone + whose origin is *origin*, or return the longest name under *origin* if the + name is origin (i.e. wrap around to the longest name, which may still be + *origin* due to length considerations. + + The relativity of the name is preserved, so if this name is relative + then the method will return a relative name, and likewise if this name + is absolute then the predecessor will be absolute. + + *prefix_ok* indicates if prefixing labels is allowed, and + defaults to ``True``. Normally it is good to allow this, but if computing + a maximal predecessor at a zone cut point then ``False`` must be specified. + """ + return _handle_relativity_and_call( + _absolute_predecessor, self, origin, prefix_ok + ) + + def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name": + """Return the minimal successor of *name* in the DNSSEC ordering in the zone + whose origin is *origin*, or return *origin* if the successor cannot be + computed due to name length limitations. + + Note that *origin* is returned in the "too long" cases because wrapping + around to the origin is how NSEC records express "end of the zone". + + The relativity of the name is preserved, so if this name is relative + then the method will return a relative name, and likewise if this name + is absolute then the successor will be absolute. + + *prefix_ok* indicates if prefixing a new minimal label is allowed, and + defaults to ``True``. Normally it is good to allow this, but if computing + a minimal successor at a zone cut point then ``False`` must be specified. + """ + return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok) + #: The root name, '.' root = Name([b""]) @@ -1082,3 +1123,161 @@ def from_wire(message: bytes, current: int) -> Tuple[Name, int]: parser = dns.wire.Parser(message, current) name = from_wire_parser(parser) return (name, parser.current - current) + + +# RFC 4471 Support + +_MINIMAL_OCTET = b"\x00" +_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET) +_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET]) +_MAXIMAL_OCTET = b"\xff" +_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET) +_AT_SIGN_VALUE = ord("@") +_LEFT_SQUARE_BRACKET_VALUE = ord("[") + + +def _wire_length(labels): + return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0) + + +def _pad_to_max_name(name): + needed = 255 - _wire_length(name.labels) + new_labels = [] + while needed > 64: + new_labels.append(_MAXIMAL_OCTET * 63) + needed -= 64 + if needed >= 2: + new_labels.append(_MAXIMAL_OCTET * (needed - 1)) + # Note we're already maximal in the needed == 1 case as while we'd like + # to add one more byte as a new label, we can't, as adding a new non-empty + # label requires at least 2 bytes. + new_labels = list(reversed(new_labels)) + new_labels.extend(name.labels) + return Name(new_labels) + + +def _pad_to_max_label(label, suffix_labels): + length = len(label) + # We have to subtract one here to account for the length byte of label. + remaining = 255 - _wire_length(suffix_labels) - length - 1 + if remaining <= 0: + # Shouldn't happen! + return label + needed = min(63 - length, remaining) + return label + _MAXIMAL_OCTET * needed + + +def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name: + # This is the RFC 4471 predecessor algorithm using the "absolute method" of section + # 3.1.1. + # + # Our caller must ensure that the name and origin are absolute, and that name is a + # subdomain of origin. + if name == origin: + return _pad_to_max_name(name) + least_significant_label = name[0] + if least_significant_label == _MINIMAL_OCTET: + return name.parent() + least_octet = least_significant_label[-1] + suffix_labels = name.labels[1:] + if least_octet == _MINIMAL_OCTET_VALUE: + new_labels = [least_significant_label[:-1]] + else: + octets = bytearray(least_significant_label) + octet = octets[-1] + if octet == _LEFT_SQUARE_BRACKET_VALUE: + octet = _AT_SIGN_VALUE + else: + octet -= 1 + octets[-1] = octet + least_significant_label = bytes(octets) + new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)] + new_labels.extend(suffix_labels) + name = Name(new_labels) + if prefix_ok: + return _pad_to_max_name(name) + else: + return name + + +def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name: + # This is the RFC 4471 successor algorithm using the "absolute method" of section + # 3.1.2. + # + # Our caller must ensure that the name and origin are absolute, and that name is a + # subdomain of origin. + if prefix_ok: + # Try prefixing \000 as new label + try: + return _SUCCESSOR_PREFIX.concatenate(name) + except NameTooLong: + pass + while name != origin: + # Try extending the least significant label. + least_significant_label = name[0] + if len(least_significant_label) < 63: + # We may be able to extend the least label with a minimal additional byte. + # This is only "may" because we could have a maximal length name even though + # the least significant label isn't maximally long. + new_labels = [least_significant_label + _MINIMAL_OCTET] + new_labels.extend(name.labels[1:]) + try: + return dns.name.Name(new_labels) + except dns.name.NameTooLong: + pass + # We can't extend the label either, so we'll try to increment the least + # signficant non-maximal byte in it. + octets = bytearray(least_significant_label) + # We do this reversed iteration with an explicit indexing variable because + # if we find something to increment, we're going to want to truncate everything + # to the right of it. + for i in range(len(octets) - 1, -1, -1): + octet = octets[i] + if octet == _MAXIMAL_OCTET_VALUE: + # We can't increment this, so keep looking. + continue + # Finally, something we can increment. We have to apply a special rule for + # incrementing "@", sending it to "[", because RFC 4034 6.1 says that when + # comparing names, uppercase letters compare as if they were their + # lower-case equivalents. If we increment "@" to "A", then it would compare + # as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have + # skipped the most minimal successor, namely "[". + if octet == _AT_SIGN_VALUE: + octet = _LEFT_SQUARE_BRACKET_VALUE + else: + octet += 1 + octets[i] = octet + # We can now truncate all of the maximal values we skipped (if any) + new_labels = [bytes(octets[: i + 1])] + new_labels.extend(name.labels[1:]) + # We haven't changed the length of the name, so the Name constructor will + # always work. + return Name(new_labels) + # We couldn't increment, so chop off the least significant label and try + # again. + name = name.parent() + + # We couldn't increment at all, so return the origin, as wrapping around is the + # DNSSEC way. + return origin + + +def _handle_relativity_and_call( + function: Callable[[Name, Name, bool], Name], + name: Name, + origin: Name, + prefix_ok: bool, +) -> Name: + # Make "name" absolute if needed, ensure that the origin is absolute, + # call function(), and then relativize the result if needed. + if not origin.is_absolute(): + raise NeedAbsoluteNameOrOrigin + relative = not name.is_absolute() + if relative: + name = name.derelativize(origin) + elif not name.is_subdomain(origin): + raise NeedSubdomainOfOrigin + result_name = function(name, origin, prefix_ok) + if relative: + result_name = result_name.relativize(origin) + return result_name diff --git a/doc/rfc.rst b/doc/rfc.rst index 23212d82..9e98bfa6 100644 --- a/doc/rfc.rst +++ b/doc/rfc.rst @@ -81,6 +81,9 @@ DNSSEC RFCs `RFC 4470 `_ Minimally covering NSEC records and On-line Signing. +`RFC 4471 `_ + Derivation of DNS Name Predecessor and Successor. + `RFC 5155 `_ DNS Security (DNSSEC) Hashed Authenticated Denial of Existence. [NSEC3] diff --git a/tests/test_name.py b/tests/test_name.py index 1d18ce3e..96c03459 100644 --- a/tests/test_name.py +++ b/tests/test_name.py @@ -16,21 +16,67 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -from typing import Dict # pylint: disable=unused-import import copy import operator import pickle import unittest - from io import BytesIO +from typing import Dict # pylint: disable=unused-import +import dns.e164 import dns.name import dns.reversename -import dns.e164 # pylint: disable=line-too-long,unsupported-assignment-operation +def expand(text): + # This is a helper routine to expand name patterns for RFC 4471 tests. + # + # Basically it turns {} into instances of the character. + # For example: + # + # r"fo{2}.example." => r"foo.example.". + # + # Two characters get special treatment: + # "-" is mapped to r"\000" and "+" is mapped to r"\255". For example + # + # r"+{3}-.example." -> r"\255\255\255\000.example." + # + # We do this just to make parsing simpler, so we don't have to process escapes + # ourselves. + i = 0 + l = len(text) + previous = "" + reading_count = False + count = 0 + expanded = [] + for c in text: + if c == "-": + c = r"\000" + elif c == "+": + c = r"\255" + if reading_count: + assert len(c) == 1 + if c >= "0" and c <= "9": + count *= 10 + count += ord(c) - ord("0") + elif c == "}": + expanded.append(previous * count) + previous = "" + reading_count = False + count = 0 + elif c == "{": + reading_count = True + else: + expanded.append(previous) + previous = c + # don't forget the last char (if there is one) + expanded.append(previous) + x = "".join(expanded) + return x + + class NameTestCase(unittest.TestCase): def setUp(self): self.origin = dns.name.from_text("example.") @@ -1139,6 +1185,104 @@ class NameTestCase(unittest.TestCase): n2 = pickle.loads(p) self.assertEqual(n1, n2) + def test_pad_to_max_name(self): + # Test edge cases in our padding helper. + tests = [ + ("o{61}.o{63}.o{63}.o{63}.", "o{61}.o{63}.o{63}.o{63}."), + ("o{60}.o{63}.o{63}.o{63}.", "o{60}.o{63}.o{63}.o{63}."), + ("o{59}.o{63}.o{63}.o{63}.", "+.o{59}.o{63}.o{63}.o{63}."), + ("o{63}.o{63}.o{63}.", "+{61}.o{63}.o{63}.o{63}."), + ("o{63}.o{63}.", "+{61}.+{63}.o{63}.o{63}."), + ] + for name_text, expected_text in tests: + name = dns.name.from_text(expand(name_text)) + expected = dns.name.from_text(expand(expected_text)) + self.assertEqual(dns.name._pad_to_max_name(name), expected) + + def test_predecessors_and_successors(self): + # Test RFC 4471 predecessor and successor methods. + + # Here we're actually testing the test suite, but expand() is complicated enough + # to deserve a little testing. + self.assertEqual(expand("f+{3}-o.example."), r"f\255\255\255\000o.example.") + + # Ok, now test successors! + origin = dns.name.from_text("example.com.") + tests = [ + # Examples from the RFC. + ("foo", True, "\\000.foo"), + ("foo", False, "foo\\000"), + # The syntax here is almost the RFC's "alternate syntax" except that to + # make expand simpler, i.e. not have to understand that \000 was one octet, + # I made the convention that "-" means 0 and "+" means 255. We use raw + # string constants where needed to avoid escaping backslash. + ( + "fo{47}.o{63}.o{63}.o{63}", + True, + "fo{47}-.o{63}.o{63}.o{63}", + ), + ( + "fo{48}.o{63}.o{63}.o{63}", + True, + "fo{47}p.o{63}.o{63}.o{63}", + ), + ("+{49}.o{63}.o{63}.o{63}", True, "o{62}p.o{63}.o{63}"), + ( + "fo{40}+{8}.o{63}.o{63}.o{63}", + True, + "fo{39}p.o{63}.o{63}.o{63}", + ), + ( + r"fo{47}\@.o{63}.o{63}.o{63}", + True, + r"fo{47}\[.o{63}.o{63}.o{63}", + ), + ("+{49}.+{63}.+{63}.+{63}", True, ""), + # Some more tests not in the RFC + ("+{49}.+{63}.o{63}.o{63}", True, "o{62}p.o{63}"), + ( + "+{48}.o{63}.o{63}.o{63}", + True, + "+{48}-.o{63}.o{63}.o{63}", + ), + ] + for test_origin in [origin, None]: + for name_text, prefix_ok, expected_successor_text in tests: + name = dns.name.from_text(expand(name_text), test_origin) + expected_successor = dns.name.from_text( + expand(expected_successor_text), test_origin + ) + successor = name.successor(origin, prefix_ok) + self.assertEqual(successor, expected_successor) + self.assertTrue( + successor > name + or successor == origin + or successor == dns.name.empty + ) + # Now test the predecessor + predecessor = successor.predecessor(origin, prefix_ok) + self.assertEqual(predecessor, name) + + # Finally, test that a maximal length origin is its own predecessor and + # successor. + origin = dns.name.from_text(expand("+{49}.+{63}.o{63}.o{63}.example.com.")) + assert origin.successor(origin, True) == origin + assert origin.predecessor(origin, True) == origin + + def test_predecessor_and_successor_errors(self): + name = dns.name.from_text("name", None) + origin = dns.name.from_text("origin", None) # note Relative! + with self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin): + name.successor(origin, True) + with self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin): + name.predecessor(origin, True) + name = dns.name.from_text("name.") # Note absolute + origin = dns.name.from_text("origin.") # 'name' is not a subdomain of 'origin' + with self.assertRaises(dns.name.NeedSubdomainOfOrigin): + name.successor(origin, True) + with self.assertRaises(dns.name.NeedSubdomainOfOrigin): + name.predecessor(origin, True) + if __name__ == "__main__": unittest.main()