]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add prefer_truncation to Message.to_wire(). (#1023)
authorBrian Wellington <bwelling@xbill.org>
Wed, 20 Dec 2023 22:09:34 +0000 (14:09 -0800)
committerGitHub <noreply@github.com>
Wed, 20 Dec 2023 22:09:34 +0000 (14:09 -0800)
If a caller passes prefer_truncation=True, the message will be truncated
if it would otherwise exceed the maximum length.  If the truncation
occurs before the additional section, the TC bit will be set.

This behavior matches what a name server would do when generating a
response.

dns/message.py
tests/test_message.py

index fff06dfa4521f054b3e6b694d90d165be1a1c875..4e226160c4831eec64f1618b27540766396371ef 100644 (file)
@@ -528,6 +528,7 @@ class Message:
         multi: bool = False,
         tsig_ctx: Optional[Any] = None,
         prepend_length: bool = False,
+        prefer_truncation: bool = False,
         **kw: Dict[str, Any],
     ) -> bytes:
         """Return a string containing the message in DNS compressed wire
@@ -554,6 +555,11 @@ class Message:
         wants the message length prepended to the message itself.  This is
         useful for messages sent over TCP, TLS (DoT), or QUIC (DoQ).
 
+        *prefer_truncation*, a ``bool``, should be set to ``True`` if the caller
+        wants the message to be truncated if it would otherwise exceed the
+        maximum length.  If the truncation occurs before the additional section,
+        the TC bit will be set.
+
         Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
 
         Returns a ``bytes``.
@@ -575,14 +581,21 @@ class Message:
         r.reserve(opt_reserve)
         tsig_reserve = self._compute_tsig_reserve()
         r.reserve(tsig_reserve)
-        for rrset in self.question:
-            r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
-        for rrset in self.answer:
-            r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
-        for rrset in self.authority:
-            r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
-        for rrset in self.additional:
-            r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
+        try:
+            for rrset in self.question:
+                r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
+            for rrset in self.answer:
+                r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
+            for rrset in self.authority:
+                r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
+            for rrset in self.additional:
+                r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
+        except dns.exception.TooBig:
+            if prefer_truncation:
+                if r.section < dns.renderer.ADDITIONAL:
+                    r.flags |= dns.flags.TC
+            else:
+                raise
         r.release_reserved()
         if self.opt is not None:
             r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
index 59ca507b0306edaf1bac3e5c3830b7b90adbdcf6..356b8a0c2cf9b5033b3965a3382cc087db67d66e 100644 (file)
@@ -882,6 +882,66 @@ www.dnspython.org. 300 IN A 1.2.3.4
         self.assertIsNotNone(q2.tsig)
         self.assertEqual(q, q2)
 
+    def test_prefer_truncation_answer(self):
+        q = dns.message.make_query("www.example", "a")
+        rrs = [
+            dns.rrset.from_text("www.example.", 3600, "in", "a", f"1.2.3.{n}")
+            for n in range(32)
+        ]
+        r = dns.message.make_response(q)
+        r.answer.extend(rrs)
+
+        # Normally, we get an exception
+        with self.assertRaises(dns.exception.TooBig):
+            w1 = r.to_wire(max_size=512)
+
+        # With prefer_truncation, we get a truncated response where 1 record
+        # doesn't fit, and TC is set.
+        w2 = r.to_wire(max_size=512, prefer_truncation=True)
+        r2 = dns.message.from_wire(w2, one_rr_per_rrset=True)
+        self.assertNotEqual(r2.flags & dns.flags.TC, 0)
+        self.assertEqual(len(r2.answer), 30)
+
+    def test_prefer_truncation_edns(self):
+        q = dns.message.make_query("www.example", "a", payload=512)
+        rrs = [
+            dns.rrset.from_text("www.example.", 3600, "in", "a", f"1.2.3.{n}")
+            for n in range(32)
+        ]
+        r = dns.message.make_response(q)
+        r.answer.extend(rrs)
+
+        # Normally, we get an exception
+        with self.assertRaises(dns.exception.TooBig):
+            w1 = r.to_wire(max_size=512)
+
+        # With prefer_truncation, we get a truncated response where 2 records
+        # don't fit, and TC is set.
+        w2 = r.to_wire(max_size=512, prefer_truncation=True)
+        r2 = dns.message.from_wire(w2, one_rr_per_rrset=True)
+        self.assertNotEqual(r2.flags & dns.flags.TC, 0)
+        self.assertEqual(len(r2.answer), 29)
+
+    def test_prefer_truncation_additional(self):
+        q = dns.message.make_query("www.example", "a")
+        rrs = [
+            dns.rrset.from_text("www.example.", 3600, "in", "a", f"1.2.3.{n}")
+            for n in range(32)
+        ]
+        r = dns.message.make_response(q)
+        r.additional.extend(rrs)
+
+        # Normally, we get an exception
+        with self.assertRaises(dns.exception.TooBig):
+            w1 = r.to_wire(max_size=512)
+
+        # With prefer_truncation, we get a truncated response where 1 record
+        # doesn't fit, and TC is not set.
+        w2 = r.to_wire(max_size=512, prefer_truncation=True)
+        r2 = dns.message.from_wire(w2, one_rr_per_rrset=True)
+        self.assertEqual(r2.flags & dns.flags.TC, 0)
+        self.assertEqual(len(r2.additional), 30)
+
 
 if __name__ == "__main__":
     unittest.main()