]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Support callable() TSIG keyrings for use-cases like GSSTSig. 530/head
authorNick Hall <nick.hall@deshaw.com>
Fri, 7 Aug 2020 22:00:36 +0000 (23:00 +0100)
committerNick Hall <nick.hall@deshaw.com>
Sat, 8 Aug 2020 00:21:24 +0000 (01:21 +0100)
dns/message.py
dns/tsig.py
tests/test_tsig.py

index 152fa506bf66fc19a4955881481bd7ddb1fbcc51..4b4eb4d33b64238b741174fb9a73f58d3d6a7937 100644 (file)
@@ -488,8 +488,8 @@ class Message:
         *key*, a ``dns.tsig.Key`` is the key to use.  If a key is specified,
         the *keyring* and *algorithm* fields are not used.
 
-        *keyring*, a ``dict`` or ``dns.tsig.Key``, is either the TSIG
-        keyring or key to use.
+        *keyring*, a ``dict``, ``callable`` or ``dns.tsig.Key``, is either
+        the TSIG keyring or key to use.
 
         The format of a keyring dict is a mapping from TSIG key name, as
         ``dns.name.Name`` to ``dns.tsig.Key`` or a TSIG secret, a ``bytes``.
@@ -497,7 +497,9 @@ class Message:
         used will be the first key in the *keyring*.  Note that the order of
         keys in a dictionary is not defined, so applications should supply a
         keyname when a ``dict`` keyring is used, unless they know the keyring
-        contains only one key.
+        contains only one key.  If a ``callable`` keyring is specified, the
+        callable will be called with the message and the keyname, and is
+        expected to return a key.
 
         *keyname*, a ``dns.name.Name``, ``str`` or ``None``, the name of
         thes TSIG key to use; defaults to ``None``.  If *keyring* is a
@@ -519,6 +521,8 @@ class Message:
 
         if isinstance(keyring, dns.tsig.Key):
             self.keyring = keyring
+        elif callable(keyring):
+            self.keyring = keyring(self, keyname)
         else:
             if isinstance(keyname, str):
                 keyname = dns.name.from_text(keyname)
@@ -920,6 +924,8 @@ class _WireReader:
                     key = self.keyring.get(absolute_name)
                     if isinstance(key, bytes):
                         key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+                elif callable(self.keyring):
+                    key = self.keyring(self.message, absolute_name)
                 else:
                     key = self.keyring
                 if key is None:
index 93614121ca1c057ef0f79f980d4e1d43f0166029..ab45951235b459191ecfd38de9b3537f382e8ba7 100644 (file)
@@ -102,6 +102,38 @@ class GSSTSig:
             raise BadSignature
 
 
