]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Remove all use of async_callback in tornado.auth.
authorBen Darnell <ben@bendarnell.com>
Sun, 1 Jun 2014 20:36:52 +0000 (16:36 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 1 Jun 2014 20:36:52 +0000 (16:36 -0400)
This was swallowing exceptions before they could be given to the returned
Future.

tornado/auth.py
tornado/test/auth_test.py

index 335b591b45e8de5049180cdbe0c26fed45a79d59..f15413e5db4809e35b162de2a6ee578037e4a8f3 100644 (file)
@@ -50,6 +50,13 @@ Example usage for Google OpenID::
                     scope=['profile', 'email'],
                     response_type='code',
                     extra_params={'approval_prompt': 'auto'})
+
+.. versionchanged:: 3.3
+   All of the callback interfaces in this module are now guaranteed
+   to run their callback with an argument of ``None`` on error.
+   Previously some functions would do this while others would simply
+   terminate the request on their own.  This change also ensures that
+   errors are more consistently reported through the ``Future`` interfaces.
 """
 
 from __future__ import absolute_import, division, print_function, with_statement
@@ -68,6 +75,7 @@ from tornado import httpclient
 from tornado import escape
 from tornado.httputil import url_concat
 from tornado.log import gen_log
+from tornado.stack_context import ExceptionStackContext
 from tornado.util import bytes_type, u, unicode_type, ArgReplacer
 
 try:
@@ -115,7 +123,14 @@ def _auth_return_future(f):
         if callback is not None:
             future.add_done_callback(
                 functools.partial(_auth_future_to_callback, callback))
-        f(*args, **kwargs)
+        def handle_exception(typ, value, tb):
+            if future.done():
+                return False
+            else:
+                future.set_exc_info((typ, value, tb))
+                return True
+        with ExceptionStackContext(handle_exception):
+            f(*args, **kwargs)
         return future
     return wrapper
 
@@ -173,7 +188,7 @@ class OpenIdMixin(object):
         url = self._OPENID_ENDPOINT
         if http_client is None:
             http_client = self.get_auth_http_client()
-        http_client.fetch(url, self.async_callback(
+        http_client.fetch(url, functools.partial(
             self._on_authentication_verified, callback),
             method="POST", body=urllib_parse.urlencode(args))
 
@@ -345,7 +360,7 @@ class OAuthMixin(object):
             http_client.fetch(
                 self._oauth_request_token_url(callback_uri=callback_uri,
                                               extra_params=extra_params),
-                self.async_callback(
+                functools.partial(
                     self._on_request_token,
                     self._OAUTH_AUTHORIZE_URL,
                     callback_uri,
@@ -353,7 +368,7 @@ class OAuthMixin(object):
         else:
             http_client.fetch(
                 self._oauth_request_token_url(),
-                self.async_callback(
+                functools.partial(
                     self._on_request_token, self._OAUTH_AUTHORIZE_URL,
                     callback_uri,
                     callback))
@@ -390,7 +405,7 @@ class OAuthMixin(object):
         if http_client is None:
             http_client = self.get_auth_http_client()
         http_client.fetch(self._oauth_access_token_url(token),
-                          self.async_callback(self._on_access_token, callback))
+                          functools.partial(self._on_access_token, callback))
 
     def _oauth_request_token_url(self, callback_uri=None, extra_params=None):
         consumer_token = self._oauth_consumer_token()
@@ -467,7 +482,7 @@ class OAuthMixin(object):
 
         access_token = _oauth_parse_response(response.body)
         self._oauth_get_user_future(access_token).add_done_callback(
-            self.async_callback(self._on_oauth_get_user, access_token, future))
+            functools.partial(self._on_oauth_get_user, access_token, future))
 
     def _oauth_consumer_token(self):
         """Subclasses must override this to return their OAuth consumer keys.
