]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Store a TSIG rrset on the message object.
authorBrian Wellington <bwelling@xbill.org>
Mon, 29 Jun 2020 21:55:35 +0000 (14:55 -0700)
committerBrian Wellington <bwelling@xbill.org>
Tue, 30 Jun 2020 15:33:19 +0000 (08:33 -0700)
dns/message.py
dns/rdtypes/ANY/TSIG.py
dns/renderer.py
dns/tsig.py
tests/test_renderer.py
tests/test_tsig.py

index 1676ba8ae1f58aef6d4f2867b7c05d8f8ee1da6d..a7c6bfaa077570e82cc1f9751072aa1eee22222d 100644 (file)
@@ -38,6 +38,7 @@ import dns.renderer
 import dns.tsig
 import dns.wiredata
 import dns.rdtypes.ANY.OPT
+import dns.rdtypes.ANY.TSIG
 
 
 class ShortHeader(dns.exception.FormError):
@@ -109,18 +110,11 @@ class Message:
         self.opt = None
         self.request_payload = 0
         self.keyring = None
-        self.keyname = None
-        self.keyalgorithm = dns.tsig.default_algorithm
+        self.tsig = None
         self.request_mac = b''
-        self.other_data = b''
-        self.tsig_error = 0
-        self.fudge = 300
-        self.original_id = self.id
-        self.mac = b''
         self.xfr = False
         self.origin = None
         self.tsig_ctx = None
-        self.had_tsig = False
         self.multi = False
         self.first = True
         self.index = {}
@@ -443,22 +437,32 @@ class Message:
         for rrset in self.additional:
             r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
         r.write_header()
-        if self.keyname is not None:
+        if self.tsig is not None:
+            (new_tsig, ctx) = dns.tsig.sign(r.get_wire(),
+                                            self.tsig.name,
+                                            self.tsig[0],
+                                            self.keyring[self.tsig.name],
+                                            int(time.time()),
+                                            self.request_mac,
+                                            tsig_ctx,
+                                            multi,
+                                            tsig_ctx is None)
+            self.tsig.clear()
+            self.tsig.add(new_tsig)
+            r.add_rrset(dns.renderer.ADDITIONAL, self.tsig)
+            r.write_header()
             if multi:
-                ctx = r.add_multi_tsig(tsig_ctx,
-                                       self.keyname, self.keyring[self.keyname],
-                                       self.fudge, self.original_id,
-                                       self.tsig_error, self.other_data,
-                                       self.request_mac, self.keyalgorithm)
                 self.tsig_ctx = ctx
-            else:
-                r.add_tsig(self.keyname, self.keyring[self.keyname],
-                           self.fudge, self.original_id, self.tsig_error,
-                           self.other_data, self.request_mac,
-                           self.keyalgorithm)
-            self.mac = r.mac
         return r.get_wire()
 
+    @staticmethod
+    def _make_tsig(keyname, algorithm, time_signed, fudge, mac, original_id,
+                   error, other):
+        tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
+                                         algorithm, time_signed, fudge, mac,
+                                         original_id, error, other)
+        return dns.rrset.from_rdata(keyname, 0, tsig)
+
     def use_tsig(self, keyring, keyname=None, fudge=300,
                  original_id=None, tsig_error=0, other_data=b'',
                  algorithm=dns.tsig.default_algorithm):
@@ -492,19 +496,45 @@ class Message:
 
         self.keyring = keyring
         if keyname is None:
-            self.keyname = list(self.keyring.keys())[0]
-        else:
-            if isinstance(keyname, str):
-                keyname = dns.name.from_text(keyname)
-            self.keyname = keyname
-        self.keyalgorithm = algorithm
-        self.fudge = fudge
+            keyname = list(self.keyring.keys())[0]
+        elif isinstance(keyname, str):
+            keyname = dns.name.from_text(keyname)
         if original_id is None:
-            self.original_id = self.id
+            original_id = self.id
+        self.tsig = self._make_tsig(keyname, algorithm, 0, fudge, b'',
+                                    original_id, tsig_error, other_data)
+
+    @property
+    def keyname(self):
+        if self.tsig:
+            return self.tsig.name
+        else:
+            return None
+
+    @property
+    def keyalgorithm(self):
+        if self.tsig:
+            return self.tsig[0].algorithm
         else:
-            self.original_id = original_id
-        self.tsig_error = tsig_error
-        self.other_data = other_data
+            return None
+
+    @property
+    def mac(self):
+        if self.tsig:
+            return self.tsig[0].mac
+        else:
+            return None
+
+    @property
+    def tsig_error(self):
+        if self.tsig:
+            return self.tsig[0].error
+        else:
+            return None
+
+    @property
+    def had_tsig(self):
+        return bool(self.tsig)
 
     @staticmethod
     def _make_opt(flags=0, payload=1280, options=None):
