]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Enum typing (#923)
authorBrian Wellington <bwelling@xbill.org>
Thu, 6 Apr 2023 13:03:07 +0000 (06:03 -0700)
committerGitHub <noreply@github.com>
Thu, 6 Apr 2023 13:03:07 +0000 (06:03 -0700)
* IntEnum improvements.

This changes make() to always return an instance of the subclass,
creating one on the fly if the value is not known, and updates the typ
registration code to deal with this.  It also adds typing annotations to
make().

* Add missing int check.

Some older versions of python weren't rejecting non-int values.

* Fix int check.

Raise TypeError for non-int, not ValueError, to make tests happy.

* Annotate to_text/from_text.

* Remove many the_ prefixed variables.

These were needed in the past to work around typing issues.

dns/dnssec.py
dns/edns.py
dns/enum.py
dns/message.py
dns/rdata.py
dns/rdataset.py
dns/resolver.py
dns/rrset.py
dns/update.py
dns/zone.py
dns/zonefile.py

index 3caa22b07d8d02fed9180b5de36e7367de1c3baa..c219965d28b6c9b663cf16bdc36c017dc89d5fc8 100644 (file)
@@ -948,9 +948,9 @@ def _make_dnskey(
         else:
             raise ValueError("unsupported ECDSA curve")
 
-    the_algorithm = Algorithm.make(algorithm)
+    algorithm = Algorithm.make(algorithm)
 
-    _ensure_algorithm_key_combination(the_algorithm, public_key)
+    _ensure_algorithm_key_combination(algorithm, public_key)
 
     if isinstance(public_key, rsa.RSAPublicKey):
         key_bytes = encode_rsa_public_key(public_key)
@@ -974,7 +974,7 @@ def _make_dnskey(
         rdtype=dns.rdatatype.DNSKEY,
         flags=flags,
         protocol=protocol,
-        algorithm=the_algorithm,
+        algorithm=algorithm,
         key=key_bytes,
     )
 
index 64436cde44108018b81ddcbf501f14711e65e722..40899ee83b39b7c3009b79e361d514db9865dc18 100644 (file)
@@ -380,7 +380,7 @@ class EDEOption(Option):  # lgtm[py/missing-equals]
     def from_wire_parser(
         cls, otype: Union[OptionType, str], parser: "dns.wire.Parser"
     ) -> Option:
-        the_code = EDECode.make(parser.get_uint16())
+        code = EDECode.make(parser.get_uint16())
         text = parser.get_remaining()
 
         if text:
@@ -390,7 +390,7 @@ class EDEOption(Option):  # lgtm[py/missing-equals]
         else:
             btext = None
 
-        return cls(the_code, btext)
+        return cls(code, btext)
 
 
 _type_to_class: Dict[OptionType, Any] = {
@@ -424,8 +424,8 @@ def option_from_wire_parser(
 
     Returns an instance of a subclass of ``dns.edns.Option``.
     """
-    the_otype = OptionType.make(otype)
-    cls = get_option_class(the_otype)
+    otype = OptionType.make(otype)
+    cls = get_option_class(otype)
     return cls.from_wire_parser(otype, parser)
 
 
index b5a4aed8efee94937436e2b8ef7dd7a9f62d5d07..968363aa8305d9d18b90b8505a37561182535b7a 100644 (file)
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+from typing import Type, TypeVar, Union
+
 import enum
 
+TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
+
 
 class IntEnum(enum.IntEnum):
+    @classmethod
+    def _missing_(cls, value):
+        cls._check_value(value)
+        val = int.__new__(cls, value)
+        val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
+        val._value_ = value
+        return val
+
     @classmethod
     def _check_value(cls, value):
         max = cls._maximum()
+        if not isinstance(value, int):
+            raise TypeError
         if value < 0 or value > max:
             name = cls._short_name()
-            raise ValueError(f"{name} must be between >= 0 and <= {max}")
+            raise ValueError(f"{name} must be an int between >= 0 and <= {max}")
 
     @classmethod
-    def from_text(cls, text):
+    def from_text(cls : Type[TIntEnum], text: str) -> TIntEnum:
         text = text.upper()
         try:
             return cls[text]
@@ -47,7 +61,7 @@ class IntEnum(enum.IntEnum):
         raise cls._unknown_exception_class()
 
     @classmethod
-    def to_text(cls, value):
+    def to_text(cls : Type[TIntEnum], value : int) -> str:
         cls._check_value(value)
         try:
             text = cls(value).name
@@ -59,7 +73,7 @@ class IntEnum(enum.IntEnum):
         return text
 
     @classmethod
-    def make(cls, value):
+    def make(cls: Type[TIntEnum], value: Union[int, str]) -> TIntEnum:
         """Convert text or a value into an enumerated type, if possible.
 
         *value*, the ``int`` or ``str`` to convert.
@@ -76,10 +90,7 @@ class IntEnum(enum.IntEnum):
         if isinstance(value, str):
             return cls.from_text(value)
         cls._check_value(value)
-        try:
-            return cls(value)
-        except ValueError:
-            return value
+        return cls(value)
 
     @classmethod
     def _maximum(cls):
index 3a6f4273f6a2c7e4693c3a91c1ee9fd06f509fc7..2ccdc2b19b14bbceeef47092b04dea0ac90ac47a 100644 (file)
@@ -1730,13 +1730,11 @@ def make_query(
 
     if isinstance(qname, str):
         qname = dns.name.from_text(qname, idna_codec=idna_codec)
-    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-    the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    rdtype = dns.rdatatype.RdataType.make(rdtype)
+    rdclass = dns.rdataclass.RdataClass.make(rdclass)
     m = QueryMessage(id=id)
     m.flags = dns.flags.Flag(flags)
-    m.find_rrset(
-        m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True
-    )
+    m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True)
     # only pass keywords on to use_edns if they have been set to a
     # non-None value.  Setting a field will turn EDNS on if it hasn't
     # been configured.
index d166b8abbc648025f4c81b59b448e161efe2a1a9..66c07eeca16914b3cd4a83a070cfeb7b02771b6d 100644 (file)
@@ -880,16 +880,19 @@ def register_type(
     it applies to all classes.
     """
 
-    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-    existing_cls = get_rdata_class(rdclass, the_rdtype)
-    if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
-        raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
+    rdtype = dns.rdatatype.RdataType.make(rdtype)
+    existing_cls = get_rdata_class(rdclass, rdtype)
+    if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
+        raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
     try:
-        if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
-            raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
+        if (
+            rdtype in dns.rdatatype.RdataType
+            and dns.rdatatype.RdataType(rdtype).name != rdtype_text
+        ):
+            raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
     except ValueError:
         pass
-    _rdata_classes[(rdclass, the_rdtype)] = getattr(
+    _rdata_classes[(rdclass, rdtype)] = getattr(
         implementation, rdtype_text.replace("-", "_")
     )
-    dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)
+    dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
index c0ede425098db8f390506f76ead2e00e31507172..b562d1f8a7a8738a99cbfc309efe5982e21f16db 100644 (file)
@@ -471,9 +471,9 @@ def from_text_list(
     Returns a ``dns.rdataset.Rdataset`` object.
     """
 
-    the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
-    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-    r = Rdataset(the_rdclass, the_rdtype)
+    rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    rdtype = dns.rdatatype.RdataType.make(rdtype)
+    r = Rdataset(rdclass, rdtype)
     r.update_ttl(ttl)
     for t in text_rdatas:
         rd = dns.rdata.from_text(
index cd041d9250f05d4a1aa067640dbfb4facd2ae4d2..61d00523b4bf0aa46e585bb367d808ff4d98c576 100644 (file)
@@ -647,17 +647,17 @@ class _Resolution:
     ) -> None:
         if isinstance(qname, str):
             qname = dns.name.from_text(qname, None)
-        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if dns.rdatatype.is_metatype(the_rdtype):
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
+        if dns.rdatatype.is_metatype(rdtype):
             raise NoMetaqueries
-        the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
-        if dns.rdataclass.is_metaclass(the_rdclass):
+        rdclass = dns.rdataclass.RdataClass.make(rdclass)
+        if dns.rdataclass.is_metaclass(rdclass):
             raise NoMetaqueries
         self.resolver = resolver
         self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
         self.qnames = self.qnames_to_try[:]
-        self.rdtype = the_rdtype
-        self.rdclass = the_rdclass
+        self.rdtype = rdtype
+        self.rdclass = rdclass
         self.tcp = tcp
         self.raise_on_no_answer = raise_on_no_answer
         self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
index 3f22a90c1a335710a537081441986219de72477b..0519051e14d2a6be8e361c70f75affd2ab259ff8 100644 (file)
@@ -214,9 +214,9 @@ def from_text_list(
 
     if isinstance(name, str):
         name = dns.name.from_text(name, None, idna_codec=idna_codec)
-    the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
-    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-    r = RRset(name, the_rdclass, the_rdtype)
+    rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    rdtype = dns.rdatatype.RdataType.make(rdtype)
+    r = RRset(name, rdclass, rdtype)
     r.update_ttl(ttl)
     for t in text_rdatas:
         rd = dns.rdata.from_text(
index b10f6ace5ec8f7f66202aecca59ad94a61d5f628..2219ec563b7267574c9cb00ef2bb5a8f15bb50b7 100644 (file)
@@ -335,12 +335,12 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                 True,
             )
         else:
-            the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+            rdtype = dns.rdatatype.RdataType.make(rdtype)
             self.find_rrset(
                 self.prerequisite,
                 name,
                 dns.rdataclass.NONE,
-                the_rdtype,
+                rdtype,
                 dns.rdatatype.NONE,
                 None,
                 True,
index 35724d7783431405a431ca0616ba4616196993dc..647538ce04271d4cd61fe6db5633518cdeb0d945 100644 (file)
@@ -321,11 +321,11 @@ class Zone(dns.transaction.TransactionManager):
         Returns a ``dns.rdataset.Rdataset``.
         """
 
-        the_name = self._validate_name(name)
-        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-        the_covers = dns.rdatatype.RdataType.make(covers)
-        node = self.find_node(the_name, create)
-        return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create)
+        name = self._validate_name(name)
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
+        covers = dns.rdatatype.RdataType.make(covers)
+        node = self.find_node(name, create)
+        return node.find_rdataset(self.rdclass, rdtype, covers, create)
 
     def get_rdataset(
         self,
@@ -404,14 +404,14 @@ class Zone(dns.transaction.TransactionManager):
         types were aggregated into a single RRSIG rdataset.
         """
 
-        the_name = self._validate_name(name)
-        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-        the_covers = dns.rdatatype.RdataType.make(covers)
-        node = self.get_node(the_name)
+        name = self._validate_name(name)
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
+        covers = dns.rdatatype.RdataType.make(covers)
+        node = self.get_node(name)
         if node is not None:
-            node.delete_rdataset(self.rdclass, the_rdtype, the_covers)
+            node.delete_rdataset(self.rdclass, rdtype, covers)
             if len(node) == 0:
-                self.delete_node(the_name)
+                self.delete_node(name)
 
     def replace_rdataset(
         self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset
@@ -484,10 +484,10 @@ class Zone(dns.transaction.TransactionManager):
         """
 
         vname = self._validate_name(name)
-        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
-        the_covers = dns.rdatatype.RdataType.make(covers)
-        rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers)
-        rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers)
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
+        covers = dns.rdatatype.RdataType.make(covers)
+        rdataset = self.nodes[vname].find_rdataset(self.rdclass, rdtype, covers)
+        rrset = dns.rrset.RRset(vname, self.rdclass, rdtype, covers)
         rrset.update(rdataset)
         return rrset
 
index fad78c3e6a1dc88add01e9588fdb45c62de871a6..48bedadb16344d35fe01efdcdbb43be5760adca6 100644 (file)
@@ -710,26 +710,26 @@ def read_rrsets(
     if isinstance(default_ttl, str):
         default_ttl = dns.ttl.from_text(default_ttl)
     if rdclass is not None:
-        the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+        rdclass = dns.rdataclass.RdataClass.make(rdclass)
     else:
-        the_rdclass = None
-    the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+        rdclass = None
+    default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
     if rdtype is not None:
-        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
     else:
-        the_rdtype = None
+        rdtype = None
     manager = RRSetsReaderManager(origin, relativize, default_rdclass)
     with manager.writer(True) as txn:
         tok = dns.tokenizer.Tokenizer(text, "<input>", idna_codec=idna_codec)
         reader = Reader(
             tok,
-            the_default_rdclass,
+            default_rdclass,
             txn,
             allow_directives=False,
             force_name=name,
             force_ttl=ttl,
-            force_rdclass=the_rdclass,
-            force_rdtype=the_rdtype,
+            force_rdclass=rdclass,
+            force_rdtype=rdtype,
             default_ttl=default_ttl,
         )
         reader.read()