]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add EDE retrieval helper [#969] and a get_options() helper. (#1056)
authorBob Halley <halley@dnspython.org>
Tue, 20 Feb 2024 22:01:39 +0000 (14:01 -0800)
committerGitHub <noreply@github.com>
Tue, 20 Feb 2024 22:01:39 +0000 (14:01 -0800)
dns/message.py
tests/test_message.py

index 44cacbd9c0a8f85c820247e7b765c3ddb8b3ec3e..8513db950a5802a972aa3b016d3c9297d7ca3f77 100644 (file)
@@ -20,7 +20,7 @@
 import contextlib
 import io
 import time
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
 
 import dns.edns
 import dns.entropy
@@ -912,6 +912,14 @@ class Message:
         self.flags &= 0x87FF
         self.flags |= dns.opcode.to_flags(opcode)
 
+    def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]:
+        """Return the list of options of the specified type."""
+        return [option for option in self.options if option.otype == otype]
+
+    def extended_errors(self) -> List[dns.edns.EDEOption]:
+        """Return the list of Extended DNS Error (EDE) options in the message"""
+        return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE))
+
     def _get_one_rr_per_rrset(self, value):
         # What the caller picked is fine.
         return value
index 93c8aafd0f5f898e0718c4a1c87ea5787a77bb08..bbd457183c1e597a7705a551adafe8ac180ed6d1 100644 (file)
@@ -991,6 +991,15 @@ www.dnspython.org. 300 IN A 1.2.3.4
         self.assertEqual(update.section_count(dns.update.UpdateSection.PREREQ), 5)
         self.assertEqual(update.section_count(dns.update.UpdateSection.UPDATE), 7)
 
+    def test_extended_errors(self):
+        options = [
+            dns.edns.EDEOption(dns.edns.EDECode.NETWORK_ERROR, "tubes not tubing"),
+            dns.edns.EDEOption(dns.edns.EDECode.OTHER, "catch all code"),
+        ]
+        r = dns.message.make_query("example", "A", use_edns=0, options=options)
+        r.flags |= dns.flags.QR
+        self.assertEqual(r.extended_errors(), options)
+
 
 if __name__ == "__main__":
     unittest.main()