]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
break name/wire circular imports; name type tweaks
authorBob Halley <halley@dnspython.org>
Wed, 24 Dec 2025 01:16:00 +0000 (17:16 -0800)
committerBob Halley <halley@dnspython.org>
Wed, 24 Dec 2025 01:16:00 +0000 (17:16 -0800)
dns/__init__.py
dns/name.py
dns/wire.py
dns/wirebase.py [new file with mode: 0644]

index d30fd742a227e80f984eac729d441d265612867e..df8edbda8ac307a8ccbf9d16a12db1a12539982b 100644 (file)
@@ -63,6 +63,7 @@ __all__ = [
     "version",
     "versioned",
     "wire",
+    "wirebase",
     "xfr",
     "zone",
     "zonetypes",
index 53f08c7e90511144de1a59df70c1f6548dbae1d8..f8bc0f15cca363d8cd6a5f9f00012105c09d103b 100644 (file)
@@ -28,7 +28,7 @@ import dns._features
 import dns.enum
 import dns.exception
 import dns.immutable
-import dns.wire
+import dns.wirebase
 
 # Dnspython will never access idna if the import fails, but pyright can't figure
 # that out, so...
@@ -356,9 +356,8 @@ def _maybe_convert_to_binary(label: bytes | str) -> bytes:
 
     if isinstance(label, bytes):
         return label
-    if isinstance(label, str):
+    else:
         return label.encode()
-    raise ValueError  # pragma: no cover
 
 
 @dns.immutable.immutable
@@ -743,7 +742,7 @@ class Name:
 
         return len(self.labels)
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: Any) -> Any:
         return self.labels[index]
 
     def __add__(self, other):
@@ -798,7 +797,7 @@ class Name:
         Returns a ``dns.name.Name``.
         """
 
-        if origin is not None and self.is_subdomain(origin):
+        if self.is_subdomain(origin):
             return Name(self[: -len(origin)])
         else:
             return self
@@ -1070,13 +1069,10 @@ def from_text(
     return Name(labels)
 
 
-# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other.
-
-
-def from_wire_parser(parser: "dns.wire.Parser") -> Name:
+def from_wire_parser(parser: dns.wirebase.Parser) -> Name:
     """Convert possibly compressed wire format into a Name.
 
-    *parser* is a dns.wire.Parser.
+    *parser* is a dns.wirebase.Parser.
 
     Raises ``dns.name.BadPointer`` if a compression pointer did not
     point backwards in the message.
