From: Ben Darnell Date: Sun, 1 Jun 2014 20:36:52 +0000 (-0400) Subject: Remove all use of async_callback in tornado.auth. X-Git-Tag: v4.0.0b1~30 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1d1b0dec7226976d39e9ffc0cfc3cb88c78fdcdd;p=thirdparty%2Ftornado.git Remove all use of async_callback in tornado.auth. This was swallowing exceptions before they could be given to the returned Future. --- diff --git a/tornado/auth.py b/tornado/auth.py index 335b591b4..f15413e5d 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -50,6 +50,13 @@ Example usage for Google OpenID:: scope=['profile', 'email'], response_type='code', extra_params={'approval_prompt': 'auto'}) + +.. versionchanged:: 3.3 + All of the callback interfaces in this module are now guaranteed + to run their callback with an argument of ``None`` on error. + Previously some functions would do this while others would simply + terminate the request on their own. This change also ensures that + errors are more consistently reported through the ``Future`` interfaces. """ from __future__ import absolute_import, division, print_function, with_statement @@ -68,6 +75,7 @@ from tornado import httpclient from tornado import escape from tornado.httputil import url_concat from tornado.log import gen_log +from tornado.stack_context import ExceptionStackContext from tornado.util import bytes_type, u, unicode_type, ArgReplacer try: @@ -115,7 +123,14 @@ def _auth_return_future(f): if callback is not None: future.add_done_callback( functools.partial(_auth_future_to_callback, callback)) - f(*args, **kwargs) + def handle_exception(typ, value, tb): + if future.done(): + return False + else: + future.set_exc_info((typ, value, tb)) + return True + with ExceptionStackContext(handle_exception): + f(*args, **kwargs) return future return wrapper @@ -173,7 +188,7 @@ class OpenIdMixin(object): url = self._OPENID_ENDPOINT if http_client is None: http_client = self.get_auth_http_client() - http_client.fetch(url, self.async_callback( + http_client.fetch(url, functools.partial( self._on_authentication_verified, callback), method="POST", body=urllib_parse.urlencode(args)) @@ -345,7 +360,7 @@ class OAuthMixin(object): http_client.fetch( self._oauth_request_token_url(callback_uri=callback_uri, extra_params=extra_params), - self.async_callback( + functools.partial( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri, @@ -353,7 +368,7 @@ class OAuthMixin(object): else: http_client.fetch( self._oauth_request_token_url(), - self.async_callback( + functools.partial( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri, callback)) @@ -390,7 +405,7 @@ class OAuthMixin(object): if http_client is None: http_client = self.get_auth_http_client() http_client.fetch(self._oauth_access_token_url(token), - self.async_callback(self._on_access_token, callback)) + functools.partial(self._on_access_token, callback)) def _oauth_request_token_url(self, callback_uri=None, extra_params=None): consumer_token = self._oauth_consumer_token() @@ -467,7 +482,7 @@ class OAuthMixin(object): access_token = _oauth_parse_response(response.body) self._oauth_get_user_future(access_token).add_done_callback( - self.async_callback(self._on_oauth_get_user, access_token, future)) + functools.partial(self._on_oauth_get_user, access_token, future)) def _oauth_consumer_token(self): """Subclasses must override this to return their OAuth consumer keys. @@ -652,7 +667,7 @@ class TwitterMixin(OAuthMixin): """ http = self.get_auth_http_client() http.fetch(self._oauth_request_token_url(callback_uri=callback_uri), - self.async_callback( + functools.partial( self._on_request_token, self._OAUTH_AUTHENTICATE_URL, None, callback)) @@ -710,7 +725,7 @@ class TwitterMixin(OAuthMixin): if args: url += "?" + urllib_parse.urlencode(args) http = self.get_auth_http_client() - http_callback = self.async_callback(self._on_twitter_request, callback) + http_callback = functools.partial(self._on_twitter_request, callback) if post_args is not None: http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), callback=http_callback) @@ -827,7 +842,7 @@ class FriendFeedMixin(OAuthMixin): args.update(oauth) if args: url += "?" + urllib_parse.urlencode(args) - callback = self.async_callback(self._on_friendfeed_request, callback) + callback = functools.partial(self._on_friendfeed_request, callback) http = self.get_auth_http_client() if post_args is not None: http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), @@ -942,7 +957,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin): http = self.get_auth_http_client() token = dict(key=token, secret="") http.fetch(self._oauth_access_token_url(token), - self.async_callback(self._on_access_token, callback)) + functools.partial(self._on_access_token, callback)) else: chain_future(OpenIdMixin.get_authenticated_user(self), callback) @@ -1014,7 +1029,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin): }) http.fetch(self._OAUTH_ACCESS_TOKEN_URL, - self.async_callback(self._on_access_token, callback), + functools.partial(self._on_access_token, callback), method="POST", headers={'Content-Type': 'application/x-www-form-urlencoded'}, body=body) def _on_access_token(self, future, response): @@ -1055,7 +1070,7 @@ class FacebookMixin(object): @tornado.web.asynchronous def get(self): if self.get_argument("session", None): - self.get_authenticated_user(self.async_callback(self._on_auth)) + self.get_authenticated_user(self._on_auth) return yield self.authenticate_redirect() @@ -1141,7 +1156,7 @@ class FacebookMixin(object): session = escape.json_decode(self.get_argument("session")) self.facebook_request( method="facebook.users.getInfo", - callback=self.async_callback( + callback=functools.partial( self._on_get_user_info, callback, session), session_key=session["session_key"], uids=session["uid"], @@ -1167,7 +1182,7 @@ class FacebookMixin(object): def get(self): self.facebook_request( method="stream.get", - callback=self.async_callback(self._on_stream), + callback=self._on_stream, session_key=self.current_user["session_key"]) def _on_stream(self, stream): @@ -1191,7 +1206,7 @@ class FacebookMixin(object): url = "http://api.facebook.com/restserver.php?" + \ urllib_parse.urlencode(args) http = self.get_auth_http_client() - http.fetch(url, callback=self.async_callback( + http.fetch(url, callback=functools.partial( self._parse_response, callback)) def _on_get_user_info(self, callback, session, users): @@ -1289,7 +1304,7 @@ class FacebookGraphMixin(OAuth2Mixin): fields.update(extra_fields) http.fetch(self._oauth_request_token_url(**args), - self.async_callback(self._on_access_token, redirect_uri, client_id, + functools.partial(self._on_access_token, redirect_uri, client_id, client_secret, callback, fields)) def _on_access_token(self, redirect_uri, client_id, client_secret, @@ -1306,7 +1321,7 @@ class FacebookGraphMixin(OAuth2Mixin): self.facebook_request( path="/me", - callback=self.async_callback( + callback=functools.partial( self._on_get_user_info, future, session, fields), access_token=session["access_token"], fields=",".join(fields) @@ -1373,7 +1388,7 @@ class FacebookGraphMixin(OAuth2Mixin): if all_args: url += "?" + urllib_parse.urlencode(all_args) - callback = self.async_callback(self._on_facebook_request, callback) + callback = functools.partial(self._on_facebook_request, callback) http = self.get_auth_http_client() if post_args is not None: http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args), diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 1d6cb8392..254e1ae13 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -67,11 +67,29 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin): self.finish(user) def _oauth_get_user(self, access_token, callback): + if self.get_argument('fail_in_get_user', None): + raise Exception("failing in get_user") if access_token != dict(key='uiop', secret='5678'): raise Exception("incorrect access token %r" % access_token) callback(dict(email='foo@example.com')) +class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler): + """Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine.""" + @gen.coroutine + def get(self): + if self.get_argument('oauth_token', None): + # Ensure that any exceptions are set on the returned Future, + # not simply thrown into the surrounding StackContext. + try: + yield self.get_authenticated_user() + except Exception as e: + self.set_status(503) + self.write("got exception: %s" % e) + else: + yield self.authorize_redirect() + + class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin): def initialize(self, version): self._OAUTH_VERSION = version @@ -255,6 +273,9 @@ class AuthTest(AsyncHTTPTestCase): dict(version='1.0')), ('/oauth10a/client/login', OAuth1ClientLoginHandler, dict(test=self, version='1.0a')), + ('/oauth10a/client/login_coroutine', + OAuth1ClientLoginCoroutineHandler, + dict(test=self, version='1.0a')), ('/oauth10a/client/request_params', OAuth1ClientRequestParametersHandler, dict(version='1.0a')), @@ -348,6 +369,12 @@ class AuthTest(AsyncHTTPTestCase): self.assertTrue('oauth_nonce' in parsed) self.assertTrue('oauth_signature' in parsed) + def test_oauth10a_get_user_coroutine_exception(self): + response = self.fetch( + '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true', + headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + self.assertEqual(response.code, 503) + def test_oauth2_redirect(self): response = self.fetch('/oauth2/client/login', follow_redirects=False) self.assertEqual(response.code, 302)