]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a Future-based interface to tornado.auth's twitter_request.
authorBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 21:21:01 +0000 (16:21 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 21:21:01 +0000 (16:21 -0500)
This allows new clients to receive exception messages from failed
requests, without breaking compatibility with existing callback-based
clients.

tornado/auth.py
tornado/test/auth_test.py

index 52900cb1039527ee2d4b59e11cd8941d269878ea..00ed7e73d99161d2b51d4e9cfe90630623eb5a4d 100644 (file)
@@ -48,16 +48,18 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import base64
 import binascii
+import functools
 import hashlib
 import hmac
 import time
 import uuid
 
+from tornado.concurrent import Future
 from tornado import httpclient
 from tornado import escape
 from tornado.httputil import url_concat
 from tornado.log import gen_log
-from tornado.util import bytes_type, u, unicode_type
+from tornado.util import bytes_type, u, unicode_type, ArgReplacer
 
 try:
     import urlparse  # py2
@@ -69,6 +71,35 @@ try:
 except ImportError:
     import urllib as urllib_parse  # py2
 
+class AuthError(Exception):
+    pass
+
+def _auth_future_to_callback(callback, future):
+    try:
+        result = future.result()
+    except AuthError as e:
+        gen_log.warning(str(e))
+        result = None
+    callback(result)
+
+def _auth_return_future(f):
+    """Similar to tornado.concurrent.return_future, but uses the auth
+    module's legacy callback interface.
+
+    Note that when using this decorator the ``callback`` parameter
+    inside the function will actually be a future.
+    """
+    replacer = ArgReplacer(f, 'callback')
+    @functools.wraps(f)
+    def wrapper(*args, **kwargs):
+        future = Future()
+        callback, args, kwargs = replacer.replace(future, args, kwargs)
+        if callback is not None:
+            future.add_done_callback(
+                functools.partial(_auth_future_to_callback, callback))
+        f(*args, **kwargs)
+        return future
+    return wrapper
 
 class OpenIdMixin(object):
     """Abstract implementation of OpenID and Attribute Exchange.
@@ -507,7 +538,8 @@ class TwitterMixin(OAuthMixin):
         http.fetch(self._oauth_request_token_url(callback_uri=callback_uri), self.async_callback(
             self._on_request_token, self._OAUTH_AUTHENTICATE_URL, None))
 
-    def twitter_request(self, path, callback, access_token=None,
+    @_auth_return_future
+    def twitter_request(self, path, callback=None, access_token=None,
                         post_args=None, **args):
         """Fetches the given API path, e.g., "/statuses/user_timeline/btaylor"
 
@@ -562,21 +594,21 @@ class TwitterMixin(OAuthMixin):
             args.update(oauth)
         if args:
             url += "?" + urllib_parse.urlencode(args)
-        callback = self.async_callback(self._on_twitter_request, callback)
         http = self.get_auth_http_client()
+        http_callback = self.async_callback(self._on_twitter_request, callback)
         if post_args is not None:
             http.fetch(url, method="POST", body=urllib_parse.urlencode(post_args),
-                       callback=callback)
+                       callback=http_callback)
         else:
-            http.fetch(url, callback=callback)
+            http.fetch(url, callback=http_callback)
 
-    def _on_twitter_request(self, callback, response):
+    def _on_twitter_request(self, future, response):
         if response.error:
-            gen_log.warning("Error response %s fetching %s", response.error,
-                            response.request.url)
-            callback(None)
+            future.set_exception(AuthError(
+                    "Error response %s fetching %s" % (response.error,
+                                                       response.request.url)))
             return
-        callback(escape.json_decode(response.body))
+        future.set_result(escape.json_decode(response.body))
 
     def _oauth_consumer_token(self):
         self.require_setting("twitter_consumer_key", "Twitter OAuth")
index 0c3a5a278c0548d212eb69f1d05af1467c08c03d..9808a0c289658a9c6ed35029759a31db13d3efc5 100644 (file)
@@ -5,12 +5,13 @@
 
 
 from __future__ import absolute_import, division, print_function, with_statement