@@ -1125,9 +1121,7 @@ def from_wire(message: bytes, current: int) -> tuple[Name, int]:
     which were consumed reading it.
     """
 
-    if not isinstance(message, bytes):
-        raise ValueError("input to from_wire() must be a byte string")
-    parser = dns.wire.Parser(message, current)
+    parser = dns.wirebase.Parser(message, current)
     name = from_wire_parser(parser)
     return (name, parser.current - current)
 
index ec06b196b6f6ef3a52555757e5adbab5b1a9547a..cff1abac09392e89e99533eed1642978b8bbab43 100644 (file)
@@ -1,98 +1,13 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
-import contextlib
-import struct
-from collections.abc import Iterator
-
-import dns.exception
 import dns.name
+import dns.wirebase
 
 
-class Parser:
-    """Helper class for parsing DNS wire format."""
-
-    def __init__(self, wire: bytes, current: int = 0):
-        """Initialize a Parser
-
-        *wire*, a ``bytes`` contains the data to be parsed, and possibly other data.
-        Typically it is the whole message or a slice of it.
-
-        *current*, an `int`, the offset within *wire* where parsing should begin.
-        """
-        self.wire = wire
-        self.current = 0
-        self.end = len(self.wire)
-        if current:
-            self.seek(current)
-        self.furthest = current
-
-    def remaining(self) -> int:
-        return self.end - self.current
-
-    def get_bytes(self, size: int) -> bytes:
-        assert size >= 0
-        if size > self.remaining():
-            raise dns.exception.FormError
-        output = self.wire[self.current : self.current + size]
-        self.current += size
-        self.furthest = max(self.furthest, self.current)
-        return output
-
-    def get_counted_bytes(self, length_size: int = 1) -> bytes:
-        length = int.from_bytes(self.get_bytes(length_size), "big")
-        return self.get_bytes(length)
-
-    def get_remaining(self) -> bytes:
-        return self.get_bytes(self.remaining())
+class Parser(dns.wirebase.Parser):
 
-    def get_uint8(self) -> int:
-        return struct.unpack("!B", self.get_bytes(1))[0]
-
-    def get_uint16(self) -> int:
-        return struct.unpack("!H", self.get_bytes(2))[0]
-
-    def get_uint32(self) -> int:
-        return struct.unpack("!I", self.get_bytes(4))[0]
-
-    def get_uint48(self) -> int:
-        return int.from_bytes(self.get_bytes(6), "big")
-
-    def get_struct(self, format: str) -> tuple:
-        return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
-
-    def get_name(self, origin: "dns.name.Name | None" = None) -> "dns.name.Name":
+    def get_name(self, origin: dns.name.Name | None = None) -> dns.name.Name:
         name = dns.name.from_wire_parser(self)
         if origin:
             name = name.relativize(origin)
         return name
-
-    def seek(self, where: int) -> None:
-        # Note that seeking to the end is OK!  (If you try to read
-        # after such a seek, you'll get an exception as expected.)
-        if where < 0 or where > self.end:
-            raise dns.exception.FormError
-        self.current = where
-
-    @contextlib.contextmanager
-    def restrict_to(self, size: int) -> Iterator:
-        assert size >= 0
-        if size > self.remaining():
-            raise dns.exception.FormError
-        saved_end = self.end
-        try:
-            self.end = self.current + size
-            yield
-            # We make this check here and not in the finally as we
-            # don't want to raise if we're already raising for some
-            # other reason.
-            if self.current != self.end:
-                raise dns.exception.FormError
-        finally:
-            self.end = saved_end
-
-    @contextlib.contextmanager
-    def restore_furthest(self) -> Iterator:
-        try:
-            yield None
-        finally:
-            self.current = self.furthest
diff --git a/dns/wirebase.py b/dns/wirebase.py
new file mode 100644 (file)
index 0000000..299558c
--- /dev/null
@@ -0,0 +1,93 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# We have wirebase and wire to avoid circularity between name.py and wire.py
+
+import contextlib
+import struct
+from collections.abc import Iterator
+
+import dns.exception
+
+
+class Parser:
+    """Helper class for parsing DNS wire format."""
+
+    def __init__(self, wire: bytes, current: int = 0):
+        """Initialize a Parser
+
+        *wire*, a ``bytes`` contains the data to be parsed, and possibly other data.
+        Typically it is the whole message or a slice of it.
+
+        *current*, an `int`, the offset within *wire* where parsing should begin.
+        """
+        self.wire = wire
+        self.current = 0
+        self.end = len(self.wire)
+        if current:
+            self.seek(current)
+        self.furthest = current
+
+    def remaining(self) -> int:
+        return self.end - self.current
+
+    def get_bytes(self, size: int) -> bytes:
+        assert size >= 0
+        if size > self.remaining():
+            raise dns.exception.FormError
+        output = self.wire[self.current : self.current + size]
+        self.current += size
+        self.furthest = max(self.furthest, self.current)
+        return output
+
+    def get_counted_bytes(self, length_size: int = 1) -> bytes:
+        length = int.from_bytes(self.get_bytes(length_size), "big")
+        return self.get_bytes(length)
+
+    def get_remaining(self) -> bytes:
+        return self.get_bytes(self.remaining())
+
+    def get_uint8(self) -> int:
+        return struct.unpack("!B", self.get_bytes(1))[0]
+
+    def get_uint16(self) -> int:
+        return struct.unpack("!H", self.get_bytes(2))[0]
+
+    def get_uint32(self) -> int:
+        return struct.unpack("!I", self.get_bytes(4))[0]
+
+    def get_uint48(self) -> int:
+        return int.from_bytes(self.get_bytes(6), "big")
+
+    def get_struct(self, format: str) -> tuple:
+        return struct.unpack(format, self.get_bytes(struct.calcsize(format)))
+
+    def seek(self, where: int) -> None:
+        # Note that seeking to the end is OK!  (If you try to read
+        # after such a seek, you'll get an exception as expected.)
+        if where < 0 or where > self.end:
+            raise dns.exception.FormError
+        self.current = where
+
+    @contextlib.contextmanager
+    def restrict_to(self, size: int) -> Iterator:
+        assert size >= 0
+        if size > self.remaining():
+            raise dns.exception.FormError
+        saved_end = self.end
+        try:
+            self.end = self.current + size
+            yield
+            # We make this check here and not in the finally as we
+            # don't want to raise if we're already raising for some
+            # other reason.
+            if self.current != self.end:
+                raise dns.exception.FormError
+        finally:
+            self.end = saved_end
+
+    @contextlib.contextmanager
+    def restore_furthest(self) -> Iterator:
+        try:
+            yield None
+        finally:
+            self.current = self.furthest