]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Enum refactoring. 475/head
authorBrian Wellington <bwelling@xbill.org>
Tue, 19 May 2020 20:18:05 +0000 (13:18 -0700)
committerBrian Wellington <bwelling@xbill.org>
Tue, 19 May 2020 20:18:05 +0000 (13:18 -0700)
Consolidate the common methods related to enum classes.

dns/dnssec.py
dns/enum.py [new file with mode: 0644]
dns/opcode.py
dns/rcode.py
dns/rdataclass.py
dns/rdatatype.py

index c0050d2c5cff525ecc329d51adda0986ed1efc8f..dee290d094fadc2a2a0646288c5e535a4f9e56a9 100644 (file)
 
 """Common DNSSEC-related functions and constants."""
 
-import enum
 import hashlib
 import io
 import struct
 import time
 import base64
 
+import dns.enum
 import dns.exception
 import dns.name
 import dns.node
@@ -41,7 +41,7 @@ class ValidationFailure(dns.exception.DNSException):
     """The DNSSEC signature is invalid."""
 
 
-class Algorithm(enum.IntEnum):
+class Algorithm(dns.enum.IntEnum):
     RSAMD5 = 1
     DH = 2
     DSA = 3
@@ -60,6 +60,11 @@ class Algorithm(enum.IntEnum):
     PRIVATEDNS = 253
     PRIVATEOID = 254
 
+    @classmethod
+    def _maximum(cls):
+        return 255
+
+
 globals().update(Algorithm.__members__)
 
 
@@ -71,10 +76,7 @@ def algorithm_from_text(text):
     Returns an ``int``.
     """
 
-    try:
-        return Algorithm[text.upper()]
-    except KeyError:
-        return int(text)
+    return Algorithm.from_text(text)
 
 
 def algorithm_to_text(value):
@@ -85,10 +87,7 @@ def algorithm_to_text(value):
     Returns a ``str``, the name of a DNSSEC algorithm.
     """
 
-    try:
-        return Algorithm(value).name
-    except ValueError:
-        return str(value)
+    return Algorithm.to_text(value)
 
 
 def _to_rdata(record, origin):
@@ -118,13 +117,17 @@ def key_id(key):
         total += ((total >> 16) & 0xffff)
         return total & 0xffff
 
-class DSDigest(enum.IntEnum):
+class DSDigest(dns.enum.IntEnum):
     """DNSSEC Delgation Signer Digest Algorithm"""
 
     SHA1 = 1
     SHA256 = 2
     SHA384 = 4
 
+    @classmethod
+    def _maximum(cls):
+        return 255
+
 
 def make_ds(name, key, algorithm, origin=None):
     """Create a DS record for a DNSSEC key.
@@ -496,11 +499,14 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None):
     raise ValidationFailure("no RRSIGs validated")
 
 
-class NSEC3Hash(enum.IntEnum):
+class NSEC3Hash(dns.enum.IntEnum):
     """NSEC3 hash algorithm"""
 
     SHA1 = 1
 
+    @classmethod
+    def _maximum(cls):
+        return 255
 
 def nsec3_hash(domain, salt, iterations, algorithm):
     """
diff --git a/dns/enum.py b/dns/enum.py
new file mode 100644 (file)
index 0000000..62b4a44
--- /dev/null
@@ -0,0 +1,77 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import enum
+
+class IntEnum(enum.IntEnum):
+    @classmethod
+    def _check_value(cls, value):
+        max = cls._maximum()
+        if value < 0 or value > max:
+            name = cls._short_name()
+            raise ValueError(f"{name} must be between >= 0 and <= {max}")
+
+    @classmethod
+    def from_text(cls, text):
+        text = text.upper()
+        try:
+            return cls[text]
+        except KeyError:
+            pass
+        prefix = cls._prefix()
+        if text.startswith(prefix) and text[len(prefix):].isdigit():
+            value = int(text[len(prefix):])
+            cls._check_value(value)
+            try:
+                return cls(value)
+            except ValueError:
+                return value
+        raise cls._unknown_exception_class()
+
+    @classmethod
+    def to_text(cls, value):
+        cls._check_value(value)
+        try:
+            return cls(value).name
+        except ValueError:
+            return f"{cls._prefix()}{value}"
+
+    @classmethod
+    def to_enum(cls, value):
+        if isinstance(value, str):
+            return cls.from_text(value)
+        cls._check_value(value)
+        try:
+            return cls(value)
+        except ValueError:
+            return value
+
+    @classmethod
+    def _maximum(cls):
+        raise NotImplemented
+
+    @classmethod
+    def _short_name(cls):
+        return cls.__name__.lower()
+
+    @classmethod
+    def _prefix(cls):
+        return ''
+
+    @classmethod
+    def _unknown_exception_class(cls):
+        return ValueError
index d81d25588324c87dc7061d9ec0b17f3ac17aff66..509916e9e4ff8e8139c5560beadc519ecca6e433 100644 (file)
 
 """DNS Opcodes."""
 
