]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
auth: Fix error handling in 5.1 2496/head
authorBen Darnell <ben@bendarnell.com>
Sun, 16 Sep 2018 16:41:25 +0000 (12:41 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 16 Sep 2018 16:41:25 +0000 (12:41 -0400)
In 5.1, callbacks in this module were moved from the http
client (which uses stack_context) to the Future (which does not).
These callbacks generally rely on the stack context for error
handling, so we must explicitly wrap everything.

Fixes #2483

tornado/auth.py
tornado/concurrent.py
tornado/test/auth_test.py

index 0f019d6fd00e06d3fba206dcae073f67861dada4..b79ad14bed2b016474ae14f90e74ddd3faa0623e 100644 (file)
@@ -82,7 +82,7 @@ from tornado import httpclient
 from tornado import escape
 from tornado.httputil import url_concat
 from tornado.log import gen_log
-from tornado.stack_context import ExceptionStackContext
+from tornado.stack_context import ExceptionStackContext, wrap
 from tornado.util import unicode_type, ArgReplacer, PY3
 
 if PY3:
@@ -127,7 +127,7 @@ def _auth_return_future(f):
             warnings.warn("callback arguments are deprecated, use the returned Future instead",
                           DeprecationWarning)
             future.add_done_callback(
-                functools.partial(_auth_future_to_callback, callback))
+                wrap(functools.partial(_auth_future_to_callback, callback)))
 
         def handle_exception(typ, value, tb):
             if future.done():
@@ -202,8 +202,8 @@ class OpenIdMixin(object):
         if http_client is None:
             http_client = self.get_auth_http_client()
         fut = http_client.fetch(url, method="POST", body=urllib_parse.urlencode(args))
-        fut.add_done_callback(functools.partial(
-            self._on_authentication_verified, callback))
+        fut.add_done_callback(wrap(functools.partial(
+            self._on_authentication_verified, callback)))
 
     def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None):
         url = urlparse.urljoin(self.request.full_url(), callback_uri)
@@ -381,18 +381,18 @@ class OAuthMixin(object):
             fut = http_client.fetch(
                 self._oauth_request_token_url(callback_uri=callback_uri,
                                               extra_params=extra_params))
-            fut.add_done_callback(functools.partial(
+            fut.add_done_callback(wrap(functools.partial(
                 self._on_request_token,
                 self._OAUTH_AUTHORIZE_URL,
                 callback_uri,
-                callback))
+                callback)))
         else:
             fut = http_client.fetch(self._oauth_request_token_url())
             fut.add_done_callback(
-                functools.partial(
+                wrap(functools.partial(
                     self._on_request_token, self._OAUTH_AUTHORIZE_URL,
                     callback_uri,
-                    callback))
+                    callback)))
 
     @_auth_return_future
     def get_authenticated_user(self, callback, http_client=None):
@@ -432,7 +432,7 @@ class OAuthMixin(object):
         if http_client is None:
             http_client = self.get_auth_http_client()
         fut = http_client.fetch(self._oauth_access_token_url(token))
-        fut.add_done_callback(functools.partial(self._on_access_token, callback))
+        fut.add_done_callback(wrap(functools.partial(self._on_access_token, callback)))
 
     def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
         consumer_token = self._oauth_consumer_token()
@@ -515,7 +515,7 @@ class OAuthMixin(object):
         fut = self._oauth_get_user_future(access_token)
         fut = gen.convert_yielded(fut)
         fut.add_done_callback(
-            functools.partial(self._on_oauth_get_user, access_token, future))
+            wrap(functools.partial(self._on_oauth_get_user, access_token, future)))
 
     def _oauth_consumer_token(self):
         """Subclasses must override this to return their OAuth consumer keys.
@@ -711,7 +711,7 @@ class OAuth2Mixin(object):
 
         if all_args:
             url += "?" + urllib_parse.urlencode(all_args)
-        callback = functools.partial(self._on_oauth2_request, callback)
+        callback = wrap(functools.partial(self._on_oauth2_request, callback))
         http = self.get_auth_http_client()
         if post_args is not None:
             fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args))
@@ -797,9 +797,9 @@ class TwitterMixin(OAuthMixin):
         """
         http = self.get_auth_http_client()
         fut = http.fetch(self._oauth_request_token_url(callback_uri=callback_uri))
