]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Cookie secret key versioning support
authorMartin Hoefling <martin.hoefling@tngtech.com>
Sun, 29 Mar 2015 15:00:59 +0000 (17:00 +0200)
committerMartin Hoefling <martin.hoefling@tngtech.com>
Sun, 29 Mar 2015 15:07:37 +0000 (17:07 +0200)
tornado/test/web_test.py
tornado/web.py

index 9c49ca7c026fc90c70f1f598dba6482725e9bdff..56701a992c0b65affb9ff93be88b339767cb3374 100644 (file)
@@ -11,7 +11,7 @@ from tornado.template import DictLoader
 from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
 from tornado.test.util import unittest
 from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
-from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
+from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version
 
 import binascii
 import contextlib
@@ -71,10 +71,14 @@ class HelloHandler(RequestHandler):
 
 class CookieTestRequestHandler(RequestHandler):
     # stub out enough methods to make the secure_cookie functions work
-    def __init__(self):
+    def __init__(self, cookie_secret='0123456789', key_version=None):
         # don't call super.__init__
         self._cookies = {}
-        self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
+        if key_version is None:
+            self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
+        else:
+            self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
+                                                        key_version=key_version))
 
     def get_cookie(self, name):
         return self._cookies.get(name)
@@ -128,6 +132,44 @@ class SecureCookieV1Test(unittest.TestCase):
         self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
 
 
+# See SignedValueTest below for more.
+class SecureCookieV2Test(unittest.TestCase):
+    KEY_VERSIONS = {
+        0: 'ajklasdf0ojaisdf',
+        1: 'aslkjasaolwkjsdf'
+    }
+    def test_round_trip(self):
+        handler = CookieTestRequestHandler()
+        handler.set_secure_cookie('foo', b'bar', version=2)
+        self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
+
+    def test_key_version_roundtrip(self):
+        handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+                                           key_version=0)
+        handler.set_secure_cookie('foo', b'bar')
+        self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
+
+    def test_key_version_increment_version(self):
+        handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+                                           key_version=0)
+        handler.set_secure_cookie('foo', b'bar')
+        new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+                                               key_version=1)
+        new_handler._cookies = handler._cookies
+        self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
+
+    def test_key_version_invalidate_version(self):
+        handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+                                           key_version=1)
+        handler.set_secure_cookie('foo', b'bar')
+        new_key_versions = self.KEY_VERSIONS.copy()
+        new_key_versions.pop(1)
+        new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
+                                               key_version=1)
+        new_handler._cookies = handler._cookies
+        self.assertEqual(new_handler.get_secure_cookie('foo'), None)
+
+
 class CookieTest(WebTestCase):
     def get_handlers(self):
         class SetCookieHandler(RequestHandler):
@@ -2139,6 +2181,7 @@ class ClientCloseTest(SimpleHandlerTestCase):
 
 class SignedValueTest(unittest.TestCase):
     SECRET = "It's a secret to everybody"
+    SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}
 
     def past(self):
         return self.present() - 86400 * 32
@@ -2245,6 +2288,41 @@ class SignedValueTest(unittest.TestCase):
                                       clock=self.present)
         self.assertEqual(value, decoded)
 
