]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add crude tests for the auth module, and fix python3 issues with oauth1
authorBen Darnell <ben@bendarnell.com>
Mon, 12 Sep 2011 05:50:13 +0000 (22:50 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 12 Sep 2011 06:15:04 +0000 (23:15 -0700)
tornado/auth.py
tornado/test/auth_test.py [new file with mode: 0644]
tornado/test/runtests.py

index b21175c00f48640238c97e973a2d641902b9dd62..612f820decc20f3ffb7b518769b96bd8c42f44ec 100644 (file)
@@ -50,7 +50,6 @@ Example usage for Google OpenID::
 
 import base64
 import binascii
-import cgi
 import hashlib
 import hmac
 import logging
@@ -85,7 +84,7 @@ class OpenIdMixin(object):
         args = self._openid_args(callback_uri, ax_attrs=ax_attrs)
         self.redirect(self._OPENID_ENDPOINT + "?" + urllib.urlencode(args))
 
-    def get_authenticated_user(self, callback):
+    def get_authenticated_user(self, callback, http_client=None):
         """Fetches the authenticated user data upon redirect.
 
         This method should be called by the handler that receives the
@@ -96,8 +95,8 @@ class OpenIdMixin(object):
         args = dict((k, v[-1]) for k, v in self.request.arguments.iteritems())
         args["openid.mode"] = u"check_authentication"
         url = self._OPENID_ENDPOINT
-        http = httpclient.AsyncHTTPClient()
-        http.fetch(url, self.async_callback(
+        if http_client is None: http_client = httpclient.AsyncHTTPClient()
+        http_client.fetch(url, self.async_callback(
             self._on_authentication_verified, callback),
             method="POST", body=urllib.urlencode(args))
 
@@ -207,7 +206,8 @@ class OAuthMixin(object):
     See TwitterMixin and FriendFeedMixin below for example implementations.
     """
 
-    def authorize_redirect(self, callback_uri=None, extra_params=None):
+    def authorize_redirect(self, callback_uri=None, extra_params=None,
+                           http_client=None):
         """Redirects the user to obtain OAuth authorization for this service.
 
         Twitter and FriendFeed both require that you register a Callback
@@ -222,20 +222,25 @@ class OAuthMixin(object):
         """
         if callback_uri and getattr(self, "_OAUTH_NO_CALLBACKS", False):
             raise Exception("This service does not support oauth_callback")
-        http = httpclient.AsyncHTTPClient()
+        if http_client is None:
+            http_client = httpclient.AsyncHTTPClient()
         if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
-            http.fetch(self._oauth_request_token_url(callback_uri=callback_uri,
-                extra_params=extra_params),
+            http_client.fetch(
+                self._oauth_request_token_url(callback_uri=callback_uri,
+                                              extra_params=extra_params),
                 self.async_callback(
                     self._on_request_token,
                     self._OAUTH_AUTHORIZE_URL,
                 callback_uri))
         else:
-            http.fetch(self._oauth_request_token_url(), self.async_callback(
-                self._on_request_token, self._OAUTH_AUTHORIZE_URL, callback_uri))
+            http_client.fetch(
+                self._oauth_request_token_url(),
+                self.async_callback(
+                    self._on_request_token, self._OAUTH_AUTHORIZE_URL,
+                    callback_uri))
 
 
-    def get_authenticated_user(self, callback):
+    def get_authenticated_user(self, callback, http_client=None):
         """Gets the OAuth authorized user and access token on callback.
 
         This method should be called from the handler for your registered
@@ -246,7 +251,7 @@ class OAuthMixin(object):
         to this service on behalf of the user.
 
         """
-        request_key = self.get_argument("oauth_token")
+        request_key = escape.utf8(self.get_argument("oauth_token"))
         oauth_verifier = self.get_argument("oauth_verifier", None)
         request_cookie = self.get_cookie("_oauth_request_token")
         if not request_cookie:
@@ -254,17 +259,19 @@ class OAuthMixin(object):
             callback(None)
             return
         self.clear_cookie("_oauth_request_token")
-        cookie_key, cookie_secret = [base64.b64decode(i) for i in request_cookie.split("|")]
+        cookie_key, cookie_secret = [base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")]
         if cookie_key != request_key:
+            logging.info((cookie_key, request_key, request_cookie))
             logging.warning("Request token does not match cookie")
             callback(None)
             return
         token = dict(key=cookie_key, secret=cookie_secret)
         if oauth_verifier:
-          token["verifier"] = oauth_verifier
-        http = httpclient.AsyncHTTPClient()
-        http.fetch(self._oauth_access_token_url(token), self.async_callback(
-            self._on_access_token, callback))
+            token["verifier"] = oauth_verifier
+        if http_client is None:
+            http_client = httpclient.AsyncHTTPClient()
+        http_client.fetch(self._oauth_access_token_url(token),
+                          self.async_callback(self._on_access_token, callback))
 
     def _oauth_request_token_url(self, callback_uri= None, extra_params=None):
         consumer_token = self._oauth_consumer_token()
@@ -292,8 +299,8 @@ class OAuthMixin(object):
         if response.error:
             raise Exception("Could not get request token")
         request_token = _oauth_parse_response(response.body)
-        data = "|".join([base64.b64encode(request_token["key"]),
-            base64.b64encode(request_token["secret"])])
+        data = (base64.b64encode(request_token["key"]) + b("|") +
+                base64.b64encode(request_token["secret"]))
         self.set_cookie("_oauth_request_token", data)
         args = dict(oauth_token=request_token["key"])
         if callback_uri:
@@ -1078,11 +1085,11 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None):
                                for k, v in sorted(parameters.items())))
     base_string =  "&".join(_oauth_escape(e) for e in base_elems)
 