-from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin
+from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin, TwitterMixin, GoogleMixin, AuthError
 from tornado.escape import json_decode
 from tornado import gen
-from tornado.testing import AsyncHTTPTestCase
+from tornado.log import gen_log
+from tornado.testing import AsyncHTTPTestCase, ExpectLog
 from tornado.util import u
-from tornado.web import RequestHandler, Application, asynchronous
+from tornado.web import RequestHandler, Application, asynchronous, HTTPError
 
 
 class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
@@ -132,11 +133,32 @@ class TwitterClientShowUserHandler(TwitterClientHandler):
     def get(self):
         # TODO: would be nice to go through the login flow instead of
         # cheating with a hard-coded access token.
-        response = yield gen.Task(self.twitter_request, '/users/show/somebody',
+        response = yield gen.Task(self.twitter_request,
+                                  '/users/show/%s' % self.get_argument('name'),
                                   access_token=dict(key='hjkl', secret='vbnm'))
+        if response is None:
+            self.set_status(500)
+            self.finish('error from twitter request')
+        else:
+            self.finish(response)
+
+
+class TwitterClientShowUserFutureHandler(TwitterClientHandler):
+    @asynchronous
+    @gen.engine
+    def get(self):
+        try:
+            response = yield self.twitter_request(
+                '/users/show/%s' % self.get_argument('name'),
+                access_token=dict(key='hjkl', secret='vbnm'))
+        except AuthError as e:
+            self.set_status(500)
+            self.finish(str(e))
+            return
         assert response is not None
         self.finish(response)
 
+
 class TwitterServerAccessTokenHandler(RequestHandler):
     def get(self):
         self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
@@ -144,6 +166,8 @@ class TwitterServerAccessTokenHandler(RequestHandler):
 
 class TwitterServerShowUserHandler(RequestHandler):
     def get(self, screen_name):
+        if screen_name == 'error':
+            raise HTTPError(500)
         assert 'oauth_nonce' in self.request.arguments
         assert 'oauth_timestamp' in self.request.arguments
         assert 'oauth_signature' in self.request.arguments
@@ -194,6 +218,7 @@ class AuthTest(AsyncHTTPTestCase):
 
                 ('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
                 ('/twitter/client/show_user', TwitterClientShowUserHandler, dict(test=self)),
+                ('/twitter/client/show_user_future', TwitterClientShowUserFutureHandler, dict(test=self)),
                 ('/google/client/openid_login', GoogleOpenIdClientLoginHandler, dict(test=self)),
 
                 # simulated servers
@@ -307,11 +332,28 @@ class AuthTest(AsyncHTTPTestCase):
                           u('username'): u('foo')})
 
     def test_twitter_show_user(self):
-        response = self.fetch('/twitter/client/show_user')
+        response = self.fetch('/twitter/client/show_user?name=somebody')
         response.rethrow()
         self.assertEqual(json_decode(response.body),
                          {'name': 'Somebody', 'screen_name': 'somebody'})
 
+    def test_twitter_show_user_error(self):
+        with ExpectLog(gen_log, 'Error response HTTP 500'):
+            response = self.fetch('/twitter/client/show_user?name=error')
+        self.assertEqual(response.code, 500)
+        self.assertEqual(response.body, b'error from twitter request')
+
+    def test_twitter_show_user_future(self):
+        response = self.fetch('/twitter/client/show_user_future?name=somebody')
+        response.rethrow()
+        self.assertEqual(json_decode(response.body),
+                         {'name': 'Somebody', 'screen_name': 'somebody'})
+
+    def test_twitter_show_user_future_error(self):
+        response = self.fetch('/twitter/client/show_user_future?name=error')
+        self.assertEqual(response.code, 500)
+        self.assertIn(b'Error response HTTP 500', response.body)
+
     def test_google_redirect(self):
         # same as test_openid_redirect
         response = self.fetch('/google/client/openid_login', follow_redirects=False)