From: Nick Hall Date: Fri, 7 Aug 2020 22:00:36 +0000 (+0100) Subject: Support callable() TSIG keyrings for use-cases like GSSTSig. X-Git-Tag: v2.1.0rc1~101^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F530%2Fhead;p=thirdparty%2Fdnspython.git Support callable() TSIG keyrings for use-cases like GSSTSig. --- diff --git a/dns/message.py b/dns/message.py index 152fa506..4b4eb4d3 100644 --- a/dns/message.py +++ b/dns/message.py @@ -488,8 +488,8 @@ class Message: *key*, a ``dns.tsig.Key`` is the key to use. If a key is specified, the *keyring* and *algorithm* fields are not used. - *keyring*, a ``dict`` or ``dns.tsig.Key``, is either the TSIG - keyring or key to use. + *keyring*, a ``dict``, ``callable`` or ``dns.tsig.Key``, is either + the TSIG keyring or key to use. The format of a keyring dict is a mapping from TSIG key name, as ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``. @@ -497,7 +497,9 @@ class Message: used will be the first key in the *keyring*. Note that the order of keys in a dictionary is not defined, so applications should supply a keyname when a ``dict`` keyring is used, unless they know the keyring - contains only one key. + contains only one key. If a ``callable`` keyring is specified, the + callable will be called with the message and the keyname, and is + expected to return a key. *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of thes TSIG key to use; defaults to ``None``. If *keyring* is a @@ -519,6 +521,8 @@ class Message: if isinstance(keyring, dns.tsig.Key): self.keyring = keyring + elif callable(keyring): + self.keyring = keyring(self, keyname) else: if isinstance(keyname, str): keyname = dns.name.from_text(keyname) @@ -920,6 +924,8 @@ class _WireReader: key = self.keyring.get(absolute_name) if isinstance(key, bytes): key = dns.tsig.Key(absolute_name, key, rd.algorithm) + elif callable(self.keyring): + key = self.keyring(self.message, absolute_name) else: key = self.keyring if key is None: diff --git a/dns/tsig.py b/dns/tsig.py index 93614121..ab459512 100644 --- a/dns/tsig.py +++ b/dns/tsig.py @@ -102,6 +102,38 @@ class GSSTSig: raise BadSignature +class GSSTSigAdapter: + def __init__(self, keyring): + self.keyring = keyring + + def __call__(self, message, keyname): + if keyname in self.keyring: + key = self.keyring[keyname] + if isinstance(key, Key) and key.algorithm == GSS_TSIG: + if message: + GSSTSigAdapter.parse_tkey_and_step(key, message, keyname) + return key + else: + return None + + @classmethod + def parse_tkey_and_step(cls, key, message, keyname): + # if the message is a TKEY type, absorb the key material + # into the context using step(); this is used to allow the + # client to complete the GSSAPI negotiation before attempting + # to verify the signed response to a TKEY message exchange + try: + rrset = message.find_rrset(message.answer, keyname, + dns.rdataclass.ANY, + dns.rdatatype.TKEY) + if rrset: + token = rrset[0].key + gssapi_context = key.secret + return gssapi_context.step(token) + except KeyError: + pass + + class HMACTSig: """ HMAC TSIG implementation. This uses the HMAC python module to handle the diff --git a/tests/test_tsig.py b/tests/test_tsig.py index c1cd5cbf..ec5a6ccc 100644 --- a/tests/test_tsig.py +++ b/tests/test_tsig.py @@ -3,6 +3,7 @@ import unittest from unittest.mock import Mock import time +import base64 import dns.rcode import dns.tsig @@ -46,12 +47,6 @@ class TSIGTestCase(unittest.TestCase): self.assertEqual(m.tsig_error, dns.rcode.BADKEY) def test_verify_mac_for_context(self): - dummy_ctx = None - dummy_expected = None - key = dns.tsig.Key('foo.com', 'abcd', 'bogus') - with self.assertRaises(NotImplementedError): - dummy_ctx = dns.tsig.get_context(key) - key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') ctx = dns.tsig.get_context(key) bad_expected = b'xxxxxxxxxx' @@ -97,10 +92,11 @@ class TSIGTestCase(unittest.TestCase): gssapi_context_mock.verify_signature.side_effect = verify_signature # create the key and add it to the keyring - key = dns.tsig.Key('gsstsigtest', gssapi_context_mock, 'gss-tsig') + keyname = 'gsstsigtest' + key = dns.tsig.Key(keyname, gssapi_context_mock, 'gss-tsig') ctx = dns.tsig.get_context(key) self.assertEqual(ctx.name, 'gss-tsig') - gsskeyname = dns.name.from_text('gsstsigtest') + gsskeyname = dns.name.from_text(keyname) keyring[gsskeyname] = key # make sure we can get the keyring (no exception == success) @@ -114,18 +110,65 @@ class TSIGTestCase(unittest.TestCase): gssapi_context_mock.verify_signature.assert_called() self.assertEqual(gssapi_context_mock.verify_signature.call_count, 1) - # create example message and go to/from wire to simulate sign/verify - m = dns.message.make_query('example', 'a') - m.use_tsig(keyring, gsskeyname) - w = m.to_wire() - # not raising is passing - dns.message.from_wire(w, keyring) + # simulate case where TKEY message is used to establish the context; + # first, the query from the client + tkey_message = dns.message.make_query(keyname, 'tkey', 'any') + + # test existent/non-existent keys in the keyring + adapted_keyring = dns.tsig.GSSTSigAdapter(keyring) + + fetched_key = adapted_keyring(tkey_message, gsskeyname) + self.assertEqual(fetched_key, key) + key = adapted_keyring(None, gsskeyname) + self.assertEqual(fetched_key, key) + key = adapted_keyring(tkey_message, "dummy") + self.assertEqual(key, None) + + # create a response, TKEY and turn it into bytes, simulating the server + # sending the response to the query + tkey_response = dns.message.make_response(tkey_message) + key = base64.b64decode('KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY') + tkey = dns.rdtypes.ANY.TKEY.TKEY(dns.rdataclass.ANY, + dns.rdatatype.TKEY, + dns.name.from_text('gss-tsig.'), + 1594203795, 1594206664, + 3, 0, key) + + # add the TKEY answer and sign it + tkey_response.set_rcode(dns.rcode.NOERROR) + tkey_response.answer = [ + dns.rrset.from_rdata(dns.name.from_text(keyname), 0, tkey)] + tkey_response.use_tsig(keyring=dns.tsig.GSSTSigAdapter(keyring), + keyname=gsskeyname, + algorithm=dns.tsig.GSS_TSIG) + + # "send" it to the client + tkey_wire = tkey_response.to_wire() + + # grab the response from the "server" and simulate the client side + dns.message.from_wire(tkey_wire, dns.tsig.GSSTSigAdapter(keyring)) # assertions to make sure the "gssapi" functions were called gssapi_context_mock.get_signature.assert_called() self.assertEqual(gssapi_context_mock.get_signature.call_count, 1) gssapi_context_mock.verify_signature.assert_called() self.assertEqual(gssapi_context_mock.verify_signature.call_count, 2) + gssapi_context_mock.step.assert_called() + self.assertEqual(gssapi_context_mock.step.call_count, 1) + + # create example message and go to/from wire to simulate sign/verify + # of regular messages + a_message = dns.message.make_query('example', 'a') + a_message.use_tsig(dns.tsig.GSSTSigAdapter(keyring), gsskeyname) + a_wire = a_message.to_wire() + # not raising is passing + dns.message.from_wire(a_wire, dns.tsig.GSSTSigAdapter(keyring)) + + # assertions to make sure the "gssapi" functions were called again + gssapi_context_mock.get_signature.assert_called() + self.assertEqual(gssapi_context_mock.get_signature.call_count, 2) + gssapi_context_mock.verify_signature.assert_called() + self.assertEqual(gssapi_context_mock.verify_signature.call_count, 3) def test_sign_and_validate(self): m = dns.message.make_query('example', 'a') @@ -134,6 +177,18 @@ class TSIGTestCase(unittest.TestCase): # not raising is passing dns.message.from_wire(w, keyring) + def test_validate_with_bad_keyring(self): + m = dns.message.make_query('example', 'a') + m.use_tsig(keyring, keyname) + w = m.to_wire() + + # keyring == None is an error + with self.assertRaises(dns.message.UnknownTSIGKey): + dns.message.from_wire(w, None) + # callable keyring that returns None is an error + with self.assertRaises(dns.message.UnknownTSIGKey): + dns.message.from_wire(w, lambda m, n: None) + def test_sign_and_validate_with_other_data(self): m = dns.message.make_query('example', 'a') m.use_tsig(keyring, keyname, other_data=b'other')