-import enum
-
+import dns.enum
 import dns.exception
 
-class Opcode(enum.IntEnum):
+class Opcode(dns.enum.IntEnum):
     #: Query
     QUERY = 0
     #: Inverse Query (historical)
@@ -33,6 +32,14 @@ class Opcode(enum.IntEnum):
     #: Dynamic Update
     UPDATE = 5
 
+    @classmethod
+    def _maximum(cls):
+        return 15
+
+    @classmethod
+    def _unknown_exception_class(cls):
+        return UnknownOpcode
+
 globals().update(Opcode.__members__)
 
 
@@ -50,17 +57,7 @@ def from_text(text):
     Returns an ``int``.
     """
 
-    if text.isdigit():
-        value = int(text)
-        if value >= 0 and value <= 15:
-            try:
-                return Opcode(value)
-            except ValueError:
-                return value
-    try:
-        return Opcode[text.upper()]
-    except KeyError:
-        raise UnknownOpcode
+    return Opcode.from_text(text)
 
 
 def from_flags(flags):
@@ -96,10 +93,7 @@ def to_text(value):
     Returns a ``str``.
     """
 
-    try:
-        return Opcode(value).name
-    except ValueError:
-        return str(value)
+    return Opcode.to_text(value)
 
 
 def is_update(flags):
index 05a8f54a9887a7798ea5e1bb393308de9ce150ca..efefe737a9535de7df6111d3693c04271dfa8c1e 100644 (file)
 
 """DNS Result Codes."""
 
-import enum
-
+import dns.enum
 import dns.exception
 
-class Rcode(enum.IntEnum):
+class Rcode(dns.enum.IntEnum):
     #: No error
     NOERROR = 0
     #: Format error
@@ -47,6 +46,14 @@ class Rcode(enum.IntEnum):
     #: Bad EDNS version.
     BADVERS = 16
 
+    @classmethod
+    def _maximum(cls):
+        return 4095
+
+    @classmethod
+    def _unknown_exception_class(cls):
+        return UnknownRcode
+
 globals().update(Rcode.__members__)
 
 class UnknownRcode(dns.exception.DNSException):
@@ -63,17 +70,7 @@ def from_text(text):
     Returns an ``int``.
     """
 
-    if text.isdigit():
-        v = int(text)
-        if v >= 0 and v <= 4095:
-            try:
-                return Rcode(v)
-            except ValueError:
-                return v
-    try:
-        return Rcode[text.upper()]
-    except KeyError:
-        raise UnknownRcode
+    return Rcode.from_text(text)
 
 
 def from_flags(flags, ednsflags):
@@ -121,9 +118,4 @@ def to_text(value):
     Returns a ``str``.
     """
 
-    if value < 0 or value > 4095:
-        raise ValueError('rcode must be >= 0 and <= 4095')
-    try:
-        return Rcode(value).name
-    except ValueError:
-        return str(value)
+    return Rcode.to_text(value)
index 828812c711b77c58df861ffd0ada061504dd431d..e7c957be0a90594101a194d57460cb16c8b2ad01 100644 (file)
 
 """DNS Rdata Classes."""
 
-import enum
-import re
-
+import dns.enum
 import dns.exception
 
-class RdataClass(enum.IntEnum):
+class RdataClass(dns.enum.IntEnum):
     """DNS Rdata Class"""
     RESERVED0 = 0
     IN = 1
@@ -34,12 +32,26 @@ class RdataClass(enum.IntEnum):
     NONE = 254
     ANY = 255
 
+    @classmethod
+    def _maximum(cls):
+        return 65535
+
+    @classmethod
+    def _short_name(cls):
+        return "class"
+
+    @classmethod
+    def _prefix(cls):
+        return "CLASS"
+
+    @classmethod
+    def _unknown_exception_class(cls):
+        return UnknownRdataclass
+
 globals().update(RdataClass.__members__)
 
 _metaclasses = {NONE, ANY}
 
-_unknown_class_pattern = re.compile('CLASS([0-9]+)$', re.I)
-
 
 class UnknownRdataclass(dns.exception.DNSException):
     """A DNS class is unknown."""
