From: Ben Darnell Date: Sun, 30 Sep 2018 16:34:05 +0000 (-0400) Subject: auth: Add type annotations X-Git-Tag: v6.0.0b1~28^2~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7b846ea56bff1892a4d4d05206210b4d234e292b;p=thirdparty%2Ftornado.git auth: Add type annotations This is a bit hacky and limited because mypy doesn't handle mixins very well and there's a lot of untypeable json here, but it's better than nothing. --- diff --git a/setup.cfg b/setup.cfg index 213bc1186..4eb80d92d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ python_version = 3.5 [mypy-tornado.*,tornado.platform.*] disallow_untyped_defs = True -[mypy-tornado.auth,tornado.websocket,tornado.wsgi] +[mypy-tornado.websocket,tornado.wsgi] disallow_untyped_defs = False # It's generally too tedious to require type annotations in tests, but diff --git a/tornado/auth.py b/tornado/auth.py index 35d5ee203..44ea5fbcf 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -68,6 +68,9 @@ from tornado import httpclient from tornado import escape from tornado.httputil import url_concat from tornado.util import unicode_type +from tornado.web import RequestHandler + +from typing import List, Any, Dict, cast, Iterable, Union, Optional class AuthError(Exception): @@ -81,8 +84,9 @@ class OpenIdMixin(object): * ``_OPENID_ENDPOINT``: the identity provider's URI. """ - def authenticate_redirect(self, callback_uri=None, - ax_attrs=["name", "email", "language", "username"]): + def authenticate_redirect( + self, callback_uri: str=None, + ax_attrs: List[str]=["name", "email", "language", "username"]) -> None: """Redirects to the authentication URL for this service. After authentication, the service will redirect back to the given @@ -99,11 +103,16 @@ class OpenIdMixin(object): longer returns an awaitable object. It is now an ordinary synchronous function. """ - callback_uri = callback_uri or self.request.uri + handler = cast(RequestHandler, self) + callback_uri = callback_uri or handler.request.uri + assert callback_uri is not None args = self._openid_args(callback_uri, ax_attrs=ax_attrs) - self.redirect(self._OPENID_ENDPOINT + "?" + urllib.parse.urlencode(args)) + endpoint = self._OPENID_ENDPOINT # type: ignore + handler.redirect(endpoint + "?" + urllib.parse.urlencode(args)) - async def get_authenticated_user(self, http_client=None): + async def get_authenticated_user( + self, http_client: httpclient.AsyncHTTPClient=None + ) -> Dict[str, Any]: """Fetches the authenticated user data upon redirect. This method should be called by the handler that receives the @@ -119,17 +128,21 @@ class OpenIdMixin(object): The ``callback`` argument was removed. Use the returned awaitable object instead. """ + handler = cast(RequestHandler, self) # Verify the OpenID response via direct request to the OP - args = dict((k, v[-1]) for k, v in self.request.arguments.items()) + args = dict((k, v[-1]) for k, v in handler.request.arguments.items()) \ + # type: Dict[str, Union[str, bytes]] args["openid.mode"] = u"check_authentication" - url = self._OPENID_ENDPOINT + url = self._OPENID_ENDPOINT # type: ignore if http_client is None: http_client = self.get_auth_http_client() resp = await http_client.fetch(url, method="POST", body=urllib.parse.urlencode(args)) return self._on_authentication_verified(resp) - def _openid_args(self, callback_uri, ax_attrs=[], oauth_scope=None): - url = urllib.parse.urljoin(self.request.full_url(), callback_uri) + def _openid_args(self, callback_uri: str, ax_attrs: Iterable[str]=[], + oauth_scope: str=None) -> Dict[str, str]: + handler = cast(RequestHandler, self) + url = urllib.parse.urljoin(handler.request.full_url(), callback_uri) args = { "openid.ns": "http://specs.openid.net/auth/2.0", "openid.claimed_id": @@ -146,7 +159,7 @@ class OpenIdMixin(object): "openid.ax.mode": "fetch_request", }) ax_attrs = set(ax_attrs) - required = [] + required = [] # type: List[str] if "name" in ax_attrs: ax_attrs -= set(["name", "firstname", "fullname", "lastname"]) required += ["firstname", "fullname", "lastname"] @@ -171,36 +184,37 @@ class OpenIdMixin(object): args.update({ "openid.ns.oauth": "http://specs.openid.net/extensions/oauth/1.0", - "openid.oauth.consumer": self.request.host.split(":")[0], + "openid.oauth.consumer": handler.request.host.split(":")[0], "openid.oauth.scope": oauth_scope, }) return args - def _on_authentication_verified(self, response): + def _on_authentication_verified(self, response: httpclient.HTTPResponse) -> Dict[str, Any]: + handler = cast(RequestHandler, self) if b"is_valid:true" not in response.body: raise AuthError("Invalid OpenID response: %s" % response.body) # Make sure we got back at least an email from attribute exchange ax_ns = None - for name in self.request.arguments: - if name.startswith("openid.ns.") and \ - self.get_argument(name) == u"http://openid.net/srv/ax/1.0": - ax_ns = name[10:] + for key in handler.request.arguments: + if key.startswith("openid.ns.") and \ + handler.get_argument(key) == u"http://openid.net/srv/ax/1.0": + ax_ns = key[10:] break - def get_ax_arg(uri): + def get_ax_arg(uri: str) -> str: if not ax_ns: return u"" prefix = "openid." + ax_ns + ".type." ax_name = None - for name in self.request.arguments.keys(): - if self.get_argument(name) == uri and name.startswith(prefix): + for name in handler.request.arguments.keys(): + if handler.get_argument(name) == uri and name.startswith(prefix): part = name[len(prefix):] ax_name = "openid." + ax_ns + ".value." + part break if not ax_name: return u"" - return self.get_argument(ax_name, u"") + return handler.get_argument(ax_name, u"") email = get_ax_arg("http://axschema.org/contact/email") name = get_ax_arg("http://axschema.org/namePerson") @@ -228,12 +242,12 @@ class OpenIdMixin(object): user["locale"] = locale if username: user["username"] = username - claimed_id = self.get_argument("openid.claimed_id", None) + claimed_id = handler.get_argument("openid.claimed_id", None) if claimed_id: user["claimed_id"] = claimed_id return user - def get_auth_http_client(self): + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: """Returns the `.AsyncHTTPClient` instance to be used for auth requests. May be overridden by subclasses to use an HTTP client other than @@ -258,8 +272,8 @@ class OAuthMixin(object): Subclasses must also override the `_oauth_get_user_future` and `_oauth_consumer_token` methods. """ - async def authorize_redirect(self, callback_uri=None, extra_params=None, - http_client=None): + async def authorize_redirect(self, callback_uri: str=None, extra_params: Dict[str, Any]=None, + http_client: httpclient.AsyncHTTPClient=None) -> None: """Redirects the user to obtain OAuth authorization for this service. The ``callback_uri`` may be omitted if you have previously @@ -291,15 +305,19 @@ class OAuthMixin(object): raise Exception("This service does not support oauth_callback") if http_client is None: http_client = self.get_auth_http_client() + assert http_client is not None if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": response = await http_client.fetch( self._oauth_request_token_url(callback_uri=callback_uri, extra_params=extra_params)) else: response = await http_client.fetch(self._oauth_request_token_url()) - self._on_request_token(self._OAUTH_AUTHORIZE_URL, callback_uri, response) + url = self._OAUTH_AUTHORIZE_URL # type: ignore + self._on_request_token(url, callback_uri, response) - async def get_authenticated_user(self, http_client=None): + async def get_authenticated_user( + self, http_client: httpclient.AsyncHTTPClient=None + ) -> Dict[str, Any]: """Gets the OAuth authorized user and access token. This method should be called from the handler for your @@ -315,21 +333,23 @@ class OAuthMixin(object): The ``callback`` argument was removed. Use the returned awaitable object instead. """ - request_key = escape.utf8(self.get_argument("oauth_token")) - oauth_verifier = self.get_argument("oauth_verifier", None) - request_cookie = self.get_cookie("_oauth_request_token") + handler = cast(RequestHandler, self) + request_key = escape.utf8(handler.get_argument("oauth_token")) + oauth_verifier = handler.get_argument("oauth_verifier", None) + request_cookie = handler.get_cookie("_oauth_request_token") if not request_cookie: raise AuthError("Missing OAuth request token cookie") - self.clear_cookie("_oauth_request_token") + handler.clear_cookie("_oauth_request_token") cookie_key, cookie_secret = [ base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")] if cookie_key != request_key: raise AuthError("Request token does not match cookie") - token = dict(key=cookie_key, secret=cookie_secret) + token = dict(key=cookie_key, secret=cookie_secret) # type: Dict[str, Union[str, bytes]] if oauth_verifier: token["verifier"] = oauth_verifier if http_client is None: http_client = self.get_auth_http_client() + assert http_client is not None response = await http_client.fetch(self._oauth_access_token_url(token)) access_token = _oauth_parse_response(response.body) user = await self._oauth_get_user_future(access_token) @@ -338,9 +358,11 @@ class OAuthMixin(object): user["access_token"] = access_token return user - def _oauth_request_token_url(self, callback_uri=None, extra_params=None): + def _oauth_request_token_url(self, callback_uri: str=None, + extra_params: Dict[str, Any]=None) -> str: + handler = cast(RequestHandler, self) consumer_token = self._oauth_consumer_token() - url = self._OAUTH_REQUEST_TOKEN_URL + url = self._OAUTH_REQUEST_TOKEN_URL # type: ignore args = dict( oauth_consumer_key=escape.to_basestring(consumer_token["key"]), oauth_signature_method="HMAC-SHA1", @@ -353,7 +375,7 @@ class OAuthMixin(object): args["oauth_callback"] = "oob" elif callback_uri: args["oauth_callback"] = urllib.parse.urljoin( - self.request.full_url(), callback_uri) + handler.request.full_url(), callback_uri) if extra_params: args.update(extra_params) signature = _oauth10a_signature(consumer_token, "GET", url, args) @@ -363,23 +385,25 @@ class OAuthMixin(object): args["oauth_signature"] = signature return url + "?" + urllib.parse.urlencode(args) - def _on_request_token(self, authorize_url, callback_uri, response): + def _on_request_token(self, authorize_url: str, callback_uri: Optional[str], + response: httpclient.HTTPResponse) -> None: + handler = cast(RequestHandler, self) request_token = _oauth_parse_response(response.body) data = (base64.b64encode(escape.utf8(request_token["key"])) + b"|" + base64.b64encode(escape.utf8(request_token["secret"]))) - self.set_cookie("_oauth_request_token", data) + handler.set_cookie("_oauth_request_token", data) args = dict(oauth_token=request_token["key"]) if callback_uri == "oob": - self.finish(authorize_url + "?" + urllib.parse.urlencode(args)) + handler.finish(authorize_url + "?" + urllib.parse.urlencode(args)) return elif callback_uri: args["oauth_callback"] = urllib.parse.urljoin( - self.request.full_url(), callback_uri) - self.redirect(authorize_url + "?" + urllib.parse.urlencode(args)) + handler.request.full_url(), callback_uri) + handler.redirect(authorize_url + "?" + urllib.parse.urlencode(args)) - def _oauth_access_token_url(self, request_token): + def _oauth_access_token_url(self, request_token: Dict[str, Any]) -> str: consumer_token = self._oauth_consumer_token() - url = self._OAUTH_ACCESS_TOKEN_URL + url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore args = dict( oauth_consumer_key=escape.to_basestring(consumer_token["key"]), oauth_token=escape.to_basestring(request_token["key"]), @@ -401,14 +425,14 @@ class OAuthMixin(object): args["oauth_signature"] = signature return url + "?" + urllib.parse.urlencode(args) - def _oauth_consumer_token(self): + def _oauth_consumer_token(self) -> Dict[str, Any]: """Subclasses must override this to return their OAuth consumer keys. The return value should be a `dict` with keys ``key`` and ``secret``. """ raise NotImplementedError() - async def _oauth_get_user_future(self, access_token): + async def _oauth_get_user_future(self, access_token: Dict[str, Any]) -> Dict[str, Any]: """Subclasses must override this to get basic information about the user. @@ -430,8 +454,9 @@ class OAuthMixin(object): """ raise NotImplementedError() - def _oauth_request_parameters(self, url, access_token, parameters={}, - method="GET"): + def _oauth_request_parameters(self, url: str, access_token: Dict[str, Any], + parameters: Dict[str, Any]={}, + method: str="GET") -> Dict[str, Any]: """Returns the OAuth parameters as a dict for the given request. parameters should include all POST arguments and query string arguments @@ -458,7 +483,7 @@ class OAuthMixin(object): base_args["oauth_signature"] = escape.to_basestring(signature) return base_args - def get_auth_http_client(self): + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: """Returns the `.AsyncHTTPClient` instance to be used for auth requests. May be overridden by subclasses to use an HTTP client other than @@ -478,9 +503,9 @@ class OAuth2Mixin(object): * ``_OAUTH_AUTHORIZE_URL``: The service's authorization url. * ``_OAUTH_ACCESS_TOKEN_URL``: The service's access token url. """ - def authorize_redirect(self, redirect_uri=None, client_id=None, - client_secret=None, extra_params=None, - scope=None, response_type="code"): + def authorize_redirect(self, redirect_uri: str=None, client_id: str=None, + client_secret: str=None, extra_params: Dict[str, Any]=None, + scope: str=None, response_type: str="code") -> None: """Redirects the user to obtain OAuth authorization for this service. Some providers require that you register a redirect URL with @@ -494,34 +519,40 @@ class OAuth2Mixin(object): The ``callback`` argument and returned awaitable were removed; this is now an ordinary synchronous function. """ + handler = cast(RequestHandler, self) args = { - "redirect_uri": redirect_uri, - "client_id": client_id, "response_type": response_type } + if redirect_uri is not None: + args["redirect_uri"] = redirect_uri + if client_id is not None: + args["client_id"] = client_id if extra_params: args.update(extra_params) if scope: args['scope'] = ' '.join(scope) - self.redirect( - url_concat(self._OAUTH_AUTHORIZE_URL, args)) - - def _oauth_request_token_url(self, redirect_uri=None, client_id=None, - client_secret=None, code=None, - extra_params=None): - url = self._OAUTH_ACCESS_TOKEN_URL - args = dict( - redirect_uri=redirect_uri, - code=code, - client_id=client_id, - client_secret=client_secret, - ) + url = self._OAUTH_AUTHORIZE_URL # type: ignore + handler.redirect(url_concat(url, args)) + + def _oauth_request_token_url(self, redirect_uri: str=None, client_id: str=None, + client_secret: str=None, code: str=None, + extra_params: Dict[str, Any]=None) -> str: + url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore + args = {} # type: Dict[str, str] + if redirect_uri is not None: + args["redirect_uri"] = redirect_uri + if code is not None: + args["code"] = code + if client_id is not None: + args["client_id"] = client_id + if client_secret is not None: + args["client_secret"] = client_secret if extra_params: args.update(extra_params) return url_concat(url, args) - async def oauth2_request(self, url, access_token=None, - post_args=None, **args): + async def oauth2_request(self, url: str, access_token: str=None, + post_args: Dict[str, Any]=None, **args: Any) -> Any: """Fetches the given URL auth an OAuth2 access token. If the request is a POST, ``post_args`` should be provided. Query @@ -569,7 +600,7 @@ class OAuth2Mixin(object): response = await http.fetch(url) return escape.json_decode(response.body) - def get_auth_http_client(self): + def get_auth_http_client(self) -> httpclient.AsyncHTTPClient: """Returns the `.AsyncHTTPClient` instance to be used for auth requests. May be overridden by subclasses to use an HTTP client other than @@ -619,7 +650,7 @@ class TwitterMixin(OAuthMixin): _OAUTH_NO_CALLBACKS = False _TWITTER_BASE_URL = "https://api.twitter.com/1.1" - async def authenticate_redirect(self, callback_uri=None): + async def authenticate_redirect(self, callback_uri: str=None) -> None: """Just like `~OAuthMixin.authorize_redirect`, but auto-redirects if authorized. @@ -639,7 +670,8 @@ class TwitterMixin(OAuthMixin): response = await http.fetch(self._oauth_request_token_url(callback_uri=callback_uri)) self._on_request_token(self._OAUTH_AUTHENTICATE_URL, None, response) - async def twitter_request(self, path, access_token=None, post_args=None, **args): + async def twitter_request(self, path: str, access_token: Dict[str, Any], + post_args: Dict[str, Any]=None, **args: Any) -> Any: """Fetches the given API path, e.g., ``statuses/user_timeline/btaylor`` The path should not include the format or API version number. @@ -705,14 +737,15 @@ class TwitterMixin(OAuthMixin): response = await http.fetch(url) return escape.json_decode(response.body) - def _oauth_consumer_token(self): - self.require_setting("twitter_consumer_key", "Twitter OAuth") - self.require_setting("twitter_consumer_secret", "Twitter OAuth") + def _oauth_consumer_token(self) -> Dict[str, Any]: + handler = cast(RequestHandler, self) + handler.require_setting("twitter_consumer_key", "Twitter OAuth") + handler.require_setting("twitter_consumer_secret", "Twitter OAuth") return dict( - key=self.settings["twitter_consumer_key"], - secret=self.settings["twitter_consumer_secret"]) + key=handler.settings["twitter_consumer_key"], + secret=handler.settings["twitter_consumer_secret"]) - async def _oauth_get_user_future(self, access_token): + async def _oauth_get_user_future(self, access_token: Dict[str, Any]) -> Dict[str, Any]: user = await self.twitter_request( "/account/verify_credentials", access_token=access_token) @@ -745,7 +778,7 @@ class GoogleOAuth2Mixin(OAuth2Mixin): _OAUTH_NO_CALLBACKS = False _OAUTH_SETTINGS_KEY = 'google_oauth' - async def get_authenticated_user(self, redirect_uri, code): + async def get_authenticated_user(self, redirect_uri: str, code: str) -> Dict[str, Any]: """Handles the login for the Google user, returning an access token. The result is a dictionary containing an ``access_token`` field @@ -787,12 +820,13 @@ class GoogleOAuth2Mixin(OAuth2Mixin): The ``callback`` argument was removed. Use the returned awaitable object instead. """ # noqa: E501 + handler = cast(RequestHandler, self) http = self.get_auth_http_client() body = urllib.parse.urlencode({ "redirect_uri": redirect_uri, "code": code, - "client_id": self.settings[self._OAUTH_SETTINGS_KEY]['key'], - "client_secret": self.settings[self._OAUTH_SETTINGS_KEY]['secret'], + "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]['key'], + "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]['secret'], "grant_type": "authorization_code", }) @@ -810,8 +844,10 @@ class FacebookGraphMixin(OAuth2Mixin): _OAUTH_NO_CALLBACKS = False _FACEBOOK_BASE_URL = "https://graph.facebook.com" - async def get_authenticated_user(self, redirect_uri, client_id, client_secret, - code, extra_fields=None): + async def get_authenticated_user( + self, redirect_uri: str, client_id: str, client_secret: str, + code: str, extra_fields: Dict[str, Any]=None + ) -> Optional[Dict[str, Any]]: """Handles the login for the Facebook user, returning a user object. Example usage: @@ -870,12 +906,13 @@ class FacebookGraphMixin(OAuth2Mixin): if extra_fields: fields.update(extra_fields) - response = await http.fetch(self._oauth_request_token_url(**args)) + response = await http.fetch(self._oauth_request_token_url(**args)) # type: ignore args = escape.json_decode(response.body) session = { "access_token": args.get("access_token"), "expires_in": args.get("expires_in") } + assert session["access_token"] is not None user = await self.facebook_request( path="/me", @@ -901,7 +938,8 @@ class FacebookGraphMixin(OAuth2Mixin): "session_expires": str(session.get("expires_in"))}) return fieldmap - async def facebook_request(self, path, access_token=None, post_args=None, **args): + async def facebook_request(self, path: str, access_token: str=None, + post_args: Dict[str, Any]=None, **args: Any) -> Any: """Fetches the given relative API path, e.g., "/btaylor/picture" If the request is a POST, ``post_args`` should be provided. Query @@ -957,7 +995,8 @@ class FacebookGraphMixin(OAuth2Mixin): post_args=post_args, **args) -def _oauth_signature(consumer_token, method, url, parameters={}, token=None): +def _oauth_signature(consumer_token: Dict[str, Any], method: str, url: str, + parameters: Dict[str, Any]={}, token: Dict[str, Any]=None) -> bytes: """Calculates the HMAC-SHA1 OAuth signature for the given request. See http://oauth.net/core/1.0/#signing_process @@ -981,7 +1020,8 @@ def _oauth_signature(consumer_token, method, url, parameters={}, token=None): return binascii.b2a_base64(hash.digest())[:-1] -def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None): +def _oauth10a_signature(consumer_token: Dict[str, Any], method: str, url: str, + parameters: Dict[str, Any]={}, token: Dict[str, Any]=None) -> bytes: """Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request. See http://oauth.net/core/1.0a/#signing_process @@ -1005,18 +1045,18 @@ def _oauth10a_signature(consumer_token, method, url, parameters={}, token=None): return binascii.b2a_base64(hash.digest())[:-1] -def _oauth_escape(val): +def _oauth_escape(val: Union[str, bytes]) -> str: if isinstance(val, unicode_type): val = val.encode("utf-8") return urllib.parse.quote(val, safe="~") -def _oauth_parse_response(body): +def _oauth_parse_response(body: bytes) -> Dict[str, Any]: # I can't find an officially-defined encoding for oauth responses and # have never seen anyone use non-ascii. Leave the response in a byte # string for python 2, and use utf8 on python 3. - body = escape.native_str(body) - p = urllib.parse.parse_qs(body, keep_blank_values=False) + body_str = escape.native_str(body) + p = urllib.parse.parse_qs(body_str, keep_blank_values=False) token = dict(key=p["oauth_token"][0], secret=p["oauth_token_secret"][0]) # Add the extra parameters the Provider included to the token diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 1a47ae9e2..0c2572dbb 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -610,9 +610,9 @@ class HTTPResponse(object): self.time_info = time_info or {} @property - def body(self) -> Optional[bytes]: + def body(self) -> bytes: if self.buffer is None: - return None + raise ValueError("body not set") elif self._body is None: self._body = self.buffer.getvalue() diff --git a/tornado/web.py b/tornado/web.py index f54f6f187..4d797eb95 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -100,7 +100,7 @@ from tornado.util import ObjectDict, unicode_type, _websocket_mask url = URLSpec from typing import (Dict, Any, Union, Optional, Awaitable, Tuple, List, Callable, Iterable, - Generator, Type, cast) + Generator, Type, cast, overload) from types import TracebackType import typing if typing.TYPE_CHECKING: @@ -396,7 +396,21 @@ class RequestHandler(object): raise ValueError("Unsafe header value %r", retval) return retval - def get_argument(self, name: str, default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT, + @overload + def get_argument(self, name: str, default: str, strip: bool=True) -> str: + pass + + @overload # noqa: F811 + def get_argument(self, name: str, default: _ArgDefaultMarker=_ARG_DEFAULT, + strip: bool=True) -> str: + pass + + @overload # noqa: F811 + def get_argument(self, name: str, default: None, strip: bool=True) -> Optional[str]: + pass + + def get_argument(self, name: str, # noqa: F811 + default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT, strip: bool=True) -> Optional[str]: """Returns the value of the argument with the given name.