]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Test (and fix) renderer.add_multi_tsig().
authorBrian Wellington <bwelling@xbill.org>
Tue, 30 Jun 2020 16:34:41 +0000 (09:34 -0700)
committerBrian Wellington <bwelling@xbill.org>
Tue, 30 Jun 2020 16:34:41 +0000 (09:34 -0700)
dns/renderer.py
tests/test_renderer.py

index 6e50d272af53648de57de7a8de8a0cd50ddcf750..be57a62f9b893be29a0280089ec91a0eb5d93c41 100644 (file)
@@ -202,7 +202,7 @@ class Renderer:
                                               b'', id, tsig_error, other_data)
         (tsig, ctx) = dns.tsig.sign(s, keyname, tsig[0], secret,
                                     int(time.time()), request_mac,
-                                    ctx, True, ctx is None)
+                                    ctx, True)
         self._write_tsig(tsig, keyname)
         return ctx
 
index db9d0f3b1eb383d2c99f0029f40a7a105300998f..c60ccf9566c5f6202101bd08ecd53f303b6b4d0b 100644 (file)
@@ -52,6 +52,35 @@ class RendererTestCase(unittest.TestCase):
         expected.id = message.id
         self.assertEqual(message, expected)
 
+    def test_multi_tsig(self):
+        qname = dns.name.from_text('foo.example')
+        keyring = dns.tsigkeyring.from_text({'key' : '12345678'})
+        keyname = next(iter(keyring))
+
+        r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+        r.add_question(qname, dns.rdatatype.A)
+        r.write_header()
+        ctx = r.add_multi_tsig(None, keyname, keyring[keyname], 300, r.id, 0,
+                               b'', b'', dns.tsig.HMAC_SHA256)
+        wire = r.get_wire()
+        message = dns.message.from_wire(wire, keyring=keyring, multi=True)
+        expected = dns.message.make_query(qname, dns.rdatatype.A)
+        expected.id = message.id
+        self.assertEqual(message, expected)
+
+        r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+        r.add_question(qname, dns.rdatatype.A)
+        r.write_header()
+        ctx = r.add_multi_tsig(ctx, keyname, keyring[keyname], 300, r.id, 0,
+                               b'', b'', dns.tsig.HMAC_SHA256)
+        wire = r.get_wire()
+        message = dns.message.from_wire(wire, keyring=keyring,
+                                        tsig_ctx=message.tsig_ctx, multi=True)
+        expected = dns.message.make_query(qname, dns.rdatatype.A)
+        expected.id = message.id
+        self.assertEqual(message, expected)
+
+
     def test_going_backwards_fails(self):
         r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
         qname = dns.name.from_text('foo.example')