@@ -60,16 +72,7 @@ def from_text(text):
     Returns an ``int``.
     """
 
-    try:
-        value = RdataClass[text.upper()]
-    except KeyError:
-        match = _unknown_class_pattern.match(text)
-        if match is None:
-            raise UnknownRdataclass
-        value = int(match.group(1))
-        if value < 0 or value > 65535:
-            raise ValueError("class must be between >= 0 and <= 65535")
-    return value
+    return RdataClass.from_text(text)
 
 
 def to_text(value):
@@ -83,12 +86,7 @@ def to_text(value):
     Returns a ``str``.
     """
 
-    if value < 0 or value > 65535:
-        raise ValueError("class must be between >= 0 and <= 65535")
-    try:
-        return RdataClass(value).name
-    except ValueError:
-        return f'CLASS{value}'
+    return RdataClass.to_text(value)
 
 
 def to_enum(value):
@@ -99,14 +97,7 @@ def to_enum(value):
     Returns an ``int``.
     """
 
-    if isinstance(value, str):
-        return from_text(value)
-    if value < 0 or value > 65535:
-        raise ValueError("class must be between >= 0 and <= 65535")
-    try:
-        return RdataClass(value)
-    except ValueError:
-        return value
+    return RdataClass.to_enum(value)
 
 
 def is_metaclass(rdclass):
index 7bfa76459067c127a158b7f0c72e90c519168820..ba2e261509cc908806c83c8f8e405ae5698398b6 100644 (file)
 
 """DNS Rdata Types."""
 
-import enum
-import re
-
+import dns.enum
 import dns.exception
 
-class RdataType(enum.IntEnum):
+class RdataType(dns.enum.IntEnum):
     """DNS Rdata Type"""
     TYPE0 = 0
     NONE = 0
@@ -97,17 +95,31 @@ class RdataType(enum.IntEnum):
     TA = 32768
     DLV = 32769
 
+    @classmethod
+    def _maximum(cls):
+        return 65535
+
+    @classmethod
+    def _short_name(cls):
+        return "type"
+
+    @classmethod
+    def _prefix(cls):
+        return "TYPE"
+
+    @classmethod
+    def _unknown_exception_class(cls):
+        return UnknownRdatatype
+
 _registered_by_text = {}
 _registered_by_value = {}
 
-globals().update(RdataType.__members__.items())
+globals().update(RdataType.__members__)
 
 _metatypes = {OPT}
 
 _singletons = {SOA, NXT, DNAME, NSEC, CNAME}
 
-_unknown_type_pattern = re.compile('TYPE([0-9]+)$', re.I)
-
 
 class UnknownRdatatype(dns.exception.DNSException):
     """DNS resource record type is unknown."""
@@ -130,18 +142,12 @@ def from_text(text):
 
     text = text.upper().replace('-', '_')
     try:
-        return RdataType[text]
-    except KeyError:
-        pass
-    value = _registered_by_text.get(text.upper())
-    if value is None:
-        match = _unknown_type_pattern.match(text)
-        if match is None:
-            raise UnknownRdatatype
-        value = int(match.group(1))
-        if value < 0 or value > 65535:
-            raise ValueError("type must be between >= 0 and <= 65535")
-    return value
+        return RdataType.from_text(text)
+    except UnknownRdatatype:
+        registered_type = _registered_by_text.get(text)
+        if registered_type:
+            return registered_type
+        raise
 
 
 def to_text(value):
@@ -155,15 +161,12 @@ def to_text(value):
     Returns a ``str``.
     """
 
-    if value < 0 or value > 65535:
-        raise ValueError("type must be between >= 0 and <= 65535")
-    try:
-        return RdataType(value).name.replace('_', '-')
-    except ValueError:
-        text = _registered_by_value.get(value)
-        if text:
-            return text
-        return f'TYPE{value}'
+    text = RdataType.to_text(value)
+    if text.startswith("TYPE"):
+        registered_text = _registered_by_value.get(value)
+        if registered_text:
+            text = registered_text
+    return text.replace('_', '-')
 
 
 def to_enum(value):
@@ -174,14 +177,7 @@ def to_enum(value):
     Returns an ``int``.
     """
 
-    if isinstance(value, str):
-        return from_text(value)
-    if value < 0 or value > 65535:
-        raise ValueError("type must be between >= 0 and <= 65535")
-    try:
-        return RdataType(value)
-    except ValueError:
-        return value
+    return RdataType.to_enum(value)
 
 
 def is_metatype(rdtype):