from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
-from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
+from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version
import binascii
import contextlib
class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
- def __init__(self):
+ def __init__(self, cookie_secret='0123456789', key_version=None):
# don't call super.__init__
self._cookies = {}
- self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
+ if key_version is None:
+ self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
+ else:
+ self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
+ key_version=key_version))
def get_cookie(self, name):
return self._cookies.get(name)
self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
+# See SignedValueTest below for more.
+class SecureCookieV2Test(unittest.TestCase):
+ KEY_VERSIONS = {
+ 0: 'ajklasdf0ojaisdf',
+ 1: 'aslkjasaolwkjsdf'
+ }
+ def test_round_trip(self):
+ handler = CookieTestRequestHandler()
+ handler.set_secure_cookie('foo', b'bar', version=2)
+ self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
+
+ def test_key_version_roundtrip(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=0)
+ handler.set_secure_cookie('foo', b'bar')
+ self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
+
+ def test_key_version_increment_version(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=0)
+ handler.set_secure_cookie('foo', b'bar')
+ new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=1)
+ new_handler._cookies = handler._cookies
+ self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
+
+ def test_key_version_invalidate_version(self):
+ handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
+ key_version=1)
+ handler.set_secure_cookie('foo', b'bar')
+ new_key_versions = self.KEY_VERSIONS.copy()
+ new_key_versions.pop(1)
+ new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
+ key_version=1)
+ new_handler._cookies = handler._cookies
+ self.assertEqual(new_handler.get_secure_cookie('foo'), None)
+
+
class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
class SignedValueTest(unittest.TestCase):
SECRET = "It's a secret to everybody"
+ SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}
def past(self):
return self.present() - 86400 * 32
clock=self.present)
self.assertEqual(value, decoded)
+ def test_key_versioning_read_write_default_key(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present)
+ decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
+ "key", signed, clock=self.present)
+ self.assertEqual(value, decoded)
+
+ def test_key_versioning_read_write_non_default_key(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present,
+ key_version=1)
+ decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
+ "key", signed, clock=self.present)
+ self.assertEqual(value, decoded)
+
+ def test_key_versioning_invalid_key(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present)
+ newkeys = SignedValueTest.SECRET_DICT.copy()
+ newkeys.pop(0)
+ decoded = decode_signed_value(newkeys,
+ "key", signed, clock=self.present)
+ self.assertEqual(None, decoded)
+
+ def test_key_version_retreival(self):
+ value = b"\xe9"
+ signed = create_signed_value(SignedValueTest.SECRET_DICT,
+ "key", value, clock=self.present,
+ key_version=1)
+ key_version = get_signature_key_version(signed)
+ self.assertEqual(1, key_version)
+
@wsgi_safe
class XSRFTest(SimpleHandlerTestCase):
DEFAULT_SIGNED_VALUE_MIN_VERSION = 1
"""The oldest signed value accepted by `.RequestHandler.get_secure_cookie`.
-May be overrided by passing a ``min_version`` keyword argument.
+May be overridden by passing a ``min_version`` keyword argument.
.. versionadded:: 3.2.1
"""
+DEFAULT_SIGN_KEY_VERSION = 0
+"""The current key index used by `.RequestHandler.set_secure_cookie`.
+
+May be overridden by passing a ``key_version`` keyword argument.
+
+.. versionadded:: x.x.x
+"""
+
class RequestHandler(object):
"""Subclass this class and define `get()` or `post()` to make a handler.
and made it the default.
"""
self.require_setting("cookie_secret", "secure cookies")
- return create_signed_value(self.application.settings["cookie_secret"],
- name, value, version=version)
+ secret = self.application.settings["cookie_secret"]
+ key_version = None
+ if isinstance(secret, dict):
+ if self.application.settings.get("key_version") is None:
+ raise Exception("key_version setting must be used for secret_key dicts")
+ key_version = self.application.settings["key_version"]
+
+ return create_signed_value(secret, name, value, version=version,
+ key_version=key_version)
def get_secure_cookie(self, name, value=None, max_age_days=31,
min_version=None):
name, value, max_age_days=max_age_days,
min_version=min_version)
+ def get_secure_cookie_key_version(self, name, value=None):
+ """Returns the signing key version of the secure cookie.
+
+ The version is returned as int.
+ """
+ self.require_setting("cookie_secret", "secure cookies")
+ if value is None:
+ value = self.get_cookie(name)
+ return get_signature_key_version(value)
+
+
def redirect(self, url, permanent=False, status=None):
"""Sends a redirect to the given (optionally relative) URL.
return result == 0
-def create_signed_value(secret, name, value, version=None, clock=None):
+def create_signed_value(secret, name, value, version=None, clock=None,
+ key_version=None):
if version is None:
version = DEFAULT_SIGNED_VALUE_VERSION
if clock is None:
clock = time.time
+
+ if key_version is None:
+ key_version = DEFAULT_SIGN_KEY_VERSION
+ else:
+ assert version >= 2, 'Version must be at least 2 for key version support'
+
timestamp = utf8(str(int(clock())))
value = base64.b64encode(utf8(value))
if version == 1:
#
# The fields are:
# - format version (i.e. 2; no length prefix)
- # - key version (currently 0; reserved for future
- # key rotation features)
+ # - key version (integer, default is 0)
# - timestamp (integer seconds since epoch)
# - name (not encoded; assumed to be ~alphanumeric)
# - value (base64-encoded)
def format_field(s):
return utf8("%d:" % len(s)) + utf8(s)
to_sign = b"|".join([
- b"2|1:0",
+ b"2",
+ format_field(str(key_version)),
format_field(timestamp),
format_field(name),
format_field(value),
b''])
+
+ if isinstance(secret, dict):
+ secret = secret[key_version]
+
signature = _create_signature_v2(secret, to_sign)
return to_sign + signature
else:
_signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")
-def decode_signed_value(secret, name, value, max_age_days=31,
- clock=None, min_version=None):
- if clock is None:
- clock = time.time
- if min_version is None:
- min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
- if min_version > 2:
- raise ValueError("Unsupported min_version %d" % min_version)
- if not value:
- return None
-
- # Figure out what version this is. Version 1 did not include an
+def _get_version(value):
+ # Figures out what version value is. Version 1 did not include an
# explicit version field and started with arbitrary base64 data,
# which makes this tricky.
- value = utf8(value)
m = _signed_value_version_re.match(value)
if m is None:
version = 1
version = 1
except ValueError:
version = 1
+ return version
+
+
+def decode_signed_value(secret, name, value, max_age_days=31,
+ clock=None, min_version=None):
+ if clock is None:
+ clock = time.time
+ if min_version is None:
+ min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
+ if min_version > 2:
+ raise ValueError("Unsupported min_version %d" % min_version)
+ if not value:
+ return None
+
+ value = utf8(value)
+ version = _get_version(value)
if version < min_version:
return None
return None
-def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
+def _decode_fields_v2(value):
def _consume_field(s):
length, _, rest = s.partition(b':')
n = int(length)
raise ValueError("malformed v2 signed value field")
rest = rest[n + 1:]
return field_value, rest
+
rest = value[2:] # remove version number
+ key_version, rest = _consume_field(rest)
+ timestamp, rest = _consume_field(rest)
+ name_field, rest = _consume_field(rest)
+ value_field, passed_sig = _consume_field(rest)
+ return int(key_version), timestamp, name_field, value_field, passed_sig
+
+
+def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
try:
- key_version, rest = _consume_field(rest)
- timestamp, rest = _consume_field(rest)
- name_field, rest = _consume_field(rest)
- value_field, rest = _consume_field(rest)
+ key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value)
except ValueError:
return None
- passed_sig = rest
signed_string = value[:-len(passed_sig)]
+
+ if isinstance(secret, dict):
+ try:
+ secret = secret[key_version]
+ except KeyError:
+ return None
+
expected_sig = _create_signature_v2(secret, signed_string)
if not _time_independent_equals(passed_sig, expected_sig):
return None
return None
+def get_signature_key_version(value):
+ value = utf8(value)
+ version = _get_version(value)
+ if version < 2:
+ return None
+ try:
+ key_version, _, _, _, _ = _decode_fields_v2(value)
+ except ValueError:
+ return None
+
+ return key_version
+
+
def _create_signature_v1(secret, *parts):
hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
for part in parts: