From: Brian Wellington Date: Fri, 15 Jul 2022 16:01:17 +0000 (-0700) Subject: Fix dns.rdatatype special cases. X-Git-Tag: v2.3.0rc1~58^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F822%2Fhead;p=thirdparty%2Fdnspython.git Fix dns.rdatatype special cases. Prior to this change, there was logic in dns.rdatatype.from_text() and to_text() to deal with types not handled by the RdataType enum; specifically, the NSAP-PTR type (the enum value has a different name, because of the hyphen) and user-registered types. This was fine when internal code called these methods, but most callers of from_text() were converted to dns.rdatatype.RdataType.make(), which supports both integer and text input, and it doesn't handle the special cases. This change adds more hooks into the enum wrapper and moves the special case handling for RdataType into them. --- diff --git a/dns/enum.py b/dns/enum.py index 9c674883..0eeafd4b 100644 --- a/dns/enum.py +++ b/dns/enum.py @@ -33,6 +33,9 @@ class IntEnum(enum.IntEnum): return cls[text] except KeyError: pass + value = cls._extra_from_text(text) + if value: + return value prefix = cls._prefix() if text.startswith(prefix) and text[len(prefix) :].isdigit(): value = int(text[len(prefix) :]) @@ -47,9 +50,13 @@ class IntEnum(enum.IntEnum): def to_text(cls, value): cls._check_value(value) try: - return cls(value).name + text = cls(value).name except ValueError: - return f"{cls._prefix()}{value}" + text = None + text = cls._extra_to_text(value, text) + if text is None: + text = f"{cls._prefix()}{value}" + return text @classmethod def make(cls, value): @@ -86,6 +93,14 @@ class IntEnum(enum.IntEnum): def _prefix(cls): return "" + @classmethod + def _extra_from_text(cls, text): + return None + + @classmethod + def _extra_to_text(cls, value, current_text): + return current_text + @classmethod def _unknown_exception_class(cls): return ValueError diff --git a/dns/rdatatype.py b/dns/rdatatype.py index 0a2854da..e6c58186 100644 --- a/dns/rdatatype.py +++ b/dns/rdatatype.py @@ -120,6 +120,23 @@ class RdataType(dns.enum.IntEnum): def _prefix(cls): return "TYPE" + @classmethod + def _extra_from_text(cls, text): + if text.find("-") >= 0: + try: + return cls[text.replace("-", "_")] + except KeyError: + pass + return _registered_by_text.get(text) + + @classmethod + def _extra_to_text(cls, value, current_text): + if current_text is None: + return _registered_by_value.get(value) + if current_text.find("_") >= 0: + return current_text.replace("_", "-") + return current_text + @classmethod def _unknown_exception_class(cls): return UnknownRdatatype @@ -158,14 +175,7 @@ def from_text(text: str) -> RdataType: Returns a ``dns.rdatatype.RdataType``. """ - text = text.upper().replace("-", "_") - try: - return RdataType.from_text(text) - except UnknownRdatatype: - registered_type = _registered_by_text.get(text) - if registered_type: - return registered_type - raise + return RdataType.from_text(text) def to_text(value: RdataType) -> str: @@ -179,12 +189,7 @@ def to_text(value: RdataType) -> str: Returns a ``str``. """ - text = RdataType.to_text(value) - if text.startswith("TYPE"): - registered_text = _registered_by_value.get(value) - if registered_text: - text = registered_text - return text.replace("_", "-") + return RdataType.to_text(value) def is_metatype(rdtype: RdataType) -> bool: diff --git a/tests/test_rdata.py b/tests/test_rdata.py index a1c066af..73023693 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -61,6 +61,9 @@ class RdataTestCase(unittest.TestCase): self.assertEqual(rdata.strings, (b"hello", b"world")) self.assertEqual(dns.rdatatype.to_text(TTXT), "TTXT") self.assertEqual(dns.rdatatype.from_text("TTXT"), TTXT) + self.assertEqual(dns.rdatatype.RdataType.make("TTXT"), TTXT) + self.assertEqual(dns.rdatatype.from_text("ttxt"), TTXT) + self.assertEqual(dns.rdatatype.RdataType.make("ttxt"), TTXT) def test_module_reregistration(self): def bad(): @@ -937,6 +940,14 @@ class UtilTestCase(unittest.TestCase): with self.assertRaises(dns.ttl.BadTTL): dns.rdataset.from_text("in", "a", "10.0.0.1", "10.0.0.2") + def test_nsap_ptr_type(self): + # The NSAP-PTR type is special because it contains a dash, which means + # that its enum value is not the same as its string value. + self.assertEqual(dns.rdatatype.from_text("NSAP-PTR"), dns.rdatatype.NSAP_PTR) + self.assertEqual( + dns.rdatatype.RdataType.make("NSAP-PTR"), dns.rdatatype.NSAP_PTR + ) + Rdata = dns.rdata.Rdata