@@ -652,7 +667,7 @@ class TwitterMixin(OAuthMixin):
         """
         http = self.get_auth_http_client()
         http.fetch(self._oauth_request_token_url(callback_uri=callback_uri),
-                   self.async_callback(
+                   functools.partial(
                        self._on_request_token, self._OAUTH_AUTHENTICATE_URL,
                        None, callback))
 
@@ -710,7 +725,7 @@ class TwitterMixin(OAuthMixin):
         if args:
             url += "?" + urllib_parse.urlencode(args)
         http = self.get_auth_http_client()
-        http_callback = self.async_callback(self._on_twitter_request, callback)
+        http_callback = functools.partial(self._on_twitter_request, callback)
         if post_args is not None:
             http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
                        callback=http_callback)
@@ -827,7 +842,7 @@ class FriendFeedMixin(OAuthMixin):
             args.update(oauth)
         if args:
             url += "?" + urllib_parse.urlencode(args)
-        callback = self.async_callback(self._on_friendfeed_request, callback)
+        callback = functools.partial(self._on_friendfeed_request, callback)
         http = self.get_auth_http_client()
         if post_args is not None:
             http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
@@ -942,7 +957,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
             http = self.get_auth_http_client()
             token = dict(key=token, secret="")
             http.fetch(self._oauth_access_token_url(token),
-                       self.async_callback(self._on_access_token, callback))
+                       functools.partial(self._on_access_token, callback))
         else:
             chain_future(OpenIdMixin.get_authenticated_user(self),
                          callback)
@@ -1014,7 +1029,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin):
         })
 
         http.fetch(self._OAUTH_ACCESS_TOKEN_URL,
-                   self.async_callback(self._on_access_token, callback),
+                   functools.partial(self._on_access_token, callback),
                    method="POST", headers={'Content-Type': 'application/x-www-form-urlencoded'}, body=body)
 
     def _on_access_token(self, future, response):
@@ -1055,7 +1070,7 @@ class FacebookMixin(object):
             @tornado.web.asynchronous
             def get(self):
                 if self.get_argument("session", None):
-                    self.get_authenticated_user(self.async_callback(self._on_auth))
+                    self.get_authenticated_user(self._on_auth)
                     return
                 yield self.authenticate_redirect()
 
@@ -1141,7 +1156,7 @@ class FacebookMixin(object):
         session = escape.json_decode(self.get_argument("session"))
         self.facebook_request(
             method="facebook.users.getInfo",
-            callback=self.async_callback(
+            callback=functools.partial(
                 self._on_get_user_info, callback, session),
             session_key=session["session_key"],
             uids=session["uid"],
@@ -1167,7 +1182,7 @@ class FacebookMixin(object):
                 def get(self):
                     self.facebook_request(
                         method="stream.get",
-                        callback=self.async_callback(self._on_stream),
+                        callback=self._on_stream,
                         session_key=self.current_user["session_key"])
 
                 def _on_stream(self, stream):
@@ -1191,7 +1206,7 @@ class FacebookMixin(object):
         url = "http://api.facebook.com/restserver.php?" + \
             urllib_parse.urlencode(args)
         http = self.get_auth_http_client()
-        http.fetch(url, callback=self.async_callback(
+        http.fetch(url, callback=functools.partial(
             self._parse_response, callback))
 
     def _on_get_user_info(self, callback, session, users):
@@ -1289,7 +1304,7 @@ class FacebookGraphMixin(OAuth2Mixin):
             fields.update(extra_fields)
 
         http.fetch(self._oauth_request_token_url(**args),
-                   self.async_callback(self._on_access_token, redirect_uri, client_id,
+                   functools.partial(self._on_access_token, redirect_uri, client_id,
                                        client_secret, callback, fields))
 
     def _on_access_token(self, redirect_uri, client_id, client_secret,
@@ -1306,7 +1321,7 @@ class FacebookGraphMixin(OAuth2Mixin):
 
         self.facebook_request(
             path="/me",
-            callback=self.async_callback(
+            callback=functools.partial(
                 self._on_get_user_info, future, session, fields),
             access_token=session["access_token"],
             fields=",".join(fields)
@@ -1373,7 +1388,7 @@ class FacebookGraphMixin(OAuth2Mixin):
 
         if all_args:
             url += "?" + urllib_parse.urlencode(all_args)
-        callback = self.async_callback(self._on_facebook_request, callback)
+        callback = functools.partial(self._on_facebook_request, callback)
         http = self.get_auth_http_client()
         if post_args is not None:
             http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
index 1d6cb8392d7a646a1844d27246e4825b55fe3310..254e1ae13c6862c319f1d2b2f952c309111b2765 100644 (file)
@@ -67,11 +67,29 @@ class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
         self.finish(user)
 
     def _oauth_get_user(self, access_token, callback):
+        if self.get_argument('fail_in_get_user', None):
+            raise Exception("failing in get_user")
         if access_token != dict(key='uiop', secret='5678'):
             raise Exception("incorrect access token %r" % access_token)
         callback(dict(email='foo@example.com'))
 
 
+class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
+    """Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
+    @gen.coroutine
+    def get(self):
+        if self.get_argument('oauth_token', None):
+            # Ensure that any exceptions are set on the returned Future,
+            # not simply thrown into the surrounding StackContext.
+            try:
+                yield self.get_authenticated_user()
+            except Exception as e:
+                self.set_status(503)
+                self.write("got exception: %s" % e)
+        else:
+            yield self.authorize_redirect()
+
+
 class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
     def initialize(self, version):
         self._OAUTH_VERSION = version
@@ -255,6 +273,9 @@ class AuthTest(AsyncHTTPTestCase):
                  dict(version='1.0')),
                 ('/oauth10a/client/login', OAuth1ClientLoginHandler,
                  dict(test=self, version='1.0a')),
+                ('/oauth10a/client/login_coroutine',
+                 OAuth1ClientLoginCoroutineHandler,
+                 dict(test=self, version='1.0a')),
                 ('/oauth10a/client/request_params',
                  OAuth1ClientRequestParametersHandler,
                  dict(version='1.0a')),
@@ -348,6 +369,12 @@ class AuthTest(AsyncHTTPTestCase):
         self.assertTrue('oauth_nonce' in parsed)
         self.assertTrue('oauth_signature' in parsed)
 
+    def test_oauth10a_get_user_coroutine_exception(self):
+        response = self.fetch(
+            '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
+            headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+        self.assertEqual(response.code, 503)
+
     def test_oauth2_redirect(self):
         response = self.fetch('/oauth2/client/login', follow_redirects=False)
         self.assertEqual(response.code, 302)