From: Ben Darnell Date: Mon, 12 Sep 2011 05:50:13 +0000 (-0700) Subject: Add crude tests for the auth module, and fix python3 issues with oauth1 X-Git-Tag: v2.1.0~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=eb5f2ceb01d58cb469ce9bf61005807ee09ce67b;p=thirdparty%2Ftornado.git Add crude tests for the auth module, and fix python3 issues with oauth1 --- diff --git a/tornado/auth.py b/tornado/auth.py index b21175c00..612f820de 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -50,7 +50,6 @@ Example usage for Google OpenID:: import base64 import binascii -import cgi import hashlib import hmac import logging @@ -85,7 +84,7 @@ class OpenIdMixin(object): args = self._openid_args(callback_uri, ax_attrs=ax_attrs) self.redirect(self._OPENID_ENDPOINT + "?" + urllib.urlencode(args)) - def get_authenticated_user(self, callback): + def get_authenticated_user(self, callback, http_client=None): """Fetches the authenticated user data upon redirect. This method should be called by the handler that receives the @@ -96,8 +95,8 @@ class OpenIdMixin(object): args = dict((k, v[-1]) for k, v in self.request.arguments.iteritems()) args["openid.mode"] = u"check_authentication" url = self._OPENID_ENDPOINT - http = httpclient.AsyncHTTPClient() - http.fetch(url, self.async_callback( + if http_client is None: http_client = httpclient.AsyncHTTPClient() + http_client.fetch(url, self.async_callback( self._on_authentication_verified, callback), method="POST", body=urllib.urlencode(args)) @@ -207,7 +206,8 @@ class OAuthMixin(object): See TwitterMixin and FriendFeedMixin below for example implementations. """ - def authorize_redirect(self, callback_uri=None, extra_params=None): + def authorize_redirect(self, callback_uri=None, extra_params=None, + http_client=None): """Redirects the user to obtain OAuth authorization for this service. Twitter and FriendFeed both require that you register a Callback @@ -222,20 +222,25 @@ class OAuthMixin(object): """ if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False): raise Exception("This service does not support oauth_callback") - http = httpclient.AsyncHTTPClient() + if http_client is None: + http_client = httpclient.AsyncHTTPClient() if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": - http.fetch(self._oauth_request_token_url(callback_uri=callback_uri, - extra_params=extra_params), + http_client.fetch( + self._oauth_request_token_url(callback_uri=callback_uri, + extra_params=extra_params), self.async_callback( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri)) else: - http.fetch(self._oauth_request_token_url(), self.async_callback( - self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri)) + http_client.fetch( + self._oauth_request_token_url(), + self.async_callback( + self._on_request_token, self._OAUTH_AUTHORIZE_URL, + callback_uri)) - def get_authenticated_user(self, callback): + def get_authenticated_user(self, callback, http_client=None): """Gets the OAuth authorized user and access token on callback. This method should be called from the handler for your registered @@ -246,7 +251,7 @@ class OAuthMixin(object): to this service on behalf of the user. """ - request_key = self.get_argument("oauth_token") + request_key = escape.utf8(self.get_argument("oauth_token")) oauth_verifier = self.get_argument("oauth_verifier", None) request_cookie = self.get_cookie("_oauth_request_token") if not request_cookie: @@ -254,17 +259,19 @@ class OAuthMixin(object): callback(None) return self.clear_cookie("_oauth_request_token") - cookie_key, cookie_secret = [base64.b64decode(i) for i in request_cookie.split("|")] + cookie_key, cookie_secret = [base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")] if cookie_key != request_key: + logging.info((cookie_key, request_key, request_cookie)) logging.warning("Request token does not match cookie") callback(None) return token = dict(key=cookie_key, secret=cookie_secret) if oauth_verifier: - token["verifier"] = oauth_verifier - http = httpclient.AsyncHTTPClient() - http.fetch(self._oauth_access_token_url(token), self.async_callback( - self._on_access_token, callback)) + token["verifier"] = oauth_verifier + if http_client is None: + http_client = httpclient.AsyncHTTPClient() + http_client.fetch(self._oauth_access_token_url(token), + self.async_callback(self._on_access_token, callback)) def _oauth_request_token_url(self, callback_uri= None, extra_params=None): consumer_token = self._oauth_consumer_token() @@ -292,8 +299,8 @@ class OAuthMixin(object): if response.error: raise Exception("Could not get request token") request_token = _oauth_parse_response(response.body) - data = "|".join([base64.b64encode(request_token["key"]), - base64.b64encode(request_token["secret"])]) + data = (base64.b64encode(request_token["key"]) + b("|") + + base64.b64encode(request_token["secret"])) self.set_cookie("_oauth_request_token", data) args = dict(oauth_token=request_token["key"]) if callback_uri: @@ -1078,11 +1085,11 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None): for k, v in sorted(parameters.items()))) base_string = "&".join(_oauth_escape(e) for e in base_elems) - key_elems = [consumer_token["secret"]] - key_elems.append(token["secret"] if token else "") - key = "&".join(key_elems) + key_elems = [escape.utf8(consumer_token["secret"])] + key_elems.append(escape.utf8(token["secret"] if token else "")) + key = b("&").join(key_elems) - hash = hmac.new(key, base_string, hashlib.sha1) + hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1) return binascii.b2a_base64(hash.digest())[:-1] def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None): @@ -1101,11 +1108,11 @@ def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None): for k, v in sorted(parameters.items()))) base_string = "&".join(_oauth_escape(e) for e in base_elems) - key_elems = [urllib.quote(consumer_token["secret"], safe='~')] - key_elems.append(urllib.quote(token["secret"], safe='~') if token else "") - key = "&".join(key_elems) + key_elems = [escape.utf8(urllib.quote(consumer_token["secret"], safe='~'))] + key_elems.append(escape.utf8(urllib.quote(token["secret"], safe='~') if token else "")) + key = b("&").join(key_elems) - hash = hmac.new(key, base_string, hashlib.sha1) + hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1) return binascii.b2a_base64(hash.digest())[:-1] def _oauth_escape(val): @@ -1115,11 +1122,11 @@ def _oauth_escape(val): def _oauth_parse_response(body): - p = cgi.parse_qs(body, keep_blank_values=False) - token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0]) + p = escape.parse_qs(body, keep_blank_values=False) + token = dict(key=p[b("oauth_token")][0], secret=p[b("oauth_token_secret")][0]) # Add the extra parameters the Provider included to the token - special = ("oauth_token", "oauth_token_secret") + special = (b("oauth_token"), b("oauth_token_secret")) token.update((k, p[k][0]) for k in p if k not in special) return token diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py new file mode 100644 index 000000000..204790409 --- /dev/null +++ b/tornado/test/auth_test.py @@ -0,0 +1,186 @@ +# These tests do not currently do much to verify the correct implementation +# of the openid/oauth protocols, they just exercise the major code paths +# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in +# python 3) + +from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin +from tornado.escape import json_decode +from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase +from tornado.util import b +from tornado.web import RequestHandler, Application, asynchronous + +class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin): + def initialize(self, test): + self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate') + + @asynchronous + def get(self): + if self.get_argument('openid.mode', None): + self.get_authenticated_user( + self.on_user, http_client=self.settings['http_client']) + return + self.authenticate_redirect() + + def on_user(self, user): + assert user is not None + self.finish(user) + +class OpenIdServerAuthenticateHandler(RequestHandler): + def post(self): + assert self.get_argument('openid.mode') == 'check_authentication' + self.write('is_valid:true') + +class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin): + def initialize(self, test, version): + self._OAUTH_VERSION = version + self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token') + self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize') + self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token') + + def _oauth_consumer_token(self): + return dict(key='asdf', secret='qwer') + + @asynchronous + def get(self): + if self.get_argument('oauth_token', None): + self.get_authenticated_user( + self.on_user, http_client=self.settings['http_client']) + return + self.authorize_redirect(http_client=self.settings['http_client']) + + def on_user(self, user): + assert user is not None + self.finish(user) + + def _oauth_get_user(self, access_token, callback): + assert access_token == dict(key=b('uiop'), secret=b('5678')), access_token + callback(dict(email='foo@example.com')) + +class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin): + def initialize(self, version): + self._OAUTH_VERSION = version + + def _oauth_consumer_token(self): + return dict(key='asdf', secret='qwer') + + def get(self): + params = self._oauth_request_parameters( + 'http://www.example.com/api/asdf', + dict(key='uiop', secret='5678'), + parameters=dict(foo='bar')) + import urllib; urllib.urlencode(params) + self.write(params) + +class OAuth1ServerRequestTokenHandler(RequestHandler): + def get(self): + self.write('oauth_token=zxcv&oauth_token_secret=1234') + +class OAuth1ServerAccessTokenHandler(RequestHandler): + def get(self): + self.write('oauth_token=uiop&oauth_token_secret=5678') + +class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin): + def initialize(self, test): + self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize') + + def get(self): + self.authorize_redirect() + + +class AuthTest(AsyncHTTPTestCase, LogTrapTestCase): + def get_app(self): + return Application( + [ + # test endpoints + ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)), + ('/oauth10/client/login', OAuth1ClientLoginHandler, + dict(test=self, version='1.0')), + ('/oauth10/client/request_params', + OAuth1ClientRequestParametersHandler, + dict(version='1.0')), + ('/oauth10a/client/login', OAuth1ClientLoginHandler, + dict(test=self, version='1.0a')), + ('/oauth10a/client/request_params', + OAuth1ClientRequestParametersHandler, + dict(version='1.0a')), + ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)), + + # simulated servers + ('/openid/server/authenticate', OpenIdServerAuthenticateHandler), + ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler), + ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler), + ], + http_client=self.http_client) + + def test_openid_redirect(self): + response = self.fetch('/openid/client/login', follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue( + '/openid/server/authenticate?' in response.headers['Location']) + + def test_openid_get_user(self): + response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com') + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed["email"], "foo@example.com") + + def test_oauth10_redirect(self): + response = self.fetch('/oauth10/client/login', follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue(response.headers['Location'].endswith( + '/oauth1/server/authorize?oauth_token=zxcv')) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], + response.headers['Set-Cookie']) + + def test_oauth10_get_user(self): + response = self.fetch( + '/oauth10/client/login?oauth_token=zxcv', + headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='}) + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed['email'], 'foo@example.com') + self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + + def test_oauth10_request_parameters(self): + response = self.fetch('/oauth10/client/request_params') + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed['oauth_consumer_key'], 'asdf') + self.assertEqual(parsed['oauth_token'], 'uiop') + self.assertTrue('oauth_nonce' in parsed) + self.assertTrue('oauth_signature' in parsed) + + def test_oauth10a_redirect(self): + response = self.fetch('/oauth10a/client/login', follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue(response.headers['Location'].endswith( + '/oauth1/server/authorize?oauth_token=zxcv')) + # the cookie is base64('zxcv')|base64('1234') + self.assertTrue( + '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], + response.headers['Set-Cookie']) + + def test_oauth10a_get_user(self): + response = self.fetch( + '/oauth10a/client/login?oauth_token=zxcv', + headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='}) + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed['email'], 'foo@example.com') + self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + + def test_oauth10a_request_parameters(self): + response = self.fetch('/oauth10a/client/request_params') + response.rethrow() + parsed = json_decode(response.body) + self.assertEqual(parsed['oauth_consumer_key'], 'asdf') + self.assertEqual(parsed['oauth_token'], 'uiop') + self.assertTrue('oauth_nonce' in parsed) + self.assertTrue('oauth_signature' in parsed) + + def test_oauth2_redirect(self): + response = self.fetch('/oauth2/client/login', follow_redirects=False) + self.assertEqual(response.code, 302) + self.assertTrue('/oauth2/server/authorize?' in response.headers['Location']) diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index 14f161a2b..188aee823 100755 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -5,6 +5,7 @@ TEST_MODULES = [ 'tornado.httputil.doctests', 'tornado.iostream.doctests', 'tornado.util.doctests', + 'tornado.test.auth_test', 'tornado.test.curl_httpclient_test', 'tornado.test.escape_test', 'tornado.test.gen_test',