+    def test_key_versioning_read_write_default_key(self):
+        value = b"\xe9"
+        signed = create_signed_value(SignedValueTest.SECRET_DICT,
+                                     "key", value, clock=self.present)
+        decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
+                                      "key", signed, clock=self.present)
+        self.assertEqual(value, decoded)
+
+    def test_key_versioning_read_write_non_default_key(self):
+        value = b"\xe9"
+        signed = create_signed_value(SignedValueTest.SECRET_DICT,
+                                     "key", value, clock=self.present,
+                                     key_version=1)
+        decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
+                                      "key", signed, clock=self.present)
+        self.assertEqual(value, decoded)
+
+    def test_key_versioning_invalid_key(self):
+        value = b"\xe9"
+        signed = create_signed_value(SignedValueTest.SECRET_DICT,
+                                     "key", value, clock=self.present)
+        newkeys = SignedValueTest.SECRET_DICT.copy()
+        newkeys.pop(0)
+        decoded = decode_signed_value(newkeys,
+                                      "key", signed, clock=self.present)
+        self.assertEqual(None, decoded)
+
+    def test_key_version_retreival(self):
+        value = b"\xe9"
+        signed = create_signed_value(SignedValueTest.SECRET_DICT,
+                                     "key", value, clock=self.present,
+                                     key_version=1)
+        key_version = get_signature_key_version(signed)
+        self.assertEqual(1, key_version)
+
 
 @wsgi_safe
 class XSRFTest(SimpleHandlerTestCase):
index 6b76e6742f5c8ae1640a1898a3b697d863aa3138..457fc6faf9b4fac9a14227183671cf7d1cf7f849 100644 (file)
@@ -139,11 +139,19 @@ May be overridden by passing a ``version`` keyword argument.
 DEFAULT_SIGNED_VALUE_MIN_VERSION = 1
 """The oldest signed value accepted by `.RequestHandler.get_secure_cookie`.
 
-May be overrided by passing a ``min_version`` keyword argument.
+May be overridden by passing a ``min_version`` keyword argument.
 
 .. versionadded:: 3.2.1
 """
 
+DEFAULT_SIGN_KEY_VERSION = 0
+"""The current key index used by `.RequestHandler.set_secure_cookie`.
+
+May be overridden by passing a ``key_version`` keyword argument.
+
+.. versionadded:: x.x.x
+"""
+
 
 class RequestHandler(object):
     """Subclass this class and define `get()` or `post()` to make a handler.
@@ -613,8 +621,15 @@ class RequestHandler(object):
            and made it the default.
         """
         self.require_setting("cookie_secret", "secure cookies")
-        return create_signed_value(self.application.settings["cookie_secret"],
-                                   name, value, version=version)
+        secret = self.application.settings["cookie_secret"]
+        key_version = None
+        if isinstance(secret, dict):
+            if self.application.settings.get("key_version") is None:
+                raise Exception("key_version setting must be used for secret_key dicts")
+            key_version = self.application.settings["key_version"]
+
+        return create_signed_value(secret, name, value, version=version,
+                                   key_version=key_version)
 
     def get_secure_cookie(self, name, value=None, max_age_days=31,
                           min_version=None):
@@ -635,6 +650,17 @@ class RequestHandler(object):
                                    name, value, max_age_days=max_age_days,
                                    min_version=min_version)
 