+class GSSTSigAdapter:
+    def __init__(self, keyring):
+        self.keyring = keyring
+
+    def __call__(self, message, keyname):
+        if keyname in self.keyring:
+            key = self.keyring[keyname]
+            if isinstance(key, Key) and key.algorithm == GSS_TSIG:
+                if message:
+                    GSSTSigAdapter.parse_tkey_and_step(key, message, keyname)
+            return key
+        else:
+            return None
+
+    @classmethod
+    def parse_tkey_and_step(cls, key, message, keyname):
+        # if the message is a TKEY type, absorb the key material
+        # into the context using step(); this is used to allow the
+        # client to complete the GSSAPI negotiation before attempting
+        # to verify the signed response to a TKEY message exchange
+        try:
+            rrset = message.find_rrset(message.answer, keyname,
+                                       dns.rdataclass.ANY,
+                                       dns.rdatatype.TKEY)
+            if rrset:
+                token = rrset[0].key
+                gssapi_context = key.secret
+                return gssapi_context.step(token)
+        except KeyError:
+            pass
+
+
 class HMACTSig:
     """
     HMAC TSIG implementation.  This uses the HMAC python module to handle the
index c1cd5cbf681ac849087a58b2b8c34904a4334e3c..ec5a6ccc308ea3608244ab7a26772d67a6902e73 100644 (file)
@@ -3,6 +3,7 @@
 import unittest
 from unittest.mock import Mock
 import time
+import base64
 
 import dns.rcode
 import dns.tsig
@@ -46,12 +47,6 @@ class TSIGTestCase(unittest.TestCase):
         self.assertEqual(m.tsig_error, dns.rcode.BADKEY)
 
     def test_verify_mac_for_context(self):
-        dummy_ctx = None
-        dummy_expected = None
-        key = dns.tsig.Key('foo.com', 'abcd', 'bogus')
-        with self.assertRaises(NotImplementedError):
-            dummy_ctx = dns.tsig.get_context(key)
-
         key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512')
         ctx = dns.tsig.get_context(key)
         bad_expected = b'xxxxxxxxxx'
@@ -97,10 +92,11 @@ class TSIGTestCase(unittest.TestCase):
         gssapi_context_mock.verify_signature.side_effect = verify_signature
 
         # create the key and add it to the keyring
-        key = dns.tsig.Key('gsstsigtest', gssapi_context_mock, 'gss-tsig')
+        keyname = 'gsstsigtest'
+        key = dns.tsig.Key(keyname, gssapi_context_mock, 'gss-tsig')
         ctx = dns.tsig.get_context(key)
         self.assertEqual(ctx.name, 'gss-tsig')
-        gsskeyname = dns.name.from_text('gsstsigtest')
+        gsskeyname = dns.name.from_text(keyname)
         keyring[gsskeyname] = key
 
         # make sure we can get the keyring (no exception == success)
@@ -114,18 +110,65 @@ class TSIGTestCase(unittest.TestCase):
         gssapi_context_mock.verify_signature.assert_called()
         self.assertEqual(gssapi_context_mock.verify_signature.call_count, 1)
 
-        # create example message and go to/from wire to simulate sign/verify
-        m = dns.message.make_query('example', 'a')
-        m.use_tsig(keyring, gsskeyname)
-        w = m.to_wire()
-        # not raising is passing
-        dns.message.from_wire(w, keyring)
+        # simulate case where TKEY message is used to establish the context;
+        # first, the query from the client
+        tkey_message = dns.message.make_query(keyname, 'tkey', 'any')
+
+        # test existent/non-existent keys in the keyring
+        adapted_keyring = dns.tsig.GSSTSigAdapter(keyring)
+
+        fetched_key = adapted_keyring(tkey_message, gsskeyname)
+        self.assertEqual(fetched_key, key)
+        key = adapted_keyring(None, gsskeyname)
+        self.assertEqual(fetched_key, key)
+        key = adapted_keyring(tkey_message, "dummy")
+        self.assertEqual(key, None)
+
+        # create a response, TKEY and turn it into bytes, simulating the server
+        # sending the response to the query
+        tkey_response = dns.message.make_response(tkey_message)
+        key = base64.b64decode('KEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEYKEY')
+        tkey = dns.rdtypes.ANY.TKEY.TKEY(dns.rdataclass.ANY,
+                                         dns.rdatatype.TKEY,
+                                         dns.name.from_text('gss-tsig.'),
+                                         1594203795, 1594206664,
+                                         3, 0, key)
+
+        # add the TKEY answer and sign it
+        tkey_response.set_rcode(dns.rcode.NOERROR)
+        tkey_response.answer = [
+            dns.rrset.from_rdata(dns.name.from_text(keyname), 0, tkey)]
+        tkey_response.use_tsig(keyring=dns.tsig.GSSTSigAdapter(keyring),
+                               keyname=gsskeyname,
+                               algorithm=dns.tsig.GSS_TSIG)
+
+        # "send" it to the client
+        tkey_wire = tkey_response.to_wire()
+
+        # grab the response from the "server" and simulate the client side
+        dns.message.from_wire(tkey_wire, dns.tsig.GSSTSigAdapter(keyring))
 
         # assertions to make sure the "gssapi" functions were called
         gssapi_context_mock.get_signature.assert_called()
         self.assertEqual(gssapi_context_mock.get_signature.call_count, 1)
         gssapi_context_mock.verify_signature.assert_called()
         self.assertEqual(gssapi_context_mock.verify_signature.call_count, 2)
+        gssapi_context_mock.step.assert_called()
+        self.assertEqual(gssapi_context_mock.step.call_count, 1)
+
+        # create example message and go to/from wire to simulate sign/verify
+        # of regular messages
+        a_message = dns.message.make_query('example', 'a')
+        a_message.use_tsig(dns.tsig.GSSTSigAdapter(keyring), gsskeyname)
+        a_wire = a_message.to_wire()
+        # not raising is passing
+        dns.message.from_wire(a_wire, dns.tsig.GSSTSigAdapter(keyring))
+
+        # assertions to make sure the "gssapi" functions were called again
+        gssapi_context_mock.get_signature.assert_called()
+        self.assertEqual(gssapi_context_mock.get_signature.call_count, 2)
+        gssapi_context_mock.verify_signature.assert_called()
+        self.assertEqual(gssapi_context_mock.verify_signature.call_count, 3)
 
     def test_sign_and_validate(self):
         m = dns.message.make_query('example', 'a')
@@ -134,6 +177,18 @@ class TSIGTestCase(unittest.TestCase):
         # not raising is passing
         dns.message.from_wire(w, keyring)
 
+    def test_validate_with_bad_keyring(self):
+        m = dns.message.make_query('example', 'a')
+        m.use_tsig(keyring, keyname)
+        w = m.to_wire()
+
+        # keyring == None is an error
+        with self.assertRaises(dns.message.UnknownTSIGKey):
+            dns.message.from_wire(w, None)
+        # callable keyring that returns None is an error
+        with self.assertRaises(dns.message.UnknownTSIGKey):
+            dns.message.from_wire(w, lambda m, n: None)
+
     def test_sign_and_validate_with_other_data(self):
         m = dns.message.make_query('example', 'a')
         m.use_tsig(keyring, keyname, other_data=b'other')