@@ -659,7 +689,7 @@ class Message:
             if section != MessageSection.ADDITIONAL or \
                rdclass != dns.rdatatype.ANY or \
                position != count - 1:
-                raise dns.error.FormError
+                raise BadTSIG
         return (rdclass, rdtype, None, False)
 
 
@@ -781,7 +811,6 @@ class _WireReader:
                 secret = self.message.keyring.get(absolute_name)
                 if secret is None:
                     raise UnknownTSIGKey("key '%s' unknown" % name)
-                self.message.keyname = absolute_name
                 self.message.tsig_ctx = \
                     dns.tsig.validate(self.wire,
                                       absolute_name,
@@ -793,9 +822,7 @@ class _WireReader:
                                       self.message.tsig_ctx,
                                       self.message.multi,
                                       self.message.first)
-                self.message.keyalgorithm = rd.algorithm
-                self.message.mac = rd.mac
-                self.message.had_tsig = True
+                self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
             else:
                 rrset = self.message.find_rrset(section, name,
                                                 rdclass, rdtype, covers,
@@ -1291,7 +1318,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
 
 
 def make_response(query, recursion_available=False, our_payload=8192,
-                  fudge=300):
+                  fudge=300, tsig_error=0):
     """Make a message which is a response for the specified query.
     The message returned is really a response skeleton; it has all
     of the infrastructure required of a response, but none of the
@@ -1310,6 +1337,8 @@ def make_response(query, recursion_available=False, our_payload=8192,
 
     *fudge*, an ``int``, the TSIG time fudge.
 
+    *tsig_error*, an ``int``, the TSIG error.
+
     Returns a ``dns.message.Message`` object whose specific class is
     appropriate for the query.  For example, if query is a
     ``dns.update.UpdateMessage``, response will be too.
@@ -1327,7 +1356,7 @@ def make_response(query, recursion_available=False, our_payload=8192,
     if query.edns >= 0:
         response.use_edns(0, 0, our_payload, query.payload)
     if query.had_tsig:
-        response.use_tsig(query.keyring, query.keyname, fudge, None, 0, b'',
-                          query.keyalgorithm)
+        response.use_tsig(query.keyring, query.keyname, fudge, None,
+                          tsig_error, b'', query.keyalgorithm)
         response.request_mac = query.mac
     return response
index e3937c6f6357db5b8645d1adb3f8277813e28cf3..c2a6ce01a06032b9f1e1cb61132b7bd2cf285ae5 100644 (file)
@@ -62,7 +62,7 @@ class TSIG(dns.rdata.Rdata):
 
     def to_text(self, origin=None, relativize=True, **kw):
         algorithm = self.algorithm.choose_relativity(origin, relativize)
-        return f"{algorithm} ... {self.fudge} {self.time_signed} " + \
+        return f"{algorithm} {self.fudge} {self.time_signed} " + \
                f"{len(self.mac)} {dns.rdata._base64ify(self.mac)} " + \
                f"{self.original_id} {self.error} " + \
                f"{len(self.other)} {dns.rdata._base64ify(self.other)}"
@@ -75,7 +75,7 @@ class TSIG(dns.rdata.Rdata):
                                self.fudge,
                                len(self.mac)))
         file.write(self.mac)
-        file.write(struct.pack('HHH', self.original_id, self.error,
+        file.write(struct.pack('!HHH', self.original_id, self.error,
                                len(self.other)))
         file.write(self.other)
 
index 8b2548760d374c540efa96cb9b1b966afc0da537..6e50d272af53648de57de7a8de8a0cd50ddcf750 100644 (file)
@@ -178,17 +178,12 @@ class Renderer:
         """Add a TSIG signature to the message."""
 
         s = self.output.getvalue()
-        (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
-                                                    keyname,
-                                                    secret,
-                                                    int(time.time()),
-                                                    fudge,
-                                                    id,
-                                                    tsig_error,
-                                                    other_data,
-                                                    request_mac,
-                                                    algorithm=algorithm)
-        self._write_tsig(tsig_rdata, keyname)
+
+        tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
+                                              b'', id, tsig_error, other_data)
+        (tsig, _) = dns.tsig.sign(s, keyname, tsig[0], secret,
+                                  int(time.time()), request_mac)
+        self._write_tsig(tsig, keyname)
 
     def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error,
                        other_data, request_mac,
@@ -202,30 +197,23 @@ class Renderer:
         add_multi_tsig() call for the previous message."""
 
         s = self.output.getvalue()
