From: Nick Hall Date: Tue, 4 Aug 2020 16:47:48 +0000 (+0100) Subject: Add a lightweight wrapper around the HMAC types and refactor the "is gss-api or not... X-Git-Tag: v2.1.0rc1~101^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=40bf9335e823a6760614b4a835ebd06af4279c66;p=thirdparty%2Fdnspython.git Add a lightweight wrapper around the HMAC types and refactor the "is gss-api or not" wrapper functions to just call the class methods --- diff --git a/dns/message.py b/dns/message.py index ceebdf9c..152fa506 100644 --- a/dns/message.py +++ b/dns/message.py @@ -424,8 +424,8 @@ class Message: *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple message sequence. - *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used - when signing zone transfers. + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the + ongoing TSIG context, used when signing zone transfers. Raises ``dns.exception.TooBig`` if *max_size* was exceeded. @@ -994,8 +994,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, of a zone transfer, *origin* should be the origin name of the zone. If not ``None``, names will be relativized to the origin. - *tsig_ctx*, a ``hmac.HMAC`` object, the ongoing TSIG context, used - when validating zone transfers. + *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the + ongoing TSIG context, used when validating zone transfers. *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple message sequence. diff --git a/dns/message.pyi b/dns/message.pyi index ca908f6f..252a4118 100644 --- a/dns/message.pyi +++ b/dns/message.pyi @@ -33,7 +33,7 @@ def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message: ... def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None, - tsig_ctx : Optional[hmac.HMAC] = None, multi=False, + tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False, question_only=False, one_rr_per_rrset=False, ignore_trailing=False) -> Message: ... diff --git a/dns/tsig.py b/dns/tsig.py index 51fffaed..93614121 100644 --- a/dns/tsig.py +++ b/dns/tsig.py @@ -89,13 +89,48 @@ class GSSTSig: def update(self, data): self.data += data - def digest(self): + def sign(self): # defer to the GSSAPI function to sign return self.gssapi_context.get_signature(self.data) - def verify(self, mac): - # defer to the GSSAPI function to verify - return self.gssapi_context.verify_signature(self.data, mac) + def verify(self, expected): + try: + # defer to the GSSAPI function to verify + return self.gssapi_context.verify_signature(self.data, expected) + except Exception: + # note the usage of a bare exception + raise BadSignature + + +class HMACTSig: + """ + HMAC TSIG implementation. This uses the HMAC python module to handle the + sign/verify operations. + """ + def __init__(self, key, algorithm): + try: + digestmod = _hashes[algorithm] + except KeyError: + raise NotImplementedError(f"TSIG algorithm {algorithm} " + + "is not supported") + + # create the HMAC context + self.hmac_context = hmac.new(key, digestmod=digestmod) + self.name = self.hmac_context.name + + def update(self, data): + return self.hmac_context.update(data) + + def sign(self): + # defer to the HMAC digest() function for that digestmod + return self.hmac_context.digest() + + def verify(self, expected): + # re-digest and compare the results + mac = self.hmac_context.digest() + if not hmac.compare_digest(mac, expected): + raise BadSignature + # TSIG Algorithms @@ -112,7 +147,6 @@ _hashes = { HMAC_SHA256: hashlib.sha256, HMAC_SHA384: hashlib.sha384, HMAC_SHA512: hashlib.sha512, - GSS_TSIG: GSSTSig, HMAC_SHA1: hashlib.sha1, HMAC_MD5: hashlib.md5, } @@ -123,7 +157,7 @@ default_algorithm = HMAC_SHA256 def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None): """Return a context containing the TSIG rdata for the input parameters - @rtype: hmac.HMAC object + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object @raises ValueError: I{other_data} is too long @raises NotImplementedError: I{algorithm} is not supported """ @@ -159,7 +193,7 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, def _maybe_start_digest(key, mac, multi): """If this is the first message in a multi-message sequence, start a new context. - @rtype: hmac.HMAC object + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object """ if multi: ctx = get_context(key) @@ -170,47 +204,23 @@ def _maybe_start_digest(key, mac, multi): return None -def _verify_mac_for_context(ctx, key, expected): - """Verifies a MAC for the specified context and key. - - @raises BadSignature: I{expected} does not match expected TSIG - """ - - try: - digestmod = _hashes[key.algorithm] - except KeyError: - raise NotImplementedError(f"TSIG algorithm {key.algorithm} " + - "is not supported") - - if digestmod == GSSTSig: - try: - ctx.verify(expected) - except Exception: - # note the usage of a bare exception - raise BadSignature - else: - mac = ctx.digest() - if not hmac.compare_digest(mac, expected): - raise BadSignature - - def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False): """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata for the input parameters, the HMAC MAC calculated by applying the TSIG signature algorithm, and the TSIG digest context. - @rtype: (string, hmac.HMAC object) + @rtype: (string, dns.tsig.HMACTSig or dns.tsig.GSSTSig object) @raises ValueError: I{other_data} is too long @raises NotImplementedError: I{algorithm} is not supported """ ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi) - mac = ctx.digest() + mac = ctx.sign() tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, key.algorithm, time, rdata.fudge, mac, rdata.original_id, rdata.error, rdata.other) - return (tsig, _maybe_start_digest(key, mac, multi)) + return tsig, _maybe_start_digest(key, mac, multi) def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, @@ -221,7 +231,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, @raises BadTime: There is too much time skew between the client and the server. @raises BadSignature: The TSIG signature did not validate - @rtype: hmac.HMAC object""" + @rtype: dns.tsig.HMACTSig or dns.tsig.GSSTSig object""" (adcount,) = struct.unpack("!H", wire[10:12]) if adcount == 0: @@ -246,7 +256,7 @@ def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None, if key.algorithm != rdata.algorithm: raise BadAlgorithm ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi) - _verify_mac_for_context(ctx, key, rdata.mac) + ctx.verify(rdata.mac) return _maybe_start_digest(key, rdata.mac, multi) @@ -257,16 +267,10 @@ def get_context(key): @raises NotImplementedError: I{algorithm} is not supported """ - try: - digestmod = _hashes[key.algorithm] - except KeyError: - raise NotImplementedError(f"TSIG algorithm {key.algorithm} " + - "is not supported") - - if digestmod == GSSTSig: + if key.algorithm == GSS_TSIG: return GSSTSig(key.secret) else: - return hmac.new(key.secret, digestmod=digestmod) + return HMACTSig(key.secret, key.algorithm) class Key: diff --git a/tests/test_tsig.py b/tests/test_tsig.py index 179c4f4b..c1cd5cbf 100644 --- a/tests/test_tsig.py +++ b/tests/test_tsig.py @@ -8,8 +8,7 @@ import dns.rcode import dns.tsig import dns.tsigkeyring import dns.message -from dns.rdatatype import RdataType -from dns.rdataclass import RdataClass +import dns.rdtypes.ANY.TKEY keyring = dns.tsigkeyring.from_text( { @@ -51,13 +50,13 @@ class TSIGTestCase(unittest.TestCase): dummy_expected = None key = dns.tsig.Key('foo.com', 'abcd', 'bogus') with self.assertRaises(NotImplementedError): - dns.tsig._verify_mac_for_context(dummy_ctx, key, dummy_expected) + dummy_ctx = dns.tsig.get_context(key) key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') ctx = dns.tsig.get_context(key) bad_expected = b'xxxxxxxxxx' with self.assertRaises(dns.tsig.BadSignature): - dns.tsig._verify_mac_for_context(ctx, key, bad_expected) + ctx.verify(bad_expected) def test_validate(self): # make message and grab the TSIG @@ -111,7 +110,7 @@ class TSIGTestCase(unittest.TestCase): # test exceptional case for _verify_mac_for_context with self.assertRaises(dns.tsig.BadSignature): ctx.update(b'throw') - dns.tsig._verify_mac_for_context(ctx, key, 'bogus') + ctx.verify(b'bogus') gssapi_context_mock.verify_signature.assert_called() self.assertEqual(gssapi_context_mock.verify_signature.call_count, 1)