]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add future interface to most of the rest of the auth module.
authorBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 22:13:45 +0000 (17:13 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 22:43:45 +0000 (17:43 -0500)
_oauth_get_user still uses callbacks internally, and I haven't
changed the old deprecated FacebookMixin.

tornado/auth.py
tornado/concurrent.py

index 00ed7e73d99161d2b51d4e9cfe90630623eb5a4d..96d6edb7245d60bea2543e4d8c771c45d50db593 100644 (file)
@@ -54,7 +54,7 @@ import hmac
 import time
 import uuid
 
-from tornado.concurrent import Future
+from tornado.concurrent import Future, chain_future
 from tornado import httpclient
 from tornado import escape
 from tornado.httputil import url_concat
@@ -122,6 +122,7 @@ class OpenIdMixin(object):
         args = self._openid_args(callback_uri, ax_attrs=ax_attrs)
         self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
 
+    @_auth_return_future
     def get_authenticated_user(self, callback, http_client=None):
         """Fetches the authenticated user data upon redirect.
 
@@ -187,11 +188,11 @@ class OpenIdMixin(object):
             })
         return args
 
-    def _on_authentication_verified(self, callback, response):
+    def _on_authentication_verified(self, future, response):
         if response.error or b"is_valid:true" not in response.body:
-            gen_log.warning("Invalid OpenID response: %s", response.error or
-                            response.body)
-            callback(None)
+            future.set_exception(AuthError(
+                    "Invalid OpenID response: %s" % (response.error or
+                                                     response.body)))
             return
 
         # Make sure we got back at least an email from attribute exchange
@@ -245,7 +246,7 @@ class OpenIdMixin(object):
         claimed_id = self.get_argument("openid.claimed_id", None)
         if claimed_id:
             user["claimed_id"] = claimed_id
-        callback(user)
+        future.set_result(user)
 
     def get_auth_http_client(self):
         """Returns the AsyncHTTPClient instance to be used for auth requests.
@@ -295,6 +296,7 @@ class OAuthMixin(object):
                     self._on_request_token, self._OAUTH_AUTHORIZE_URL,
                     callback_uri))
 
+    @_auth_return_future
     def get_authenticated_user(self, callback, http_client=None):
         """Gets the OAuth authorized user and access token on callback.
 
@@ -306,19 +308,19 @@ class OAuthMixin(object):
         to this service on behalf of the user.
 
         """
+        future = callback
         request_key = escape.utf8(self.get_argument("oauth_token"))
         oauth_verifier = self.get_argument("oauth_verifier", None)
         request_cookie = self.get_cookie("_oauth_request_token")
         if not request_cookie:
-            gen_log.warning("Missing OAuth request token cookie")
-            callback(None)
+            future.set_exception(AuthError(
+                    "Missing OAuth request token cookie"))
             return
         self.clear_cookie("_oauth_request_token")
         cookie_key, cookie_secret = [base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")]
         if cookie_key != request_key:
-            gen_log.info((cookie_key, request_key, request_cookie))
-            gen_log.warning("Request token does not match cookie")
-            callback(None)
+            future.set_exception(AuthError(
+                    "Request token does not match cookie"))
             return
         token = dict(key=cookie_key, secret=cookie_secret)
         if oauth_verifier:
@@ -393,25 +395,24 @@ class OAuthMixin(object):
         args["oauth_signature"] = signature
         return url + "?" + urllib_parse.urlencode(args)
 
-    def _on_access_token(self, callback, response):
+    def _on_access_token(self, future, response):
         if response.error:
-            gen_log.warning("Could not fetch access token")
-            callback(None)
+            future.set_exception(AuthError("Could not fetch access token"))
             return
 
         access_token = _oauth_parse_response(response.body)
         self._oauth_get_user(access_token, self.async_callback(
-                             self._on_oauth_get_user, access_token, callback))
+                             self._on_oauth_get_user, access_token, future))
 
     def _oauth_get_user(self, access_token, callback):
         raise NotImplementedError()
 
-    def _on_oauth_get_user(self, access_token, callback, user):
+    def _on_oauth_get_user(self, access_token, future, user):
         if not user:
-            callback(None)
+            future.set_exception(AuthError("Error getting user"))
             return
         user["access_token"] = access_token
-        callback(user)
+        future.set_result(user)
 
     def _oauth_request_parameters(self, url, access_token, parameters={},
                                   method="GET"):
@@ -669,6 +670,7 @@ class FriendFeedMixin(OAuthMixin):
     _OAUTH_NO_CALLBACKS = True
     _OAUTH_VERSION = "1.0"
 
+    @_auth_return_future
     def friendfeed_request(self, path, callback, access_token=None,
                            post_args=None, **args):
         """Fetches the given relative API path, e.g., "/bret/friends"
@@ -724,13 +726,13 @@ class FriendFeedMixin(OAuthMixin):
         else:
             http.fetch(url, callback=callback)
 
-    def _on_friendfeed_request(self, callback, response):
+    def _on_friendfeed_request(self, future, response):
         if response.error:
-            gen_log.warning("Error response %s fetching %s", response.error,
-                            response.request.url)
-            callback(None)
+            future.set_exception(AuthError(
+                    "Error response %s fetching %s" % (response.error,
+                                                       response.request.url)))
             return
-        callback(escape.json_decode(response.body))
+        future.set_result(escape.json_decode(response.body))
 
     def _oauth_consumer_token(self):
         self.require_setting("friendfeed_consumer_key", "FriendFeed OAuth")
@@ -797,6 +799,7 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
                                  oauth_scope=oauth_scope)
         self.redirect(self._OPENID_ENDPOINT + "?" + urllib_parse.urlencode(args))
 
+    @_auth_return_future
     def get_authenticated_user(self, callback):
         """Fetches the authenticated user data upon redirect."""
         # Look to see if we are doing combined OpenID/OAuth
@@ -813,7 +816,8 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
             http.fetch(self._oauth_access_token_url(token),
                        self.async_callback(self._on_access_token, callback))
         else:
-            OpenIdMixin.get_authenticated_user(self, callback)
+            chain_future(OpenIdMixin.get_authenticated_user(self),
+                         callback)
 
     def _oauth_consumer_token(self):
         self.require_setting("google_consumer_key", "Google OAuth")
@@ -829,8 +833,9 @@ class GoogleMixin(OpenIdMixin, OAuthMixin):
 class FacebookMixin(object):
     """Facebook Connect authentication.
 
-    New applications should consider using `FacebookGraphMixin` below instead
-    of this class.
+    *Deprecated:* New applications should use `FacebookGraphMixin`
+    below instead of this class.  This class does not support the
+    Future-based interface seen on other classes in this module.
 
     To authenticate with Facebook, register your application with
     Facebook at http://www.facebook.com/developers/apps.php. Then
index d73a59c9e9af174f8add34b7ddf6f0f3db07c119..0e78b916d4eaa13a3b9aa2a62b9c71bfe78aebf9 100644 (file)
@@ -169,3 +169,16 @@ def return_future(f):
             raise_exc_info(exc_info)
         return future
     return wrapper
+
+def chain_future(a, b):
+    """Chain two futures together so that when one completes, so does the other.
+
+    The result (success or failure) of ``a`` will be copied to ``b``.
+    """
+    def copy(future):
+        assert future is a
+        if a.exception() is not None:
+            b.set_exception(a.exception())
+        else:
+            b.set_result(a.result())
+    a.add_done_callback(copy)