-        (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
-                                                    keyname,
-                                                    secret,
-                                                    int(time.time()),
-                                                    fudge,
-                                                    id,
-                                                    tsig_error,
-                                                    other_data,
-                                                    request_mac,
-                                                    ctx=ctx,
-                                                    first=ctx is None,
-                                                    multi=True,
-                                                    algorithm=algorithm)
-        self._write_tsig(tsig_rdata, keyname)
+
+        tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
+                                              b'', id, tsig_error, other_data)
+        (tsig, ctx) = dns.tsig.sign(s, keyname, tsig[0], secret,
+                                    int(time.time()), request_mac,
+                                    ctx, True, ctx is None)
+        self._write_tsig(tsig, keyname)
         return ctx
 
-    def _write_tsig(self, tsig_rdata, keyname):
+    def _write_tsig(self, tsig, keyname):
         self._set_section(ADDITIONAL)
         with self._track_size():
             keyname.to_wire(self.output, self.compress, self.origin)
             self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
                                           dns.rdataclass.ANY, 0, 0))
             rdata_start = self.output.tell()
-            self.output.write(tsig_rdata)
+            tsig.to_wire(self.output)
 
         after = self.output.tell()
         self.output.seek(rdata_start - 2)
index a9d85de1f51e4f61c863dd0a2ccefc93097f8340..2780c3c15e957ee541ee7d5556a07150085a0d1b 100644 (file)
@@ -85,9 +85,8 @@ BADTIME = 18
 BADTRUNC = 22
 
 
-def sign(wire, keyname, secret, time, fudge, original_id, error,
-         other_data, request_mac, ctx=None, multi=False, first=True,
-         algorithm=default_algorithm):
+def sign(wire, keyname, rdata, secret, time=None, request_mac=None,
+         ctx=None, multi=False, first=True):
     """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
     for the input parameters, the HMAC MAC calculated by applying the
     TSIG signature algorithm, and the TSIG digest context.
@@ -96,46 +95,44 @@ def sign(wire, keyname, secret, time, fudge, original_id, error,
     @raises NotImplementedError: I{algorithm} is not supported
     """
 
-    if isinstance(other_data, str):
-        other_data = other_data.encode()
-    (algorithm_name, digestmod) = get_algorithm(algorithm)
+    (algorithm_name, digestmod) = get_algorithm(rdata.algorithm)
     if first:
         ctx = hmac.new(secret, digestmod=digestmod)
-        ml = len(request_mac)
-        if ml > 0:
-            ctx.update(struct.pack('!H', ml))
+        if request_mac:
+            ctx.update(struct.pack('!H', len(request_mac)))
             ctx.update(request_mac)
-    id = struct.pack('!H', original_id)
-    ctx.update(id)
+    ctx.update(struct.pack('!H', rdata.original_id))
     ctx.update(wire[2:])
     if first:
         ctx.update(keyname.to_digestable())
         ctx.update(struct.pack('!H', dns.rdataclass.ANY))
         ctx.update(struct.pack('!I', 0))
+    if time is None:
+        time = rdata.time_signed
     upper_time = (time >> 32) & 0xffff
     lower_time = time & 0xffffffff
-    time_mac = struct.pack('!HIH', upper_time, lower_time, fudge)
-    pre_mac = algorithm_name + time_mac
-    ol = len(other_data)
-    if ol > 65535:
+    time_encoded = struct.pack('!HIH', upper_time, lower_time, rdata.fudge)
+    other_len = len(rdata.other)
+    if other_len > 65535:
         raise ValueError('TSIG Other Data is > 65535 bytes')
-    post_mac = struct.pack('!HH', error, ol) + other_data
     if first:
-        ctx.update(pre_mac)
-        ctx.update(post_mac)
+        ctx.update(algorithm_name + time_encoded)
+        ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other)
     else:
-        ctx.update(time_mac)
+        ctx.update(time_encoded)
     mac = ctx.digest()
-    mpack = struct.pack('!H', len(mac))
-    tsig_rdata = pre_mac + mpack + mac + id + post_mac
     if multi:
         ctx = hmac.new(secret, digestmod=digestmod)
-        ml = len(mac)
-        ctx.update(struct.pack('!H', ml))
+        ctx.update(struct.pack('!H', len(mac)))
         ctx.update(mac)
     else:
         ctx = None
