]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
edns: implement Extended DNS Error Option support 741/head
authorTomas Krizek <tomas.krizek@nic.cz>
Sat, 18 Dec 2021 17:12:47 +0000 (18:12 +0100)
committerTomas Krizek <tomas.krizek@nic.cz>
Sat, 18 Dec 2021 18:00:06 +0000 (19:00 +0100)
This is quite minimalistic implementation of the Extended DNS Errors
(RFC 8914). It just allows access to code and text fields.

dns/edns.py
tests/test_edns.py

index 237178f24985c66dce4d803b2a636b3e907da163..5d85da9b65cba87bfffd539a29f26dbe83ff7f55 100644 (file)
@@ -47,6 +47,8 @@ class OptionType(dns.enum.IntEnum):
     PADDING = 12
     #: CHAIN
     CHAIN = 13
+    #: EDE (extended-dns-error)
+    EDE = 15
 
     @classmethod
     def _maximum(cls):
@@ -300,10 +302,63 @@ class ECSOption(Option):
         return cls(addr, src, scope)
 
 
+class EDEOption(Option):
+    """Extended DNS Error (EDE, RFC8914)"""
+
+    def __init__(self, code, text=None):
+        """*code*, an ``int``, the info code of the extended error.
+
+        *text*, a ``str``, optional field containing additional textual
+        information.
+        """
+
+        super().__init__(OptionType.EDE)
+
+        if code < 0 or code > 65535:
+            raise ValueError('code must be uint16')
+        if text is not None and not isinstance(text, str):
+            raise ValueError('text must be string or None')
+
+        self.code = code
+        self.text = text
+
+    def to_text(self):
+        output = "EDE {}".format(self.code)
+        if self.text is not None:
+            output += ': {}'.format(self.text)
+        return output
+
+    def to_wire(self, file=None):
+        value = struct.pack('!H', self.code)
+        if self.text is not None:
+            value += self.text.encode('utf8')
+
+        if file:
+            file.write(value)
+        else:
+            return value
+
+    @classmethod
+    def from_wire_parser(cls, otype, parser):
+        code = parser.get_uint16()
+        text = parser.get_remaining()
+
+        if text:
+            if text[-1] == 0:  # text MAY be null-terminated
+                text = text[:-1]
+            text = text.decode('utf8')
+        else:
+            text = None
+
+        return cls(code, text)
+
+
 _type_to_class = {
-    OptionType.ECS: ECSOption
+    OptionType.ECS: ECSOption,
+    OptionType.EDE: EDEOption,
 }
 
+
 def get_option_class(otype):
     """Return the class for the specified option type.
 
@@ -372,5 +427,6 @@ COOKIE = OptionType.COOKIE
 KEEPALIVE = OptionType.KEEPALIVE
 PADDING = OptionType.PADDING
 CHAIN = OptionType.CHAIN
+EDE = OptionType.EDE
 
 ### END generated OptionType constants
index 6ba0c995320a31bce2a094567e2ec93119e61db9..427eb29c25b2279bb25cbcc955c0291b747234e9 100644 (file)
@@ -134,6 +134,42 @@ class OptionTestCase(unittest.TestCase):
             opt = dns.edns.option_from_wire(dns.edns.ECS,
                                             b'\x00\xff\x18\x00\x01\x02\x03',
                                             0, 7)
+    def testEDEOption(self):
+        opt = dns.edns.EDEOption(3)
+        io = BytesIO()
+        opt.to_wire(io)
+        data = io.getvalue()
+        self.assertEqual(data, b'\x00\x03')
+        self.assertEqual(str(opt), 'EDE 3')
+        # with text
+        opt = dns.edns.EDEOption(16, 'test')
+        io = BytesIO()
+        opt.to_wire(io)
+        data = io.getvalue()
+        self.assertEqual(data, b'\x00\x10test')
+
+    def testEDEOption_invalid(self):
+        with self.assertRaises(ValueError):
+            opt = dns.edns.EDEOption(-1)
+        with self.assertRaises(ValueError):
+            opt = dns.edns.EDEOption(65536)
+        with self.assertRaises(ValueError):
+            opt = dns.edns.EDEOption(0, 0)
+
+    def testEDEOption_from_wire(self):
+        data = b'\x00\01'
+        self.assertEqual(
+            dns.edns.option_from_wire(dns.edns.EDE, data, 0, 2),
+            dns.edns.EDEOption(1))
+        data = b'\x00\01test'
+        self.assertEqual(
+            dns.edns.option_from_wire(dns.edns.EDE, data, 0, 6),
+            dns.edns.EDEOption(1, 'test'))
+        # utf-8 text MAY be null-terminated
+        data = b'\x00\01test\x00'
+        self.assertEqual(
+            dns.edns.option_from_wire(dns.edns.EDE, data, 0, 7),
+            dns.edns.EDEOption(1, 'test'))
 
     def test_basic_relations(self):
         o1 = dns.edns.ECSOption.from_text('1.2.3.0/24/0')