+    def get_secure_cookie_key_version(self, name, value=None):
+        """Returns the signing key version of the secure cookie.
+
+        The version is returned as int.
+        """
+        self.require_setting("cookie_secret", "secure cookies")
+        if value is None:
+            value = self.get_cookie(name)
+        return get_signature_key_version(value)
+
+
     def redirect(self, url, permanent=False, status=None):
         """Sends a redirect to the given (optionally relative) URL.
 
@@ -2961,11 +2987,18 @@ else:
         return result == 0
 
 
-def create_signed_value(secret, name, value, version=None, clock=None):
+def create_signed_value(secret, name, value, version=None, clock=None,
+                        key_version=None):
     if version is None:
         version = DEFAULT_SIGNED_VALUE_VERSION
     if clock is None:
         clock = time.time
+
+    if key_version is None:
+        key_version = DEFAULT_SIGN_KEY_VERSION
+    else:
+        assert version >= 2, 'Version must be at least 2 for key version support'
+
     timestamp = utf8(str(int(clock())))
     value = base64.b64encode(utf8(value))
     if version == 1:
@@ -2982,8 +3015,7 @@ def create_signed_value(secret, name, value, version=None, clock=None):
         #
         # The fields are:
         # - format version (i.e. 2; no length prefix)
-        # - key version (currently 0; reserved for future
-        #       key rotation features)
+        # - key version (integer, default is 0)
         # - timestamp (integer seconds since epoch)
         # - name (not encoded; assumed to be ~alphanumeric)
         # - value (base64-encoded)
@@ -2991,11 +3023,16 @@ def create_signed_value(secret, name, value, version=None, clock=None):
         def format_field(s):
             return utf8("%d:" % len(s)) + utf8(s)
         to_sign = b"|".join([
-            b"2|1:0",
+            b"2",
+            format_field(str(key_version)),
             format_field(timestamp),
             format_field(name),
             format_field(value),
             b''])
+
+        if isinstance(secret, dict):
+            secret = secret[key_version]
+
         signature = _create_signature_v2(secret, to_sign)
         return to_sign + signature
     else:
@@ -3006,21 +3043,10 @@ def create_signed_value(secret, name, value, version=None, clock=None):
 _signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")
 
 
-def decode_signed_value(secret, name, value, max_age_days=31,
-                        clock=None, min_version=None):
-    if clock is None:
-        clock = time.time
-    if min_version is None:
-        min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
-    if min_version > 2:
-        raise ValueError("Unsupported min_version %d" % min_version)
-    if not value:
-        return None
-
-    # Figure out what version this is.  Version 1 did not include an
+def _get_version(value):
+    # Figures out what version value is.  Version 1 did not include an
     # explicit version field and started with arbitrary base64 data,
     # which makes this tricky.
-    value = utf8(value)
     m = _signed_value_version_re.match(value)
     if m is None:
         version = 1
@@ -3037,6 +3063,22 @@ def decode_signed_value(secret, name, value, max_age_days=31,
                 version = 1
         except ValueError:
             version = 1
+    return version
+
+
+def decode_signed_value(secret, name, value, max_age_days=31,
+                        clock=None, min_version=None):
+    if clock is None:
+        clock = time.time
+    if min_version is None:
+        min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
+    if min_version > 2:
+        raise ValueError("Unsupported min_version %d" % min_version)
+    if not value:
+        return None
+
+    value = utf8(value)
+    version = _get_version(value)
 
     if version < min_version:
         return None
@@ -3080,7 +3122,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
         return None
 
 
-def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
+def _decode_fields_v2(value):
     def _consume_field(s):
         length, _, rest = s.partition(b':')
         n = int(length)
@@ -3091,16 +3133,28 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
             raise ValueError("malformed v2 signed value field")
         rest = rest[n + 1:]
         return field_value, rest
+
     rest = value[2:]  # remove version number
+    key_version, rest = _consume_field(rest)
+    timestamp, rest = _consume_field(rest)
+    name_field, rest = _consume_field(rest)
+    value_field, passed_sig = _consume_field(rest)
+    return int(key_version), timestamp, name_field, value_field, passed_sig
+
+
+def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
     try:
-        key_version, rest = _consume_field(rest)
-        timestamp, rest = _consume_field(rest)
-        name_field, rest = _consume_field(rest)
-        value_field, rest = _consume_field(rest)
+        key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value)
     except ValueError:
         return None
-    passed_sig = rest
     signed_string = value[:-len(passed_sig)]
+
+    if isinstance(secret, dict):
+        try:
+            secret = secret[key_version]
+        except KeyError:
+            return None
+
     expected_sig = _create_signature_v2(secret, signed_string)
     if not _time_independent_equals(passed_sig, expected_sig):
         return None
@@ -3116,6 +3170,19 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
         return None
 
 
+def get_signature_key_version(value):
+    value = utf8(value)
+    version = _get_version(value)
+    if version < 2:
+        return None
+    try:
+        key_version, _, _, _, _ = _decode_fields_v2(value)
+    except ValueError:
+        return None
+
+    return key_version
+
+
 def _create_signature_v1(secret, *parts):
     hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
     for part in parts: