From: Brian Wellington Date: Tue, 19 May 2020 20:18:05 +0000 (-0700) Subject: Enum refactoring. X-Git-Tag: v2.0.0rc1~188^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F475%2Fhead;p=thirdparty%2Fdnspython.git Enum refactoring. Consolidate the common methods related to enum classes. --- diff --git a/dns/dnssec.py b/dns/dnssec.py index c0050d2c..dee290d0 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -17,13 +17,13 @@ """Common DNSSEC-related functions and constants.""" -import enum import hashlib import io import struct import time import base64 +import dns.enum import dns.exception import dns.name import dns.node @@ -41,7 +41,7 @@ class ValidationFailure(dns.exception.DNSException): """The DNSSEC signature is invalid.""" -class Algorithm(enum.IntEnum): +class Algorithm(dns.enum.IntEnum): RSAMD5 = 1 DH = 2 DSA = 3 @@ -60,6 +60,11 @@ class Algorithm(enum.IntEnum): PRIVATEDNS = 253 PRIVATEOID = 254 + @classmethod + def _maximum(cls): + return 255 + + globals().update(Algorithm.__members__) @@ -71,10 +76,7 @@ def algorithm_from_text(text): Returns an ``int``. """ - try: - return Algorithm[text.upper()] - except KeyError: - return int(text) + return Algorithm.from_text(text) def algorithm_to_text(value): @@ -85,10 +87,7 @@ def algorithm_to_text(value): Returns a ``str``, the name of a DNSSEC algorithm. """ - try: - return Algorithm(value).name - except ValueError: - return str(value) + return Algorithm.to_text(value) def _to_rdata(record, origin): @@ -118,13 +117,17 @@ def key_id(key): total += ((total >> 16) & 0xffff) return total & 0xffff -class DSDigest(enum.IntEnum): +class DSDigest(dns.enum.IntEnum): """DNSSEC Delgation Signer Digest Algorithm""" SHA1 = 1 SHA256 = 2 SHA384 = 4 + @classmethod + def _maximum(cls): + return 255 + def make_ds(name, key, algorithm, origin=None): """Create a DS record for a DNSSEC key. @@ -496,11 +499,14 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): raise ValidationFailure("no RRSIGs validated") -class NSEC3Hash(enum.IntEnum): +class NSEC3Hash(dns.enum.IntEnum): """NSEC3 hash algorithm""" SHA1 = 1 + @classmethod + def _maximum(cls): + return 255 def nsec3_hash(domain, salt, iterations, algorithm): """ diff --git a/dns/enum.py b/dns/enum.py new file mode 100644 index 00000000..62b4a446 --- /dev/null +++ b/dns/enum.py @@ -0,0 +1,77 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import enum + +class IntEnum(enum.IntEnum): + @classmethod + def _check_value(cls, value): + max = cls._maximum() + if value < 0 or value > max: + name = cls._short_name() + raise ValueError(f"{name} must be between >= 0 and <= {max}") + + @classmethod + def from_text(cls, text): + text = text.upper() + try: + return cls[text] + except KeyError: + pass + prefix = cls._prefix() + if text.startswith(prefix) and text[len(prefix):].isdigit(): + value = int(text[len(prefix):]) + cls._check_value(value) + try: + return cls(value) + except ValueError: + return value + raise cls._unknown_exception_class() + + @classmethod + def to_text(cls, value): + cls._check_value(value) + try: + return cls(value).name + except ValueError: + return f"{cls._prefix()}{value}" + + @classmethod + def to_enum(cls, value): + if isinstance(value, str): + return cls.from_text(value) + cls._check_value(value) + try: + return cls(value) + except ValueError: + return value + + @classmethod + def _maximum(cls): + raise NotImplemented + + @classmethod + def _short_name(cls): + return cls.__name__.lower() + + @classmethod + def _prefix(cls): + return '' + + @classmethod + def _unknown_exception_class(cls): + return ValueError diff --git a/dns/opcode.py b/dns/opcode.py index d81d2558..509916e9 100644 --- a/dns/opcode.py +++ b/dns/opcode.py @@ -17,11 +17,10 @@ """DNS Opcodes.""" -import enum - +import dns.enum import dns.exception -class Opcode(enum.IntEnum): +class Opcode(dns.enum.IntEnum): #: Query QUERY = 0 #: Inverse Query (historical) @@ -33,6 +32,14 @@ class Opcode(enum.IntEnum): #: Dynamic Update UPDATE = 5 + @classmethod + def _maximum(cls): + return 15 + + @classmethod + def _unknown_exception_class(cls): + return UnknownOpcode + globals().update(Opcode.__members__) @@ -50,17 +57,7 @@ def from_text(text): Returns an ``int``. """ - if text.isdigit(): - value = int(text) - if value >= 0 and value <= 15: - try: - return Opcode(value) - except ValueError: - return value - try: - return Opcode[text.upper()] - except KeyError: - raise UnknownOpcode + return Opcode.from_text(text) def from_flags(flags): @@ -96,10 +93,7 @@ def to_text(value): Returns a ``str``. """ - try: - return Opcode(value).name - except ValueError: - return str(value) + return Opcode.to_text(value) def is_update(flags): diff --git a/dns/rcode.py b/dns/rcode.py index 05a8f54a..efefe737 100644 --- a/dns/rcode.py +++ b/dns/rcode.py @@ -17,11 +17,10 @@ """DNS Result Codes.""" -import enum - +import dns.enum import dns.exception -class Rcode(enum.IntEnum): +class Rcode(dns.enum.IntEnum): #: No error NOERROR = 0 #: Format error @@ -47,6 +46,14 @@ class Rcode(enum.IntEnum): #: Bad EDNS version. BADVERS = 16 + @classmethod + def _maximum(cls): + return 4095 + + @classmethod + def _unknown_exception_class(cls): + return UnknownRcode + globals().update(Rcode.__members__) class UnknownRcode(dns.exception.DNSException): @@ -63,17 +70,7 @@ def from_text(text): Returns an ``int``. """ - if text.isdigit(): - v = int(text) - if v >= 0 and v <= 4095: - try: - return Rcode(v) - except ValueError: - return v - try: - return Rcode[text.upper()] - except KeyError: - raise UnknownRcode + return Rcode.from_text(text) def from_flags(flags, ednsflags): @@ -121,9 +118,4 @@ def to_text(value): Returns a ``str``. """ - if value < 0 or value > 4095: - raise ValueError('rcode must be >= 0 and <= 4095') - try: - return Rcode(value).name - except ValueError: - return str(value) + return Rcode.to_text(value) diff --git a/dns/rdataclass.py b/dns/rdataclass.py index 828812c7..e7c957be 100644 --- a/dns/rdataclass.py +++ b/dns/rdataclass.py @@ -17,12 +17,10 @@ """DNS Rdata Classes.""" -import enum -import re - +import dns.enum import dns.exception -class RdataClass(enum.IntEnum): +class RdataClass(dns.enum.IntEnum): """DNS Rdata Class""" RESERVED0 = 0 IN = 1 @@ -34,12 +32,26 @@ class RdataClass(enum.IntEnum): NONE = 254 ANY = 255 + @classmethod + def _maximum(cls): + return 65535 + + @classmethod + def _short_name(cls): + return "class" + + @classmethod + def _prefix(cls): + return "CLASS" + + @classmethod + def _unknown_exception_class(cls): + return UnknownRdataclass + globals().update(RdataClass.__members__) _metaclasses = {NONE, ANY} -_unknown_class_pattern = re.compile('CLASS([0-9]+)$', re.I) - class UnknownRdataclass(dns.exception.DNSException): """A DNS class is unknown.""" @@ -60,16 +72,7 @@ def from_text(text): Returns an ``int``. """ - try: - value = RdataClass[text.upper()] - except KeyError: - match = _unknown_class_pattern.match(text) - if match is None: - raise UnknownRdataclass - value = int(match.group(1)) - if value < 0 or value > 65535: - raise ValueError("class must be between >= 0 and <= 65535") - return value + return RdataClass.from_text(text) def to_text(value): @@ -83,12 +86,7 @@ def to_text(value): Returns a ``str``. """ - if value < 0 or value > 65535: - raise ValueError("class must be between >= 0 and <= 65535") - try: - return RdataClass(value).name - except ValueError: - return f'CLASS{value}' + return RdataClass.to_text(value) def to_enum(value): @@ -99,14 +97,7 @@ def to_enum(value): Returns an ``int``. """ - if isinstance(value, str): - return from_text(value) - if value < 0 or value > 65535: - raise ValueError("class must be between >= 0 and <= 65535") - try: - return RdataClass(value) - except ValueError: - return value + return RdataClass.to_enum(value) def is_metaclass(rdclass): diff --git a/dns/rdatatype.py b/dns/rdatatype.py index 7bfa7645..ba2e2615 100644 --- a/dns/rdatatype.py +++ b/dns/rdatatype.py @@ -17,12 +17,10 @@ """DNS Rdata Types.""" -import enum -import re - +import dns.enum import dns.exception -class RdataType(enum.IntEnum): +class RdataType(dns.enum.IntEnum): """DNS Rdata Type""" TYPE0 = 0 NONE = 0 @@ -97,17 +95,31 @@ class RdataType(enum.IntEnum): TA = 32768 DLV = 32769 + @classmethod + def _maximum(cls): + return 65535 + + @classmethod + def _short_name(cls): + return "type" + + @classmethod + def _prefix(cls): + return "TYPE" + + @classmethod + def _unknown_exception_class(cls): + return UnknownRdatatype + _registered_by_text = {} _registered_by_value = {} -globals().update(RdataType.__members__.items()) +globals().update(RdataType.__members__) _metatypes = {OPT} _singletons = {SOA, NXT, DNAME, NSEC, CNAME} -_unknown_type_pattern = re.compile('TYPE([0-9]+)$', re.I) - class UnknownRdatatype(dns.exception.DNSException): """DNS resource record type is unknown.""" @@ -130,18 +142,12 @@ def from_text(text): text = text.upper().replace('-', '_') try: - return RdataType[text] - except KeyError: - pass - value = _registered_by_text.get(text.upper()) - if value is None: - match = _unknown_type_pattern.match(text) - if match is None: - raise UnknownRdatatype - value = int(match.group(1)) - if value < 0 or value > 65535: - raise ValueError("type must be between >= 0 and <= 65535") - return value + return RdataType.from_text(text) + except UnknownRdatatype: + registered_type = _registered_by_text.get(text) + if registered_type: + return registered_type + raise def to_text(value): @@ -155,15 +161,12 @@ def to_text(value): Returns a ``str``. """ - if value < 0 or value > 65535: - raise ValueError("type must be between >= 0 and <= 65535") - try: - return RdataType(value).name.replace('_', '-') - except ValueError: - text = _registered_by_value.get(value) - if text: - return text - return f'TYPE{value}' + text = RdataType.to_text(value) + if text.startswith("TYPE"): + registered_text = _registered_by_value.get(value) + if registered_text: + text = registered_text + return text.replace('_', '-') def to_enum(value): @@ -174,14 +177,7 @@ def to_enum(value): Returns an ``int``. """ - if isinstance(value, str): - return from_text(value) - if value < 0 or value > 65535: - raise ValueError("type must be between >= 0 and <= 65535") - try: - return RdataType(value) - except ValueError: - return value + return RdataType.to_enum(value) def is_metatype(rdtype):