BADTRUNC = 22
-def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
- """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
- for the input parameters, the HMAC MAC calculated by applying the
- TSIG signature algorithm, and the TSIG digest context.
- @rtype: (string, hmac.HMAC object)
+def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None,
+ multi=None):
+ """Return a context containing the TSIG rdata for the input parameters
+ @rtype: hmac.HMAC object
@raises ValueError: I{other_data} is too long
@raises NotImplementedError: I{algorithm} is not supported
"""
first = not (ctx and multi)
- (algorithm_name, digestmod) = get_algorithm(key.algorithm)
if first:
- ctx = hmac.new(key.secret, digestmod=digestmod)
+ ctx = get_context(key)
if request_mac:
ctx.update(struct.pack('!H', len(request_mac)))
ctx.update(request_mac)
if other_len > 65535:
raise ValueError('TSIG Other Data is > 65535 bytes')
if first:
- ctx.update(algorithm_name + time_encoded)
+ ctx.update(key.algorithm.to_digestable() + time_encoded)
ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other)
else:
ctx.update(time_encoded)
- mac = ctx.digest()
+ return ctx
+
+
+def _maybe_start_digest(key, mac, multi):
+ """If this is the first message in a multi-message sequence,
+ start a new context.
+ @rtype: hmac.HMAC object
+ """
if multi:
- ctx = hmac.new(key.secret, digestmod=digestmod)
+ ctx = get_context(key)
ctx.update(struct.pack('!H', len(mac)))
ctx.update(mac)
+ return ctx
else:
- ctx = None
+ return None
+
+
+def sign(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=False):
+ """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
+ for the input parameters, the HMAC MAC calculated by applying the
+ TSIG signature algorithm, and the TSIG digest context.
+ @rtype: (string, hmac.HMAC object)
+ @raises ValueError: I{other_data} is too long
+ @raises NotImplementedError: I{algorithm} is not supported
+ """
+
+ ctx = _digest(wire, key, rdata, time, request_mac, ctx, multi)
+ mac = ctx.digest()
tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
key.algorithm, time, rdata.fudge, mac,
rdata.original_id, rdata.error,
rdata.other)
- return (tsig, ctx)
+ return (tsig, _maybe_start_digest(key, mac, multi))
def validate(wire, key, owner, rdata, now, request_mac, tsig_start, ctx=None,
raise BadKey
if key.algorithm != rdata.algorithm:
raise BadAlgorithm
- (our_rdata, ctx) = sign(new_wire, key, rdata, None, request_mac, ctx, multi)
- if our_rdata.mac != rdata.mac:
+ ctx = _digest(new_wire, key, rdata, None, request_mac, ctx, multi)
+ mac = ctx.digest()
+ if not hmac.compare_digest(mac, rdata.mac):
raise BadSignature
- return ctx
+ return _maybe_start_digest(key, mac, multi)
-def get_algorithm(algorithm):
- """Returns the wire format string and the hash module to use for the
- specified TSIG algorithm
+def get_context(key):
+ """Returns an HMAC context foe the specified key.
- @rtype: (string, hash constructor)
+ @rtype: HMAC context
@raises NotImplementedError: I{algorithm} is not supported
"""
- if isinstance(algorithm, str):
- algorithm = dns.name.from_text(algorithm)
-
try:
- return (algorithm.to_digestable(), _hashes[algorithm])
+ digestmod = _hashes[key.algorithm]
except KeyError:
- raise NotImplementedError("TSIG algorithm " + str(algorithm) +
- " is not supported")
+ raise NotImplementedError(f"TSIG algorithm {key.algorithm} " +
+ "is not supported")
+ return hmac.new(key.secret, digestmod=digestmod)
+
class Key:
def __init__(self, name, secret, algorithm=default_algorithm):
class TSIGTestCase(unittest.TestCase):
- def test_get_algorithm(self):
- n = dns.name.from_text('hmac-sha256')
- (w, alg) = dns.tsig.get_algorithm(n)
- self.assertEqual(alg, hashlib.sha256)
- (w, alg) = dns.tsig.get_algorithm('hmac-sha256')
- self.assertEqual(alg, hashlib.sha256)
- self.assertRaises(NotImplementedError,
- lambda: dns.tsig.get_algorithm('bogus'))
+ def test_get_context(self):
+ key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha256')
+ ctx = dns.tsig.get_context(key)
+ self.assertEqual(ctx.name, 'hmac-sha256')
+ key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512')
+ ctx = dns.tsig.get_context(key)
+ self.assertEqual(ctx.name, 'hmac-sha512')
+ bogus = dns.tsig.Key('foo.com', 'abcd', 'bogus')
+ with self.assertRaises(NotImplementedError):
+ dns.tsig.get_context(bogus)
def test_sign_and_validate(self):
m = dns.message.make_query('example', 'a')