From: Bob Halley Date: Wed, 24 Dec 2025 01:16:00 +0000 (-0800) Subject: break name/wire circular imports; name type tweaks X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d701b6bc8142d23361b841e6f9b7aa44873ee57f;p=thirdparty%2Fdnspython.git break name/wire circular imports; name type tweaks --- diff --git a/dns/__init__.py b/dns/__init__.py index d30fd742..df8edbda 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -63,6 +63,7 @@ __all__ = [ "version", "versioned", "wire", + "wirebase", "xfr", "zone", "zonetypes", diff --git a/dns/name.py b/dns/name.py index 53f08c7e..f8bc0f15 100644 --- a/dns/name.py +++ b/dns/name.py @@ -28,7 +28,7 @@ import dns._features import dns.enum import dns.exception import dns.immutable -import dns.wire +import dns.wirebase # Dnspython will never access idna if the import fails, but pyright can't figure # that out, so... @@ -356,9 +356,8 @@ def _maybe_convert_to_binary(label: bytes | str) -> bytes: if isinstance(label, bytes): return label - if isinstance(label, str): + else: return label.encode() - raise ValueError # pragma: no cover @dns.immutable.immutable @@ -743,7 +742,7 @@ class Name: return len(self.labels) - def __getitem__(self, index): + def __getitem__(self, index: Any) -> Any: return self.labels[index] def __add__(self, other): @@ -798,7 +797,7 @@ class Name: Returns a ``dns.name.Name``. """ - if origin is not None and self.is_subdomain(origin): + if self.is_subdomain(origin): return Name(self[: -len(origin)]) else: return self @@ -1070,13 +1069,10 @@ def from_text( return Name(labels) -# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other. - - -def from_wire_parser(parser: "dns.wire.Parser") -> Name: +def from_wire_parser(parser: dns.wirebase.Parser) -> Name: """Convert possibly compressed wire format into a Name. - *parser* is a dns.wire.Parser. + *parser* is a dns.wirebase.Parser. Raises ``dns.name.BadPointer`` if a compression pointer did not point backwards in the message. @@ -1125,9 +1121,7 @@ def from_wire(message: bytes, current: int) -> tuple[Name, int]: which were consumed reading it. """ - if not isinstance(message, bytes): - raise ValueError("input to from_wire() must be a byte string") - parser = dns.wire.Parser(message, current) + parser = dns.wirebase.Parser(message, current) name = from_wire_parser(parser) return (name, parser.current - current) diff --git a/dns/wire.py b/dns/wire.py index ec06b196..cff1abac 100644 --- a/dns/wire.py +++ b/dns/wire.py @@ -1,98 +1,13 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -import contextlib -import struct -from collections.abc import Iterator - -import dns.exception import dns.name +import dns.wirebase -class Parser: - """Helper class for parsing DNS wire format.""" - - def __init__(self, wire: bytes, current: int = 0): - """Initialize a Parser - - *wire*, a ``bytes`` contains the data to be parsed, and possibly other data. - Typically it is the whole message or a slice of it. - - *current*, an `int`, the offset within *wire* where parsing should begin. - """ - self.wire = wire - self.current = 0 - self.end = len(self.wire) - if current: - self.seek(current) - self.furthest = current - - def remaining(self) -> int: - return self.end - self.current - - def get_bytes(self, size: int) -> bytes: - assert size >= 0 - if size > self.remaining(): - raise dns.exception.FormError - output = self.wire[self.current : self.current + size] - self.current += size - self.furthest = max(self.furthest, self.current) - return output - - def get_counted_bytes(self, length_size: int = 1) -> bytes: - length = int.from_bytes(self.get_bytes(length_size), "big") - return self.get_bytes(length) - - def get_remaining(self) -> bytes: - return self.get_bytes(self.remaining()) +class Parser(dns.wirebase.Parser): - def get_uint8(self) -> int: - return struct.unpack("!B", self.get_bytes(1))[0] - - def get_uint16(self) -> int: - return struct.unpack("!H", self.get_bytes(2))[0] - - def get_uint32(self) -> int: - return struct.unpack("!I", self.get_bytes(4))[0] - - def get_uint48(self) -> int: - return int.from_bytes(self.get_bytes(6), "big") - - def get_struct(self, format: str) -> tuple: - return struct.unpack(format, self.get_bytes(struct.calcsize(format))) - - def get_name(self, origin: "dns.name.Name | None" = None) -> "dns.name.Name": + def get_name(self, origin: dns.name.Name | None = None) -> dns.name.Name: name = dns.name.from_wire_parser(self) if origin: name = name.relativize(origin) return name - - def seek(self, where: int) -> None: - # Note that seeking to the end is OK! (If you try to read - # after such a seek, you'll get an exception as expected.) - if where < 0 or where > self.end: - raise dns.exception.FormError - self.current = where - - @contextlib.contextmanager - def restrict_to(self, size: int) -> Iterator: - assert size >= 0 - if size > self.remaining(): - raise dns.exception.FormError - saved_end = self.end - try: - self.end = self.current + size - yield - # We make this check here and not in the finally as we - # don't want to raise if we're already raising for some - # other reason. - if self.current != self.end: - raise dns.exception.FormError - finally: - self.end = saved_end - - @contextlib.contextmanager - def restore_furthest(self) -> Iterator: - try: - yield None - finally: - self.current = self.furthest diff --git a/dns/wirebase.py b/dns/wirebase.py new file mode 100644 index 00000000..299558c3 --- /dev/null +++ b/dns/wirebase.py @@ -0,0 +1,93 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# We have wirebase and wire to avoid circularity between name.py and wire.py + +import contextlib +import struct +from collections.abc import Iterator + +import dns.exception + + +class Parser: + """Helper class for parsing DNS wire format.""" + + def __init__(self, wire: bytes, current: int = 0): + """Initialize a Parser + + *wire*, a ``bytes`` contains the data to be parsed, and possibly other data. + Typically it is the whole message or a slice of it. + + *current*, an `int`, the offset within *wire* where parsing should begin. + """ + self.wire = wire + self.current = 0 + self.end = len(self.wire) + if current: + self.seek(current) + self.furthest = current + + def remaining(self) -> int: + return self.end - self.current + + def get_bytes(self, size: int) -> bytes: + assert size >= 0 + if size > self.remaining(): + raise dns.exception.FormError + output = self.wire[self.current : self.current + size] + self.current += size + self.furthest = max(self.furthest, self.current) + return output + + def get_counted_bytes(self, length_size: int = 1) -> bytes: + length = int.from_bytes(self.get_bytes(length_size), "big") + return self.get_bytes(length) + + def get_remaining(self) -> bytes: + return self.get_bytes(self.remaining()) + + def get_uint8(self) -> int: + return struct.unpack("!B", self.get_bytes(1))[0] + + def get_uint16(self) -> int: + return struct.unpack("!H", self.get_bytes(2))[0] + + def get_uint32(self) -> int: + return struct.unpack("!I", self.get_bytes(4))[0] + + def get_uint48(self) -> int: + return int.from_bytes(self.get_bytes(6), "big") + + def get_struct(self, format: str) -> tuple: + return struct.unpack(format, self.get_bytes(struct.calcsize(format))) + + def seek(self, where: int) -> None: + # Note that seeking to the end is OK! (If you try to read + # after such a seek, you'll get an exception as expected.) + if where < 0 or where > self.end: + raise dns.exception.FormError + self.current = where + + @contextlib.contextmanager + def restrict_to(self, size: int) -> Iterator: + assert size >= 0 + if size > self.remaining(): + raise dns.exception.FormError + saved_end = self.end + try: + self.end = self.current + size + yield + # We make this check here and not in the finally as we + # don't want to raise if we're already raising for some + # other reason. + if self.current != self.end: + raise dns.exception.FormError + finally: + self.end = saved_end + + @contextlib.contextmanager + def restore_furthest(self) -> Iterator: + try: + yield None + finally: + self.current = self.furthest