]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add a lightweight wrapper around the HMAC types and refactor the "is gss-api or not...
authorNick Hall <nick.hall@deshaw.com>
Tue, 4 Aug 2020 16:47:48 +0000 (17:47 +0100)
committerNick Hall <nick.hall@deshaw.com>
Sat, 8 Aug 2020 00:21:24 +0000 (01:21 +0100)
dns/message.py
dns/message.pyi
dns/tsig.py
tests/test_tsig.py

index ceebdf9c1453d3c734f6f25274f02cefa80d9af4..152fa506bf66fc19a4955881481bd7ddb1fbcc51 100644 (file)
@@ -424,8 +424,8 @@ class Message:
         *multi*, a ``bool``, should be set to ``True`` if this message is
         part of a multiple message sequence.
 
-        *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used
-        when signing zone transfers.
+        *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
+        ongoing TSIG context, used when signing zone transfers.
 
         Raises ``dns.exception.TooBig`` if *max_size* was exceeded.
 
@@ -994,8 +994,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     of a zone transfer, *origin* should be the origin name of the
     zone.  If not ``None``, names will be relativized to the origin.
 
-    *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used
-    when validating zone transfers.
+    *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
+    ongoing TSIG context, used when validating zone transfers.
 
     *multi*, a ``bool``, should be set to ``True`` if this message is
     part of a multiple message sequence.
index ca908f6f7a7c308f8b9539a0b3d2d2f3f2dbfd76..252a4118540c877747f5e2610c84140f90899ee4 100644 (file)
@@ -33,7 +33,7 @@ def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
     ...
 
 def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
-              tsig_ctx : Optional[hmac.HMAC] = None, multi=False,
+              tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False,
               question_only=False, one_rr_per_rrset=False,
               ignore_trailing=False) -> Message:
     ...
index 51fffaedfd95fc69a1d6501f38a3799aeba00a64..93614121ca1c057ef0f79f980d4e1d43f0166029 100644 (file)
@@ -89,13 +89,48 @@ class GSSTSig:
     def update(self, data):
         self.data += data
 
-    def digest(self):
+    def sign(self):
         # defer to the GSSAPI function to sign
         return self.gssapi_context.get_signature(self.data)
 
-    def verify(self, mac):
-        # defer to the GSSAPI function to verify
-        return self.gssapi_context.verify_signature(self.data, mac)
+    def verify(self, expected):
+        try:
+            # defer to the GSSAPI function to verify
+            return self.gssapi_context.verify_signature(self.data, expected)
+        except Exception:
+            # note the usage of a bare exception
+            raise BadSignature
+
+
+class HMACTSig:
+    """
+    HMAC TSIG implementation.  This uses the HMAC python module to handle the
+    sign/verify operations.
+    """
+    def __init__(self, key, algorithm):
+        try:
+            digestmod = _hashes[algorithm]
+        except KeyError:
+            raise NotImplementedError(f"TSIG algorithm {algorithm} " +
+                                      "is not supported")
+
+        # create the HMAC context
+        self.hmac_context = hmac.new(key, digestmod=digestmod)
+        self.name = self.hmac_context.name
+
+    def update(self, data):
+        return self.hmac_context.update(data)
+
+    def sign(self):
+        # defer to the HMAC digest() function for that digestmod
+        return self.hmac_context.digest()
+
+    def verify(self, expected):
+        # re-digest and compare the results
+        mac = self.hmac_context.digest()
+        if not hmac.compare_digest(mac, expected):
+            raise BadSignature
+
 
 # TSIG Algorithms
 
@@ -112,7 +147,6 @@ _hashes = {
     HMAC_SHA256: hashlib.sha256,
     HMAC_SHA384: hashlib.sha384,
     HMAC_SHA512: hashlib.sha512,
-    GSS_TSIG: GSSTSig,
     HMAC_SHA1: hashlib.sha1,
     HMAC_MD5: hashlib.md5,
 }
@@ -123,7 +157,7 @@ default_algorithm = HMAC_SHA256
 def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None,
             multi=None):
     """Return a context containing the TSIG rdata for the input parameters
-    @rtype: hmac.HMAC object
+    @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object
     @raises ValueError: I{other_data} is too long
     @raises NotImplementedError: I{algorithm} is not supported
     """
@@ -159,7 +193,7 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None,
 def _maybe_start_digest(key, mac, multi):
     """If this is the first message in a multi-message sequence,
     start a new context.
-    @rtype: hmac.HMAC object
+    @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object
     """
     if multi:
         ctx = get_context(key)