-    return (tsig_rdata, mac, ctx)
+    tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
+                                     rdata.algorithm, time, rdata.fudge, mac,
+                                     rdata.original_id, rdata.error,
+                                     rdata.other)
+
+    return (tsig, ctx)
 
 
 def validate(wire, keyname, rdata, secret, now, request_mac, tsig_start,
@@ -166,11 +163,9 @@ def validate(wire, keyname, rdata, secret, now, request_mac, tsig_start,
             raise PeerError('unknown TSIG error code %d' % rdata.error)
     if abs(rdata.time_signed - now) > rdata.fudge:
         raise BadTime
-    (junk, our_mac, ctx) = sign(new_wire, keyname, secret, rdata.time_signed,
-                                rdata.fudge, rdata.original_id, rdata.error,
-                                rdata.other, request_mac, ctx, multi, first,
-                                rdata.algorithm)
-    if our_mac != rdata.mac:
+    (our_rdata, ctx) = sign(new_wire, keyname, rdata, secret, None, request_mac,
+                            ctx, multi, first)
+    if our_rdata.mac != rdata.mac:
         raise BadSignature
     return ctx
 
@@ -191,20 +186,3 @@ def get_algorithm(algorithm):
     except KeyError:
         raise NotImplementedError("TSIG algorithm " + str(algorithm) +
                                   " is not supported")
-
-
-def get_algorithm_and_mac(wire, tsig_rdata, tsig_rdlen):
-    """Return the tsig algorithm for the specified tsig_rdata
-    @raises FormError: The TSIG is badly formed.
-    """
-    current = tsig_rdata
-    (aname, used) = dns.name.from_wire(wire, current)
-    current = current + used
-    (upper_time, lower_time, fudge, mac_size) = \
-        struct.unpack("!HIHH", wire[current:current + 10])
-    current += 10
-    mac = wire[current:current + mac_size]
-    current += mac_size
-    if current > tsig_rdata + tsig_rdlen:
-        raise dns.exception.FormError
-    return (aname, mac)
index 345ef82320642162cd1af6ec237fbf53f32f4c56..db9d0f3b1eb383d2c99f0029f40a7a105300998f 100644 (file)
@@ -3,9 +3,11 @@
 import unittest
 
 import dns.exception
+import dns.flags
 import dns.message
 import dns.renderer
-import dns.flags
+import dns.tsig
+import dns.tsigkeyring
 
 basic_answer = \
     """flags QR
@@ -35,6 +37,21 @@ class RendererTestCase(unittest.TestCase):
         expected.id = message.id
         self.assertEqual(message, expected)
 
+    def test_tsig(self):
+        r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+        qname = dns.name.from_text('foo.example')
+        r.add_question(qname, dns.rdatatype.A)
+        keyring = dns.tsigkeyring.from_text({'key' : '12345678'})
+        keyname = next(iter(keyring))
+        r.write_header()
+        r.add_tsig(keyname, keyring[keyname], 300, r.id, 0, b'', b'',
+                   dns.tsig.HMAC_SHA256)
+        wire = r.get_wire()
+        message = dns.message.from_wire(wire, keyring=keyring)
+        expected = dns.message.make_query(qname, dns.rdatatype.A)
+        expected.id = message.id
+        self.assertEqual(message, expected)
+
     def test_going_backwards_fails(self):
         r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
         qname = dns.name.from_text('foo.example')
index 037d5aa7ddc68d8d66e7abd29fb4de9267cd62a9..2722e154536999a75e14c91254fffbd291de5f2a 100644 (file)
@@ -42,12 +42,11 @@ class TSIGTestCase(unittest.TestCase):
         # not raising is passing
         dns.message.from_wire(w, keyring)
 
-    def make_message_pair(self, qname='example', rdtype='A'):
+    def make_message_pair(self, qname='example', rdtype='A', tsig_error=0):
         q = dns.message.make_query(qname, rdtype)
         q.use_tsig(keyring=keyring, keyname=keyname)
-        q.had_tsig = True  # so make_response() does the right thing
         q.to_wire()  # to set q.mac
-        r = dns.message.make_response(q)
+        r = dns.message.make_response(q, tsig_error=tsig_error)
         return(q, r)
 
     def test_peer_errors(self):
@@ -58,8 +57,7 @@ class TSIGTestCase(unittest.TestCase):
                  (99, dns.tsig.PeerError),
                  ]
         for err, ex in items:
-            q, r = self.make_message_pair()
-            r.tsig_error = err
+            q, r = self.make_message_pair(tsig_error=err)
             w = r.to_wire()
             def bad():
                 dns.message.from_wire(w, keyring=keyring, request_mac=q.mac)