-        fut.add_done_callback(functools.partial(
+        fut.add_done_callback(wrap(functools.partial(
             self._on_request_token, self._OAUTH_AUTHENTICATE_URL,
-            None, callback))
+            None, callback)))
 
     @_auth_return_future
     def twitter_request(self, path, callback=None, access_token=None,
@@ -863,7 +863,7 @@ class TwitterMixin(OAuthMixin):
         if args:
             url += "?" + urllib_parse.urlencode(args)
         http = self.get_auth_http_client()
-        http_callback = functools.partial(self._on_twitter_request, callback, url)
+        http_callback = wrap(functools.partial(self._on_twitter_request, callback, url))
         if post_args is not None:
             fut = http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args))
         else:
@@ -977,7 +977,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
                          method="POST",
                          headers={'Content-Type': 'application/x-www-form-urlencoded'},
                          body=body)
-        fut.add_done_callback(functools.partial(self._on_access_token, callback))
+        fut.add_done_callback(wrap(functools.partial(self._on_access_token, callback)))
 
     def _on_access_token(self, future, response_fut):
         """Callback function for the exchange to the access token."""
@@ -1061,8 +1061,8 @@ class FacebookGraphMixin(OAuth2Mixin):
             fields.update(extra_fields)
 
         fut = http.fetch(self._oauth_request_token_url(**args))
-        fut.add_done_callback(functools.partial(self._on_access_token, redirect_uri, client_id,
-                                                client_secret, callback, fields))
+        fut.add_done_callback(wrap(functools.partial(self._on_access_token, redirect_uri, client_id,
+                                                     client_secret, callback, fields)))
 
     @gen.coroutine
     def _on_access_token(self, redirect_uri, client_id, client_secret,
index 78b20919b90d47e9b28ba6babbb711ea6f4411ae..f7e6bcccb19058d2ed18273ce49b7815d1f20412 100644 (file)
@@ -519,10 +519,10 @@ def return_future(f):
     """
     warnings.warn("@return_future is deprecated, use coroutines instead",
                   DeprecationWarning)
-    return _non_deprecated_return_future(f)
+    return _non_deprecated_return_future(f, warn=True)
 
 
-def _non_deprecated_return_future(f):
+def _non_deprecated_return_future(f, warn=False):
     # Allow auth.py to use this decorator without triggering
     # deprecation warnings. This will go away once auth.py has removed
     # its legacy interfaces in 6.0.
@@ -539,7 +539,15 @@ def _non_deprecated_return_future(f):
             future_set_exc_info(future, (typ, value, tb))
             return True
         exc_info = None
-        with ExceptionStackContext(handle_error, delay_warning=True):
+        esc = ExceptionStackContext(handle_error, delay_warning=True)
+        with esc:
+            if not warn:
+                # HACK: In non-deprecated mode (only used in auth.py),
+                # suppress the warning entirely. Since this is added
+                # in a 5.1 patch release and already removed in 6.0
+                # I'm prioritizing a minimial change instead of a
+                # clean solution.
+                esc.delay_warning = False
             try:
                 result = f(*args, **kwargs)
                 if result is not None:
index 41993b1f6bdf17f66127b4ebffe05f1285d8a97e..14bc3353cdd7387b165599be0d8d0785c1af0775 100644 (file)
@@ -6,6 +6,7 @@
 
 from __future__ import absolute_import, division, print_function
 
+import unittest
 import warnings
 
 from tornado.auth import (
@@ -16,11 +17,16 @@ from tornado.concurrent import Future
 from tornado.escape import json_decode
 from tornado import gen
 from tornado.httputil import url_concat
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
 from tornado.testing import AsyncHTTPTestCase, ExpectLog
 from tornado.test.util import ignore_deprecation
 from tornado.web import RequestHandler, Application, asynchronous, HTTPError
 
+try:
+    from unittest import mock
+except ImportError:
+    mock = None
+
 
 class OpenIdClientLoginHandlerLegacy(RequestHandler, OpenIdMixin):
     def initialize(self, test):
@@ -527,6 +533,14 @@ class AuthTest(AsyncHTTPTestCase):
             '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
             response.headers['Set-Cookie'])
 
+    @unittest.skipIf(mock is None, 'mock package not present')
+    def test_oauth10a_redirect_error(self):
+        with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get:
+            get.side_effect = Exception("boom")
+            with ExpectLog(app_log, "Uncaught exception"):
+                response = self.fetch('/oauth10a/client/login', follow_redirects=False)
+            self.assertEqual(response.code, 500)
+
     def test_oauth10a_get_user(self):
         response = self.fetch(
             '/oauth10a/client/login?oauth_token=zxcv',