]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Support registering new types with classes. (#1167)
authorBrian Wellington <bwelling@xbill.org>
Fri, 29 Nov 2024 20:23:24 +0000 (12:23 -0800)
committerGitHub <noreply@github.com>
Fri, 29 Nov 2024 20:23:24 +0000 (12:23 -0800)
* Support registering new types with classes.

Previously, dns.rdata.register_type() required passing a module which
contained the implementation of the new type, and it would extract the
class from the module.  This change allows passing the class directly.

dns/rdata.py
tests/test_rdata.py

index bcdac094e87f92e286b8c814c0136fc3f82d055d..1913dd6ca575c3c3436dce7270a4ece83e7baeca 100644 (file)
@@ -891,8 +891,8 @@ def register_type(
 ) -> None:
     """Dynamically register a module to handle an rdatatype.
 
-    *implementation*, a module implementing the type in the usual dnspython
-    way.
+    *implementation*, a subclass of ``dns.rdata.Rdata`` implementing the type,
+    or a module containing such a class named by its text form.
 
     *rdtype*, an ``int``, the rdatatype to register.
 
@@ -909,7 +909,9 @@ def register_type(
     existing_cls = get_rdata_class(rdclass, rdtype)
     if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
         raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
-    _rdata_classes[(rdclass, rdtype)] = getattr(
-        implementation, rdtype_text.replace("-", "_")
-    )
+    if isinstance(implementation, type) and issubclass(implementation, Rdata):
+        impclass = implementation
+    else:
+        impclass = getattr(implementation, rdtype_text.replace("-", "_"))
+    _rdata_classes[(rdclass, rdtype)] = impclass
     dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
index 4c62aa1d4c25b812301bb366e6ea91ea0900c6d3..c1d3416c6a09cec75bb35bf2bcd69ec2d8544b3e 100644 (file)
@@ -63,6 +63,20 @@ class RdataTestCase(unittest.TestCase):
         self.assertEqual(dns.rdatatype.from_text("ttxt"), TTXT)
         self.assertEqual(dns.rdatatype.RdataType.make("ttxt"), TTXT)
 
+    def test_class_registration(self):
+        CTXT = 64003
+        class CTXTImp(dns.rdtypes.txtbase.TXTBase):
+            """Test TXT-like record"""
+
+        dns.rdata.register_type(CTXTImp, CTXT, "CTXT")
+        rdata = dns.rdata.from_text(dns.rdataclass.IN, CTXT, "hello world")
+        self.assertEqual(rdata.strings, (b"hello", b"world"))
+        self.assertEqual(dns.rdatatype.to_text(CTXT), "CTXT")
+        self.assertEqual(dns.rdatatype.from_text("CTXT"), CTXT)
+        self.assertEqual(dns.rdatatype.RdataType.make("CTXT"), CTXT)
+        self.assertEqual(dns.rdatatype.from_text("ctxt"), CTXT)
+        self.assertEqual(dns.rdatatype.RdataType.make("ctxt"), CTXT)
+
     def test_module_reregistration(self):
         def bad():
             TTXTTWO = dns.rdatatype.TXT