]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Split TSIG sign and validate. 531/head
authorBrian Wellington <bwelling@xbill.org>
Tue, 7 Jul 2020 16:39:23 +0000 (09:39 -0700)
committerBrian Wellington <bwelling@xbill.org>
Tue, 7 Jul 2020 16:39:23 +0000 (09:39 -0700)
dns/tsig.py
tests/test_tsig.py

index 08ab41e45689b1e9618b72887cc78ab4794a9c90..3273813795ec8e6267310715a486b43c63bfe7ac 100644 (file)
@@ -96,19 +96,17 @@ BADTIME = 18
 BADTRUNC = 22
 
 
-def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
-    """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
-    for the input parameters, the HMAC MAC calculated by applying the
-    TSIG signature algorithm, and the TSIG digest context.
-    @rtype: (string, hmac.HMAC object)
+def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None,
+            multi=None):
+    """Return a context containing the TSIG rdata for the input parameters
+    @rtype: hmac.HMAC object
     @raises ValueError: I{other_data} is too long
     @raises NotImplementedError: I{algorithm} is not supported
     """
 
     first = not (ctx and multi)
-    (algorithm_name, digestmod) = get_algorithm(key.algorithm)
     if first:
-        ctx = hmac.new(key.secret, digestmod=digestmod)
+        ctx = get_context(key)
         if request_mac:
             ctx.update(struct.pack('!H', len(request_mac)))
             ctx.update(request_mac)
@@ -127,23 +125,44 @@ def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
     if other_len > 65535:
         raise ValueError('TSIG Other Data is > 65535 bytes')
     if first:
-        ctx.update(algorithm_name + time_encoded)
+        ctx.update(key.algorithm.to_digestable() + time_encoded)
         ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other)
     else:
         ctx.update(time_encoded)
-    mac = ctx.digest()
+    return ctx
+
+
+def _maybe_start_digest(key, mac, multi):
+    """If this is the first message in a multi-message sequence,
+    start a new context.
+    @rtype: hmac.HMAC object
+    """
     if multi:
-        ctx = hmac.new(key.secret, digestmod=digestmod)
+        ctx = get_context(key)
         ctx.update(struct.pack('!H', len(mac)))
         ctx.update(mac)
+        return ctx
     else:
-        ctx = None
+        return None
+
+
+def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
+    """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
+    for the input parameters, the HMAC MAC calculated by applying the
+    TSIG signature algorithm, and the TSIG digest context.
+    @rtype: (string, hmac.HMAC object)
+    @raises ValueError: I{other_data} is too long
+    @raises NotImplementedError: I{algorithm} is not supported
+    """
+
+    ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi)
+    mac = ctx.digest()
     tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
                                      key.algorithm, time, rdata.fudge, mac,
                                      rdata.original_id, rdata.error,
                                      rdata.other)
 
-    return (tsig, ctx)
+    return (tsig, _maybe_start_digest(key, mac, multi))
 
 
 def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
@@ -178,28 +197,27 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
         raise BadKey
     if key.algorithm != rdata.algorithm:
         raise BadAlgorithm
-    (our_rdata, ctx) = sign(new_wire, key, rdata, None, request_mac, ctx, multi)
-    if our_rdata.mac != rdata.mac:
+    ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi)
+    mac = ctx.digest()
+    if not hmac.compare_digest(mac, rdata.mac):
         raise BadSignature
-    return ctx
+    return _maybe_start_digest(key, mac, multi)
 
 
-def get_algorithm(algorithm):
-    """Returns the wire format string and the hash module to use for the
-    specified TSIG algorithm
+def get_context(key):
+    """Returns an HMAC context foe the specified key.
 
-    @rtype: (string, hash constructor)
+    @rtype: HMAC context
     @raises NotImplementedError: I{algorithm} is not supported
     """
 
-    if isinstance(algorithm, str):
-        algorithm = dns.name.from_text(algorithm)
-
     try:
-        return (algorithm.to_digestable(), _hashes[algorithm])
+        digestmod = _hashes[key.algorithm]
     except KeyError:
-        raise NotImplementedError("TSIG algorithm " + str(algorithm) +
-                                  " is not supported")
+        raise NotImplementedError(f"TSIG algorithm {key.algorithm} " +
+                                  "is not supported")
+    return hmac.new(key.secret, digestmod=digestmod)
+
 
 class Key:
     def __init__(self, name, secret, algorithm=default_algorithm):
index 2722e154536999a75e14c91254fffbd291de5f2a..59b83be09e6d3c989ba85a4215937775a0ac1040 100644 (file)
@@ -18,14 +18,16 @@ keyname = dns.name.from_text('keyname')
 
 class TSIGTestCase(unittest.TestCase):
 
-    def test_get_algorithm(self):
-        n = dns.name.from_text('hmac-sha256')
-        (w, alg) = dns.tsig.get_algorithm(n)
-        self.assertEqual(alg, hashlib.sha256)
-        (w, alg) = dns.tsig.get_algorithm('hmac-sha256')
-        self.assertEqual(alg, hashlib.sha256)
-        self.assertRaises(NotImplementedError,
-                          lambda: dns.tsig.get_algorithm('bogus'))
+    def test_get_context(self):
+        key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha256')
+        ctx = dns.tsig.get_context(key)
+        self.assertEqual(ctx.name, 'hmac-sha256')
+        key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512')
+        ctx = dns.tsig.get_context(key)
+        self.assertEqual(ctx.name, 'hmac-sha512')
+        bogus = dns.tsig.Key('foo.com', 'abcd', 'bogus')
+        with self.assertRaises(NotImplementedError):
+            dns.tsig.get_context(bogus)
 
     def test_sign_and_validate(self):
         m = dns.message.make_query('example', 'a')