-    key_elems = [consumer_token["secret"]]
-    key_elems.append(token["secret"] if token else "")
-    key = "&".join(key_elems)
+    key_elems = [escape.utf8(consumer_token["secret"])]
+    key_elems.append(escape.utf8(token["secret"] if token else ""))
+    key = b("&").join(key_elems)
 
-    hash = hmac.new(key, base_string, hashlib.sha1)
+    hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
     return binascii.b2a_base64(hash.digest())[:-1]
 
 def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
@@ -1101,11 +1108,11 @@ def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None):
                                for k, v in sorted(parameters.items())))
 
     base_string =  "&".join(_oauth_escape(e) for e in base_elems)
-    key_elems = [urllib.quote(consumer_token["secret"], safe='~')]
-    key_elems.append(urllib.quote(token["secret"], safe='~') if token else "")
-    key = "&".join(key_elems)
+    key_elems = [escape.utf8(urllib.quote(consumer_token["secret"], safe='~'))]
+    key_elems.append(escape.utf8(urllib.quote(token["secret"], safe='~') if token else ""))
+    key = b("&").join(key_elems)
 
-    hash = hmac.new(key, base_string, hashlib.sha1)
+    hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
     return binascii.b2a_base64(hash.digest())[:-1]
 
 def _oauth_escape(val):
@@ -1115,11 +1122,11 @@ def _oauth_escape(val):
 
 
 def _oauth_parse_response(body):
-    p = cgi.parse_qs(body, keep_blank_values=False)
-    token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0])
+    p = escape.parse_qs(body, keep_blank_values=False)
+    token = dict(key=p[b("oauth_token")][0], secret=p[b("oauth_token_secret")][0])
 
     # Add the extra parameters the Provider included to the token
-    special = ("oauth_token", "oauth_token_secret")
+    special = (b("oauth_token"), b("oauth_token_secret"))
     token.update((k, p[k][0]) for k in p if k not in special)
     return token
 
diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py
new file mode 100644 (file)
index 0000000..2047904
--- /dev/null
@@ -0,0 +1,186 @@
+# These tests do not currently do much to verify the correct implementation
+# of the openid/oauth protocols, they just exercise the major code paths
+# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
+# python 3)
+
+from tornado.auth import OpenIdMixin, OAuthMixin, OAuth2Mixin
+from tornado.escape import json_decode
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.util import b
+from tornado.web import RequestHandler, Application, asynchronous
+
+class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
+    def initialize(self, test):
+        self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
+
+    @asynchronous
+    def get(self):
+        if self.get_argument('openid.mode', None):
+            self.get_authenticated_user(
+                self.on_user, http_client=self.settings['http_client'])
+            return
+        self.authenticate_redirect()
+
+    def on_user(self, user):
+        assert user is not None
+        self.finish(user)
+
+class OpenIdServerAuthenticateHandler(RequestHandler):
+    def post(self):
+        assert self.get_argument('openid.mode') == 'check_authentication'
+        self.write('is_valid:true')
+
+class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
+    def initialize(self, test, version):
+        self._OAUTH_VERSION = version
+        self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
+        self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
+        self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
+
+    def _oauth_consumer_token(self):
+        return dict(key='asdf', secret='qwer')
+
+    @asynchronous
+    def get(self):
+        if self.get_argument('oauth_token', None):
+            self.get_authenticated_user(
+                self.on_user, http_client=self.settings['http_client'])
+            return
+        self.authorize_redirect(http_client=self.settings['http_client'])
+
+    def on_user(self, user):
+        assert user is not None
+        self.finish(user)
+
+    def _oauth_get_user(self, access_token, callback):
+        assert access_token == dict(key=b('uiop'), secret=b('5678')), access_token
+        callback(dict(email='foo@example.com'))
+
+class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
+    def initialize(self, version):
+        self._OAUTH_VERSION = version
+
+    def _oauth_consumer_token(self):
+        return dict(key='asdf', secret='qwer')
+
+    def get(self):
+        params = self._oauth_request_parameters(
+            'http://www.example.com/api/asdf',
+            dict(key='uiop', secret='5678'),
+            parameters=dict(foo='bar'))
+        import urllib; urllib.urlencode(params)
+        self.write(params)
+
+class OAuth1ServerRequestTokenHandler(RequestHandler):
+    def get(self):
+        self.write('oauth_token=zxcv&oauth_token_secret=1234')
+
+class OAuth1ServerAccessTokenHandler(RequestHandler):
+    def get(self):
+        self.write('oauth_token=uiop&oauth_token_secret=5678')
+
+class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
+    def initialize(self, test):
+        self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
+
+    def get(self):
+        self.authorize_redirect()
+
+
+class AuthTest(AsyncHTTPTestCase, LogTrapTestCase):
+    def get_app(self):
+        return Application(
+            [
+                # test endpoints
+                ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
+                ('/oauth10/client/login', OAuth1ClientLoginHandler,
+                 dict(test=self, version='1.0')),
+                ('/oauth10/client/request_params',
+                 OAuth1ClientRequestParametersHandler,
+                 dict(version='1.0')),
+                ('/oauth10a/client/login', OAuth1ClientLoginHandler,
+                 dict(test=self, version='1.0a')),
+                ('/oauth10a/client/request_params',
+                 OAuth1ClientRequestParametersHandler,
+                 dict(version='1.0a')),
+                ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
+
+                # simulated servers
+                ('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
+                ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
+                ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
+                ],
+            http_client=self.http_client)
+
+    def test_openid_redirect(self):
+        response = self.fetch('/openid/client/login', follow_redirects=False)
+        self.assertEqual(response.code, 302)
+        self.assertTrue(
+            '/openid/server/authenticate?' in response.headers['Location'])
+
+    def test_openid_get_user(self):
+        response = self.fetch('/openid/client/login?openid.mode=blah&openid.ns.ax=http://openid.net/srv/ax/1.0&openid.ax.type.email=http://axschema.org/contact/email&openid.ax.value.email=foo@example.com')
+        response.rethrow()
+        parsed = json_decode(response.body)
+        self.assertEqual(parsed["email"], "foo@example.com")
+
+    def test_oauth10_redirect(self):
+        response = self.fetch('/oauth10/client/login', follow_redirects=False)
+        self.assertEqual(response.code, 302)
+        self.assertTrue(response.headers['Location'].endswith(
+            '/oauth1/server/authorize?oauth_token=zxcv'))
+        # the cookie is base64('zxcv')|base64('1234')
+        self.assertTrue(
+            '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
+            response.headers['Set-Cookie'])
+
+    def test_oauth10_get_user(self):
+        response = self.fetch(
+            '/oauth10/client/login?oauth_token=zxcv',
+            headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='})
+        response.rethrow()
+        parsed = json_decode(response.body)
+        self.assertEqual(parsed['email'], 'foo@example.com')
+        self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+
+    def test_oauth10_request_parameters(self):
+        response = self.fetch('/oauth10/client/request_params')
+        response.rethrow()
+        parsed = json_decode(response.body)
+        self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
+        self.assertEqual(parsed['oauth_token'], 'uiop')
+        self.assertTrue('oauth_nonce' in parsed)
+        self.assertTrue('oauth_signature' in parsed)
+
+    def test_oauth10a_redirect(self):
+        response = self.fetch('/oauth10a/client/login', follow_redirects=False)
+        self.assertEqual(response.code, 302)
+        self.assertTrue(response.headers['Location'].endswith(
+            '/oauth1/server/authorize?oauth_token=zxcv'))
+        # the cookie is base64('zxcv')|base64('1234')
+        self.assertTrue(
+            '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
+            response.headers['Set-Cookie'])
+
+    def test_oauth10a_get_user(self):
+        response = self.fetch(
+            '/oauth10a/client/login?oauth_token=zxcv',
+            headers={'Cookie':'_oauth_request_token=enhjdg==|MTIzNA=='})
+        response.rethrow()
+        parsed = json_decode(response.body)
+        self.assertEqual(parsed['email'], 'foo@example.com')
+        self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+
+    def test_oauth10a_request_parameters(self):
+        response = self.fetch('/oauth10a/client/request_params')
+        response.rethrow()
+        parsed = json_decode(response.body)
+        self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
+        self.assertEqual(parsed['oauth_token'], 'uiop')
+        self.assertTrue('oauth_nonce' in parsed)
+        self.assertTrue('oauth_signature' in parsed)
+
+    def test_oauth2_redirect(self):
+        response = self.fetch('/oauth2/client/login', follow_redirects=False)
+        self.assertEqual(response.code, 302)
+        self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
index 14f161a2b0a953166dbb0f85f64a2126362a33b5..188aee8233dec6656767184101d2ab2f6bf15b93 100755 (executable)
@@ -5,6 +5,7 @@ TEST_MODULES = [
     'tornado.httputil.doctests',
     'tornado.iostream.doctests',
     'tornado.util.doctests',
+    'tornado.test.auth_test',
     'tornado.test.curl_httpclient_test',
     'tornado.test.escape_test',
     'tornado.test.gen_test',