@@ -170,47 +204,23 @@ def _maybe_start_digest(key, mac, multi):
         return None
 
 
-def _verify_mac_for_context(ctx, key, expected):
-    """Verifies a MAC for the specified context and key.
-
-    @raises BadSignature: I{expected} does not match expected TSIG
-    """
-
-    try:
-        digestmod = _hashes[key.algorithm]
-    except KeyError:
-        raise NotImplementedError(f"TSIG algorithm {key.algorithm} " +
-                                  "is not supported")
-
-    if digestmod == GSSTSig:
-        try:
-            ctx.verify(expected)
-        except Exception:
-            # note the usage of a bare exception
-            raise BadSignature
-    else:
-        mac = ctx.digest()
-        if not hmac.compare_digest(mac, expected):
-            raise BadSignature
-
-
 def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
     """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
     for the input parameters, the HMAC MAC calculated by applying the
     TSIG signature algorithm, and the TSIG digest context.
-    @rtype: (string, hmac.HMAC object)
+    @rtype: (string, dns.tsig.HMACTSig or dns.tsig.GSSTSig object)
     @raises ValueError: I{other_data} is too long
     @raises NotImplementedError: I{algorithm} is not supported
     """
 
     ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi)
-    mac = ctx.digest()
+    mac = ctx.sign()
     tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
                                      key.algorithm, time, rdata.fudge, mac,
                                      rdata.original_id, rdata.error,
                                      rdata.other)
 
-    return (tsig, _maybe_start_digest(key, mac, multi))
+    return tsig, _maybe_start_digest(key, mac, multi)
 
 
 def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
@@ -221,7 +231,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
     @raises BadTime: There is too much time skew between the client and the
     server.
     @raises BadSignature: The TSIG signature did not validate
-    @rtype: hmac.HMAC object"""
+    @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object"""
 
     (adcount,) = struct.unpack("!H", wire[10:12])
     if adcount == 0:
@@ -246,7 +256,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
     if key.algorithm != rdata.algorithm:
         raise BadAlgorithm
     ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi)
-    _verify_mac_for_context(ctx, key, rdata.mac)
+    ctx.verify(rdata.mac)
     return _maybe_start_digest(key, rdata.mac, multi)
 
 
@@ -257,16 +267,10 @@ def get_context(key):
     @raises NotImplementedError: I{algorithm} is not supported
     """
 
-    try:
-        digestmod = _hashes[key.algorithm]
-    except KeyError:
-        raise NotImplementedError(f"TSIG algorithm {key.algorithm} " +
-                                  "is not supported")
-
-    if digestmod == GSSTSig:
+    if key.algorithm == GSS_TSIG:
         return GSSTSig(key.secret)
     else:
-        return hmac.new(key.secret, digestmod=digestmod)
+        return HMACTSig(key.secret, key.algorithm)
 
 
 class Key:
index 179c4f4b929e263e722098abba733e4b746b9601..c1cd5cbf681ac849087a58b2b8c34904a4334e3c 100644 (file)
@@ -8,8 +8,7 @@ import dns.rcode
 import dns.tsig
 import dns.tsigkeyring
 import dns.message
-from dns.rdatatype import RdataType
-from dns.rdataclass import RdataClass
+import dns.rdtypes.ANY.TKEY
 
 keyring = dns.tsigkeyring.from_text(
     {
@@ -51,13 +50,13 @@ class TSIGTestCase(unittest.TestCase):
         dummy_expected = None
         key = dns.tsig.Key('foo.com', 'abcd', 'bogus')
         with self.assertRaises(NotImplementedError):
-            dns.tsig._verify_mac_for_context(dummy_ctx, key, dummy_expected)
+            dummy_ctx = dns.tsig.get_context(key)
 
         key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512')
         ctx = dns.tsig.get_context(key)
         bad_expected = b'xxxxxxxxxx'
         with self.assertRaises(dns.tsig.BadSignature):
-            dns.tsig._verify_mac_for_context(ctx, key, bad_expected)
+            ctx.verify(bad_expected)
 
     def test_validate(self):
         # make message and grab the TSIG
@@ -111,7 +110,7 @@ class TSIGTestCase(unittest.TestCase):
         # test exceptional case for _verify_mac_for_context
         with self.assertRaises(dns.tsig.BadSignature):
             ctx.update(b'throw')
-            dns.tsig._verify_mac_for_context(ctx, key, 'bogus')
+            ctx.verify(b'bogus')
         gssapi_context_mock.verify_signature.assert_called()
         self.assertEqual(gssapi_context_mock.verify_signature.call_count, 1)