]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix dns.rdatatype special cases. 822/head
authorBrian Wellington <bwelling@xbill.org>
Fri, 15 Jul 2022 16:01:17 +0000 (09:01 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 15 Jul 2022 16:01:17 +0000 (09:01 -0700)
Prior to this change, there was logic in dns.rdatatype.from_text() and
to_text() to deal with types not handled by the RdataType enum;
specifically, the NSAP-PTR type (the enum value has a different name,
because of the hyphen) and user-registered types.

This was fine when internal code called these methods, but most callers
of from_text() were converted to dns.rdatatype.RdataType.make(), which
supports both integer and text input, and it doesn't handle the special
cases.

This change adds more hooks into the enum wrapper and moves the special
case handling for RdataType into them.

dns/enum.py
dns/rdatatype.py
tests/test_rdata.py

index 9c674883ccfc51a6045ba13038315661b305d7f8..0eeafd4bf728378fa5503c3344322eb396018951 100644 (file)
@@ -33,6 +33,9 @@ class IntEnum(enum.IntEnum):
             return cls[text]
         except KeyError:
             pass
+        value = cls._extra_from_text(text)
+        if value:
+            return value
         prefix = cls._prefix()
         if text.startswith(prefix) and text[len(prefix) :].isdigit():
             value = int(text[len(prefix) :])
@@ -47,9 +50,13 @@ class IntEnum(enum.IntEnum):
     def to_text(cls, value):
         cls._check_value(value)
         try:
-            return cls(value).name
+            text = cls(value).name
         except ValueError:
-            return f"{cls._prefix()}{value}"
+            text = None
+        text = cls._extra_to_text(value, text)
+        if text is None:
+            text = f"{cls._prefix()}{value}"
+        return text
 
     @classmethod
     def make(cls, value):
@@ -86,6 +93,14 @@ class IntEnum(enum.IntEnum):
     def _prefix(cls):
         return ""
 
+    @classmethod
+    def _extra_from_text(cls, text):
+        return None
+
+    @classmethod
+    def _extra_to_text(cls, value, current_text):
+        return current_text
+
     @classmethod
     def _unknown_exception_class(cls):
         return ValueError
index 0a2854da0d48a45dfeb68cba65a57365c1a5ceed..e6c581867bcc7cd7e806ea92c3dab28f0d021d3e 100644 (file)
@@ -120,6 +120,23 @@ class RdataType(dns.enum.IntEnum):
     def _prefix(cls):
         return "TYPE"
 
+    @classmethod
+    def _extra_from_text(cls, text):
+        if text.find("-") >= 0:
+            try:
+                return cls[text.replace("-", "_")]
+            except KeyError:
+                pass
+        return _registered_by_text.get(text)
+
+    @classmethod
+    def _extra_to_text(cls, value, current_text):
+        if current_text is None:
+            return _registered_by_value.get(value)
+        if current_text.find("_") >= 0:
+            return current_text.replace("_", "-")
+        return current_text
+
     @classmethod
     def _unknown_exception_class(cls):
         return UnknownRdatatype
@@ -158,14 +175,7 @@ def from_text(text: str) -> RdataType:
     Returns a ``dns.rdatatype.RdataType``.
     """
 
-    text = text.upper().replace("-", "_")
-    try:
-        return RdataType.from_text(text)
-    except UnknownRdatatype:
-        registered_type = _registered_by_text.get(text)
-        if registered_type:
-            return registered_type
-        raise
+    return RdataType.from_text(text)
 
 
 def to_text(value: RdataType) -> str:
@@ -179,12 +189,7 @@ def to_text(value: RdataType) -> str:
     Returns a ``str``.
     """
 
-    text = RdataType.to_text(value)
-    if text.startswith("TYPE"):
-        registered_text = _registered_by_value.get(value)
-        if registered_text:
-            text = registered_text
-    return text.replace("_", "-")
+    return RdataType.to_text(value)
 
 
 def is_metatype(rdtype: RdataType) -> bool:
index a1c066afacd9885a6a527d09c6e00658d0ad8c3a..73023693a74292af78196e491dca0de2bfd94561 100644 (file)
@@ -61,6 +61,9 @@ class RdataTestCase(unittest.TestCase):
         self.assertEqual(rdata.strings, (b"hello", b"world"))
         self.assertEqual(dns.rdatatype.to_text(TTXT), "TTXT")
         self.assertEqual(dns.rdatatype.from_text("TTXT"), TTXT)
+        self.assertEqual(dns.rdatatype.RdataType.make("TTXT"), TTXT)
+        self.assertEqual(dns.rdatatype.from_text("ttxt"), TTXT)
+        self.assertEqual(dns.rdatatype.RdataType.make("ttxt"), TTXT)
 
     def test_module_reregistration(self):
         def bad():
@@ -937,6 +940,14 @@ class UtilTestCase(unittest.TestCase):
         with self.assertRaises(dns.ttl.BadTTL):
             dns.rdataset.from_text("in", "a", "10.0.0.1", "10.0.0.2")
 
+    def test_nsap_ptr_type(self):
+        # The NSAP-PTR type is special because it contains a dash, which means
+        # that its enum value is not the same as its string value.
+        self.assertEqual(dns.rdatatype.from_text("NSAP-PTR"), dns.rdatatype.NSAP_PTR)
+        self.assertEqual(
+            dns.rdatatype.RdataType.make("NSAP-PTR"), dns.rdatatype.NSAP_PTR
+        )
+
 
 Rdata = dns.rdata.Rdata