]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
add Bitmap.from_rdtypes() (#906)
authorJakob Schlyter <jakob@kirei.se>
Sat, 11 Mar 2023 02:13:51 +0000 (03:13 +0100)
committerGitHub <noreply@github.com>
Sat, 11 Mar 2023 02:13:51 +0000 (18:13 -0800)
* add `Bitmap.from_rdtypes()` and add missing typing

* more typing

* add missing import

* add more typing

* fix tok type

dns/rdtypes/util.py

index 74596f05287bc3f6ab2bfc20daff7fe2a4b59b06..46c98cf76f5e2fe4acbe8c197e4c570f1167fa82 100644 (file)
@@ -18,6 +18,7 @@
 import collections
 import random
 import struct
+from typing import Any, List
 
 import dns.exception
 import dns.ipv4
@@ -132,7 +133,7 @@ class Bitmap:
             if len(bitmap) == 0 or len(bitmap) > 32:
                 raise ValueError(f"bad {self.type_name} octets")
 
-    def to_text(self):
+    def to_text(self) -> str:
         text = ""
         for (window, bitmap) in self.windows:
             bits = []
@@ -145,14 +146,18 @@ class Bitmap:
         return text
 
     @classmethod
-    def from_text(cls, tok):
+    def from_text(cls, tok: "dns.tokenizer.Tokenizer") -> "Bitmap":
         rdtypes = []
         for token in tok.get_remaining():
             rdtype = dns.rdatatype.from_text(token.unescape().value)
             if rdtype == 0:
                 raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
             rdtypes.append(rdtype)
-        rdtypes.sort()
+        return cls.from_rdtypes(rdtypes)
+
+    @classmethod
+    def from_rdtypes(cls, rdtypes: List[dns.rdatatype.RdataType]) -> "Bitmap":
+        rdtypes = sorted(rdtypes)
         window = 0
         octets = 0
         prior_rdtype = 0
@@ -177,13 +182,13 @@ class Bitmap:
             windows.append((window, bytes(bitmap[0:octets])))
         return cls(windows)
 
-    def to_wire(self, file):
+    def to_wire(self, file: Any) -> None:
         for (window, bitmap) in self.windows:
             file.write(struct.pack("!BB", window, len(bitmap)))
             file.write(bitmap)
 
     @classmethod
-    def from_wire_parser(cls, parser):
+    def from_wire_parser(cls, parser: "dns.wire.Parser") -> "Bitmap":
         windows = []
         while parser.remaining() > 0:
             window = parser.get_uint8()