]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
More TSIG fixes. (#1207)
authorBrian Wellington <bwelling@xbill.org>
Thu, 10 Jul 2025 20:55:55 +0000 (13:55 -0700)
committerGitHub <noreply@github.com>
Thu, 10 Jul 2025 20:55:55 +0000 (13:55 -0700)
* More TSIG fixes.

This attempts to fix the bug that when a message containing a TSIG
record is parsed from wire format, attempting to render it back to wire
format will either regenerate the TSIG (if the TSIG was verified) or
throw an exception (if the TSIG was not verified).  In either case,
the rendered message should contain the TSIG record that was parsed.

* Fix setting tsig_ctx.

dns/message.py
tests/test_tsig.py

index e29941d439548704f27db59c62f6511c2ae9ba7a..77bec1e94cbc00acc47c178aa32c4a82c8cd018c 100644 (file)
@@ -157,6 +157,7 @@ class Message:
         self.pad = 0
         self.keyring: Any = None
         self.tsig: Optional[dns.rrset.RRset] = None
+        self.want_tsig_sign = False
         self.request_mac = b""
         self.xfr = False
         self.origin: Optional[dns.name.Name] = None
@@ -637,21 +638,22 @@ class Message:
             r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
         r.write_header()
         if self.tsig is not None:
-            (new_tsig, ctx) = dns.tsig.sign(
-                r.get_wire(),
-                self.keyring,
-                self.tsig[0],
-                int(time.time()),
-                self.request_mac,
-                tsig_ctx,
-                multi,
-            )
-            self.tsig.clear()
-            self.tsig.add(new_tsig)
+            if self.want_tsig_sign:
+                (new_tsig, ctx) = dns.tsig.sign(
+                    r.get_wire(),
+                    self.keyring,
+                    self.tsig[0],
+                    int(time.time()),
+                    self.request_mac,
+                    tsig_ctx,
+                    multi,
+                )
+                self.tsig.clear()
+                self.tsig.add(new_tsig)
+                if multi:
+                    self.tsig_ctx = ctx
             r.add_rrset(dns.renderer.ADDITIONAL, self.tsig)
             r.write_header()
-            if multi:
-                self.tsig_ctx = ctx
         wire = r.get_wire()
         self.wire = wire
         if prepend_length:
@@ -688,9 +690,6 @@ class Message:
         """When sending, a TSIG signature using the specified key
         should be added.
 
-        *key*, a ``dns.tsig.Key`` is the key to use.  If a key is specified,
-        the *keyring* and *algorithm* fields are not used.
-
         *keyring*, a ``dict``, ``callable`` or ``dns.tsig.Key``, is either
         the TSIG keyring or key to use.
 
@@ -748,6 +747,7 @@ class Message:
             tsig_error,
             other_data,
         )
+        self.want_tsig_sign = True
 
     @property
     def keyname(self) -> Optional[dns.name.Name]:
index 4715e393d002e2289874f6d2fee8fccbe67edb88..c8c1efe53225fe3e0987da8e681cb65aa7303323 100644 (file)
@@ -378,3 +378,27 @@ example. 300 IN SOA . . 1 2 3 4 5
         m = dns.message.make_response(q2)
 
         self.assertIsNone(m.tsig)
+
+    def test_render_message_with_existing_tsig(self):
+        q1 = dns.message.make_query("example", "a")
+        q1.use_tsig(keyring, keyname)
+        wire = q1.to_wire()
+
+        # The TSIG record parsed from wire format is rendered if
+        # it is not verified.
+        q2 = dns.message.from_wire(wire, keyring=False)
+        self.assertEqual(q2.to_wire(), wire)
+
+        # The TSIG record parsed from wire format is rendered if
+        # it is verified.
+        q3 = dns.message.from_wire(wire, keyring=keyring)
+        self.assertEqual(q3.to_wire(), wire)
+
+        # A new TSIG record is generated if we call use_tsig().
+        #
+        # Note that this specifies a new key because if the same key were used,
+        # the same TSIG record might be generated, and the wire format would
+        # be identical.
+        q4 = dns.message.from_wire(wire, keyring=keyring)
+        q4.use_tsig(dns.tsig.Key(keyname, "abcd"))
+        self.assertNotEqual(q4.to_wire(), wire)