From: Ben Darnell Date: Sun, 16 Sep 2018 16:41:25 +0000 (-0400) Subject: auth: Fix error handling in 5.1 X-Git-Tag: v5.1.1~2^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fcc0f4ea055d896538983bd7af1784fa9c301ffc;p=thirdparty%2Ftornado.git auth: Fix error handling in 5.1 In 5.1, callbacks in this module were moved from the http client (which uses stack_context) to the Future (which does not). These callbacks generally rely on the stack context for error handling, so we must explicitly wrap everything. Fixes #2483 --- diff --git a/tornado/auth.py b/tornado/auth.py index 0f019d6fd..b79ad14be 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -82,7 +82,7 @@ from tornado import httpclient from tornado import escape from tornado.httputil import url_concat from tornado.log import gen_log -from tornado.stack_context import ExceptionStackContext +from tornado.stack_context import ExceptionStackContext, wrap from tornado.util import unicode_type, ArgReplacer, PY3 if PY3: @@ -127,7 +127,7 @@ def _auth_return_future(f): warnings.warn("callback arguments are deprecated, use the returned Future instead", DeprecationWarning) future.add_done_callback( - functools.partial(_auth_future_to_callback, callback)) + wrap(functools.partial(_auth_future_to_callback, callback))) def handle_exception(typ, value, tb): if future.done(): @@ -202,8 +202,8 @@ class OpenIdMixin(object): if http_client is None: http_client = self.get_auth_http_client() fut = http_client.fetch(url, method="POST", body=urllib_parse.urlencode(args)) - fut.add_done_callback(functools.partial( - self._on_authentication_verified, callback)) + fut.add_done_callback(wrap(functools.partial( + self._on_authentication_verified, callback))) def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None): url = urlparse.urljoin(self.request.full_url(), callback_uri) @@ -381,18 +381,18 @@ class OAuthMixin(object): fut = http_client.fetch( self._oauth_request_token_url(callback_uri=callback_uri, extra_params=extra_params)) - fut.add_done_callback(functools.partial( + fut.add_done_callback(wrap(functools.partial( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri, - callback)) + callback))) else: fut = http_client.fetch(self._oauth_request_token_url()) fut.add_done_callback( - functools.partial( + wrap(functools.partial( self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri, - callback)) + callback))) @_auth_return_future def get_authenticated_user(self, callback, http_client=None): @@ -432,7 +432,7 @@ class OAuthMixin(object): if http_client is None: http_client = self.get_auth_http_client() fut = http_client.fetch(self._oauth_access_token_url(token)) - fut.add_done_callback(functools.partial(self._on_access_token, callback)) + fut.add_done_callback(wrap(functools.partial(self._on_access_token, callback))) def _oauth_request_token_url(self, callback_uri=None, extra_params=None): consumer_token = self._oauth_consumer_token() @@ -515,7 +515,7 @@ class OAuthMixin(object): fut = self._oauth_get_user_future(access_token) fut = gen.convert_yielded(fut) fut.add_done_callback( - functools.partial(self._on_oauth_get_user, access_token, future)) + wrap(functools.partial(self._on_oauth_get_user, access_token, future))) def _oauth_consumer_token(self): """Subclasses must override this to return their OAuth consumer keys. @@ -711,7 +711,7 @@ class OAuth2Mixin(object): if all_args: url += "?" + urllib_parse.urlencode(all_args) - callback = functools.partial(self._on_oauth2_request, callback) + callback = wrap(functools.partial(self._on_oauth2_request, callback)) http = self.get_auth_http_client() if post_args is not None: fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args)) @@ -797,9 +797,9 @@ class TwitterMixin(OAuthMixin): """ http = self.get_auth_http_client() fut = http.fetch(self._oauth_request_token_url(callback_uri=callback_uri)) - fut.add_done_callback(functools.partial( + fut.add_done_callback(wrap(functools.partial( self._on_request_token, self._OAUTH_AUTHENTICATE_URL, - None, callback)) + None, callback))) @_auth_return_future def twitter_request(self, path, callback=None, access_token=None, @@ -863,7 +863,7 @@ class TwitterMixin(OAuthMixin): if args: url += "?" + urllib_parse.urlencode(args) http = self.get_auth_http_client() - http_callback = functools.partial(self._on_twitter_request, callback, url) + http_callback = wrap(functools.partial(self._on_twitter_request, callback, url)) if post_args is not None: fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args)) else: @@ -977,7 +977,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin): method="POST", headers={'Content-Type': 'application/x-www-form-urlencoded'}, body=body) - fut.add_done_callback(functools.partial(self._on_access_token, callback)) + fut.add_done_callback(wrap(functools.partial(self._on_access_token, callback))) def _on_access_token(self, future, response_fut): """Callback function for the exchange to the access token.""" @@ -1061,8 +1061,8 @@ class FacebookGraphMixin(OAuth2Mixin): fields.update(extra_fields) fut = http.fetch(self._oauth_request_token_url(**args)) - fut.add_done_callback(functools.partial(self._on_access_token, redirect_uri, client_id, - client_secret, callback, fields)) + fut.add_done_callback(wrap(functools.partial(self._on_access_token, redirect_uri, client_id, + client_secret, callback, fields))) @gen.coroutine def _on_access_token(self, redirect_uri, client_id, client_secret, diff --git a/tornado/concurrent.py b/tornado/concurrent.py index 78b20919b..f7e6bcccb 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -519,10 +519,10 @@ def return_future(f): """ warnings.warn("@return_future is deprecated, use coroutines instead", DeprecationWarning) - return _non_deprecated_return_future(f) + return _non_deprecated_return_future(f, warn=True) -def _non_deprecated_return_future(f): +def _non_deprecated_return_future(f, warn=False): # Allow auth.py to use this decorator without triggering # deprecation warnings. This will go away once auth.py has removed # its legacy interfaces in 6.0. @@ -539,7 +539,15 @@ def _non_deprecated_return_future(f): future_set_exc_info(future, (typ, value, tb)) return True exc_info = None - with ExceptionStackContext(handle_error, delay_warning=True): + esc = ExceptionStackContext(handle_error, delay_warning=True) + with esc: + if not warn: + # HACK: In non-deprecated mode (only used in auth.py), + # suppress the warning entirely. Since this is added + # in a 5.1 patch release and already removed in 6.0 + # I'm prioritizing a minimial change instead of a + # clean solution. + esc.delay_warning = False try: result = f(*args, **kwargs) if result is not None: diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 41993b1f6..14bc3353c 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -6,6 +6,7 @@ from __future__ import absolute_import, division, print_function +import unittest import warnings from tornado.auth import ( @@ -16,11 +17,16 @@ from tornado.concurrent import Future from tornado.escape import json_decode from tornado import gen from tornado.httputil import url_concat -from tornado.log import gen_log +from tornado.log import gen_log, app_log from tornado.testing import AsyncHTTPTestCase, ExpectLog from tornado.test.util import ignore_deprecation from tornado.web import RequestHandler, Application, asynchronous, HTTPError +try: + from unittest import mock +except ImportError: + mock = None + class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin): def initialize(self, test): @@ -527,6 +533,14 @@ class AuthTest(AsyncHTTPTestCase): '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], response.headers['Set-Cookie']) + @unittest.skipIf(mock is None, 'mock package not present') + def test_oauth10a_redirect_error(self): + with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get: + get.side_effect = Exception("boom") + with ExpectLog(app_log, "Uncaught exception"): + response = self.fetch('/oauth10a/client/login', follow_redirects=False) + self.assertEqual(response.code, 500) + def test_oauth10a_get_user(self): response = self.fetch( '/oauth10a/client/login?oauth_token=zxcv',