import copy
import encodings.idna # type: ignore
+import functools
import struct
-from typing import Any, Dict, Iterable, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
try:
import idna # type: ignore
super().__init__(*args, **kwargs)
+class NeedSubdomainOfOrigin(dns.exception.DNSException):
+ """An absolute name was provided that is not a subdomain of the specified origin."""
+
+
_escaped = b'"().;\\@$'
_escaped_text = '"().;\\@$'
raise NoParent
return Name(self.labels[1:])
+ def predecessor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
+ """Return the maximal predecessor of *name* in the DNSSEC ordering in the zone
+ whose origin is *origin*, or return the longest name under *origin* if the
+ name is origin (i.e. wrap around to the longest name, which may still be
+ *origin* due to length considerations.
+
+ The relativity of the name is preserved, so if this name is relative
+ then the method will return a relative name, and likewise if this name
+ is absolute then the predecessor will be absolute.
+
+ *prefix_ok* indicates if prefixing labels is allowed, and
+ defaults to ``True``. Normally it is good to allow this, but if computing
+ a maximal predecessor at a zone cut point then ``False`` must be specified.
+ """
+ return _handle_relativity_and_call(
+ _absolute_predecessor, self, origin, prefix_ok
+ )
+
+ def successor(self, origin: "Name", prefix_ok: bool = True) -> "Name":
+ """Return the minimal successor of *name* in the DNSSEC ordering in the zone
+ whose origin is *origin*, or return *origin* if the successor cannot be
+ computed due to name length limitations.
+
+ Note that *origin* is returned in the "too long" cases because wrapping
+ around to the origin is how NSEC records express "end of the zone".
+
+ The relativity of the name is preserved, so if this name is relative
+ then the method will return a relative name, and likewise if this name
+ is absolute then the successor will be absolute.
+
+ *prefix_ok* indicates if prefixing a new minimal label is allowed, and
+ defaults to ``True``. Normally it is good to allow this, but if computing
+ a minimal successor at a zone cut point then ``False`` must be specified.
+ """
+ return _handle_relativity_and_call(_absolute_successor, self, origin, prefix_ok)
+
#: The root name, '.'
root = Name([b""])
parser = dns.wire.Parser(message, current)
name = from_wire_parser(parser)
return (name, parser.current - current)
+
+
+# RFC 4471 Support
+
+_MINIMAL_OCTET = b"\x00"
+_MINIMAL_OCTET_VALUE = ord(_MINIMAL_OCTET)
+_SUCCESSOR_PREFIX = Name([_MINIMAL_OCTET])
+_MAXIMAL_OCTET = b"\xff"
+_MAXIMAL_OCTET_VALUE = ord(_MAXIMAL_OCTET)
+_AT_SIGN_VALUE = ord("@")
+_LEFT_SQUARE_BRACKET_VALUE = ord("[")
+
+
+def _wire_length(labels):
+ return functools.reduce(lambda v, x: v + len(x) + 1, labels, 0)
+
+
+def _pad_to_max_name(name):
+ needed = 255 - _wire_length(name.labels)
+ new_labels = []
+ while needed > 64:
+ new_labels.append(_MAXIMAL_OCTET * 63)
+ needed -= 64
+ if needed >= 2:
+ new_labels.append(_MAXIMAL_OCTET * (needed - 1))
+ # Note we're already maximal in the needed == 1 case as while we'd like
+ # to add one more byte as a new label, we can't, as adding a new non-empty
+ # label requires at least 2 bytes.
+ new_labels = list(reversed(new_labels))
+ new_labels.extend(name.labels)
+ return Name(new_labels)
+
+
+def _pad_to_max_label(label, suffix_labels):
+ length = len(label)
+ # We have to subtract one here to account for the length byte of label.
+ remaining = 255 - _wire_length(suffix_labels) - length - 1
+ if remaining <= 0:
+ # Shouldn't happen!
+ return label
+ needed = min(63 - length, remaining)
+ return label + _MAXIMAL_OCTET * needed
+
+
+def _absolute_predecessor(name: Name, origin: Name, prefix_ok: bool) -> Name:
+ # This is the RFC 4471 predecessor algorithm using the "absolute method" of section
+ # 3.1.1.
+ #
+ # Our caller must ensure that the name and origin are absolute, and that name is a
+ # subdomain of origin.
+ if name == origin:
+ return _pad_to_max_name(name)
+ least_significant_label = name[0]
+ if least_significant_label == _MINIMAL_OCTET:
+ return name.parent()
+ least_octet = least_significant_label[-1]
+ suffix_labels = name.labels[1:]
+ if least_octet == _MINIMAL_OCTET_VALUE:
+ new_labels = [least_significant_label[:-1]]
+ else:
+ octets = bytearray(least_significant_label)
+ octet = octets[-1]
+ if octet == _LEFT_SQUARE_BRACKET_VALUE:
+ octet = _AT_SIGN_VALUE
+ else:
+ octet -= 1
+ octets[-1] = octet
+ least_significant_label = bytes(octets)
+ new_labels = [_pad_to_max_label(least_significant_label, suffix_labels)]
+ new_labels.extend(suffix_labels)
+ name = Name(new_labels)
+ if prefix_ok:
+ return _pad_to_max_name(name)
+ else:
+ return name
+
+
+def _absolute_successor(name: Name, origin: Name, prefix_ok: bool) -> Name:
+ # This is the RFC 4471 successor algorithm using the "absolute method" of section
+ # 3.1.2.
+ #
+ # Our caller must ensure that the name and origin are absolute, and that name is a
+ # subdomain of origin.
+ if prefix_ok:
+ # Try prefixing \000 as new label
+ try:
+ return _SUCCESSOR_PREFIX.concatenate(name)
+ except NameTooLong:
+ pass
+ while name != origin:
+ # Try extending the least significant label.
+ least_significant_label = name[0]
+ if len(least_significant_label) < 63:
+ # We may be able to extend the least label with a minimal additional byte.
+ # This is only "may" because we could have a maximal length name even though
+ # the least significant label isn't maximally long.
+ new_labels = [least_significant_label + _MINIMAL_OCTET]
+ new_labels.extend(name.labels[1:])
+ try:
+ return dns.name.Name(new_labels)
+ except dns.name.NameTooLong:
+ pass
+ # We can't extend the label either, so we'll try to increment the least
+ # signficant non-maximal byte in it.
+ octets = bytearray(least_significant_label)
+ # We do this reversed iteration with an explicit indexing variable because
+ # if we find something to increment, we're going to want to truncate everything
+ # to the right of it.
+ for i in range(len(octets) - 1, -1, -1):
+ octet = octets[i]
+ if octet == _MAXIMAL_OCTET_VALUE:
+ # We can't increment this, so keep looking.
+ continue
+ # Finally, something we can increment. We have to apply a special rule for
+ # incrementing "@", sending it to "[", because RFC 4034 6.1 says that when
+ # comparing names, uppercase letters compare as if they were their
+ # lower-case equivalents. If we increment "@" to "A", then it would compare
+ # as "a", which is after "[", "\", "]", "^", "_", and "`", so we would have
+ # skipped the most minimal successor, namely "[".
+ if octet == _AT_SIGN_VALUE:
+ octet = _LEFT_SQUARE_BRACKET_VALUE
+ else:
+ octet += 1
+ octets[i] = octet
+ # We can now truncate all of the maximal values we skipped (if any)
+ new_labels = [bytes(octets[: i + 1])]
+ new_labels.extend(name.labels[1:])
+ # We haven't changed the length of the name, so the Name constructor will
+ # always work.
+ return Name(new_labels)
+ # We couldn't increment, so chop off the least significant label and try
+ # again.
+ name = name.parent()
+
+ # We couldn't increment at all, so return the origin, as wrapping around is the
+ # DNSSEC way.
+ return origin
+
+
+def _handle_relativity_and_call(
+ function: Callable[[Name, Name, bool], Name],
+ name: Name,
+ origin: Name,
+ prefix_ok: bool,
+) -> Name:
+ # Make "name" absolute if needed, ensure that the origin is absolute,
+ # call function(), and then relativize the result if needed.
+ if not origin.is_absolute():
+ raise NeedAbsoluteNameOrOrigin
+ relative = not name.is_absolute()
+ if relative:
+ name = name.derelativize(origin)
+ elif not name.is_subdomain(origin):
+ raise NeedSubdomainOfOrigin
+ result_name = function(name, origin, prefix_ok)
+ if relative:
+ result_name = result_name.relativize(origin)
+ return result_name
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-from typing import Dict # pylint: disable=unused-import
import copy
import operator
import pickle
import unittest
-
from io import BytesIO
+from typing import Dict # pylint: disable=unused-import
+import dns.e164
import dns.name
import dns.reversename
-import dns.e164
# pylint: disable=line-too-long,unsupported-assignment-operation
+def expand(text):
+ # This is a helper routine to expand name patterns for RFC 4471 tests.
+ #
+ # Basically it turns <character>{<n>} into <n> instances of the character.
+ # For example:
+ #
+ # r"fo{2}.example." => r"foo.example.".
+ #
+ # Two characters get special treatment:
+ # "-" is mapped to r"\000" and "+" is mapped to r"\255". For example
+ #
+ # r"+{3}-.example." -> r"\255\255\255\000.example."
+ #
+ # We do this just to make parsing simpler, so we don't have to process escapes
+ # ourselves.
+ i = 0
+ l = len(text)
+ previous = ""
+ reading_count = False
+ count = 0
+ expanded = []
+ for c in text:
+ if c == "-":
+ c = r"\000"
+ elif c == "+":
+ c = r"\255"
+ if reading_count:
+ assert len(c) == 1
+ if c >= "0" and c <= "9":
+ count *= 10
+ count += ord(c) - ord("0")
+ elif c == "}":
+ expanded.append(previous * count)
+ previous = ""
+ reading_count = False
+ count = 0
+ elif c == "{":
+ reading_count = True
+ else:
+ expanded.append(previous)
+ previous = c
+ # don't forget the last char (if there is one)
+ expanded.append(previous)
+ x = "".join(expanded)
+ return x
+
+
class NameTestCase(unittest.TestCase):
def setUp(self):
self.origin = dns.name.from_text("example.")
n2 = pickle.loads(p)
self.assertEqual(n1, n2)
+ def test_pad_to_max_name(self):
+ # Test edge cases in our padding helper.
+ tests = [
+ ("o{61}.o{63}.o{63}.o{63}.", "o{61}.o{63}.o{63}.o{63}."),
+ ("o{60}.o{63}.o{63}.o{63}.", "o{60}.o{63}.o{63}.o{63}."),
+ ("o{59}.o{63}.o{63}.o{63}.", "+.o{59}.o{63}.o{63}.o{63}."),
+ ("o{63}.o{63}.o{63}.", "+{61}.o{63}.o{63}.o{63}."),
+ ("o{63}.o{63}.", "+{61}.+{63}.o{63}.o{63}."),
+ ]
+ for name_text, expected_text in tests:
+ name = dns.name.from_text(expand(name_text))
+ expected = dns.name.from_text(expand(expected_text))
+ self.assertEqual(dns.name._pad_to_max_name(name), expected)
+
+ def test_predecessors_and_successors(self):
+ # Test RFC 4471 predecessor and successor methods.
+
+ # Here we're actually testing the test suite, but expand() is complicated enough
+ # to deserve a little testing.
+ self.assertEqual(expand("f+{3}-o.example."), r"f\255\255\255\000o.example.")
+
+ # Ok, now test successors!
+ origin = dns.name.from_text("example.com.")
+ tests = [
+ # Examples from the RFC.
+ ("foo", True, "\\000.foo"),
+ ("foo", False, "foo\\000"),
+ # The syntax here is almost the RFC's "alternate syntax" except that to
+ # make expand simpler, i.e. not have to understand that \000 was one octet,
+ # I made the convention that "-" means 0 and "+" means 255. We use raw
+ # string constants where needed to avoid escaping backslash.
+ (
+ "fo{47}.o{63}.o{63}.o{63}",
+ True,
+ "fo{47}-.o{63}.o{63}.o{63}",
+ ),
+ (
+ "fo{48}.o{63}.o{63}.o{63}",
+ True,
+ "fo{47}p.o{63}.o{63}.o{63}",
+ ),
+ ("+{49}.o{63}.o{63}.o{63}", True, "o{62}p.o{63}.o{63}"),
+ (
+ "fo{40}+{8}.o{63}.o{63}.o{63}",
+ True,
+ "fo{39}p.o{63}.o{63}.o{63}",
+ ),
+ (
+ r"fo{47}\@.o{63}.o{63}.o{63}",
+ True,
+ r"fo{47}\[.o{63}.o{63}.o{63}",
+ ),
+ ("+{49}.+{63}.+{63}.+{63}", True, ""),
+ # Some more tests not in the RFC
+ ("+{49}.+{63}.o{63}.o{63}", True, "o{62}p.o{63}"),
+ (
+ "+{48}.o{63}.o{63}.o{63}",
+ True,
+ "+{48}-.o{63}.o{63}.o{63}",
+ ),
+ ]
+ for test_origin in [origin, None]:
+ for name_text, prefix_ok, expected_successor_text in tests:
+ name = dns.name.from_text(expand(name_text), test_origin)
+ expected_successor = dns.name.from_text(
+ expand(expected_successor_text), test_origin
+ )
+ successor = name.successor(origin, prefix_ok)
+ self.assertEqual(successor, expected_successor)
+ self.assertTrue(
+ successor > name
+ or successor == origin
+ or successor == dns.name.empty
+ )
+ # Now test the predecessor
+ predecessor = successor.predecessor(origin, prefix_ok)
+ self.assertEqual(predecessor, name)
+
+ # Finally, test that a maximal length origin is its own predecessor and
+ # successor.
+ origin = dns.name.from_text(expand("+{49}.+{63}.o{63}.o{63}.example.com."))
+ assert origin.successor(origin, True) == origin
+ assert origin.predecessor(origin, True) == origin
+
+ def test_predecessor_and_successor_errors(self):
+ name = dns.name.from_text("name", None)
+ origin = dns.name.from_text("origin", None) # note Relative!
+ with self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin):
+ name.successor(origin, True)
+ with self.assertRaises(dns.name.NeedAbsoluteNameOrOrigin):
+ name.predecessor(origin, True)
+ name = dns.name.from_text("name.") # Note absolute
+ origin = dns.name.from_text("origin.") # 'name' is not a subdomain of 'origin'
+ with self.assertRaises(dns.name.NeedSubdomainOfOrigin):
+ name.successor(origin, True)
+ with self.assertRaises(dns.name.NeedSubdomainOfOrigin):
+ name.predecessor(origin, True)
+
if __name__ == "__main__":
unittest.main()