From: Ben Darnell Date: Sun, 7 Oct 2018 03:03:26 +0000 (-0400) Subject: *: Adopt black as code formatter X-Git-Tag: v6.0.0b1~28^2~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=72026c5eb8ff13d6bf671fa87ea89847e6485566;p=thirdparty%2Ftornado.git *: Adopt black as code formatter It occasionally makes some odd-looking decisions and uses a lot of vertical space but overall it's a big improvement, especially for the dense type signatures. --- diff --git a/.flake8 b/.flake8 index 1c2c768d1..18c72168c 100644 --- a/.flake8 +++ b/.flake8 @@ -10,4 +10,8 @@ ignore = E402, # E722 do not use bare except E722, + # flake8 and black disagree about + # W503 line break before binary operator + # E203 whitespace before ':' + W503,E203 doctests = true diff --git a/tornado/auth.py b/tornado/auth.py index 44ea5fbcf..115c00923 100644 --- a/tornado/auth.py +++ b/tornado/auth.py @@ -84,9 +84,12 @@ class OpenIdMixin(object): * ``_OPENID_ENDPOINT``: the identity provider's URI. """ + def authenticate_redirect( - self, callback_uri: str=None, - ax_attrs: List[str]=["name", "email", "language", "username"]) -> None: + self, + callback_uri: str = None, + ax_attrs: List[str] = ["name", "email", "language", "username"], + ) -> None: """Redirects to the authentication URL for this service. After authentication, the service will redirect back to the given @@ -111,7 +114,7 @@ class OpenIdMixin(object): handler.redirect(endpoint + "?" + urllib.parse.urlencode(args)) async def get_authenticated_user( - self, http_client: httpclient.AsyncHTTPClient=None + self, http_client: httpclient.AsyncHTTPClient = None ) -> Dict[str, Any]: """Fetches the authenticated user data upon redirect. @@ -130,47 +133,50 @@ class OpenIdMixin(object): """ handler = cast(RequestHandler, self) # Verify the OpenID response via direct request to the OP - args = dict((k, v[-1]) for k, v in handler.request.arguments.items()) \ - # type: Dict[str, Union[str, bytes]] + args = dict( + (k, v[-1]) for k, v in handler.request.arguments.items() + ) # type: Dict[str, Union[str, bytes]] args["openid.mode"] = u"check_authentication" url = self._OPENID_ENDPOINT # type: ignore if http_client is None: http_client = self.get_auth_http_client() - resp = await http_client.fetch(url, method="POST", body=urllib.parse.urlencode(args)) + resp = await http_client.fetch( + url, method="POST", body=urllib.parse.urlencode(args) + ) return self._on_authentication_verified(resp) - def _openid_args(self, callback_uri: str, ax_attrs: Iterable[str]=[], - oauth_scope: str=None) -> Dict[str, str]: + def _openid_args( + self, callback_uri: str, ax_attrs: Iterable[str] = [], oauth_scope: str = None + ) -> Dict[str, str]: handler = cast(RequestHandler, self) url = urllib.parse.urljoin(handler.request.full_url(), callback_uri) args = { "openid.ns": "http://specs.openid.net/auth/2.0", - "openid.claimed_id": - "http://specs.openid.net/auth/2.0/identifier_select", - "openid.identity": - "http://specs.openid.net/auth/2.0/identifier_select", + "openid.claimed_id": "http://specs.openid.net/auth/2.0/identifier_select", + "openid.identity": "http://specs.openid.net/auth/2.0/identifier_select", "openid.return_to": url, - "openid.realm": urllib.parse.urljoin(url, '/'), + "openid.realm": urllib.parse.urljoin(url, "/"), "openid.mode": "checkid_setup", } if ax_attrs: - args.update({ - "openid.ns.ax": "http://openid.net/srv/ax/1.0", - "openid.ax.mode": "fetch_request", - }) + args.update( + { + "openid.ns.ax": "http://openid.net/srv/ax/1.0", + "openid.ax.mode": "fetch_request", + } + ) ax_attrs = set(ax_attrs) required = [] # type: List[str] if "name" in ax_attrs: ax_attrs -= set(["name", "firstname", "fullname", "lastname"]) required += ["firstname", "fullname", "lastname"] - args.update({ - "openid.ax.type.firstname": - "http://axschema.org/namePerson/first", - "openid.ax.type.fullname": - "http://axschema.org/namePerson", - "openid.ax.type.lastname": - "http://axschema.org/namePerson/last", - }) + args.update( + { + "openid.ax.type.firstname": "http://axschema.org/namePerson/first", + "openid.ax.type.fullname": "http://axschema.org/namePerson", + "openid.ax.type.lastname": "http://axschema.org/namePerson/last", + } + ) known_attrs = { "email": "http://axschema.org/contact/email", "language": "http://axschema.org/pref/language", @@ -181,15 +187,18 @@ class OpenIdMixin(object): required.append(name) args["openid.ax.required"] = ",".join(required) if oauth_scope: - args.update({ - "openid.ns.oauth": - "http://specs.openid.net/extensions/oauth/1.0", - "openid.oauth.consumer": handler.request.host.split(":")[0], - "openid.oauth.scope": oauth_scope, - }) + args.update( + { + "openid.ns.oauth": "http://specs.openid.net/extensions/oauth/1.0", + "openid.oauth.consumer": handler.request.host.split(":")[0], + "openid.oauth.scope": oauth_scope, + } + ) return args - def _on_authentication_verified(self, response: httpclient.HTTPResponse) -> Dict[str, Any]: + def _on_authentication_verified( + self, response: httpclient.HTTPResponse + ) -> Dict[str, Any]: handler = cast(RequestHandler, self) if b"is_valid:true" not in response.body: raise AuthError("Invalid OpenID response: %s" % response.body) @@ -197,8 +206,10 @@ class OpenIdMixin(object): # Make sure we got back at least an email from attribute exchange ax_ns = None for key in handler.request.arguments: - if key.startswith("openid.ns.") and \ - handler.get_argument(key) == u"http://openid.net/srv/ax/1.0": + if ( + key.startswith("openid.ns.") + and handler.get_argument(key) == u"http://openid.net/srv/ax/1.0" + ): ax_ns = key[10:] break @@ -209,7 +220,7 @@ class OpenIdMixin(object): ax_name = None for name in handler.request.arguments.keys(): if handler.get_argument(name) == uri and name.startswith(prefix): - part = name[len(prefix):] + part = name[len(prefix) :] ax_name = "openid." + ax_ns + ".value." + part break if not ax_name: @@ -272,8 +283,13 @@ class OAuthMixin(object): Subclasses must also override the `_oauth_get_user_future` and `_oauth_consumer_token` methods. """ - async def authorize_redirect(self, callback_uri: str=None, extra_params: Dict[str, Any]=None, - http_client: httpclient.AsyncHTTPClient=None) -> None: + + async def authorize_redirect( + self, + callback_uri: str = None, + extra_params: Dict[str, Any] = None, + http_client: httpclient.AsyncHTTPClient = None, + ) -> None: """Redirects the user to obtain OAuth authorization for this service. The ``callback_uri`` may be omitted if you have previously @@ -308,15 +324,17 @@ class OAuthMixin(object): assert http_client is not None if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": response = await http_client.fetch( - self._oauth_request_token_url(callback_uri=callback_uri, - extra_params=extra_params)) + self._oauth_request_token_url( + callback_uri=callback_uri, extra_params=extra_params + ) + ) else: response = await http_client.fetch(self._oauth_request_token_url()) url = self._OAUTH_AUTHORIZE_URL # type: ignore self._on_request_token(url, callback_uri, response) async def get_authenticated_user( - self, http_client: httpclient.AsyncHTTPClient=None + self, http_client: httpclient.AsyncHTTPClient = None ) -> Dict[str, Any]: """Gets the OAuth authorized user and access token. @@ -341,10 +359,13 @@ class OAuthMixin(object): raise AuthError("Missing OAuth request token cookie") handler.clear_cookie("_oauth_request_token") cookie_key, cookie_secret = [ - base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")] + base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|") + ] if cookie_key != request_key: raise AuthError("Request token does not match cookie") - token = dict(key=cookie_key, secret=cookie_secret) # type: Dict[str, Union[str, bytes]] + token = dict( + key=cookie_key, secret=cookie_secret + ) # type: Dict[str, Union[str, bytes]] if oauth_verifier: token["verifier"] = oauth_verifier if http_client is None: @@ -358,8 +379,9 @@ class OAuthMixin(object): user["access_token"] = access_token return user - def _oauth_request_token_url(self, callback_uri: str=None, - extra_params: Dict[str, Any]=None) -> str: + def _oauth_request_token_url( + self, callback_uri: str = None, extra_params: Dict[str, Any] = None + ) -> str: handler = cast(RequestHandler, self) consumer_token = self._oauth_consumer_token() url = self._OAUTH_REQUEST_TOKEN_URL # type: ignore @@ -375,7 +397,8 @@ class OAuthMixin(object): args["oauth_callback"] = "oob" elif callback_uri: args["oauth_callback"] = urllib.parse.urljoin( - handler.request.full_url(), callback_uri) + handler.request.full_url(), callback_uri + ) if extra_params: args.update(extra_params) signature = _oauth10a_signature(consumer_token, "GET", url, args) @@ -385,12 +408,19 @@ class OAuthMixin(object): args["oauth_signature"] = signature return url + "?" + urllib.parse.urlencode(args) - def _on_request_token(self, authorize_url: str, callback_uri: Optional[str], - response: httpclient.HTTPResponse) -> None: + def _on_request_token( + self, + authorize_url: str, + callback_uri: Optional[str], + response: httpclient.HTTPResponse, + ) -> None: handler = cast(RequestHandler, self) request_token = _oauth_parse_response(response.body) - data = (base64.b64encode(escape.utf8(request_token["key"])) + b"|" + - base64.b64encode(escape.utf8(request_token["secret"]))) + data = ( + base64.b64encode(escape.utf8(request_token["key"])) + + b"|" + + base64.b64encode(escape.utf8(request_token["secret"])) + ) handler.set_cookie("_oauth_request_token", data) args = dict(oauth_token=request_token["key"]) if callback_uri == "oob": @@ -398,7 +428,8 @@ class OAuthMixin(object): return elif callback_uri: args["oauth_callback"] = urllib.parse.urljoin( - handler.request.full_url(), callback_uri) + handler.request.full_url(), callback_uri + ) handler.redirect(authorize_url + "?" + urllib.parse.urlencode(args)) def _oauth_access_token_url(self, request_token: Dict[str, Any]) -> str: @@ -416,11 +447,13 @@ class OAuthMixin(object): args["oauth_verifier"] = request_token["verifier"] if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": - signature = _oauth10a_signature(consumer_token, "GET", url, args, - request_token) + signature = _oauth10a_signature( + consumer_token, "GET", url, args, request_token + ) else: - signature = _oauth_signature(consumer_token, "GET", url, args, - request_token) + signature = _oauth_signature( + consumer_token, "GET", url, args, request_token + ) args["oauth_signature"] = signature return url + "?" + urllib.parse.urlencode(args) @@ -432,7 +465,9 @@ class OAuthMixin(object): """ raise NotImplementedError() - async def _oauth_get_user_future(self, access_token: Dict[str, Any]) -> Dict[str, Any]: + async def _oauth_get_user_future( + self, access_token: Dict[str, Any] + ) -> Dict[str, Any]: """Subclasses must override this to get basic information about the user. @@ -454,9 +489,13 @@ class OAuthMixin(object): """ raise NotImplementedError() - def _oauth_request_parameters(self, url: str, access_token: Dict[str, Any], - parameters: Dict[str, Any]={}, - method: str="GET") -> Dict[str, Any]: + def _oauth_request_parameters( + self, + url: str, + access_token: Dict[str, Any], + parameters: Dict[str, Any] = {}, + method: str = "GET", + ) -> Dict[str, Any]: """Returns the OAuth parameters as a dict for the given request. parameters should include all POST arguments and query string arguments @@ -475,11 +514,13 @@ class OAuthMixin(object): args.update(base_args) args.update(parameters) if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a": - signature = _oauth10a_signature(consumer_token, method, url, args, - access_token) + signature = _oauth10a_signature( + consumer_token, method, url, args, access_token + ) else: - signature = _oauth_signature(consumer_token, method, url, args, - access_token) + signature = _oauth_signature( + consumer_token, method, url, args, access_token + ) base_args["oauth_signature"] = escape.to_basestring(signature) return base_args @@ -503,9 +544,16 @@ class OAuth2Mixin(object): * ``_OAUTH_AUTHORIZE_URL``: The service's authorization url. * ``_OAUTH_ACCESS_TOKEN_URL``: The service's access token url. """ - def authorize_redirect(self, redirect_uri: str=None, client_id: str=None, - client_secret: str=None, extra_params: Dict[str, Any]=None, - scope: str=None, response_type: str="code") -> None: + + def authorize_redirect( + self, + redirect_uri: str = None, + client_id: str = None, + client_secret: str = None, + extra_params: Dict[str, Any] = None, + scope: str = None, + response_type: str = "code", + ) -> None: """Redirects the user to obtain OAuth authorization for this service. Some providers require that you register a redirect URL with @@ -520,9 +568,7 @@ class OAuth2Mixin(object): this is now an ordinary synchronous function. """ handler = cast(RequestHandler, self) - args = { - "response_type": response_type - } + args = {"response_type": response_type} if redirect_uri is not None: args["redirect_uri"] = redirect_uri if client_id is not None: @@ -530,13 +576,18 @@ class OAuth2Mixin(object): if extra_params: args.update(extra_params) if scope: - args['scope'] = ' '.join(scope) + args["scope"] = " ".join(scope) url = self._OAUTH_AUTHORIZE_URL # type: ignore handler.redirect(url_concat(url, args)) - def _oauth_request_token_url(self, redirect_uri: str=None, client_id: str=None, - client_secret: str=None, code: str=None, - extra_params: Dict[str, Any]=None) -> str: + def _oauth_request_token_url( + self, + redirect_uri: str = None, + client_id: str = None, + client_secret: str = None, + code: str = None, + extra_params: Dict[str, Any] = None, + ) -> str: url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore args = {} # type: Dict[str, str] if redirect_uri is not None: @@ -551,8 +602,13 @@ class OAuth2Mixin(object): args.update(extra_params) return url_concat(url, args) - async def oauth2_request(self, url: str, access_token: str=None, - post_args: Dict[str, Any]=None, **args: Any) -> Any: + async def oauth2_request( + self, + url: str, + access_token: str = None, + post_args: Dict[str, Any] = None, + **args: Any + ) -> Any: """Fetches the given URL auth an OAuth2 access token. If the request is a POST, ``post_args`` should be provided. Query @@ -595,7 +651,9 @@ class OAuth2Mixin(object): url += "?" + urllib.parse.urlencode(all_args) http = self.get_auth_http_client() if post_args is not None: - response = await http.fetch(url, method="POST", body=urllib.parse.urlencode(post_args)) + response = await http.fetch( + url, method="POST", body=urllib.parse.urlencode(post_args) + ) else: response = await http.fetch(url) return escape.json_decode(response.body) @@ -643,6 +701,7 @@ class TwitterMixin(OAuthMixin): and all of the custom Twitter user attributes described at https://dev.twitter.com/docs/api/1.1/get/users/show """ + _OAUTH_REQUEST_TOKEN_URL = "https://api.twitter.com/oauth/request_token" _OAUTH_ACCESS_TOKEN_URL = "https://api.twitter.com/oauth/access_token" _OAUTH_AUTHORIZE_URL = "https://api.twitter.com/oauth/authorize" @@ -650,7 +709,7 @@ class TwitterMixin(OAuthMixin): _OAUTH_NO_CALLBACKS = False _TWITTER_BASE_URL = "https://api.twitter.com/1.1" - async def authenticate_redirect(self, callback_uri: str=None) -> None: + async def authenticate_redirect(self, callback_uri: str = None) -> None: """Just like `~OAuthMixin.authorize_redirect`, but auto-redirects if authorized. @@ -667,11 +726,18 @@ class TwitterMixin(OAuthMixin): awaitable object instead. """ http = self.get_auth_http_client() - response = await http.fetch(self._oauth_request_token_url(callback_uri=callback_uri)) + response = await http.fetch( + self._oauth_request_token_url(callback_uri=callback_uri) + ) self._on_request_token(self._OAUTH_AUTHENTICATE_URL, None, response) - async def twitter_request(self, path: str, access_token: Dict[str, Any], - post_args: Dict[str, Any]=None, **args: Any) -> Any: + async def twitter_request( + self, + path: str, + access_token: Dict[str, Any], + post_args: Dict[str, Any] = None, + **args: Any + ) -> Any: """Fetches the given API path, e.g., ``statuses/user_timeline/btaylor`` The path should not include the format or API version number. @@ -713,7 +779,7 @@ class TwitterMixin(OAuthMixin): The ``callback`` argument was removed. Use the returned awaitable object instead. """ - if path.startswith('http:') or path.startswith('https:'): + if path.startswith("http:") or path.startswith("https:"): # Raw urls are useful for e.g. search which doesn't follow the # usual pattern: http://search.twitter.com/search.json url = path @@ -726,13 +792,16 @@ class TwitterMixin(OAuthMixin): all_args.update(post_args or {}) method = "POST" if post_args is not None else "GET" oauth = self._oauth_request_parameters( - url, access_token, all_args, method=method) + url, access_token, all_args, method=method + ) args.update(oauth) if args: url += "?" + urllib.parse.urlencode(args) http = self.get_auth_http_client() if post_args is not None: - response = await http.fetch(url, method="POST", body=urllib.parse.urlencode(post_args)) + response = await http.fetch( + url, method="POST", body=urllib.parse.urlencode(post_args) + ) else: response = await http.fetch(url) return escape.json_decode(response.body) @@ -743,12 +812,15 @@ class TwitterMixin(OAuthMixin): handler.require_setting("twitter_consumer_secret", "Twitter OAuth") return dict( key=handler.settings["twitter_consumer_key"], - secret=handler.settings["twitter_consumer_secret"]) + secret=handler.settings["twitter_consumer_secret"], + ) - async def _oauth_get_user_future(self, access_token: Dict[str, Any]) -> Dict[str, Any]: + async def _oauth_get_user_future( + self, access_token: Dict[str, Any] + ) -> Dict[str, Any]: user = await self.twitter_request( - "/account/verify_credentials", - access_token=access_token) + "/account/verify_credentials", access_token=access_token + ) if user: user["username"] = user["screen_name"] return user @@ -772,13 +844,16 @@ class GoogleOAuth2Mixin(OAuth2Mixin): .. versionadded:: 3.2 """ + _OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth" _OAUTH_ACCESS_TOKEN_URL = "https://www.googleapis.com/oauth2/v4/token" _OAUTH_USERINFO_URL = "https://www.googleapis.com/oauth2/v1/userinfo" _OAUTH_NO_CALLBACKS = False - _OAUTH_SETTINGS_KEY = 'google_oauth' + _OAUTH_SETTINGS_KEY = "google_oauth" - async def get_authenticated_user(self, redirect_uri: str, code: str) -> Dict[str, Any]: + async def get_authenticated_user( + self, redirect_uri: str, code: str + ) -> Dict[str, Any]: """Handles the login for the Google user, returning an access token. The result is a dictionary containing an ``access_token`` field @@ -822,31 +897,40 @@ class GoogleOAuth2Mixin(OAuth2Mixin): """ # noqa: E501 handler = cast(RequestHandler, self) http = self.get_auth_http_client() - body = urllib.parse.urlencode({ - "redirect_uri": redirect_uri, - "code": code, - "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]['key'], - "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]['secret'], - "grant_type": "authorization_code", - }) - - response = await http.fetch(self._OAUTH_ACCESS_TOKEN_URL, - method="POST", - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - body=body) + body = urllib.parse.urlencode( + { + "redirect_uri": redirect_uri, + "code": code, + "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]["key"], + "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]["secret"], + "grant_type": "authorization_code", + } + ) + + response = await http.fetch( + self._OAUTH_ACCESS_TOKEN_URL, + method="POST", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + body=body, + ) return escape.json_decode(response.body) class FacebookGraphMixin(OAuth2Mixin): """Facebook authentication using the new Graph API and OAuth2.""" + _OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?" _OAUTH_AUTHORIZE_URL = "https://www.facebook.com/dialog/oauth?" _OAUTH_NO_CALLBACKS = False _FACEBOOK_BASE_URL = "https://graph.facebook.com" async def get_authenticated_user( - self, redirect_uri: str, client_id: str, client_secret: str, - code: str, extra_fields: Dict[str, Any]=None + self, + redirect_uri: str, + client_id: str, + client_secret: str, + code: str, + extra_fields: Dict[str, Any] = None, ) -> Optional[Dict[str, Any]]: """Handles the login for the Facebook user, returning a user object. @@ -901,26 +985,31 @@ class FacebookGraphMixin(OAuth2Mixin): "client_secret": client_secret, } - fields = set(['id', 'name', 'first_name', 'last_name', - 'locale', 'picture', 'link']) + fields = set( + ["id", "name", "first_name", "last_name", "locale", "picture", "link"] + ) if extra_fields: fields.update(extra_fields) - response = await http.fetch(self._oauth_request_token_url(**args)) # type: ignore + response = await http.fetch( + self._oauth_request_token_url(**args) # type: ignore + ) args = escape.json_decode(response.body) session = { "access_token": args.get("access_token"), - "expires_in": args.get("expires_in") + "expires_in": args.get("expires_in"), } assert session["access_token"] is not None user = await self.facebook_request( path="/me", access_token=session["access_token"], - appsecret_proof=hmac.new(key=client_secret.encode('utf8'), - msg=session["access_token"].encode('utf8'), - digestmod=hashlib.sha256).hexdigest(), - fields=",".join(fields) + appsecret_proof=hmac.new( + key=client_secret.encode("utf8"), + msg=session["access_token"].encode("utf8"), + digestmod=hashlib.sha256, + ).hexdigest(), + fields=",".join(fields), ) if user is None: @@ -934,12 +1023,21 @@ class FacebookGraphMixin(OAuth2Mixin): # older versions in which the server used url-encoding and # this code simply returned the string verbatim. # This should change in Tornado 5.0. - fieldmap.update({"access_token": session["access_token"], - "session_expires": str(session.get("expires_in"))}) + fieldmap.update( + { + "access_token": session["access_token"], + "session_expires": str(session.get("expires_in")), + } + ) return fieldmap - async def facebook_request(self, path: str, access_token: str=None, - post_args: Dict[str, Any]=None, **args: Any) -> Any: + async def facebook_request( + self, + path: str, + access_token: str = None, + post_args: Dict[str, Any] = None, + **args: Any + ) -> Any: """Fetches the given relative API path, e.g., "/btaylor/picture" If the request is a POST, ``post_args`` should be provided. Query @@ -991,12 +1089,18 @@ class FacebookGraphMixin(OAuth2Mixin): The ``callback`` argument was removed. Use the returned awaitable object instead. """ url = self._FACEBOOK_BASE_URL + path - return await self.oauth2_request(url, access_token=access_token, - post_args=post_args, **args) + return await self.oauth2_request( + url, access_token=access_token, post_args=post_args, **args + ) -def _oauth_signature(consumer_token: Dict[str, Any], method: str, url: str, - parameters: Dict[str, Any]={}, token: Dict[str, Any]=None) -> bytes: +def _oauth_signature( + consumer_token: Dict[str, Any], + method: str, + url: str, + parameters: Dict[str, Any] = {}, + token: Dict[str, Any] = None, +) -> bytes: """Calculates the HMAC-SHA1 OAuth signature for the given request. See http://oauth.net/core/1.0/#signing_process @@ -1008,8 +1112,11 @@ def _oauth_signature(consumer_token: Dict[str, Any], method: str, url: str, base_elems = [] base_elems.append(method.upper()) base_elems.append(normalized_url) - base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v))) - for k, v in sorted(parameters.items()))) + base_elems.append( + "&".join( + "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items()) + ) + ) base_string = "&".join(_oauth_escape(e) for e in base_elems) key_elems = [escape.utf8(consumer_token["secret"])] @@ -1020,8 +1127,13 @@ def _oauth_signature(consumer_token: Dict[str, Any], method: str, url: str, return binascii.b2a_base64(hash.digest())[:-1] -def _oauth10a_signature(consumer_token: Dict[str, Any], method: str, url: str, - parameters: Dict[str, Any]={}, token: Dict[str, Any]=None) -> bytes: +def _oauth10a_signature( + consumer_token: Dict[str, Any], + method: str, + url: str, + parameters: Dict[str, Any] = {}, + token: Dict[str, Any] = None, +) -> bytes: """Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request. See http://oauth.net/core/1.0a/#signing_process @@ -1033,12 +1145,17 @@ def _oauth10a_signature(consumer_token: Dict[str, Any], method: str, url: str, base_elems = [] base_elems.append(method.upper()) base_elems.append(normalized_url) - base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v))) - for k, v in sorted(parameters.items()))) + base_elems.append( + "&".join( + "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items()) + ) + ) base_string = "&".join(_oauth_escape(e) for e in base_elems) - key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe='~'))] - key_elems.append(escape.utf8(urllib.parse.quote(token["secret"], safe='~') if token else "")) + key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe="~"))] + key_elems.append( + escape.utf8(urllib.parse.quote(token["secret"], safe="~") if token else "") + ) key = b"&".join(key_elems) hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1) diff --git a/tornado/autoreload.py b/tornado/autoreload.py index 339c60c55..1c47aaec1 100644 --- a/tornado/autoreload.py +++ b/tornado/autoreload.py @@ -97,13 +97,14 @@ except ImportError: import typing from typing import Callable, Dict + if typing.TYPE_CHECKING: from typing import List, Optional, Union # noqa: F401 # os.execv is broken on Windows and can't properly parse command line # arguments and executable name if they contain whitespaces. subprocess # fixes that behavior. -_has_execv = sys.platform != 'win32' +_has_execv = sys.platform != "win32" _watched_files = set() _reload_hooks = [] @@ -114,7 +115,7 @@ _original_argv = None # type: Optional[List[str]] _original_spec = None -def start(check_time: int=500) -> None: +def start(check_time: int = 500) -> None: """Begins watching source files for changes. .. versionchanged:: 5.0 @@ -224,16 +225,16 @@ def _reload() -> None: spec = _original_spec argv = _original_argv else: - spec = getattr(sys.modules['__main__'], '__spec__', None) + spec = getattr(sys.modules["__main__"], "__spec__", None) argv = sys.argv if spec: - argv = ['-m', spec.name] + argv[1:] + argv = ["-m", spec.name] + argv[1:] else: - path_prefix = '.' + os.pathsep - if (sys.path[0] == '' and - not os.environ.get("PYTHONPATH", "").startswith(path_prefix)): - os.environ["PYTHONPATH"] = (path_prefix + - os.environ.get("PYTHONPATH", "")) + path_prefix = "." + os.pathsep + if sys.path[0] == "" and not os.environ.get("PYTHONPATH", "").startswith( + path_prefix + ): + os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "") if not _has_execv: subprocess.Popen([sys.executable] + argv) os._exit(0) @@ -252,7 +253,9 @@ def _reload() -> None: # Unfortunately the errno returned in this case does not # appear to be consistent, so we can't easily check for # this error specifically. - os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + argv) # type: ignore + os.spawnv( # type: ignore + os.P_NOWAIT, sys.executable, [sys.executable] + argv + ) # At this point the IOLoop has been closed and finally # blocks will experience errors if we allow the stack to # unwind, so just exit uncleanly. @@ -283,12 +286,13 @@ def main() -> None: # The main module can be tricky; set the variables both in our globals # (which may be __main__) and the real importable version. import tornado.autoreload + global _autoreload_is_main global _original_argv, _original_spec tornado.autoreload._autoreload_is_main = _autoreload_is_main = True original_argv = sys.argv tornado.autoreload._original_argv = _original_argv = original_argv - original_spec = getattr(sys.modules['__main__'], '__spec__', None) + original_spec = getattr(sys.modules["__main__"], "__spec__", None) tornado.autoreload._original_spec = _original_spec = original_spec sys.argv = sys.argv[:] if len(sys.argv) >= 3 and sys.argv[1] == "-m": @@ -306,6 +310,7 @@ def main() -> None: try: if mode == "module": import runpy + runpy.run_module(module, run_name="__main__", alter_sys=True) elif mode == "script": with open(script) as f: @@ -343,7 +348,7 @@ def main() -> None: # restore sys.argv so subsequent executions will include autoreload sys.argv = original_argv - if mode == 'module': + if mode == "module": # runpy did a fake import of the module as __main__, but now it's # no longer in sys.modules. Figure out where it is and watch it. loader = pkgutil.get_loader(module) diff --git a/tornado/concurrent.py b/tornado/concurrent.py index d0dad7a34..ee5ca40a0 100644 --- a/tornado/concurrent.py +++ b/tornado/concurrent.py @@ -36,7 +36,7 @@ import types import typing from typing import Any, Callable, Optional, Tuple, Union -_T = typing.TypeVar('_T') +_T = typing.TypeVar("_T") class ReturnValueIgnoredError(Exception): @@ -54,7 +54,9 @@ def is_future(x: Any) -> bool: class DummyExecutor(futures.Executor): - def submit(self, fn: Callable[..., _T], *args: Any, **kwargs: Any) -> 'futures.Future[_T]': + def submit( + self, fn: Callable[..., _T], *args: Any, **kwargs: Any + ) -> "futures.Future[_T]": future = futures.Future() # type: futures.Future[_T] try: future_set_result_unless_cancelled(future, fn(*args, **kwargs)) @@ -62,7 +64,7 @@ class DummyExecutor(futures.Executor): future_set_exc_info(future, sys.exc_info()) return future - def shutdown(self, wait: bool=True) -> None: + def shutdown(self, wait: bool = True) -> None: pass @@ -121,7 +123,9 @@ def run_on_executor(*args: Any, **kwargs: Any) -> Callable: conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs) chain_future(conc_future, async_future) return async_future + return wrapper + if args and kwargs: raise ValueError("cannot combine positional and keyword args") if len(args) == 1: @@ -134,7 +138,7 @@ def run_on_executor(*args: Any, **kwargs: Any) -> Callable: _NO_RESULT = object() -def chain_future(a: 'Future[_T]', b: 'Future[_T]') -> None: +def chain_future(a: "Future[_T]", b: "Future[_T]") -> None: """Chain two futures together so that when one completes, so does the other. The result (success or failure) of ``a`` will be copied to ``b``, unless @@ -146,27 +150,30 @@ def chain_future(a: 'Future[_T]', b: 'Future[_T]') -> None: `concurrent.futures.Future`. """ - def copy(future: 'Future[_T]') -> None: + + def copy(future: "Future[_T]") -> None: assert future is a if b.done(): return - if (hasattr(a, 'exc_info') and - a.exc_info() is not None): # type: ignore + if hasattr(a, "exc_info") and a.exc_info() is not None: # type: ignore future_set_exc_info(b, a.exc_info()) # type: ignore elif a.exception() is not None: b.set_exception(a.exception()) else: b.set_result(a.result()) + if isinstance(a, Future): future_add_done_callback(a, copy) else: # concurrent.futures.Future from tornado.ioloop import IOLoop + IOLoop.current().add_future(a, copy) -def future_set_result_unless_cancelled(future: Union['futures.Future[_T]', 'Future[_T]'], - value: _T) -> None: +def future_set_result_unless_cancelled( + future: Union["futures.Future[_T]", "Future[_T]"], value: _T +) -> None: """Set the given ``value`` as the `Future`'s result, if not cancelled. Avoids asyncio.InvalidStateError when calling set_result() on @@ -178,9 +185,12 @@ def future_set_result_unless_cancelled(future: Union['futures.Future[_T]', 'Futu future.set_result(value) -def future_set_exc_info(future: Union['futures.Future[_T]', 'Future[_T]'], - exc_info: Tuple[Optional[type], Optional[BaseException], - Optional[types.TracebackType]]) -> None: +def future_set_exc_info( + future: Union["futures.Future[_T]", "Future[_T]"], + exc_info: Tuple[ + Optional[type], Optional[BaseException], Optional[types.TracebackType] + ], +) -> None: """Set the given ``exc_info`` as the `Future`'s exception. Understands both `asyncio.Future` and Tornado's extensions to @@ -188,7 +198,7 @@ def future_set_exc_info(future: Union['futures.Future[_T]', 'Future[_T]'], .. versionadded:: 5.0 """ - if hasattr(future, 'set_exc_info'): + if hasattr(future, "set_exc_info"): # Tornado's Future future.set_exc_info(exc_info) # type: ignore else: @@ -199,19 +209,22 @@ def future_set_exc_info(future: Union['futures.Future[_T]', 'Future[_T]'], @typing.overload -def future_add_done_callback(future: 'futures.Future[_T]', - callback: Callable[['futures.Future[_T]'], None]) -> None: +def future_add_done_callback( + future: "futures.Future[_T]", callback: Callable[["futures.Future[_T]"], None] +) -> None: pass @typing.overload # noqa: F811 -def future_add_done_callback(future: 'Future[_T]', - callback: Callable[['Future[_T]'], None]) -> None: +def future_add_done_callback( + future: "Future[_T]", callback: Callable[["Future[_T]"], None] +) -> None: pass -def future_add_done_callback(future: Union['futures.Future[_T]', 'Future[_T]'], # noqa: F811 - callback: Callable[..., None]) -> None: +def future_add_done_callback( # noqa: F811 + future: Union["futures.Future[_T]", "Future[_T]"], callback: Callable[..., None] +) -> None: """Arrange to call ``callback`` when ``future`` is complete. ``callback`` is invoked with one argument, the ``future``. diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index 7f13403cd..4119585fd 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -27,28 +27,37 @@ from tornado import httputil from tornado import ioloop from tornado.escape import utf8, native_str -from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main +from tornado.httpclient import ( + HTTPRequest, + HTTPResponse, + HTTPError, + AsyncHTTPClient, + main, +) from tornado.log import app_log from typing import Dict, Any, Callable, Union import typing + if typing.TYPE_CHECKING: from typing import Deque, Tuple, Optional # noqa: F401 -curl_log = logging.getLogger('tornado.curl_httpclient') +curl_log = logging.getLogger("tornado.curl_httpclient") class CurlAsyncHTTPClient(AsyncHTTPClient): - def initialize(self, max_clients: int=10, # type: ignore - defaults: Dict[str, Any]=None) -> None: + def initialize( # type: ignore + self, max_clients: int = 10, defaults: Dict[str, Any] = None + ) -> None: super(CurlAsyncHTTPClient, self).initialize(defaults=defaults) self._multi = pycurl.CurlMulti() self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout) self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket) self._curls = [self._curl_create() for i in range(max_clients)] self._free_list = self._curls[:] - self._requests = collections.deque() \ - # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]] + self._requests = ( + collections.deque() + ) # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]] self._fds = {} # type: Dict[int, int] self._timeout = None # type: Optional[object] @@ -57,7 +66,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): # SOCKETFUNCTION. Mitigate the effects of such bugs by # forcing a periodic scan of all active requests. self._force_timeout_callback = ioloop.PeriodicCallback( - self._handle_force_timeout, 1000) + self._handle_force_timeout, 1000 + ) self._force_timeout_callback.start() # Work around a bug in libcurl 7.29.0: Some fields in the curl @@ -84,7 +94,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): self._force_timeout_callback = None # type: ignore self._multi = None - def fetch_impl(self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]) -> None: + def fetch_impl( + self, request: HTTPRequest, callback: Callable[[HTTPResponse], None] + ) -> None: self._requests.append((request, callback, self.io_loop.time())) self._process_queue() self._set_timeout(0) @@ -97,7 +109,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): pycurl.POLL_NONE: ioloop.IOLoop.NONE, pycurl.POLL_IN: ioloop.IOLoop.READ, pycurl.POLL_OUT: ioloop.IOLoop.WRITE, - pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE + pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE, } if event == pycurl.POLL_REMOVE: if fd in self._fds: @@ -115,8 +127,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): # instead of update. if fd in self._fds: self.io_loop.remove_handler(fd) - self.io_loop.add_handler(fd, self._handle_events, - ioloop_event) + self.io_loop.add_handler(fd, self._handle_events, ioloop_event) self._fds[fd] = ioloop_event def _set_timeout(self, msecs: int) -> None: @@ -124,7 +135,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = self.io_loop.add_timeout( - self.io_loop.time() + msecs / 1000.0, self._handle_timeout) + self.io_loop.time() + msecs / 1000.0, self._handle_timeout + ) def _handle_events(self, fd: int, events: int) -> None: """Called by IOLoop when there is activity on one of our @@ -149,8 +161,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): self._timeout = None while True: try: - ret, num_handles = self._multi.socket_action( - pycurl.SOCKET_TIMEOUT, 0) + ret, num_handles = self._multi.socket_action(pycurl.SOCKET_TIMEOUT, 0) except pycurl.error as e: ret = e.args[0] if ret != pycurl.E_CALL_MULTI_PERFORM: @@ -219,8 +230,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): } try: self._curl_setup_request( - curl, request, curl.info["buffer"], - curl.info["headers"]) + curl, request, curl.info["buffer"], curl.info["headers"] + ) except Exception as e: # If there was an error in setup, pass it on # to the callback. Note that allowing the @@ -231,17 +242,16 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): # _finish_pending_requests the exceptions have # nowhere to go. self._free_list.append(curl) - callback(HTTPResponse( - request=request, - code=599, - error=e)) + callback(HTTPResponse(request=request, code=599, error=e)) else: self._multi.add_handle(curl) if not started: break - def _finish(self, curl: pycurl.Curl, curl_error: int=None, curl_message: str=None) -> None: + def _finish( + self, curl: pycurl.Curl, curl_error: int = None, curl_message: str = None + ) -> None: info = curl.info curl.info = None self._multi.remove_handle(curl) @@ -273,13 +283,20 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): redirect=curl.getinfo(pycurl.REDIRECT_TIME), ) try: - info["callback"](HTTPResponse( - request=info["request"], code=code, headers=info["headers"], - buffer=buffer, effective_url=effective_url, error=error, - reason=info['headers'].get("X-Http-Reason", None), - request_time=self.io_loop.time() - info["curl_start_ioloop_time"], - start_time=info["curl_start_time"], - time_info=time_info)) + info["callback"]( + HTTPResponse( + request=info["request"], + code=code, + headers=info["headers"], + buffer=buffer, + effective_url=effective_url, + error=error, + reason=info["headers"].get("X-Http-Reason", None), + request_time=self.io_loop.time() - info["curl_start_ioloop_time"], + start_time=info["curl_start_time"], + time_info=time_info, + ) + ) except Exception: self.handle_callback_exception(info["callback"]) @@ -291,13 +308,20 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): if curl_log.isEnabledFor(logging.DEBUG): curl.setopt(pycurl.VERBOSE, 1) curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug) - if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12) + if hasattr( + pycurl, "PROTOCOLS" + ): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12) curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS) curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS) return curl - def _curl_setup_request(self, curl: pycurl.Curl, request: HTTPRequest, - buffer: BytesIO, headers: httputil.HTTPHeaders) -> None: + def _curl_setup_request( + self, + curl: pycurl.Curl, + request: HTTPRequest, + buffer: BytesIO, + headers: httputil.HTTPHeaders, + ) -> None: curl.setopt(pycurl.URL, native_str(request.url)) # libcurl's magic "Expect: 100-continue" behavior causes delays @@ -315,18 +339,27 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): if "Pragma" not in request.headers: request.headers["Pragma"] = "" - curl.setopt(pycurl.HTTPHEADER, - ["%s: %s" % (native_str(k), native_str(v)) - for k, v in request.headers.get_all()]) + curl.setopt( + pycurl.HTTPHEADER, + [ + "%s: %s" % (native_str(k), native_str(v)) + for k, v in request.headers.get_all() + ], + ) - curl.setopt(pycurl.HEADERFUNCTION, - functools.partial(self._curl_header_callback, - headers, request.header_callback)) + curl.setopt( + pycurl.HEADERFUNCTION, + functools.partial( + self._curl_header_callback, headers, request.header_callback + ), + ) if request.streaming_callback: + def write_function(b: Union[bytes, bytearray]) -> int: assert request.streaming_callback is not None self.io_loop.add_callback(request.streaming_callback, b) return len(b) + else: write_function = buffer.write curl.setopt(pycurl.WRITEFUNCTION, write_function) @@ -351,20 +384,21 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): curl.setopt(pycurl.PROXYPORT, request.proxy_port) if request.proxy_username: assert request.proxy_password is not None - credentials = httputil.encode_username_password(request.proxy_username, - request.proxy_password) + credentials = httputil.encode_username_password( + request.proxy_username, request.proxy_password + ) curl.setopt(pycurl.PROXYUSERPWD, credentials) - if (request.proxy_auth_mode is None or - request.proxy_auth_mode == "basic"): + if request.proxy_auth_mode is None or request.proxy_auth_mode == "basic": curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC) elif request.proxy_auth_mode == "digest": curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST) else: raise ValueError( - "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode) + "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode + ) else: - curl.setopt(pycurl.PROXY, '') + curl.setopt(pycurl.PROXY, "") curl.unsetopt(pycurl.PROXYUSERPWD) if request.validate_cert: curl.setopt(pycurl.SSL_VERIFYPEER, 1) @@ -407,7 +441,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): elif request.allow_nonstandard_methods or request.method in custom_methods: curl.setopt(pycurl.CUSTOMREQUEST, request.method) else: - raise KeyError('unknown method ' + request.method) + raise KeyError("unknown method " + request.method) body_expected = request.method in ("POST", "PATCH", "PUT") body_present = request.body is not None @@ -415,12 +449,14 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): # Some HTTP methods nearly always have bodies while others # almost never do. Fail in this case unless the user has # opted out of sanity checks with allow_nonstandard_methods. - if ((body_expected and not body_present) or - (body_present and not body_expected)): + if (body_expected and not body_present) or ( + body_present and not body_expected + ): raise ValueError( - 'Body must %sbe None for method %s (unless ' - 'allow_nonstandard_methods is true)' % - ('not ' if body_expected else '', request.method)) + "Body must %sbe None for method %s (unless " + "allow_nonstandard_methods is true)" + % ("not " if body_expected else "", request.method) + ) if body_expected or body_present: if request.method == "GET": @@ -429,19 +465,20 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): # unless we use CUSTOMREQUEST). While the spec doesn't # forbid clients from sending a body, it arguably # disallows the server from doing anything with them. - raise ValueError('Body must be None for GET request') - request_buffer = BytesIO(utf8(request.body or '')) + raise ValueError("Body must be None for GET request") + request_buffer = BytesIO(utf8(request.body or "")) def ioctl(cmd: int) -> None: if cmd == curl.IOCMD_RESTARTREAD: request_buffer.seek(0) + curl.setopt(pycurl.READFUNCTION, request_buffer.read) curl.setopt(pycurl.IOCTLFUNCTION, ioctl) if request.method == "POST": - curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or '')) + curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or "")) else: curl.setopt(pycurl.UPLOAD, True) - curl.setopt(pycurl.INFILESIZE, len(request.body or '')) + curl.setopt(pycurl.INFILESIZE, len(request.body or "")) if request.auth_username is not None: assert request.auth_password is not None @@ -452,11 +489,16 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): else: raise ValueError("Unsupported auth_mode %s" % request.auth_mode) - userpwd = httputil.encode_username_password(request.auth_username, - request.auth_password) + userpwd = httputil.encode_username_password( + request.auth_username, request.auth_password + ) curl.setopt(pycurl.USERPWD, userpwd) - curl_log.debug("%s %s (username: %r)", request.method, request.url, - request.auth_username) + curl_log.debug( + "%s %s (username: %r)", + request.method, + request.url, + request.auth_username, + ) else: curl.unsetopt(pycurl.USERPWD) curl_log.debug("%s %s", request.method, request.url) @@ -483,10 +525,13 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): if request.prepare_curl_callback is not None: request.prepare_curl_callback(curl) - def _curl_header_callback(self, headers: httputil.HTTPHeaders, - header_callback: Callable[[str], None], - header_line_bytes: bytes) -> None: - header_line = native_str(header_line_bytes.decode('latin1')) + def _curl_header_callback( + self, + headers: httputil.HTTPHeaders, + header_callback: Callable[[str], None], + header_line_bytes: bytes, + ) -> None: + header_line = native_str(header_line_bytes.decode("latin1")) if header_callback is not None: self.io_loop.add_callback(header_callback, header_line) # header_line as returned by curl includes the end-of-line characters. @@ -504,16 +549,16 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): headers.parse_line(header_line) def _curl_debug(self, debug_type: int, debug_msg: str) -> None: - debug_types = ('I', '<', '>', '<', '>') + debug_types = ("I", "<", ">", "<", ">") if debug_type == 0: debug_msg = native_str(debug_msg) - curl_log.debug('%s', debug_msg.strip()) + curl_log.debug("%s", debug_msg.strip()) elif debug_type in (1, 2): debug_msg = native_str(debug_msg) for line in debug_msg.splitlines(): - curl_log.debug('%s %s', debug_types[debug_type], line) + curl_log.debug("%s %s", debug_types[debug_type], line) elif debug_type == 4: - curl_log.debug('%s %r', debug_types[debug_type], debug_msg) + curl_log.debug("%s %r", debug_types[debug_type], debug_msg) class CurlError(HTTPError): diff --git a/tornado/escape.py b/tornado/escape.py index 8bf4d1d53..bd73e305b 100644 --- a/tornado/escape.py +++ b/tornado/escape.py @@ -30,9 +30,14 @@ import typing from typing import Union, Any, Optional, Dict, List, Callable -_XHTML_ESCAPE_RE = re.compile('[&<>"\']') -_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"', - '\'': '''} +_XHTML_ESCAPE_RE = re.compile("[&<>\"']") +_XHTML_ESCAPE_DICT = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", +} def xhtml_escape(value: Union[str, bytes]) -> str: @@ -46,8 +51,9 @@ def xhtml_escape(value: Union[str, bytes]) -> str: Added the single quote to the list of escaped characters. """ - return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)], - to_basestring(value)) + return _XHTML_ESCAPE_RE.sub( + lambda match: _XHTML_ESCAPE_DICT[match.group(0)], to_basestring(value) + ) def xhtml_unescape(value: Union[str, bytes]) -> str: @@ -79,7 +85,7 @@ def squeeze(value: str) -> str: return re.sub(r"[\x00-\x20]+", " ", value).strip() -def url_escape(value: Union[str, bytes], plus: bool=True) -> str: +def url_escape(value: Union[str, bytes], plus: bool = True) -> str: """Returns a URL-encoded version of the given value. If ``plus`` is true (the default), spaces will be represented @@ -95,17 +101,20 @@ def url_escape(value: Union[str, bytes], plus: bool=True) -> str: @typing.overload -def url_unescape(value: Union[str, bytes], encoding: None, plus: bool=True) -> bytes: +def url_unescape(value: Union[str, bytes], encoding: None, plus: bool = True) -> bytes: pass @typing.overload # noqa: F811 -def url_unescape(value: Union[str, bytes], encoding: str='utf-8', plus: bool=True) -> str: +def url_unescape( + value: Union[str, bytes], encoding: str = "utf-8", plus: bool = True +) -> str: pass -def url_unescape(value: Union[str, bytes], encoding: Optional[str]='utf-8', # noqa: F811 - plus: bool=True) -> Union[str, bytes]: +def url_unescape( # noqa: F811 + value: Union[str, bytes], encoding: Optional[str] = "utf-8", plus: bool = True +) -> Union[str, bytes]: """Decodes the given value from a URL. The argument may be either a byte or unicode string. @@ -125,16 +134,16 @@ def url_unescape(value: Union[str, bytes], encoding: Optional[str]='utf-8', # n if encoding is None: if plus: # unquote_to_bytes doesn't have a _plus variant - value = to_basestring(value).replace('+', ' ') + value = to_basestring(value).replace("+", " ") return urllib.parse.unquote_to_bytes(value) else: - unquote = (urllib.parse.unquote_plus if plus - else urllib.parse.unquote) + unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote return unquote(to_basestring(value), encoding=encoding) -def parse_qs_bytes(qs: str, keep_blank_values: bool=False, - strict_parsing: bool=False) -> Dict[str, List[bytes]]: +def parse_qs_bytes( + qs: str, keep_blank_values: bool = False, strict_parsing: bool = False +) -> Dict[str, List[bytes]]: """Parses a query string like urlparse.parse_qs, but returns the values as byte strings. @@ -144,11 +153,12 @@ def parse_qs_bytes(qs: str, keep_blank_values: bool=False, """ # This is gross, but python3 doesn't give us another way. # Latin1 is the universal donor of character encodings. - result = urllib.parse.parse_qs(qs, keep_blank_values, strict_parsing, - encoding='latin1', errors='strict') + result = urllib.parse.parse_qs( + qs, keep_blank_values, strict_parsing, encoding="latin1", errors="strict" + ) encoded = {} for k, v in result.items(): - encoded[k] = [i.encode('latin1') for i in v] + encoded[k] = [i.encode("latin1") for i in v] return encoded @@ -179,9 +189,7 @@ def utf8(value: Union[None, str, bytes]) -> Optional[bytes]: # noqa: F811 if isinstance(value, _UTF8_TYPES): return value if not isinstance(value, unicode_type): - raise TypeError( - "Expected bytes, unicode, or None; got %r" % type(value) - ) + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) return value.encode("utf-8") @@ -212,9 +220,7 @@ def to_unicode(value: Union[None, str, bytes]) -> Optional[str]: # noqa: F811 if isinstance(value, _TO_UNICODE_TYPES): return value if not isinstance(value, bytes): - raise TypeError( - "Expected bytes, unicode, or None; got %r" % type(value) - ) + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) return value.decode("utf-8") @@ -256,9 +262,7 @@ def to_basestring(value: Union[None, str, bytes]) -> Optional[str]: # noqa: F81 if isinstance(value, _BASESTRING_TYPES): return value if not isinstance(value, bytes): - raise TypeError( - "Expected bytes, unicode, or None; got %r" % type(value) - ) + raise TypeError("Expected bytes, unicode, or None; got %r" % type(value)) return value.decode("utf-8") @@ -268,7 +272,9 @@ def recursive_unicode(obj: Any) -> Any: Supports lists, tuples, and dictionaries. """ if isinstance(obj, dict): - return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items()) + return dict( + (recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items() + ) elif isinstance(obj, list): return list(recursive_unicode(i) for i in obj) elif isinstance(obj, tuple): @@ -286,14 +292,20 @@ def recursive_unicode(obj: Any) -> Any: # This regex should avoid those problems. # Use to_unicode instead of tornado.util.u - we don't want backslashes getting # processed as escapes. -_URL_RE = re.compile(to_unicode( - r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501 -)) - - -def linkify(text: Union[str, bytes], shorten: bool=False, - extra_params: Union[str, Callable[[str], str]]="", - require_protocol: bool=False, permitted_protocols: List[str]=["http", "https"]) -> str: +_URL_RE = re.compile( + to_unicode( + r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501 + ) +) + + +def linkify( + text: Union[str, bytes], + shorten: bool = False, + extra_params: Union[str, Callable[[str], str]] = "", + require_protocol: bool = False, + permitted_protocols: List[str] = ["http", "https"], +) -> str: """Converts plain text into HTML with links. For example: ``linkify("Hello http://tornadoweb.org!")`` would return @@ -337,7 +349,7 @@ def linkify(text: Union[str, bytes], shorten: bool=False, href = m.group(1) if not proto: - href = "http://" + href # no proto specified, use http + href = "http://" + href # no proto specified, use http if callable(extra_params): params = " " + extra_params(href).strip() @@ -359,14 +371,18 @@ def linkify(text: Union[str, bytes], shorten: bool=False, # The path is usually not that interesting once shortened # (no more slug, etc), so it really just provides a little # extra indication of shortening. - url = url[:proto_len] + parts[0] + "/" + \ - parts[1][:8].split('?')[0].split('.')[0] + url = ( + url[:proto_len] + + parts[0] + + "/" + + parts[1][:8].split("?")[0].split(".")[0] + ) if len(url) > max_len * 1.5: # still too long url = url[:max_len] if url != before_clip: - amp = url.rfind('&') + amp = url.rfind("&") # avoid splitting html char entities if amp > max_len - 5: url = url[:amp] @@ -391,7 +407,7 @@ def linkify(text: Union[str, bytes], shorten: bool=False, def _convert_entity(m: typing.Match) -> str: if m.group(1) == "#": try: - if m.group(2)[:1].lower() == 'x': + if m.group(2)[:1].lower() == "x": return chr(int(m.group(2)[1:], 16)) else: return chr(int(m.group(2))) diff --git a/tornado/gen.py b/tornado/gen.py index 346e897d6..27b504391 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -82,8 +82,14 @@ from inspect import isawaitable import sys import types -from tornado.concurrent import (Future, is_future, chain_future, future_set_exc_info, - future_add_done_callback, future_set_result_unless_cancelled) +from tornado.concurrent import ( + Future, + is_future, + chain_future, + future_set_exc_info, + future_add_done_callback, + future_set_result_unless_cancelled, +) from tornado.ioloop import IOLoop from tornado.log import app_log from tornado.util import TimeoutError @@ -94,10 +100,11 @@ from typing import Union, Any, Callable, List, Type, Tuple, Awaitable, Dict if typing.TYPE_CHECKING: from typing import Sequence, Deque, Optional, Set, Iterable # noqa: F401 -_T = typing.TypeVar('_T') +_T = typing.TypeVar("_T") -_Yieldable = Union[None, Awaitable, List[Awaitable], Dict[Any, Awaitable], - concurrent.futures.Future] +_Yieldable = Union[ + None, Awaitable, List[Awaitable], Dict[Any, Awaitable], concurrent.futures.Future +] class KeyReuseError(Exception): @@ -120,7 +127,7 @@ class ReturnValueIgnoredError(Exception): pass -def _value_from_stopiteration(e: Union[StopIteration, 'Return']) -> Any: +def _value_from_stopiteration(e: Union[StopIteration, "Return"]) -> Any: try: # StopIteration has a value attribute beginning in py33. # So does our Return class. @@ -150,7 +157,9 @@ def _create_future() -> Future: return future -def coroutine(func: Callable[..., 'Generator[Any, Any, _T]']) -> Callable[..., 'Future[_T]']: +def coroutine( + func: Callable[..., "Generator[Any, Any, _T]"] +) -> Callable[..., "Future[_T]"]: """Decorator for asynchronous generators. Any generator that yields objects from this module must be wrapped @@ -182,6 +191,7 @@ def coroutine(func: Callable[..., 'Generator[Any, Any, _T]']) -> Callable[..., ' awaitable object instead. """ + @functools.wraps(func) def wrapper(*args, **kwargs): # type: (*Any, **Any) -> Future[_T] @@ -209,7 +219,9 @@ def coroutine(func: Callable[..., 'Generator[Any, Any, _T]']) -> Callable[..., ' try: yielded = next(result) except (StopIteration, Return) as e: - future_set_result_unless_cancelled(future, _value_from_stopiteration(e)) + future_set_result_unless_cancelled( + future, _value_from_stopiteration(e) + ) except Exception: future_set_exc_info(future, sys.exc_info()) else: @@ -250,7 +262,7 @@ def is_coroutine_function(func: Any) -> bool: .. versionadded:: 4.5 """ - return getattr(func, '__tornado_coroutine__', False) + return getattr(func, "__tornado_coroutine__", False) class Return(Exception): @@ -273,7 +285,8 @@ class Return(Exception): but it is never necessary to ``raise gen.Return()``. The ``return`` statement can be used with no arguments instead. """ - def __init__(self, value: Any=None) -> None: + + def __init__(self, value: Any = None) -> None: super(Return, self).__init__() self.value = value # Cython recognizes subclasses of StopIteration with a .args tuple. @@ -338,8 +351,7 @@ class WaitIterator(object): def __init__(self, *args: Future, **kwargs: Future) -> None: if args and kwargs: - raise ValueError( - "You must provide args or kwargs, not both") + raise ValueError("You must provide args or kwargs, not both") if kwargs: self._unfinished = dict((f, k) for (k, f) in kwargs.items()) @@ -400,14 +412,14 @@ class WaitIterator(object): def __anext__(self) -> Future: if self.done(): # Lookup by name to silence pyflakes on older versions. - raise getattr(builtins, 'StopAsyncIteration')() + raise getattr(builtins, "StopAsyncIteration")() return self.next() def multi( - children: Union[List[_Yieldable], Dict[Any, _Yieldable]], - quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]]=(), -) -> Union['Future[List]', 'Future[Dict]']: + children: Union[List[_Yieldable], Dict[Any, _Yieldable]], + quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (), +) -> Union["Future[List]", "Future[Dict]"]: """Runs multiple asynchronous operations in parallel. ``children`` may either be a list or a dict whose values are @@ -459,9 +471,9 @@ Multi = multi def multi_future( - children: Union[List[_Yieldable], Dict[Any, _Yieldable]], - quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]]=(), -) -> Union['Future[List]', 'Future[Dict]']: + children: Union[List[_Yieldable], Dict[Any, _Yieldable]], + quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (), +) -> Union["Future[List]", "Future[Dict]"]: """Wait for multiple asynchronous futures in parallel. Since Tornado 6.0, this function is exactly the same as `multi`. @@ -488,8 +500,7 @@ def multi_future( future = _create_future() if not children_futs: - future_set_result_unless_cancelled(future, - {} if keys is not None else []) + future_set_result_unless_cancelled(future, {} if keys is not None else []) def callback(fut: Future) -> None: unfinished_children.remove(fut) @@ -501,14 +512,16 @@ def multi_future( except Exception as e: if future.done(): if not isinstance(e, quiet_exceptions): - app_log.error("Multiple exceptions in yield list", - exc_info=True) + app_log.error( + "Multiple exceptions in yield list", exc_info=True + ) else: future_set_exc_info(future, sys.exc_info()) if not future.done(): if keys is not None: - future_set_result_unless_cancelled(future, - dict(zip(keys, result_list))) + future_set_result_unless_cancelled( + future, dict(zip(keys, result_list)) + ) else: future_set_result_unless_cancelled(future, result_list) @@ -542,8 +555,9 @@ def maybe_future(x: Any) -> Future: def with_timeout( - timeout: Union[float, datetime.timedelta], future: _Yieldable, - quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]]=(), + timeout: Union[float, datetime.timedelta], + future: _Yieldable, + quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (), ) -> Future: """Wraps a `.Future` (or other yieldable object) in a timeout. @@ -585,31 +599,34 @@ def with_timeout( future.result() except Exception as e: if not isinstance(e, quiet_exceptions): - app_log.error("Exception in Future %r after timeout", - future, exc_info=True) + app_log.error( + "Exception in Future %r after timeout", future, exc_info=True + ) def timeout_callback() -> None: if not result.done(): result.set_exception(TimeoutError("Timeout")) # In case the wrapped future goes on to fail, log it. future_add_done_callback(future_converted, error_callback) - timeout_handle = io_loop.add_timeout( - timeout, timeout_callback) + + timeout_handle = io_loop.add_timeout(timeout, timeout_callback) if isinstance(future_converted, Future): # We know this future will resolve on the IOLoop, so we don't # need the extra thread-safety of IOLoop.add_future (and we also # don't care about StackContext here. future_add_done_callback( - future_converted, lambda future: io_loop.remove_timeout(timeout_handle)) + future_converted, lambda future: io_loop.remove_timeout(timeout_handle) + ) else: # concurrent.futures.Futures may resolve on any thread, so we # need to route them back to the IOLoop. io_loop.add_future( - future_converted, lambda future: io_loop.remove_timeout(timeout_handle)) + future_converted, lambda future: io_loop.remove_timeout(timeout_handle) + ) return result -def sleep(duration: float) -> 'Future[None]': +def sleep(duration: float) -> "Future[None]": """Return a `.Future` that resolves after the given number of seconds. When used with ``yield`` in a coroutine, this is a non-blocking @@ -624,8 +641,9 @@ def sleep(duration: float) -> 'Future[None]': .. versionadded:: 4.1 """ f = _create_future() - IOLoop.current().call_later(duration, - lambda: future_set_result_unless_cancelled(f, None)) + IOLoop.current().call_later( + duration, lambda: future_set_result_unless_cancelled(f, None) + ) return f @@ -641,6 +659,7 @@ class _NullFuture(object): a _NullFuture into a code path that doesn't understand what to do with it. """ + def result(self) -> None: return None @@ -654,8 +673,7 @@ class _NullFuture(object): _null_future = typing.cast(Future, _NullFuture()) moment = typing.cast(Future, _NullFuture()) -moment.__doc__ = \ - """A special object which may be yielded to allow the IOLoop to run for +moment.__doc__ = """A special object which may be yielded to allow the IOLoop to run for one iteration. This is not needed in normal use but it can be helpful in long-running @@ -679,8 +697,13 @@ class Runner(object): The results of the generator are stored in ``result_future`` (a `.Future`) """ - def __init__(self, gen: 'Generator[_Yieldable, Any, _T]', result_future: 'Future[_T]', - first_yielded: _Yieldable) -> None: + + def __init__( + self, + gen: "Generator[_Yieldable, Any, _T]", + result_future: "Future[_T]", + first_yielded: _Yieldable, + ) -> None: self.gen = gen self.result_future = result_future self.future = _null_future # type: Union[None, Future] @@ -728,8 +751,9 @@ class Runner(object): except (StopIteration, Return) as e: self.finished = True self.future = _null_future - future_set_result_unless_cancelled(self.result_future, - _value_from_stopiteration(e)) + future_set_result_unless_cancelled( + self.result_future, _value_from_stopiteration(e) + ) self.result_future = None # type: ignore return except Exception: @@ -757,17 +781,19 @@ class Runner(object): elif self.future is None: raise Exception("no pending future") elif not self.future.done(): + def inner(f: Any) -> None: # Break a reference cycle to speed GC. f = None # noqa self.run() - self.io_loop.add_future( - self.future, inner) + + self.io_loop.add_future(self.future, inner) return False return True - def handle_exception(self, typ: Type[Exception], value: Exception, - tb: types.TracebackType) -> bool: + def handle_exception( + self, typ: Type[Exception], value: Exception, tb: types.TracebackType + ) -> bool: if not self.running and not self.finished: self.future = Future() future_set_exc_info(self.future, (typ, value, tb)) @@ -783,7 +809,7 @@ try: except AttributeError: # asyncio.ensure_future was introduced in Python 3.4.4, but # Debian jessie still ships with 3.4.2 so try the old name. - _wrap_awaitable = getattr(asyncio, 'async') + _wrap_awaitable = getattr(asyncio, "async") def convert_yielded(yielded: _Yieldable) -> Future: diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 101166635..15c3c59b0 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -22,8 +22,11 @@ import logging import re import types -from tornado.concurrent import (Future, future_add_done_callback, - future_set_result_unless_cancelled) +from tornado.concurrent import ( + Future, + future_add_done_callback, + future_set_result_unless_cancelled, +) from tornado.escape import native_str, utf8 from tornado import gen from tornado import httputil @@ -32,7 +35,17 @@ from tornado.log import gen_log, app_log from tornado.util import GzipDecompressor -from typing import cast, Optional, Type, Awaitable, Generator, Any, Callable, Union, Tuple +from typing import ( + cast, + Optional, + Type, + Awaitable, + Generator, + Any, + Callable, + Union, + Tuple, +) class _QuietException(Exception): @@ -45,15 +58,19 @@ class _ExceptionLoggingContext(object): log any exceptions with the given logger. Any exceptions caught are converted to _QuietException """ + def __init__(self, logger: logging.Logger) -> None: self.logger = logger def __enter__(self) -> None: pass - def __exit__(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: types.TracebackType) -> None: + def __exit__( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: types.TracebackType, + ) -> None: if value is not None: assert typ is not None self.logger.error("Uncaught exception", exc_info=(typ, value, tb)) @@ -63,10 +80,17 @@ class _ExceptionLoggingContext(object): class HTTP1ConnectionParameters(object): """Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`. """ - def __init__(self, no_keep_alive: bool=False, chunk_size: int=None, - max_header_size: int=None, header_timeout: float=None, - max_body_size: int=None, body_timeout: float=None, - decompress: bool=False) -> None: + + def __init__( + self, + no_keep_alive: bool = False, + chunk_size: int = None, + max_header_size: int = None, + header_timeout: float = None, + max_body_size: int = None, + body_timeout: float = None, + decompress: bool = False, + ) -> None: """ :arg bool no_keep_alive: If true, always close the connection after one request. @@ -93,8 +117,14 @@ class HTTP1Connection(httputil.HTTPConnection): This class can be on its own for clients, or via `HTTP1ServerConnection` for servers. """ - def __init__(self, stream: iostream.IOStream, is_client: bool, - params: HTTP1ConnectionParameters=None, context: object=None) -> None: + + def __init__( + self, + stream: iostream.IOStream, + is_client: bool, + params: HTTP1ConnectionParameters = None, + context: object = None, + ) -> None: """ :arg stream: an `.IOStream` :arg bool is_client: client or server @@ -111,8 +141,7 @@ class HTTP1Connection(httputil.HTTPConnection): self.no_keep_alive = params.no_keep_alive # The body limits can be altered by the delegate, so save them # here instead of just referencing self.params later. - self._max_body_size = (self.params.max_body_size or - self.stream.max_buffer_size) + self._max_body_size = self.params.max_body_size or self.stream.max_buffer_size self._body_timeout = self.params.body_timeout # _write_finished is set to True when finish() has been called, # i.e. there will be no more data sent. Data may still be in the @@ -158,12 +187,14 @@ class HTTP1Connection(httputil.HTTPConnection): return self._read_message(delegate) @gen.coroutine - def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, bool]: + def _read_message( + self, delegate: httputil.HTTPMessageDelegate + ) -> Generator[Any, Any, bool]: need_delegate_close = False try: header_future = self.stream.read_until_regex( - b"\r?\n\r?\n", - max_bytes=self.params.max_header_size) + b"\r?\n\r?\n", max_bytes=self.params.max_header_size + ) if self.params.header_timeout is None: header_data = yield header_future else: @@ -171,7 +202,8 @@ class HTTP1Connection(httputil.HTTPConnection): header_data = yield gen.with_timeout( self.stream.io_loop.time() + self.params.header_timeout, header_future, - quiet_exceptions=iostream.StreamClosedError) + quiet_exceptions=iostream.StreamClosedError, + ) except gen.TimeoutError: self.close() return False @@ -179,8 +211,9 @@ class HTTP1Connection(httputil.HTTPConnection): if self.is_client: resp_start_line = httputil.parse_response_start_line(start_line_str) self._response_start_line = resp_start_line - start_line = resp_start_line \ - # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine] + start_line = ( + resp_start_line + ) # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine] # TODO: this will need to change to support client-side keepalive self._disconnect_on_finish = False else: @@ -189,7 +222,8 @@ class HTTP1Connection(httputil.HTTPConnection): self._request_headers = headers start_line = req_start_line self._disconnect_on_finish = not self._can_keep_alive( - req_start_line, headers) + req_start_line, headers + ) need_delegate_close = True with _ExceptionLoggingContext(app_log): header_recv_future = delegate.headers_received(start_line, headers) @@ -202,8 +236,10 @@ class HTTP1Connection(httputil.HTTPConnection): skip_body = False if self.is_client: assert isinstance(start_line, httputil.ResponseStartLine) - if (self._request_start_line is not None and - self._request_start_line.method == 'HEAD'): + if ( + self._request_start_line is not None + and self._request_start_line.method == "HEAD" + ): skip_body = True code = start_line.code if code == 304: @@ -214,20 +250,20 @@ class HTTP1Connection(httputil.HTTPConnection): if code >= 100 and code < 200: # 1xx responses should never indicate the presence of # a body. - if ('Content-Length' in headers or - 'Transfer-Encoding' in headers): + if "Content-Length" in headers or "Transfer-Encoding" in headers: raise httputil.HTTPInputError( - "Response code %d cannot have body" % code) + "Response code %d cannot have body" % code + ) # TODO: client delegates will get headers_received twice # in the case of a 100-continue. Document or change? yield self._read_message(delegate) else: - if (headers.get("Expect") == "100-continue" and - not self._write_finished): + if headers.get("Expect") == "100-continue" and not self._write_finished: self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") if not skip_body: body_future = self._read_body( - resp_start_line.code if self.is_client else 0, headers, delegate) + resp_start_line.code if self.is_client else 0, headers, delegate + ) if body_future is not None: if self._body_timeout is None: yield body_future @@ -236,10 +272,10 @@ class HTTP1Connection(httputil.HTTPConnection): yield gen.with_timeout( self.stream.io_loop.time() + self._body_timeout, body_future, - quiet_exceptions=iostream.StreamClosedError) + quiet_exceptions=iostream.StreamClosedError, + ) except gen.TimeoutError: - gen_log.info("Timeout reading body from %s", - self.context) + gen_log.info("Timeout reading body from %s", self.context) self.stream.close() return False self._read_finished = True @@ -250,9 +286,11 @@ class HTTP1Connection(httputil.HTTPConnection): # If we're waiting for the application to produce an asynchronous # response, and we're not detached, register a close callback # on the stream (we didn't need one while we were reading) - if (not self._finish_future.done() and - self.stream is not None and - not self.stream.closed()): + if ( + not self._finish_future.done() + and self.stream is not None + and not self.stream.closed() + ): self.stream.set_close_callback(self._on_connection_close) yield self._finish_future if self.is_client and self._disconnect_on_finish: @@ -260,10 +298,9 @@ class HTTP1Connection(httputil.HTTPConnection): if self.stream is None: return False except httputil.HTTPInputError as e: - gen_log.info("Malformed HTTP message from %s: %s", - self.context, e) + gen_log.info("Malformed HTTP message from %s: %s", self.context, e) if not self.is_client: - yield self.stream.write(b'HTTP/1.1 400 Bad Request\r\n\r\n') + yield self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n") self.close() return False finally: @@ -348,67 +385,83 @@ class HTTP1Connection(httputil.HTTPConnection): """ self._max_body_size = max_body_size - def write_headers(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders, chunk: bytes=None) -> 'Future[None]': + def write_headers( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + chunk: bytes = None, + ) -> "Future[None]": """Implements `.HTTPConnection.write_headers`.""" lines = [] if self.is_client: assert isinstance(start_line, httputil.RequestStartLine) self._request_start_line = start_line - lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1]))) + lines.append(utf8("%s %s HTTP/1.1" % (start_line[0], start_line[1]))) # Client requests with a non-empty body must have either a # Content-Length or a Transfer-Encoding. self._chunking_output = ( - start_line.method in ('POST', 'PUT', 'PATCH') and - 'Content-Length' not in headers and - 'Transfer-Encoding' not in headers) + start_line.method in ("POST", "PUT", "PATCH") + and "Content-Length" not in headers + and "Transfer-Encoding" not in headers + ) else: assert isinstance(start_line, httputil.ResponseStartLine) assert self._request_start_line is not None assert self._request_headers is not None self._response_start_line = start_line - lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2]))) + lines.append(utf8("HTTP/1.1 %d %s" % (start_line[1], start_line[2]))) self._chunking_output = ( # TODO: should this use # self._request_start_line.version or # start_line.version? - self._request_start_line.version == 'HTTP/1.1' and + self._request_start_line.version == "HTTP/1.1" + and # 1xx, 204 and 304 responses have no body (not even a zero-length # body), and so should not have either Content-Length or # Transfer-Encoding headers. - start_line.code not in (204, 304) and - (start_line.code < 100 or start_line.code >= 200) and + start_line.code not in (204, 304) + and (start_line.code < 100 or start_line.code >= 200) + and # No need to chunk the output if a Content-Length is specified. - 'Content-Length' not in headers and + "Content-Length" not in headers + and # Applications are discouraged from touching Transfer-Encoding, # but if they do, leave it alone. - 'Transfer-Encoding' not in headers) + "Transfer-Encoding" not in headers + ) # If connection to a 1.1 client will be closed, inform client - if (self._request_start_line.version == 'HTTP/1.1' and self._disconnect_on_finish): - headers['Connection'] = 'close' + if ( + self._request_start_line.version == "HTTP/1.1" + and self._disconnect_on_finish + ): + headers["Connection"] = "close" # If a 1.0 client asked for keep-alive, add the header. - if (self._request_start_line.version == 'HTTP/1.0' and - self._request_headers.get('Connection', '').lower() == 'keep-alive'): - headers['Connection'] = 'Keep-Alive' + if ( + self._request_start_line.version == "HTTP/1.0" + and self._request_headers.get("Connection", "").lower() == "keep-alive" + ): + headers["Connection"] = "Keep-Alive" if self._chunking_output: - headers['Transfer-Encoding'] = 'chunked' - if (not self.is_client and - (self._request_start_line.method == 'HEAD' or - cast(httputil.ResponseStartLine, start_line).code == 304)): + headers["Transfer-Encoding"] = "chunked" + if not self.is_client and ( + self._request_start_line.method == "HEAD" + or cast(httputil.ResponseStartLine, start_line).code == 304 + ): self._expected_content_remaining = 0 - elif 'Content-Length' in headers: - self._expected_content_remaining = int(headers['Content-Length']) + elif "Content-Length" in headers: + self._expected_content_remaining = int(headers["Content-Length"]) else: self._expected_content_remaining = None # TODO: headers are supposed to be of type str, but we still have some # cases that let bytes slip through. Remove these native_str calls when those # are fixed. - header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all()) - lines.extend(l.encode('latin1') for l in header_lines) + header_lines = ( + native_str(n) + ": " + native_str(v) for n, v in headers.get_all() + ) + lines.extend(l.encode("latin1") for l in header_lines) for line in lines: - if b'\n' in line: - raise ValueError('Newline in header: ' + repr(line)) + if b"\n" in line: + raise ValueError("Newline in header: " + repr(line)) future = None if self.stream.closed(): future = self._write_future = Future() @@ -430,7 +483,8 @@ class HTTP1Connection(httputil.HTTPConnection): # Close the stream now to stop further framing errors. self.stream.close() raise httputil.HTTPOutputError( - "Tried to write more data than Content-Length") + "Tried to write more data than Content-Length" + ) if self._chunking_output and chunk: # Don't write out empty chunks because that means END-OF-STREAM # with chunked encoding @@ -438,7 +492,7 @@ class HTTP1Connection(httputil.HTTPConnection): else: return chunk - def write(self, chunk: bytes) -> 'Future[None]': + def write(self, chunk: bytes) -> "Future[None]": """Implements `.HTTPConnection.write`. For backwards compatibility it is allowed but deprecated to @@ -458,13 +512,16 @@ class HTTP1Connection(httputil.HTTPConnection): def finish(self) -> None: """Implements `.HTTPConnection.finish`.""" - if (self._expected_content_remaining is not None and - self._expected_content_remaining != 0 and - not self.stream.closed()): + if ( + self._expected_content_remaining is not None + and self._expected_content_remaining != 0 + and not self.stream.closed() + ): self.stream.close() raise httputil.HTTPOutputError( - "Tried to write %d bytes less than Content-Length" % - self._expected_content_remaining) + "Tried to write %d bytes less than Content-Length" + % self._expected_content_remaining + ) if self._chunking_output: if not self.stream.closed(): self._pending_write = self.stream.write(b"0\r\n\r\n") @@ -485,7 +542,7 @@ class HTTP1Connection(httputil.HTTPConnection): else: future_add_done_callback(self._pending_write, self._finish_request) - def _on_write_complete(self, future: 'Future[None]') -> None: + def _on_write_complete(self, future: "Future[None]") -> None: exc = future.exception() if exc is not None and not isinstance(exc, iostream.StreamClosedError): future.result() @@ -498,8 +555,9 @@ class HTTP1Connection(httputil.HTTPConnection): self._write_future = None future_set_result_unless_cancelled(future, None) - def _can_keep_alive(self, start_line: httputil.RequestStartLine, - headers: httputil.HTTPHeaders) -> bool: + def _can_keep_alive( + self, start_line: httputil.RequestStartLine, headers: httputil.HTTPHeaders + ) -> bool: if self.params.no_keep_alive: return False connection_header = headers.get("Connection") @@ -507,15 +565,17 @@ class HTTP1Connection(httputil.HTTPConnection): connection_header = connection_header.lower() if start_line.version == "HTTP/1.1": return connection_header != "close" - elif ("Content-Length" in headers or - headers.get("Transfer-Encoding", "").lower() == "chunked" or - getattr(start_line, 'method', None) in ("HEAD", "GET")): + elif ( + "Content-Length" in headers + or headers.get("Transfer-Encoding", "").lower() == "chunked" + or getattr(start_line, "method", None) in ("HEAD", "GET") + ): # start_line may be a request or response start line; only # the former has a method attribute. return connection_header == "keep-alive" return False - def _finish_request(self, future: Optional['Future[None]']) -> None: + def _finish_request(self, future: Optional["Future[None]"]) -> None: self._clear_callbacks() if not self.is_client and self._disconnect_on_finish: self.close() @@ -531,31 +591,37 @@ class HTTP1Connection(httputil.HTTPConnection): # insert between messages of a reused connection. Per RFC 7230, # we SHOULD ignore at least one empty line before the request. # http://tools.ietf.org/html/rfc7230#section-3.5 - data_str = native_str(data.decode('latin1')).lstrip("\r\n") + data_str = native_str(data.decode("latin1")).lstrip("\r\n") # RFC 7230 section allows for both CRLF and bare LF. eol = data_str.find("\n") start_line = data_str[:eol].rstrip("\r") headers = httputil.HTTPHeaders.parse(data_str[eol:]) return start_line, headers - def _read_body(self, code: int, headers: httputil.HTTPHeaders, - delegate: httputil.HTTPMessageDelegate) -> Optional[Awaitable[None]]: + def _read_body( + self, + code: int, + headers: httputil.HTTPHeaders, + delegate: httputil.HTTPMessageDelegate, + ) -> Optional[Awaitable[None]]: if "Content-Length" in headers: if "Transfer-Encoding" in headers: # Response cannot contain both Content-Length and # Transfer-Encoding headers. # http://tools.ietf.org/html/rfc7230#section-3.3.3 raise httputil.HTTPInputError( - "Response with both Transfer-Encoding and Content-Length") + "Response with both Transfer-Encoding and Content-Length" + ) if "," in headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. - pieces = re.split(r',\s*', headers["Content-Length"]) + pieces = re.split(r",\s*", headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise httputil.HTTPInputError( - "Multiple unequal Content-Lengths: %r" % - headers["Content-Length"]) + "Multiple unequal Content-Lengths: %r" + % headers["Content-Length"] + ) headers["Content-Length"] = pieces[0] try: @@ -563,7 +629,9 @@ class HTTP1Connection(httputil.HTTPConnection): except ValueError: # Handles non-integer Content-Length value. raise httputil.HTTPInputError( - "Only integer Content-Length is allowed: %s" % headers["Content-Length"]) + "Only integer Content-Length is allowed: %s" + % headers["Content-Length"] + ) if cast(int, content_length) > self._max_body_size: raise httputil.HTTPInputError("Content-Length too long") @@ -574,10 +642,10 @@ class HTTP1Connection(httputil.HTTPConnection): # This response code is not allowed to have a non-empty body, # and has an implicit length of zero instead of read-until-close. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 - if ("Transfer-Encoding" in headers or - content_length not in (None, 0)): + if "Transfer-Encoding" in headers or content_length not in (None, 0): raise httputil.HTTPInputError( - "Response with code %d should not have body" % code) + "Response with code %d should not have body" % code + ) content_length = 0 if content_length is not None: @@ -589,11 +657,13 @@ class HTTP1Connection(httputil.HTTPConnection): return None @gen.coroutine - def _read_fixed_body(self, content_length: int, - delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, None]: + def _read_fixed_body( + self, content_length: int, delegate: httputil.HTTPMessageDelegate + ) -> Generator[Any, Any, None]: while content_length > 0: body = yield self.stream.read_bytes( - min(self.params.chunk_size, content_length), partial=True) + min(self.params.chunk_size, content_length), partial=True + ) content_length -= len(body) if not self._write_finished or self.is_client: with _ExceptionLoggingContext(app_log): @@ -603,7 +673,8 @@ class HTTP1Connection(httputil.HTTPConnection): @gen.coroutine def _read_chunked_body( - self, delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, None]: + self, delegate: httputil.HTTPMessageDelegate + ) -> Generator[Any, Any, None]: # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 total_size = 0 while True: @@ -611,8 +682,10 @@ class HTTP1Connection(httputil.HTTPConnection): chunk_len = int(chunk_len.strip(), 16) if chunk_len == 0: crlf = yield self.stream.read_bytes(2) - if crlf != b'\r\n': - raise httputil.HTTPInputError("improperly terminated chunked request") + if crlf != b"\r\n": + raise httputil.HTTPInputError( + "improperly terminated chunked request" + ) return total_size += chunk_len if total_size > self._max_body_size: @@ -620,7 +693,8 @@ class HTTP1Connection(httputil.HTTPConnection): bytes_to_read = chunk_len while bytes_to_read: chunk = yield self.stream.read_bytes( - min(bytes_to_read, self.params.chunk_size), partial=True) + min(bytes_to_read, self.params.chunk_size), partial=True + ) bytes_to_read -= len(chunk) if not self._write_finished or self.is_client: with _ExceptionLoggingContext(app_log): @@ -633,7 +707,8 @@ class HTTP1Connection(httputil.HTTPConnection): @gen.coroutine def _read_body_until_close( - self, delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, None]: + self, delegate: httputil.HTTPMessageDelegate + ) -> Generator[Any, Any, None]: body = yield self.stream.read_until_close() if not self._write_finished or self.is_client: with _ExceptionLoggingContext(app_log): @@ -643,21 +718,23 @@ class HTTP1Connection(httputil.HTTPConnection): class _GzipMessageDelegate(httputil.HTTPMessageDelegate): """Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``. """ + def __init__(self, delegate: httputil.HTTPMessageDelegate, chunk_size: int) -> None: self._delegate = delegate self._chunk_size = chunk_size self._decompressor = None # type: Optional[GzipDecompressor] - def headers_received(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: if headers.get("Content-Encoding") == "gzip": self._decompressor = GzipDecompressor() # Downstream delegates will only see uncompressed data, # so rename the content-encoding header. # (but note that curl_httpclient doesn't do this). - headers.add("X-Consumed-Content-Encoding", - headers["Content-Encoding"]) + headers.add("X-Consumed-Content-Encoding", headers["Content-Encoding"]) del headers["Content-Encoding"] return self._delegate.headers_received(start_line, headers) @@ -667,7 +744,8 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate): compressed_data = chunk while compressed_data: decompressed = self._decompressor.decompress( - compressed_data, self._chunk_size) + compressed_data, self._chunk_size + ) if decompressed: ret = self._delegate.data_received(decompressed) if ret is not None: @@ -696,8 +774,13 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate): class HTTP1ServerConnection(object): """An HTTP/1.x server.""" - def __init__(self, stream: iostream.IOStream, - params: HTTP1ConnectionParameters=None, context: object=None) -> None: + + def __init__( + self, + stream: iostream.IOStream, + params: HTTP1ConnectionParameters = None, + context: object = None, + ) -> None: """ :arg stream: an `.IOStream` :arg params: a `.HTTP1ConnectionParameters` or None @@ -738,16 +821,15 @@ class HTTP1ServerConnection(object): @gen.coroutine def _server_request_loop( - self, delegate: httputil.HTTPServerConnectionDelegate) -> Generator[Any, Any, None]: + self, delegate: httputil.HTTPServerConnectionDelegate + ) -> Generator[Any, Any, None]: try: while True: - conn = HTTP1Connection(self.stream, False, - self.params, self.context) + conn = HTTP1Connection(self.stream, False, self.params, self.context) request_delegate = delegate.start_request(self, conn) try: ret = yield conn.read_response(request_delegate) - except (iostream.StreamClosedError, - iostream.UnsatisfiableReadError): + except (iostream.StreamClosedError, iostream.UnsatisfiableReadError): return except _QuietException: # This exception was already logged. diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 0c2572dbb..e2623d532 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -83,7 +83,10 @@ class HTTPClient(object): Use `AsyncHTTPClient` instead. """ - def __init__(self, async_client_class: Type['AsyncHTTPClient']=None, **kwargs: Any) -> None: + + def __init__( + self, async_client_class: Type["AsyncHTTPClient"] = None, **kwargs: Any + ) -> None: # Initialize self._closed at the beginning of the constructor # so that an exception raised here doesn't lead to confusing # failures in __del__. @@ -94,10 +97,11 @@ class HTTPClient(object): # Create the client while our IOLoop is "current", without # clobbering the thread's real current IOLoop (if any). - async def make_client() -> 'AsyncHTTPClient': + async def make_client() -> "AsyncHTTPClient": await gen.sleep(0) assert async_client_class is not None return async_client_class(**kwargs) + self._async_client = self._io_loop.run_sync(make_client) self._closed = False @@ -111,7 +115,9 @@ class HTTPClient(object): self._io_loop.close() self._closed = True - def fetch(self, request: Union['HTTPRequest', str], **kwargs: Any) -> 'HTTPResponse': + def fetch( + self, request: Union["HTTPRequest", str], **kwargs: Any + ) -> "HTTPResponse": """Executes a request, returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. @@ -121,8 +127,9 @@ class HTTPClient(object): If an error occurs during the fetch, we raise an `HTTPError` unless the ``raise_error`` keyword argument is set to False. """ - response = self._io_loop.run_sync(functools.partial( - self._async_client.fetch, request, **kwargs)) + response = self._io_loop.run_sync( + functools.partial(self._async_client.fetch, request, **kwargs) + ) return response @@ -174,16 +181,17 @@ class AsyncHTTPClient(Configurable): @classmethod def configurable_default(cls) -> Type[Configurable]: from tornado.simple_httpclient import SimpleAsyncHTTPClient + return SimpleAsyncHTTPClient @classmethod - def _async_clients(cls) -> Dict[IOLoop, 'AsyncHTTPClient']: - attr_name = '_async_client_dict_' + cls.__name__ + def _async_clients(cls) -> Dict[IOLoop, "AsyncHTTPClient"]: + attr_name = "_async_client_dict_" + cls.__name__ if not hasattr(cls, attr_name): setattr(cls, attr_name, weakref.WeakKeyDictionary()) return getattr(cls, attr_name) - def __new__(cls, force_instance: bool=False, **kwargs: Any) -> 'AsyncHTTPClient': + def __new__(cls, force_instance: bool = False, **kwargs: Any) -> "AsyncHTTPClient": io_loop = IOLoop.current() if force_instance: instance_cache = None @@ -201,7 +209,7 @@ class AsyncHTTPClient(Configurable): instance_cache[instance.io_loop] = instance return instance - def initialize(self, defaults: Dict[str, Any]=None) -> None: + def initialize(self, defaults: Dict[str, Any] = None) -> None: self.io_loop = IOLoop.current() self.defaults = dict(HTTPRequest._DEFAULTS) if defaults is not None: @@ -229,8 +237,12 @@ class AsyncHTTPClient(Configurable): raise RuntimeError("inconsistent AsyncHTTPClient cache") del self._instance_cache[self.io_loop] - def fetch(self, request: Union[str, 'HTTPRequest'], - raise_error: bool=True, **kwargs: Any) -> 'Future[HTTPResponse]': + def fetch( + self, + request: Union[str, "HTTPRequest"], + raise_error: bool = True, + **kwargs: Any + ) -> "Future[HTTPResponse]": """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. @@ -265,7 +277,9 @@ class AsyncHTTPClient(Configurable): request = HTTPRequest(url=request, **kwargs) else: if kwargs: - raise ValueError("kwargs can't be used if request is an HTTPRequest object") + raise ValueError( + "kwargs can't be used if request is an HTTPRequest object" + ) # We may modify this (to add Host, Accept-Encoding, etc), # so make sure we don't modify the caller's object. This is also # where normal dicts get converted to HTTPHeaders objects. @@ -273,21 +287,25 @@ class AsyncHTTPClient(Configurable): request_proxy = _RequestProxy(request, self.defaults) future = Future() # type: Future[HTTPResponse] - def handle_response(response: 'HTTPResponse') -> None: + def handle_response(response: "HTTPResponse") -> None: if response.error: if raise_error or not response._error_is_response_code: future.set_exception(response.error) return future_set_result_unless_cancelled(future, response) + self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response) return future - def fetch_impl(self, request: 'HTTPRequest', - callback: Callable[['HTTPResponse'], None]) -> None: + def fetch_impl( + self, request: "HTTPRequest", callback: Callable[["HTTPResponse"], None] + ) -> None: raise NotImplementedError() @classmethod - def configure(cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any) -> None: + def configure( + cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any + ) -> None: """Configures the `AsyncHTTPClient` subclass to use. ``AsyncHTTPClient()`` actually creates an instance of a subclass. @@ -311,6 +329,7 @@ class AsyncHTTPClient(Configurable): class HTTPRequest(object): """HTTP client request object.""" + _headers = None # type: Union[Dict[str, str], httputil.HTTPHeaders] # Default values for HTTPRequest parameters. @@ -322,30 +341,47 @@ class HTTPRequest(object): follow_redirects=True, max_redirects=5, decompress_response=True, - proxy_password='', + proxy_password="", allow_nonstandard_methods=False, - validate_cert=True) - - def __init__(self, url: str, method: str="GET", - headers: Union[Dict[str, str], httputil.HTTPHeaders]=None, - body: Union[bytes, str]=None, - auth_username: str=None, auth_password: str=None, auth_mode: str=None, - connect_timeout: float=None, request_timeout: float=None, - if_modified_since: Union[float, datetime.datetime]=None, - follow_redirects: bool=None, - max_redirects: int=None, user_agent: str=None, use_gzip: bool=None, - network_interface: str=None, - streaming_callback: Callable[[bytes], None]=None, - header_callback: Callable[[str], None]=None, - prepare_curl_callback: Callable[[Any], None]=None, - proxy_host: str=None, proxy_port: int=None, proxy_username: str=None, - proxy_password: str=None, proxy_auth_mode: str=None, - allow_nonstandard_methods: bool=None, validate_cert: bool=None, - ca_certs: str=None, allow_ipv6: bool=None, client_key: str=None, - client_cert: str=None, - body_producer: Callable[[Callable[[bytes], None]], 'Future[None]']=None, - expect_100_continue: bool=False, decompress_response: bool=None, - ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None) -> None: + validate_cert=True, + ) + + def __init__( + self, + url: str, + method: str = "GET", + headers: Union[Dict[str, str], httputil.HTTPHeaders] = None, + body: Union[bytes, str] = None, + auth_username: str = None, + auth_password: str = None, + auth_mode: str = None, + connect_timeout: float = None, + request_timeout: float = None, + if_modified_since: Union[float, datetime.datetime] = None, + follow_redirects: bool = None, + max_redirects: int = None, + user_agent: str = None, + use_gzip: bool = None, + network_interface: str = None, + streaming_callback: Callable[[bytes], None] = None, + header_callback: Callable[[str], None] = None, + prepare_curl_callback: Callable[[Any], None] = None, + proxy_host: str = None, + proxy_port: int = None, + proxy_username: str = None, + proxy_password: str = None, + proxy_auth_mode: str = None, + allow_nonstandard_methods: bool = None, + validate_cert: bool = None, + ca_certs: str = None, + allow_ipv6: bool = None, + client_key: str = None, + client_cert: str = None, + body_producer: Callable[[Callable[[bytes], None]], "Future[None]"] = None, + expect_100_continue: bool = False, + decompress_response: bool = None, + ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, + ) -> None: r"""All parameters except ``url`` are optional. :arg str url: URL to fetch @@ -462,7 +498,8 @@ class HTTPRequest(object): self.headers = headers if if_modified_since: self.headers["If-Modified-Since"] = httputil.format_timestamp( - if_modified_since) + if_modified_since + ) self.proxy_host = proxy_host self.proxy_port = proxy_port self.proxy_username = proxy_username @@ -569,16 +606,25 @@ class HTTPResponse(object): is excluded in both implementations. ``request_time`` is now more accurate for ``curl_httpclient`` because it uses a monotonic clock when available. """ + # I'm not sure why these don't get type-inferred from the references in __init__. error = None # type: Optional[BaseException] _error_is_response_code = False request = None # type: HTTPRequest - def __init__(self, request: HTTPRequest, code: int, - headers: httputil.HTTPHeaders=None, buffer: BytesIO=None, - effective_url: str=None, error: BaseException=None, - request_time: float=None, time_info: Dict[str, float]=None, - reason: str=None, start_time: float=None) -> None: + def __init__( + self, + request: HTTPRequest, + code: int, + headers: httputil.HTTPHeaders = None, + buffer: BytesIO = None, + effective_url: str = None, + error: BaseException = None, + request_time: float = None, + time_info: Dict[str, float] = None, + reason: str = None, + start_time: float = None, + ) -> None: if isinstance(request, _RequestProxy): self.request = request.request else: @@ -599,8 +645,7 @@ class HTTPResponse(object): if error is None: if self.code < 200 or self.code >= 300: self._error_is_response_code = True - self.error = HTTPError(self.code, message=self.reason, - response=self) + self.error = HTTPError(self.code, message=self.reason, response=self) else: self.error = None else: @@ -648,7 +693,10 @@ class HTTPClientError(Exception): `tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains as an alias. """ - def __init__(self, code: int, message: str=None, response: HTTPResponse=None) -> None: + + def __init__( + self, code: int, message: str = None, response: HTTPResponse = None + ) -> None: self.code = code self.message = message or httputil.responses.get(code, "Unknown") self.response = response @@ -672,7 +720,10 @@ class _RequestProxy(object): Used internally by AsyncHTTPClient implementations. """ - def __init__(self, request: HTTPRequest, defaults: Optional[Dict[str, Any]]) -> None: + + def __init__( + self, request: HTTPRequest, defaults: Optional[Dict[str, Any]] + ) -> None: self.request = request self.defaults = defaults @@ -688,6 +739,7 @@ class _RequestProxy(object): def main() -> None: from tornado.options import define, options, parse_command_line + define("print_headers", type=bool, default=False) define("print_body", type=bool, default=True) define("follow_redirects", type=bool, default=True) @@ -698,12 +750,13 @@ def main() -> None: client = HTTPClient() for arg in args: try: - response = client.fetch(arg, - follow_redirects=options.follow_redirects, - validate_cert=options.validate_cert, - proxy_host=options.proxy_host, - proxy_port=options.proxy_port, - ) + response = client.fetch( + arg, + follow_redirects=options.follow_redirects, + validate_cert=options.validate_cert, + proxy_host=options.proxy_host, + proxy_port=options.proxy_port, + ) except HTTPError as e: if e.response is not None: response = e.response diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 75bcbedaf..0552ec91b 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -38,13 +38,24 @@ from tornado.tcpserver import TCPServer from tornado.util import Configurable import typing -from typing import Union, Any, Dict, Callable, List, Type, Generator, Tuple, Optional, Awaitable +from typing import ( + Union, + Any, + Dict, + Callable, + List, + Type, + Generator, + Tuple, + Optional, + Awaitable, +) + if typing.TYPE_CHECKING: from typing import Set # noqa: F401 -class HTTPServer(TCPServer, Configurable, - httputil.HTTPServerConnectionDelegate): +class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. A server is defined by a subclass of `.HTTPServerConnectionDelegate`, @@ -141,6 +152,7 @@ class HTTPServer(TCPServer, Configurable, .. versionchanged:: 5.0 The ``io_loop`` argument has been removed. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: # Ignore args to __init__; real initialization belongs in # initialize since we're Configurable. (there's something @@ -149,21 +161,25 @@ class HTTPServer(TCPServer, Configurable, # completely) pass - def initialize(self, # type: ignore - request_callback: Union[httputil.HTTPServerConnectionDelegate, - Callable[[httputil.HTTPServerRequest], None]], - no_keep_alive: bool=False, - xheaders: bool=False, - ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, - protocol: str=None, - decompress_request: bool=False, - chunk_size: int=None, - max_header_size: int=None, - idle_connection_timeout: float=None, - body_timeout: float=None, - max_body_size: int=None, - max_buffer_size: int=None, - trusted_downstream: List[str]=None) -> None: + def initialize( # type: ignore + self, + request_callback: Union[ + httputil.HTTPServerConnectionDelegate, + Callable[[httputil.HTTPServerRequest], None], + ], + no_keep_alive: bool = False, + xheaders: bool = False, + ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, + protocol: str = None, + decompress_request: bool = False, + chunk_size: int = None, + max_header_size: int = None, + idle_connection_timeout: float = None, + body_timeout: float = None, + max_body_size: int = None, + max_buffer_size: int = None, + trusted_downstream: List[str] = None, + ) -> None: self.request_callback = request_callback self.xheaders = xheaders self.protocol = protocol @@ -174,10 +190,14 @@ class HTTPServer(TCPServer, Configurable, header_timeout=idle_connection_timeout or 3600, max_body_size=max_body_size, body_timeout=body_timeout, - no_keep_alive=no_keep_alive) - TCPServer.__init__(self, ssl_options=ssl_options, - max_buffer_size=max_buffer_size, - read_chunk_size=chunk_size) + no_keep_alive=no_keep_alive, + ) + TCPServer.__init__( + self, + ssl_options=ssl_options, + max_buffer_size=max_buffer_size, + read_chunk_size=chunk_size, + ) self._connections = set() # type: Set[HTTP1ServerConnection] self.trusted_downstream = trusted_downstream @@ -197,16 +217,16 @@ class HTTPServer(TCPServer, Configurable, yield conn.close() def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None: - context = _HTTPRequestContext(stream, address, - self.protocol, - self.trusted_downstream) - conn = HTTP1ServerConnection( - stream, self.conn_params, context) + context = _HTTPRequestContext( + stream, address, self.protocol, self.trusted_downstream + ) + conn = HTTP1ServerConnection(stream, self.conn_params, context) self._connections.add(conn) conn.start_serving(self) - def start_request(self, server_conn: object, - request_conn: httputil.HTTPConnection) -> httputil.HTTPMessageDelegate: + def start_request( + self, server_conn: object, request_conn: httputil.HTTPConnection + ) -> httputil.HTTPMessageDelegate: if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate): delegate = self.request_callback.start_request(server_conn, request_conn) else: @@ -222,21 +242,27 @@ class HTTPServer(TCPServer, Configurable, class _CallableAdapter(httputil.HTTPMessageDelegate): - def __init__(self, request_callback: Callable[[httputil.HTTPServerRequest], None], - request_conn: httputil.HTTPConnection) -> None: + def __init__( + self, + request_callback: Callable[[httputil.HTTPServerRequest], None], + request_conn: httputil.HTTPConnection, + ) -> None: self.connection = request_conn self.request_callback = request_callback self.request = None # type: Optional[httputil.HTTPServerRequest] self.delegate = None self._chunks = [] # type: List[bytes] - def headers_received(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: self.request = httputil.HTTPServerRequest( connection=self.connection, start_line=typing.cast(httputil.RequestStartLine, start_line), - headers=headers) + headers=headers, + ) return None def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: @@ -245,7 +271,7 @@ class _CallableAdapter(httputil.HTTPMessageDelegate): def finish(self) -> None: assert self.request is not None - self.request.body = b''.join(self._chunks) + self.request.body = b"".join(self._chunks) self.request._parse_body() self.request_callback(self.request) @@ -254,9 +280,13 @@ class _CallableAdapter(httputil.HTTPMessageDelegate): class _HTTPRequestContext(object): - def __init__(self, stream: iostream.IOStream, address: Tuple, - protocol: Optional[str], - trusted_downstream: List[str]=None) -> None: + def __init__( + self, + stream: iostream.IOStream, + address: Tuple, + protocol: Optional[str], + trusted_downstream: List[str] = None, + ) -> None: self.address = address # Save the socket's address family now so we know how to # interpret self.address even after the stream is closed @@ -266,12 +296,14 @@ class _HTTPRequestContext(object): else: self.address_family = None # In HTTPServerRequest we want an IP, not a full socket address. - if (self.address_family in (socket.AF_INET, socket.AF_INET6) and - address is not None): + if ( + self.address_family in (socket.AF_INET, socket.AF_INET6) + and address is not None + ): self.remote_ip = address[0] else: # Unix (or other) socket; fake the remote address. - self.remote_ip = '0.0.0.0' + self.remote_ip = "0.0.0.0" if protocol: self.protocol = protocol elif isinstance(stream, iostream.SSLIOStream): @@ -298,7 +330,7 @@ class _HTTPRequestContext(object): # Squid uses X-Forwarded-For, others use X-Real-Ip ip = headers.get("X-Forwarded-For", self.remote_ip) # Skip trusted downstream hosts in X-Forwarded-For list - for ip in (cand.strip() for cand in reversed(ip.split(','))): + for ip in (cand.strip() for cand in reversed(ip.split(","))): if ip not in self.trusted_downstream: break ip = headers.get("X-Real-Ip", ip) @@ -306,12 +338,12 @@ class _HTTPRequestContext(object): self.remote_ip = ip # AWS uses X-Forwarded-Proto proto_header = headers.get( - "X-Scheme", headers.get("X-Forwarded-Proto", - self.protocol)) + "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol) + ) if proto_header: # use only the last proto entry if there is more than one # TODO: support trusting mutiple layers of proxied protocol - proto_header = proto_header.split(',')[-1].strip() + proto_header = proto_header.split(",")[-1].strip() if proto_header in ("http", "https"): self.protocol = proto_header @@ -326,14 +358,19 @@ class _HTTPRequestContext(object): class _ProxyAdapter(httputil.HTTPMessageDelegate): - def __init__(self, delegate: httputil.HTTPMessageDelegate, - request_conn: httputil.HTTPConnection) -> None: + def __init__( + self, + delegate: httputil.HTTPMessageDelegate, + request_conn: httputil.HTTPConnection, + ) -> None: self.connection = request_conn self.delegate = delegate - def headers_received(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: # TODO: either make context an official part of the # HTTPConnection interface or figure out some other way to do this. self.connection.context._apply_xheaders(headers) # type: ignore diff --git a/tornado/httputil.py b/tornado/httputil.py index c88e862a2..e3d63857b 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -42,8 +42,19 @@ from tornado.util import ObjectDict, unicode_type responses import typing -from typing import (Tuple, Iterable, List, Mapping, Iterator, Dict, Union, Optional, - Awaitable, Generator, AnyStr) +from typing import ( + Tuple, + Iterable, + List, + Mapping, + Iterator, + Dict, + Union, + Optional, + Awaitable, + Generator, + AnyStr, +) if typing.TYPE_CHECKING: from typing import Deque # noqa @@ -53,7 +64,7 @@ if typing.TYPE_CHECKING: # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line # terminator and ignore any preceding CR. -_CRLF_RE = re.compile(r'\r?\n') +_CRLF_RE = re.compile(r"\r?\n") class _NormalizedHeaderCache(dict): @@ -67,6 +78,7 @@ class _NormalizedHeaderCache(dict): >>> normalized_headers["coNtent-TYPE"] 'Content-Type' """ + def __init__(self, size: int) -> None: super(_NormalizedHeaderCache, self).__init__() self.size = size @@ -116,6 +128,7 @@ class HTTPHeaders(collections.MutableMapping): Set-Cookie: A=B Set-Cookie: C=D """ + @typing.overload def __init__(self, __arg: Mapping[str, List[str]]) -> None: pass @@ -136,8 +149,7 @@ class HTTPHeaders(collections.MutableMapping): self._dict = {} # type: typing.Dict[str, str] self._as_list = {} # type: typing.Dict[str, typing.List[str]] self._last_key = None - if (len(args) == 1 and len(kwargs) == 0 and - isinstance(args[0], HTTPHeaders)): + if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], HTTPHeaders): # Copy constructor for k, v in args[0].get_all(): self.add(k, v) @@ -152,8 +164,9 @@ class HTTPHeaders(collections.MutableMapping): norm_name = _normalized_headers[name] self._last_key = norm_name if norm_name in self: - self._dict[norm_name] = (native_str(self[norm_name]) + ',' + - native_str(value)) + self._dict[norm_name] = ( + native_str(self[norm_name]) + "," + native_str(value) + ) self._as_list[norm_name].append(value) else: self[norm_name] = value @@ -185,7 +198,7 @@ class HTTPHeaders(collections.MutableMapping): # continuation of a multi-line header if self._last_key is None: raise HTTPInputError("first header line cannot start with whitespace") - new_part = ' ' + line.lstrip() + new_part = " " + line.lstrip() self._as_list[self._last_key][-1] += new_part self._dict[self._last_key] += new_part else: @@ -196,7 +209,7 @@ class HTTPHeaders(collections.MutableMapping): self.add(name, value.strip()) @classmethod - def parse(cls, headers: str) -> 'HTTPHeaders': + def parse(cls, headers: str) -> "HTTPHeaders": """Returns a dictionary from HTTP header text. >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n") @@ -236,7 +249,7 @@ class HTTPHeaders(collections.MutableMapping): def __iter__(self) -> Iterator[typing.Any]: return iter(self._dict) - def copy(self) -> 'HTTPHeaders': + def copy(self) -> "HTTPHeaders": # defined in dict but not in MutableMapping. return HTTPHeaders(self) @@ -346,16 +359,26 @@ class HTTPServerRequest(object): .. versionchanged:: 4.0 Moved from ``tornado.httpserver.HTTPRequest``. """ + path = None # type: str query = None # type: str # HACK: Used for stream_request_body _body_future = None # type: Future[None] - def __init__(self, method: str=None, uri: str=None, version: str="HTTP/1.0", - headers: HTTPHeaders=None, body: bytes=None, host: str=None, - files: Dict[str, List['HTTPFile']]=None, connection: 'HTTPConnection'=None, - start_line: 'RequestStartLine'=None, server_connection: object=None) -> None: + def __init__( + self, + method: str = None, + uri: str = None, + version: str = "HTTP/1.0", + headers: HTTPHeaders = None, + body: bytes = None, + host: str = None, + files: Dict[str, List["HTTPFile"]] = None, + connection: "HTTPConnection" = None, + start_line: "RequestStartLine" = None, + server_connection: object = None, + ) -> None: if start_line is not None: method, uri, version = start_line self.method = method @@ -365,9 +388,9 @@ class HTTPServerRequest(object): self.body = body or b"" # set remote IP and protocol - context = getattr(connection, 'context', None) - self.remote_ip = getattr(context, 'remote_ip', None) - self.protocol = getattr(context, 'protocol', "http") + context = getattr(connection, "context", None) + self.remote_ip = getattr(context, "remote_ip", None) + self.protocol = getattr(context, "protocol", "http") self.host = host or self.headers.get("Host") or "127.0.0.1" self.host_name = split_host_and_port(self.host.lower())[0] @@ -378,7 +401,7 @@ class HTTPServerRequest(object): self._finish_time = None if uri is not None: - self.path, sep, self.query = uri.partition('?') + self.path, sep, self.query = uri.partition("?") self.arguments = parse_qs_bytes(self.query, keep_blank_values=True) self.query_arguments = copy.deepcopy(self.arguments) self.body_arguments = {} # type: Dict[str, List[bytes]] @@ -415,7 +438,9 @@ class HTTPServerRequest(object): else: return self._finish_time - self._start_time - def get_ssl_certificate(self, binary_form: bool=False) -> Union[None, Dict, bytes]: + def get_ssl_certificate( + self, binary_form: bool = False + ) -> Union[None, Dict, bytes]: """Returns the client's SSL certificate, if any. To use client certificates, the HTTPServer's @@ -439,15 +464,19 @@ class HTTPServerRequest(object): return None # TODO: add a method to HTTPConnection for this so it can work with HTTP/2 return self.connection.stream.socket.getpeercert( # type: ignore - binary_form=binary_form) + binary_form=binary_form + ) except SSLError: return None def _parse_body(self) -> None: parse_body_arguments( - self.headers.get("Content-Type", ""), self.body, - self.body_arguments, self.files, - self.headers) + self.headers.get("Content-Type", ""), + self.body, + self.body_arguments, + self.files, + self.headers, + ) for k, v in self.body_arguments.items(): self.arguments.setdefault(k, []).extend(v) @@ -464,6 +493,7 @@ class HTTPInputError(Exception): .. versionadded:: 4.0 """ + pass @@ -472,6 +502,7 @@ class HTTPOutputError(Exception): .. versionadded:: 4.0 """ + pass @@ -480,8 +511,10 @@ class HTTPServerConnectionDelegate(object): .. versionadded:: 4.0 """ - def start_request(self, server_conn: object, - request_conn: 'HTTPConnection')-> 'HTTPMessageDelegate': + + def start_request( + self, server_conn: object, request_conn: "HTTPConnection" + ) -> "HTTPMessageDelegate": """This method is called by the server when a new request has started. :arg server_conn: is an opaque object representing the long-lived @@ -507,9 +540,13 @@ class HTTPMessageDelegate(object): .. versionadded:: 4.0 """ + # TODO: genericize this class to avoid exposing the Union. - def headers_received(self, start_line: Union['RequestStartLine', 'ResponseStartLine'], - headers: HTTPHeaders) -> Optional[Awaitable[None]]: + def headers_received( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + ) -> Optional[Awaitable[None]]: """Called when the HTTP headers have been received and parsed. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine` @@ -549,8 +586,13 @@ class HTTPConnection(object): .. versionadded:: 4.0 """ - def write_headers(self, start_line: Union['RequestStartLine', 'ResponseStartLine'], - headers: HTTPHeaders, chunk: bytes=None) -> 'Future[None]': + + def write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: bytes = None, + ) -> "Future[None]": """Write an HTTP header block. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`. @@ -569,7 +611,7 @@ class HTTPConnection(object): """ raise NotImplementedError() - def write(self, chunk: bytes) -> 'Future[None]': + def write(self, chunk: bytes) -> "Future[None]": """Writes a chunk of body data. Returns a future for flow control. @@ -586,8 +628,12 @@ class HTTPConnection(object): raise NotImplementedError() -def url_concat(url: str, args: Union[None, Dict[str, str], List[Tuple[str, str]], - Tuple[Tuple[str, str], ...]]) -> str: +def url_concat( + url: str, + args: Union[ + None, Dict[str, str], List[Tuple[str, str]], Tuple[Tuple[str, str], ...] + ], +) -> str: """Concatenate url and arguments regardless of whether url has existing query parameters. @@ -612,16 +658,20 @@ def url_concat(url: str, args: Union[None, Dict[str, str], List[Tuple[str, str]] parsed_query.extend(args) else: err = "'args' parameter should be dict, list or tuple. Not {0}".format( - type(args)) + type(args) + ) raise TypeError(err) final_query = urlencode(parsed_query) - url = urlunparse(( - parsed_url[0], - parsed_url[1], - parsed_url[2], - parsed_url[3], - final_query, - parsed_url[5])) + url = urlunparse( + ( + parsed_url[0], + parsed_url[1], + parsed_url[2], + parsed_url[3], + final_query, + parsed_url[5], + ) + ) return url @@ -635,10 +685,13 @@ class HTTPFile(ObjectDict): * ``body`` * ``content_type`` """ + pass -def _parse_request_range(range_header: str) -> Optional[Tuple[Optional[int], Optional[int]]]: +def _parse_request_range( + range_header: str +) -> Optional[Tuple[Optional[int], Optional[int]]]: """Parses a Range header. Returns either ``None`` or tuple ``(start, end)``. @@ -709,8 +762,13 @@ def _int_or_none(val: str) -> Optional[int]: return int(val) -def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, List[bytes]], - files: Dict[str, List[HTTPFile]], headers: HTTPHeaders=None) -> None: +def parse_body_arguments( + content_type: str, + body: bytes, + arguments: Dict[str, List[bytes]], + files: Dict[str, List[HTTPFile]], + headers: HTTPHeaders = None, +) -> None: """Parses a form request body. Supports ``application/x-www-form-urlencoded`` and @@ -719,15 +777,14 @@ def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, Li and ``files`` parameters are dictionaries that will be updated with the parsed contents. """ - if headers and 'Content-Encoding' in headers: - gen_log.warning("Unsupported Content-Encoding: %s", - headers['Content-Encoding']) + if headers and "Content-Encoding" in headers: + gen_log.warning("Unsupported Content-Encoding: %s", headers["Content-Encoding"]) return if content_type.startswith("application/x-www-form-urlencoded"): try: uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True) except Exception as e: - gen_log.warning('Invalid x-www-form-urlencoded body: %s', e) + gen_log.warning("Invalid x-www-form-urlencoded body: %s", e) uri_arguments = {} for name, values in uri_arguments.items(): if values: @@ -746,8 +803,12 @@ def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, Li gen_log.warning("Invalid multipart/form-data: %s", e) -def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str, List[bytes]], - files: Dict[str, List[HTTPFile]]) -> None: +def parse_multipart_form_data( + boundary: bytes, + data: bytes, + arguments: Dict[str, List[bytes]], + files: Dict[str, List[HTTPFile]], +) -> None: """Parses a ``multipart/form-data`` body. The ``boundary`` and ``data`` parameters are both byte strings. @@ -784,21 +845,25 @@ def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str, if disposition != "form-data" or not part.endswith(b"\r\n"): gen_log.warning("Invalid multipart/form-data") continue - value = part[eoh + 4:-2] + value = part[eoh + 4 : -2] if not disp_params.get("name"): gen_log.warning("multipart/form-data value missing name") continue name = disp_params["name"] if disp_params.get("filename"): ctype = headers.get("Content-Type", "application/unknown") - files.setdefault(name, []).append(HTTPFile( - filename=disp_params["filename"], body=value, - content_type=ctype)) + files.setdefault(name, []).append( + HTTPFile( + filename=disp_params["filename"], body=value, content_type=ctype + ) + ) else: arguments.setdefault(name, []).append(value) -def format_timestamp(ts: Union[int, float, tuple, time.struct_time, datetime.datetime]) -> str: +def format_timestamp( + ts: Union[int, float, tuple, time.struct_time, datetime.datetime] +) -> str: """Formats a timestamp in the format used by HTTP. The argument may be a numeric timestamp as returned by `time.time`, @@ -820,7 +885,8 @@ def format_timestamp(ts: Union[int, float, tuple, time.struct_time, datetime.dat RequestStartLine = collections.namedtuple( - 'RequestStartLine', ['method', 'path', 'version']) + "RequestStartLine", ["method", "path", "version"] +) def parse_request_start_line(line: str) -> RequestStartLine: @@ -839,12 +905,14 @@ def parse_request_start_line(line: str) -> RequestStartLine: raise HTTPInputError("Malformed HTTP request line") if not re.match(r"^HTTP/1\.[0-9]$", version): raise HTTPInputError( - "Malformed HTTP version in HTTP Request-Line: %r" % version) + "Malformed HTTP version in HTTP Request-Line: %r" % version + ) return RequestStartLine(method, path, version) ResponseStartLine = collections.namedtuple( - 'ResponseStartLine', ['version', 'code', 'reason']) + "ResponseStartLine", ["version", "code", "reason"] +) def parse_response_start_line(line: str) -> ResponseStartLine: @@ -859,8 +927,8 @@ def parse_response_start_line(line: str) -> ResponseStartLine: match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line) if not match: raise HTTPInputError("Error parsing response start line") - return ResponseStartLine(match.group(1), int(match.group(2)), - match.group(3)) + return ResponseStartLine(match.group(1), int(match.group(2)), match.group(3)) + # _parseparam and _parse_header are copied and modified from python2.7's cgi.py # The original 2.7 version of this code did not correctly support some @@ -871,11 +939,11 @@ def parse_response_start_line(line: str) -> ResponseStartLine: def _parseparam(s: str) -> Generator[str, None, None]: - while s[:1] == ';': + while s[:1] == ";": s = s[1:] - end = s.find(';') + end = s.find(";") while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2: - end = s.find(';', end + 1) + end = s.find(";", end + 1) if end < 0: end = len(s) f = s[:end] @@ -897,15 +965,15 @@ def _parse_header(line: str) -> Tuple[str, Dict[str, str]]: >>> d['foo'] 'b\\a"r' """ - parts = _parseparam(';' + line) + parts = _parseparam(";" + line) key = next(parts) # decode_params treats first argument special, but we already stripped key - params = [('Dummy', 'value')] + params = [("Dummy", "value")] for p in parts: - i = p.find('=') + i = p.find("=") if i >= 0: name = p[:i].strip().lower() - value = p[i + 1:].strip() + value = p[i + 1 :].strip() params.append((name, native_str(value))) decoded_params = email.utils.decode_params(params) decoded_params.pop(0) # get rid of the dummy again @@ -934,11 +1002,13 @@ def _encode_header(key: str, pdict: Dict[str, str]) -> str: out.append(k) else: # TODO: quote if necessary. - out.append('%s=%s' % (k, v)) - return '; '.join(out) + out.append("%s=%s" % (k, v)) + return "; ".join(out) -def encode_username_password(username: Union[str, bytes], password: Union[str, bytes]) -> bytes: +def encode_username_password( + username: Union[str, bytes], password: Union[str, bytes] +) -> bytes: """Encodes a username/password pair in the format used by HTTP auth. The return value is a byte string in the form ``username:password``. @@ -946,15 +1016,16 @@ def encode_username_password(username: Union[str, bytes], password: Union[str, b .. versionadded:: 5.1 """ if isinstance(username, unicode_type): - username = unicodedata.normalize('NFC', username) + username = unicodedata.normalize("NFC", username) if isinstance(password, unicode_type): - password = unicodedata.normalize('NFC', password) + password = unicodedata.normalize("NFC", password) return utf8(username) + b":" + utf8(password) def doctests(): # type: () -> unittest.TestSuite import doctest + return doctest.DocTestSuite() @@ -965,7 +1036,7 @@ def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]: .. versionadded:: 4.1 """ - match = re.match(r'^(.+):(\d+)$', netloc) + match = re.match(r"^(.+):(\d+)$", netloc) if match: host = match.group(1) port = int(match.group(2)) # type: Optional[int] @@ -987,7 +1058,7 @@ def qs_to_qsl(qs: Dict[str, List[AnyStr]]) -> Iterable[Tuple[str, AnyStr]]: _OctalPatt = re.compile(r"\\[0-3][0-7][0-7]") _QuotePatt = re.compile(r"[\\].") -_nulljoin = ''.join +_nulljoin = "".join def _unquote_cookie(s: str) -> str: @@ -1020,7 +1091,7 @@ def _unquote_cookie(s: str) -> str: while 0 <= i < n: o_match = _OctalPatt.search(s, i) q_match = _QuotePatt.search(s, i) - if not o_match and not q_match: # Neither matched + if not o_match and not q_match: # Neither matched res.append(s[i:]) break # else: @@ -1029,13 +1100,13 @@ def _unquote_cookie(s: str) -> str: j = o_match.start(0) if q_match: k = q_match.start(0) - if q_match and (not o_match or k < j): # QuotePatt matched + if q_match and (not o_match or k < j): # QuotePatt matched res.append(s[i:k]) res.append(s[k + 1]) i = k + 2 - else: # OctalPatt matched + else: # OctalPatt matched res.append(s[i:j]) - res.append(chr(int(s[j + 1:j + 4], 8))) + res.append(chr(int(s[j + 1 : j + 4], 8))) i = j + 4 return _nulljoin(res) @@ -1052,13 +1123,13 @@ def parse_cookie(cookie: str) -> Dict[str, str]: .. versionadded:: 4.4.2 """ cookiedict = {} - for chunk in cookie.split(str(';')): - if str('=') in chunk: - key, val = chunk.split(str('='), 1) + for chunk in cookie.split(str(";")): + if str("=") in chunk: + key, val = chunk.split(str("="), 1) else: # Assume an empty name per # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 - key, val = str(''), chunk + key, val = str(""), chunk key, val = key.strip(), val.strip() if key or val: # unquote using Python's algorithm. diff --git a/tornado/ioloop.py b/tornado/ioloop.py index 3078d26ac..84c04fcd2 100644 --- a/tornado/ioloop.py +++ b/tornado/ioloop.py @@ -41,12 +41,19 @@ import time import math import random -from tornado.concurrent import Future, is_future, chain_future, future_set_exc_info, future_add_done_callback # noqa: E501 +from tornado.concurrent import ( + Future, + is_future, + chain_future, + future_set_exc_info, + future_add_done_callback, +) # noqa: E501 from tornado.log import app_log from tornado.util import Configurable, TimeoutError, import_object import typing from typing import Union, Any, Type, Optional, Callable, TypeVar, Tuple, Awaitable + if typing.TYPE_CHECKING: from typing import Dict, List # noqa: F401 @@ -63,8 +70,8 @@ class _Selectable(Protocol): pass -_T = TypeVar('_T') -_S = TypeVar('_S', bound=_Selectable) +_T = TypeVar("_T") +_S = TypeVar("_S", bound=_Selectable) class IOLoop(Configurable): @@ -149,6 +156,7 @@ class IOLoop(Configurable): to redundantly specify the `asyncio` event loop. """ + # These constants were originally based on constants from the epoll module. NONE = 0 READ = 0x001 @@ -159,7 +167,9 @@ class IOLoop(Configurable): _ioloop_for_asyncio = dict() # type: Dict[asyncio.AbstractEventLoop, IOLoop] @classmethod - def configure(cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any) -> None: + def configure( + cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any + ) -> None: if asyncio is not None: from tornado.platform.asyncio import BaseAsyncIOLoop @@ -167,11 +177,12 @@ class IOLoop(Configurable): impl = import_object(impl) if isinstance(impl, type) and not issubclass(impl, BaseAsyncIOLoop): raise RuntimeError( - "only AsyncIOLoop is allowed when asyncio is available") + "only AsyncIOLoop is allowed when asyncio is available" + ) super(IOLoop, cls).configure(impl, **kwargs) @staticmethod - def instance() -> 'IOLoop': + def instance() -> "IOLoop": """Deprecated alias for `IOLoop.current()`. .. versionchanged:: 5.0 @@ -224,16 +235,16 @@ class IOLoop(Configurable): @typing.overload @staticmethod - def current() -> 'IOLoop': + def current() -> "IOLoop": pass @typing.overload # noqa: F811 @staticmethod - def current(instance: bool=True) -> Optional['IOLoop']: + def current(instance: bool = True) -> Optional["IOLoop"]: pass @staticmethod # noqa: F811 - def current(instance: bool=True) -> Optional['IOLoop']: + def current(instance: bool = True) -> Optional["IOLoop"]: """Returns the current thread's `IOLoop`. If an `IOLoop` is currently running or has been marked as @@ -264,6 +275,7 @@ class IOLoop(Configurable): except KeyError: if instance: from tornado.platform.asyncio import AsyncIOMainLoop + current = AsyncIOMainLoop(make_current=True) # type: Optional[IOLoop] else: current = None @@ -317,9 +329,10 @@ class IOLoop(Configurable): @classmethod def configurable_default(cls) -> Type[Configurable]: from tornado.platform.asyncio import AsyncIOLoop + return AsyncIOLoop - def initialize(self, make_current: bool=None) -> None: + def initialize(self, make_current: bool = None) -> None: if make_current is None: if IOLoop.current(instance=False) is None: self.make_current() @@ -330,7 +343,7 @@ class IOLoop(Configurable): raise RuntimeError("current IOLoop already exists") self.make_current() - def close(self, all_fds: bool=False) -> None: + def close(self, all_fds: bool = False) -> None: """Closes the `IOLoop`, freeing any resources used. If ``all_fds`` is true, all file descriptors registered on the @@ -358,15 +371,20 @@ class IOLoop(Configurable): raise NotImplementedError() @typing.overload - def add_handler(self, fd: int, handler: Callable[[int, int], None], events: int) -> None: + def add_handler( + self, fd: int, handler: Callable[[int, int], None], events: int + ) -> None: pass @typing.overload # noqa: F811 - def add_handler(self, fd: _S, handler: Callable[[_S, int], None], events: int) -> None: + def add_handler( + self, fd: _S, handler: Callable[[_S, int], None], events: int + ) -> None: pass - def add_handler(self, fd: Union[int, _Selectable], # noqa: F811 - handler: Callable[..., None], events: int) -> None: + def add_handler( # noqa: F811 + self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int + ) -> None: """Registers the given handler to receive the given events for ``fd``. The ``fd`` argument may either be an integer file descriptor or @@ -420,9 +438,13 @@ class IOLoop(Configurable): This method should be called from start() in subclasses. """ - if not any([logging.getLogger().handlers, - logging.getLogger('tornado').handlers, - logging.getLogger('tornado.application').handlers]): + if not any( + [ + logging.getLogger().handlers, + logging.getLogger("tornado").handlers, + logging.getLogger("tornado.application").handlers, + ] + ): logging.basicConfig() def stop(self) -> None: @@ -438,7 +460,7 @@ class IOLoop(Configurable): """ raise NotImplementedError() - def run_sync(self, func: Callable, timeout: float=None) -> Any: + def run_sync(self, func: Callable, timeout: float = None) -> Any: """Starts the `IOLoop`, runs the given function, and stops the loop. The function must return either an awaitable object or @@ -475,6 +497,7 @@ class IOLoop(Configurable): result = func() if result is not None: from tornado.gen import convert_yielded + result = convert_yielded(result) except Exception: fut = Future() # type: Future[Any] @@ -489,8 +512,10 @@ class IOLoop(Configurable): fut.set_result(result) assert future_cell[0] is not None self.add_future(future_cell[0], lambda future: self.stop()) + self.add_callback(run) if timeout is not None: + def timeout_callback() -> None: # If we can cancel the future, do so and wait on it. If not, # Just stop the loop and return with the task still pending. @@ -499,13 +524,14 @@ class IOLoop(Configurable): assert future_cell[0] is not None if not future_cell[0].cancel(): self.stop() + timeout_handle = self.add_timeout(self.time() + timeout, timeout_callback) self.start() if timeout is not None: self.remove_timeout(timeout_handle) assert future_cell[0] is not None if future_cell[0].cancelled() or not future_cell[0].done(): - raise TimeoutError('Operation timed out after %s seconds' % timeout) + raise TimeoutError("Operation timed out after %s seconds" % timeout) return future_cell[0].result() def time(self) -> float: @@ -523,9 +549,13 @@ class IOLoop(Configurable): """ return time.time() - def add_timeout(self, deadline: Union[float, datetime.timedelta], - callback: Callable[..., None], - *args: Any, **kwargs: Any) -> object: + def add_timeout( + self, + deadline: Union[float, datetime.timedelta], + callback: Callable[..., None], + *args: Any, + **kwargs: Any + ) -> object: """Runs the ``callback`` at the time ``deadline`` from the I/O loop. Returns an opaque handle that may be passed to @@ -554,13 +584,15 @@ class IOLoop(Configurable): if isinstance(deadline, numbers.Real): return self.call_at(deadline, callback, *args, **kwargs) elif isinstance(deadline, datetime.timedelta): - return self.call_at(self.time() + deadline.total_seconds(), - callback, *args, **kwargs) + return self.call_at( + self.time() + deadline.total_seconds(), callback, *args, **kwargs + ) else: raise TypeError("Unsupported deadline %r" % deadline) - def call_later(self, delay: float, callback: Callable[..., None], - *args: Any, **kwargs: Any) -> object: + def call_later( + self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> object: """Runs the ``callback`` after ``delay`` seconds have passed. Returns an opaque handle that may be passed to `remove_timeout` @@ -573,8 +605,9 @@ class IOLoop(Configurable): """ return self.call_at(self.time() + delay, callback, *args, **kwargs) - def call_at(self, when: float, callback: Callable[..., None], - *args: Any, **kwargs: Any) -> object: + def call_at( + self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> object: """Runs the ``callback`` at the absolute time designated by ``when``. ``when`` must be a number using the same reference point as @@ -599,8 +632,7 @@ class IOLoop(Configurable): """ raise NotImplementedError() - def add_callback(self, callback: Callable, - *args: Any, **kwargs: Any) -> None: + def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: """Calls the given callback on the next I/O loop iteration. It is safe to call this method from any thread at any time, @@ -615,8 +647,9 @@ class IOLoop(Configurable): """ raise NotImplementedError() - def add_callback_from_signal(self, callback: Callable, - *args: Any, **kwargs: Any) -> None: + def add_callback_from_signal( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> None: """Calls the given callback on the next I/O loop iteration. Safe for use from a Python signal handler; should not be used @@ -624,8 +657,7 @@ class IOLoop(Configurable): """ raise NotImplementedError() - def spawn_callback(self, callback: Callable, - *args: Any, **kwargs: Any) -> None: + def spawn_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: """Calls the given callback on the next IOLoop iteration. As of Tornado 6.0, this method is equivalent to `add_callback`. @@ -634,8 +666,11 @@ class IOLoop(Configurable): """ self.add_callback(callback, *args, **kwargs) - def add_future(self, future: Union['Future[_T]', 'concurrent.futures.Future[_T]'], - callback: Callable[['Future[_T]'], None]) -> None: + def add_future( + self, + future: Union["Future[_T]", "concurrent.futures.Future[_T]"], + callback: Callable[["Future[_T]"], None], + ) -> None: """Schedules a callback on the ``IOLoop`` when the given `.Future` is finished. @@ -648,10 +683,15 @@ class IOLoop(Configurable): """ assert is_future(future) future_add_done_callback( - future, lambda future: self.add_callback(callback, future)) - - def run_in_executor(self, executor: Optional[concurrent.futures.Executor], - func: Callable[..., _T], *args: Any) -> Awaitable[_T]: + future, lambda future: self.add_callback(callback, future) + ) + + def run_in_executor( + self, + executor: Optional[concurrent.futures.Executor], + func: Callable[..., _T], + *args: Any + ) -> Awaitable[_T]: """Runs a function in a ``concurrent.futures.Executor``. If ``executor`` is ``None``, the IO loop's default executor will be used. @@ -660,10 +700,12 @@ class IOLoop(Configurable): .. versionadded:: 5.0 """ if executor is None: - if not hasattr(self, '_executor'): + if not hasattr(self, "_executor"): from tornado.process import cpu_count + self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=(cpu_count() * 5)) # type: concurrent.futures.Executor + max_workers=(cpu_count() * 5) + ) # type: concurrent.futures.Executor executor = self._executor c_future = executor.submit(func, *args) # Concurrent Futures are not usable with await. Wrap this in a @@ -688,6 +730,7 @@ class IOLoop(Configurable): ret = callback() if ret is not None: from tornado import gen + # Functions that return Futures typically swallow all # exceptions and store them in the Future. If a Future # makes it out to the IOLoop, ensure its exception (if any) @@ -708,7 +751,9 @@ class IOLoop(Configurable): """Avoid unhandled-exception warnings from spawned coroutines.""" future.result() - def split_fd(self, fd: Union[int, _Selectable]) -> Tuple[int, Union[int, _Selectable]]: + def split_fd( + self, fd: Union[int, _Selectable] + ) -> Tuple[int, Union[int, _Selectable]]: """Returns an (fd, obj) pair from an ``fd`` parameter. We accept both raw file descriptors and file-like objects as @@ -753,24 +798,28 @@ class _Timeout(object): """An IOLoop timeout, a UNIX timestamp and a callback""" # Reduce memory overhead when there are lots of pending callbacks - __slots__ = ['deadline', 'callback', 'tdeadline'] + __slots__ = ["deadline", "callback", "tdeadline"] - def __init__(self, deadline: float, callback: Callable[[], None], - io_loop: IOLoop) -> None: + def __init__( + self, deadline: float, callback: Callable[[], None], io_loop: IOLoop + ) -> None: if not isinstance(deadline, numbers.Real): raise TypeError("Unsupported deadline %r" % deadline) self.deadline = deadline self.callback = callback - self.tdeadline = (deadline, next(io_loop._timeout_counter)) # type: Tuple[float, int] + self.tdeadline = ( + deadline, + next(io_loop._timeout_counter), + ) # type: Tuple[float, int] # Comparison methods to sort by deadline, with object id as a tiebreaker # to guarantee a consistent ordering. The heapq module uses __le__ # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons # use __lt__). - def __lt__(self, other: '_Timeout') -> bool: + def __lt__(self, other: "_Timeout") -> bool: return self.tdeadline < other.tdeadline - def __le__(self, other: '_Timeout') -> bool: + def __le__(self, other: "_Timeout") -> bool: return self.tdeadline <= other.tdeadline @@ -800,8 +849,10 @@ class PeriodicCallback(object): .. versionchanged:: 5.1 The ``jitter`` argument is added. """ - def __init__(self, callback: Callable[[], None], - callback_time: float, jitter: float=0) -> None: + + def __init__( + self, callback: Callable[[], None], callback_time: float, jitter: float = 0 + ) -> None: self.callback = callback if callback_time <= 0: raise ValueError("Periodic callback must have a positive callback_time") @@ -859,8 +910,9 @@ class PeriodicCallback(object): # to the start of the next. If one call takes too long, # skip cycles to get back to a multiple of the original # schedule. - self._next_timeout += (math.floor((current_time - self._next_timeout) / - callback_time_sec) + 1) * callback_time_sec + self._next_timeout += ( + math.floor((current_time - self._next_timeout) / callback_time_sec) + 1 + ) * callback_time_sec else: # If the clock moved backwards, ensure we advance the next # timeout instead of recomputing the same value again. diff --git a/tornado/iostream.py b/tornado/iostream.py index aa179dc9a..6d5cae972 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -40,12 +40,24 @@ from tornado.netutil import ssl_wrap_socket, _client_ssl_defaults, _server_ssl_d from tornado.util import errno_from_exception import typing -from typing import Union, Optional, Awaitable, Callable, Type, Pattern, Any, Dict, TypeVar, Tuple +from typing import ( + Union, + Optional, + Awaitable, + Callable, + Type, + Pattern, + Any, + Dict, + TypeVar, + Tuple, +) from types import TracebackType + if typing.TYPE_CHECKING: from typing import Deque, List # noqa: F401 -_IOStreamType = TypeVar('_IOStreamType', bound='IOStream') +_IOStreamType = TypeVar("_IOStreamType", bound="IOStream") try: from tornado.platform.posix import _set_nonblocking @@ -62,13 +74,16 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # These errnos indicate that a connection has been abruptly terminated. # They should be caught and handled less noisily than other errors. -_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, - errno.ETIMEDOUT) +_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT) if hasattr(errno, "WSAECONNRESET"): - _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore # noqa: E501 + _ERRNO_CONNRESET += ( # type: ignore + errno.WSAECONNRESET, # type: ignore + errno.WSAECONNABORTED, # type: ignore + errno.WSAETIMEDOUT, # type: ignore + ) -if sys.platform == 'darwin': +if sys.platform == "darwin": # OSX appears to have a race condition that causes send(2) to return # EPROTOTYPE if called while a socket is being torn down: # http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/ @@ -82,7 +97,7 @@ _ERRNO_INPROGRESS = (errno.EINPROGRESS,) if hasattr(errno, "WSAEINPROGRESS"): _ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore -_WINDOWS = sys.platform.startswith('win') +_WINDOWS = sys.platform.startswith("win") class StreamClosedError(IOError): @@ -98,8 +113,9 @@ class StreamClosedError(IOError): .. versionchanged:: 4.3 Added the ``real_error`` attribute. """ - def __init__(self, real_error: BaseException=None) -> None: - super(StreamClosedError, self).__init__('Stream is closed') + + def __init__(self, real_error: BaseException = None) -> None: + super(StreamClosedError, self).__init__("Stream is closed") self.real_error = real_error @@ -109,6 +125,7 @@ class UnsatisfiableReadError(Exception): Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes`` argument. """ + pass @@ -125,8 +142,9 @@ class _StreamBuffer(object): def __init__(self) -> None: # A sequence of (False, bytearray) and (True, memoryview) objects - self._buffers = collections.deque() \ - # type: Deque[Tuple[bool, Union[bytearray, memoryview]]] + self._buffers = ( + collections.deque() + ) # type: Deque[Tuple[bool, Union[bytearray, memoryview]]] # Position in the first buffer self._first_pos = 0 self._size = 0 @@ -169,13 +187,13 @@ class _StreamBuffer(object): try: is_memview, b = self._buffers[0] except IndexError: - return memoryview(b'') + return memoryview(b"") pos = self._first_pos if is_memview: - return typing.cast(memoryview, b[pos:pos + size]) + return typing.cast(memoryview, b[pos : pos + size]) else: - return memoryview(b)[pos:pos + size] + return memoryview(b)[pos : pos + size] def advance(self, size: int) -> None: """ @@ -225,8 +243,13 @@ class BaseIOStream(object): `read_from_fd`, and optionally `get_fd_error`. """ - def __init__(self, max_buffer_size: int=None, - read_chunk_size: int=None, max_write_buffer_size: int=None) -> None: + + def __init__( + self, + max_buffer_size: int = None, + read_chunk_size: int = None, + max_write_buffer_size: int = None, + ) -> None: """`BaseIOStream` constructor. :arg max_buffer_size: Maximum amount of incoming data to buffer; @@ -247,8 +270,7 @@ class BaseIOStream(object): self.max_buffer_size = max_buffer_size or 104857600 # A chunk size that is too close to max_buffer_size can cause # spurious failures. - self.read_chunk_size = min(read_chunk_size or 65536, - self.max_buffer_size // 2) + self.read_chunk_size = min(read_chunk_size or 65536, self.max_buffer_size // 2) self.max_write_buffer_size = max_write_buffer_size self.error = None # type: Optional[BaseException] self._read_buffer = bytearray() @@ -266,7 +288,9 @@ class BaseIOStream(object): self._read_partial = False self._read_until_close = False self._read_future = None # type: Optional[Future] - self._write_futures = collections.deque() # type: Deque[Tuple[int, Future[None]]] + self._write_futures = ( + collections.deque() + ) # type: Deque[Tuple[int, Future[None]]] self._close_callback = None # type: Optional[Callable[[], None]] self._connect_future = None # type: Optional[Future[IOStream]] # _ssl_connect_future should be defined in SSLIOStream @@ -322,7 +346,7 @@ class BaseIOStream(object): """ return None - def read_until_regex(self, regex: bytes, max_bytes: int=None) -> Awaitable[bytes]: + def read_until_regex(self, regex: bytes, max_bytes: int = None) -> Awaitable[bytes]: """Asynchronously read until we have matched the given regex. The result includes the data that matches the regex and anything @@ -359,7 +383,7 @@ class BaseIOStream(object): raise return future - def read_until(self, delimiter: bytes, max_bytes: int=None) -> Awaitable[bytes]: + def read_until(self, delimiter: bytes, max_bytes: int = None) -> Awaitable[bytes]: """Asynchronously read until we have found the given delimiter. The result includes all the data read including the delimiter. @@ -392,7 +416,7 @@ class BaseIOStream(object): raise return future - def read_bytes(self, num_bytes: int, partial: bool=False) -> Awaitable[bytes]: + def read_bytes(self, num_bytes: int, partial: bool = False) -> Awaitable[bytes]: """Asynchronously read a number of bytes. If ``partial`` is true, data is returned as soon as we have @@ -420,7 +444,7 @@ class BaseIOStream(object): raise return future - def read_into(self, buf: bytearray, partial: bool=False) -> Awaitable[int]: + def read_into(self, buf: bytearray, partial: bool = False) -> Awaitable[int]: """Asynchronously read a number of bytes. ``buf`` must be a writable buffer into which data will be read. @@ -444,11 +468,13 @@ class BaseIOStream(object): n = len(buf) if available_bytes >= n: end = self._read_buffer_pos + n - buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end] + buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos : end] del self._read_buffer[:end] self._after_user_read_buffer = self._read_buffer elif available_bytes > 0: - buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:] + buf[:available_bytes] = memoryview(self._read_buffer)[ + self._read_buffer_pos : + ] # Set up the supplied buffer as our temporary read buffer. # The original (if it had any data remaining) has been @@ -497,7 +523,7 @@ class BaseIOStream(object): raise return future - def write(self, data: Union[bytes, memoryview]) -> 'Future[None]': + def write(self, data: Union[bytes, memoryview]) -> "Future[None]": """Asynchronously write the given data to this stream. This method returns a `.Future` that resolves (with a result @@ -519,8 +545,10 @@ class BaseIOStream(object): """ self._check_closed() if data: - if (self.max_write_buffer_size is not None and - len(self._write_buffer) + len(data) > self.max_write_buffer_size): + if ( + self.max_write_buffer_size is not None + and len(self._write_buffer) + len(data) > self.max_write_buffer_size + ): raise StreamBufferFullError("Reached maximum write buffer size") self._write_buffer.append(data) self._total_write_index += len(data) @@ -549,10 +577,19 @@ class BaseIOStream(object): self._close_callback = callback self._maybe_add_error_listener() - def close(self, exc_info: Union[None, bool, BaseException, - Tuple[Optional[Type[BaseException]], - Optional[BaseException], - Optional[TracebackType]]]=False) -> None: + def close( + self, + exc_info: Union[ + None, + bool, + BaseException, + Tuple[ + Optional[Type[BaseException]], + Optional[BaseException], + Optional[TracebackType], + ], + ] = False, + ) -> None: """Close this stream. If ``exc_info`` is true, set the ``error`` attribute to the current @@ -680,16 +717,16 @@ class BaseIOStream(object): # yet anyway, so we don't need to listen in this case. state |= self.io_loop.READ if state != self._state: - assert self._state is not None, \ - "shouldn't happen: _handle_events without self._state" + assert ( + self._state is not None + ), "shouldn't happen: _handle_events without self._state" self._state = state self.io_loop.update_handler(self.fileno(), self._state) except UnsatisfiableReadError as e: gen_log.info("Unsatisfiable read, closing connection: %s" % e) self.close(exc_info=e) except Exception as e: - gen_log.error("Uncaught exception, closing connection.", - exc_info=True) + gen_log.error("Uncaught exception, closing connection.", exc_info=True) self.close(exc_info=e) raise @@ -720,8 +757,7 @@ class BaseIOStream(object): # this loop. # If we've reached target_bytes, we know we're done. - if (target_bytes is not None and - self._read_buffer_size >= target_bytes): + if target_bytes is not None and self._read_buffer_size >= target_bytes: break # Otherwise, we need to call the more expensive find_read_pos. @@ -801,8 +837,9 @@ class BaseIOStream(object): while True: try: if self._user_read_buffer: - buf = memoryview(self._read_buffer)[self._read_buffer_size:] \ - # type: Union[memoryview, bytearray] + buf = memoryview(self._read_buffer)[ + self._read_buffer_size : + ] # type: Union[memoryview, bytearray] else: buf = bytearray(self.read_chunk_size) bytes_read = self.read_from_fd(buf) @@ -854,9 +891,10 @@ class BaseIOStream(object): Returns a position in the buffer if the current read can be satisfied, or None if it cannot. """ - if (self._read_bytes is not None and - (self._read_buffer_size >= self._read_bytes or - (self._read_partial and self._read_buffer_size > 0))): + if self._read_bytes is not None and ( + self._read_buffer_size >= self._read_bytes + or (self._read_partial and self._read_buffer_size > 0) + ): num_bytes = min(self._read_bytes, self._read_buffer_size) return num_bytes elif self._read_delimiter is not None: @@ -869,20 +907,18 @@ class BaseIOStream(object): # since large merges are relatively expensive and get undone in # _consume(). if self._read_buffer: - loc = self._read_buffer.find(self._read_delimiter, - self._read_buffer_pos) + loc = self._read_buffer.find( + self._read_delimiter, self._read_buffer_pos + ) if loc != -1: loc -= self._read_buffer_pos delimiter_len = len(self._read_delimiter) - self._check_max_bytes(self._read_delimiter, - loc + delimiter_len) + self._check_max_bytes(self._read_delimiter, loc + delimiter_len) return loc + delimiter_len - self._check_max_bytes(self._read_delimiter, - self._read_buffer_size) + self._check_max_bytes(self._read_delimiter, self._read_buffer_size) elif self._read_regex is not None: if self._read_buffer: - m = self._read_regex.search(self._read_buffer, - self._read_buffer_pos) + m = self._read_regex.search(self._read_buffer, self._read_buffer_pos) if m is not None: loc = m.end() - self._read_buffer_pos self._check_max_bytes(self._read_regex, loc) @@ -891,11 +927,11 @@ class BaseIOStream(object): return None def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None: - if (self._read_max_bytes is not None and - size > self._read_max_bytes): + if self._read_max_bytes is not None and size > self._read_max_bytes: raise UnsatisfiableReadError( - "delimiter %r not found within %d bytes" % ( - delimiter, self._read_max_bytes)) + "delimiter %r not found within %d bytes" + % (delimiter, self._read_max_bytes) + ) def _handle_write(self) -> None: while True: @@ -925,8 +961,7 @@ class BaseIOStream(object): # Broken pipe errors are usually caused by connection # reset, and its better to not log EPIPE errors to # minimize log spam - gen_log.warning("Write error on %s: %s", - self.fileno(), e) + gen_log.warning("Write error on %s: %s", self.fileno(), e) self.close(exc_info=e) return @@ -943,16 +978,18 @@ class BaseIOStream(object): return b"" assert loc <= self._read_buffer_size # Slice the bytearray buffer into bytes, without intermediate copying - b = (memoryview(self._read_buffer) - [self._read_buffer_pos:self._read_buffer_pos + loc] - ).tobytes() + b = ( + memoryview(self._read_buffer)[ + self._read_buffer_pos : self._read_buffer_pos + loc + ] + ).tobytes() self._read_buffer_pos += loc self._read_buffer_size -= loc # Amortized O(1) shrink # (this heuristic is implemented natively in Python 3.4+ # but is replicated here for Python 2) if self._read_buffer_pos > self._read_buffer_size: - del self._read_buffer[:self._read_buffer_pos] + del self._read_buffer[: self._read_buffer_pos] self._read_buffer_pos = 0 return b @@ -968,9 +1005,11 @@ class BaseIOStream(object): # immediately anyway. Instead, we insert checks at various times to # see if the connection is idle and add the read listener then. if self._state is None or self._state == ioloop.IOLoop.ERROR: - if (not self.closed() and - self._read_buffer_size == 0 and - self._close_callback is not None): + if ( + not self.closed() + and self._read_buffer_size == 0 + and self._close_callback is not None + ): self._add_io_state(ioloop.IOLoop.READ) def _add_io_state(self, state: int) -> None: @@ -997,8 +1036,7 @@ class BaseIOStream(object): return if self._state is None: self._state = ioloop.IOLoop.ERROR | state - self.io_loop.add_handler( - self.fileno(), self._handle_events, self._state) + self.io_loop.add_handler(self.fileno(), self._handle_events, self._state) elif not self._state & state: self._state = self._state | state self.io_loop.update_handler(self.fileno(), self._state) @@ -1008,8 +1046,10 @@ class BaseIOStream(object): May be overridden in subclasses. """ - return (isinstance(exc, (socket.error, IOError)) and - errno_from_exception(exc) in _ERRNO_CONNRESET) + return ( + isinstance(exc, (socket.error, IOError)) + and errno_from_exception(exc) in _ERRNO_CONNRESET + ) class IOStream(BaseIOStream): @@ -1059,6 +1099,7 @@ class IOStream(BaseIOStream): :hide: """ + def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None: self.socket = socket self.socket.setblocking(False) @@ -1072,8 +1113,7 @@ class IOStream(BaseIOStream): self.socket = None # type: ignore def get_fd_error(self) -> Optional[Exception]: - errno = self.socket.getsockopt(socket.SOL_SOCKET, - socket.SO_ERROR) + errno = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) return socket.error(errno, os.strerror(errno)) def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: @@ -1095,8 +1135,9 @@ class IOStream(BaseIOStream): # See https://github.com/tornadoweb/tornado/pull/2008 del data - def connect(self: _IOStreamType, address: tuple, - server_hostname: str=None) -> 'Future[_IOStreamType]': + def connect( + self: _IOStreamType, address: tuple, server_hostname: str = None + ) -> "Future[_IOStreamType]": """Connects the socket to a remote address without blocking. May only be called if the socket passed to the constructor was @@ -1143,7 +1184,7 @@ class IOStream(BaseIOStream): """ self._connecting = True future = Future() # type: Future[_IOStreamType] - self._connect_future = typing.cast('Future[IOStream]', future) + self._connect_future = typing.cast("Future[IOStream]", future) try: self.socket.connect(address) except socket.error as e: @@ -1154,19 +1195,25 @@ class IOStream(BaseIOStream): # returned immediately when attempting to connect to # localhost, so handle them the same way as an error # reported later in _handle_connect. - if (errno_from_exception(e) not in _ERRNO_INPROGRESS and - errno_from_exception(e) not in _ERRNO_WOULDBLOCK): + if ( + errno_from_exception(e) not in _ERRNO_INPROGRESS + and errno_from_exception(e) not in _ERRNO_WOULDBLOCK + ): if future is None: - gen_log.warning("Connect error on fd %s: %s", - self.socket.fileno(), e) + gen_log.warning( + "Connect error on fd %s: %s", self.socket.fileno(), e + ) self.close(exc_info=e) return future self._add_io_state(self.io_loop.WRITE) return future - def start_tls(self, server_side: bool, - ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, - server_hostname: str=None) -> Awaitable['SSLIOStream']: + def start_tls( + self, + server_side: bool, + ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, + server_hostname: str = None, + ) -> Awaitable["SSLIOStream"]: """Convert this `IOStream` to an `SSLIOStream`. This enables protocols that begin in clear-text mode and @@ -1201,11 +1248,14 @@ class IOStream(BaseIOStream): ``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a suitably-configured `ssl.SSLContext` to disable. """ - if (self._read_future or - self._write_futures or - self._connect_future or - self._closed or - self._read_buffer or self._write_buffer): + if ( + self._read_future + or self._write_futures + or self._connect_future + or self._closed + or self._read_buffer + or self._write_buffer + ): raise ValueError("IOStream is not idle; cannot convert to SSL") if ssl_options is None: if server_side: @@ -1216,10 +1266,13 @@ class IOStream(BaseIOStream): socket = self.socket self.io_loop.remove_handler(socket) self.socket = None # type: ignore - socket = ssl_wrap_socket(socket, ssl_options, - server_hostname=server_hostname, - server_side=server_side, - do_handshake_on_connect=False) + socket = ssl_wrap_socket( + socket, + ssl_options, + server_hostname=server_hostname, + server_side=server_side, + do_handshake_on_connect=False, + ) orig_close_callback = self._close_callback self._close_callback = None @@ -1246,8 +1299,11 @@ class IOStream(BaseIOStream): # in that case a connection failure would be handled by the # error path in _handle_events instead of here. if self._connect_future is None: - gen_log.warning("Connect error on fd %s: %s", - self.socket.fileno(), errno.errorcode[err]) + gen_log.warning( + "Connect error on fd %s: %s", + self.socket.fileno(), + errno.errorcode[err], + ) self.close() return if self._connect_future is not None: @@ -1257,11 +1313,14 @@ class IOStream(BaseIOStream): self._connecting = False def set_nodelay(self, value: bool) -> None: - if (self.socket is not None and - self.socket.family in (socket.AF_INET, socket.AF_INET6)): + if self.socket is not None and self.socket.family in ( + socket.AF_INET, + socket.AF_INET6, + ): try: - self.socket.setsockopt(socket.IPPROTO_TCP, - socket.TCP_NODELAY, 1 if value else 0) + self.socket.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 if value else 0 + ) except socket.error as e: # Sometimes setsockopt will fail if the socket is closed # at the wrong time. This can happen with HTTPServer @@ -1281,6 +1340,7 @@ class SSLIOStream(IOStream): before constructing the `SSLIOStream`. Unconnected sockets will be wrapped when `IOStream.connect` is finished. """ + socket = None # type: ssl.SSLSocket def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -1288,7 +1348,7 @@ class SSLIOStream(IOStream): `ssl.SSLContext` object or a dictionary of keywords arguments for `ssl.wrap_socket` """ - self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults) + self._ssl_options = kwargs.pop("ssl_options", _client_ssl_defaults) super(SSLIOStream, self).__init__(*args, **kwargs) self._ssl_accepting = True self._handshake_reading = False @@ -1325,16 +1385,16 @@ class SSLIOStream(IOStream): elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: self._handshake_writing = True return - elif err.args[0] in (ssl.SSL_ERROR_EOF, - ssl.SSL_ERROR_ZERO_RETURN): + elif err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): return self.close(exc_info=err) elif err.args[0] == ssl.SSL_ERROR_SSL: try: peer = self.socket.getpeername() except Exception: - peer = '(not connected)' - gen_log.warning("SSL Error on %s %s: %s", - self.socket.fileno(), peer, err) + peer = "(not connected)" + gen_log.warning( + "SSL Error on %s %s: %s", self.socket.fileno(), peer, err + ) return self.close(exc_info=err) raise except socket.error as err: @@ -1342,8 +1402,7 @@ class SSLIOStream(IOStream): # to cause do_handshake to raise EBADF and ENOTCONN, so make # those errors quiet as well. # https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0 - if (self._is_connreset(err) or - err.args[0] in (errno.EBADF, errno.ENOTCONN)): + if self._is_connreset(err) or err.args[0] in (errno.EBADF, errno.ENOTCONN): return self.close(exc_info=err) raise except AttributeError as err: @@ -1373,7 +1432,7 @@ class SSLIOStream(IOStream): the hostname. """ if isinstance(self._ssl_options, dict): - verify_mode = self._ssl_options.get('cert_reqs', ssl.CERT_NONE) + verify_mode = self._ssl_options.get("cert_reqs", ssl.CERT_NONE) elif isinstance(self._ssl_options, ssl.SSLContext): verify_mode = self._ssl_options.verify_mode assert verify_mode in (ssl.CERT_NONE, ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL) @@ -1403,7 +1462,9 @@ class SSLIOStream(IOStream): return super(SSLIOStream, self)._handle_write() - def connect(self, address: Tuple, server_hostname: str=None) -> 'Future[SSLIOStream]': + def connect( + self, address: Tuple, server_hostname: str = None + ) -> "Future[SSLIOStream]": self._server_hostname = server_hostname # Ignore the result of connect(). If it fails, # wait_for_handshake will raise an error too. This is @@ -1439,12 +1500,15 @@ class SSLIOStream(IOStream): old_state = self._state assert old_state is not None self._state = None - self.socket = ssl_wrap_socket(self.socket, self._ssl_options, - server_hostname=self._server_hostname, - do_handshake_on_connect=False) + self.socket = ssl_wrap_socket( + self.socket, + self._ssl_options, + server_hostname=self._server_hostname, + do_handshake_on_connect=False, + ) self._add_io_state(old_state) - def wait_for_handshake(self) -> 'Future[SSLIOStream]': + def wait_for_handshake(self) -> "Future[SSLIOStream]": """Wait for the initial SSL handshake to complete. If a ``callback`` is given, it will be called with no @@ -1532,6 +1596,7 @@ class PipeIOStream(BaseIOStream): one-way, so a `PipeIOStream` can be used for reading or writing but not both. """ + def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None: self.fd = fd self._fio = io.FileIO(self.fd, "r+") @@ -1569,4 +1634,5 @@ class PipeIOStream(BaseIOStream): def doctests() -> Any: import doctest + return doctest.DocTestSuite() diff --git a/tornado/locale.py b/tornado/locale.py index ae714af8b..85b0b2f08 100644 --- a/tornado/locale.py +++ b/tornado/locale.py @@ -60,7 +60,7 @@ _use_gettext = False CONTEXT_SEPARATOR = "\x04" -def get(*locale_codes: str) -> 'Locale': +def get(*locale_codes: str) -> "Locale": """Returns the closest match for the given locale codes. We iterate over all given locale codes in order. If we have a tight @@ -88,7 +88,7 @@ def set_default_locale(code: str) -> None: _supported_locales = frozenset(list(_translations.keys()) + [_default_locale]) -def load_translations(directory: str, encoding: str=None) -> None: +def load_translations(directory: str, encoding: str = None) -> None: """Loads translations from CSV files in a directory. Translations are strings with optional Python-style named placeholders @@ -131,21 +131,24 @@ def load_translations(directory: str, encoding: str=None) -> None: continue locale, extension = path.split(".") if not re.match("[a-z]+(_[A-Z]+)?$", locale): - gen_log.error("Unrecognized locale %r (path: %s)", locale, - os.path.join(directory, path)) + gen_log.error( + "Unrecognized locale %r (path: %s)", + locale, + os.path.join(directory, path), + ) continue full_path = os.path.join(directory, path) if encoding is None: # Try to autodetect encoding based on the BOM. - with open(full_path, 'rb') as bf: + with open(full_path, "rb") as bf: data = bf.read(len(codecs.BOM_UTF16_LE)) if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): - encoding = 'utf-16' + encoding = "utf-16" else: # utf-8-sig is "utf-8 with optional BOM". It's discouraged # in most cases but is common with CSV files because Excel # cannot read utf-8 files without a BOM. - encoding = 'utf-8-sig' + encoding = "utf-8-sig" # python 3: csv.reader requires a file open in text mode. # Specify an encoding to avoid dependence on $LANG environment variable. f = open(full_path, "r", encoding=encoding) @@ -160,8 +163,12 @@ def load_translations(directory: str, encoding: str=None) -> None: else: plural = "unknown" if plural not in ("plural", "singular", "unknown"): - gen_log.error("Unrecognized plural indicator %r in %s line %d", - plural, path, i + 1) + gen_log.error( + "Unrecognized plural indicator %r in %s line %d", + plural, + path, + i + 1, + ) continue _translations[locale].setdefault(plural, {})[english] = translation f.close() @@ -191,19 +198,21 @@ def load_gettext_translations(directory: str, domain: str) -> None: msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo """ import gettext + global _translations global _supported_locales global _use_gettext _translations = {} for lang in os.listdir(directory): - if lang.startswith('.'): + if lang.startswith("."): continue # skip .svn, etc if os.path.isfile(os.path.join(directory, lang)): continue try: os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo")) - _translations[lang] = gettext.translation(domain, directory, - languages=[lang]) + _translations[lang] = gettext.translation( + domain, directory, languages=[lang] + ) except Exception as e: gen_log.error("Cannot load translation for '%s': %s", lang, str(e)) continue @@ -223,10 +232,11 @@ class Locale(object): After calling one of `load_translations` or `load_gettext_translations`, call `get` or `get_closest` to get a Locale object. """ + _cache = {} # type: Dict[str, Locale] @classmethod - def get_closest(cls, *locale_codes: str) -> 'Locale': + def get_closest(cls, *locale_codes: str) -> "Locale": """Returns the closest match for the given locale code.""" for code in locale_codes: if not code: @@ -244,7 +254,7 @@ class Locale(object): return cls.get(_default_locale) @classmethod - def get(cls, code: str) -> 'Locale': + def get(cls, code: str) -> "Locale": """Returns the Locale for the given locale code. If it is not supported, we raise an exception. @@ -273,14 +283,32 @@ class Locale(object): # Initialize strings for date formatting _ = self.translate self._months = [ - _("January"), _("February"), _("March"), _("April"), - _("May"), _("June"), _("July"), _("August"), - _("September"), _("October"), _("November"), _("December")] + _("January"), + _("February"), + _("March"), + _("April"), + _("May"), + _("June"), + _("July"), + _("August"), + _("September"), + _("October"), + _("November"), + _("December"), + ] self._weekdays = [ - _("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"), - _("Friday"), _("Saturday"), _("Sunday")] - - def translate(self, message: str, plural_message: str=None, count: int=None) -> str: + _("Monday"), + _("Tuesday"), + _("Wednesday"), + _("Thursday"), + _("Friday"), + _("Saturday"), + _("Sunday"), + ] + + def translate( + self, message: str, plural_message: str = None, count: int = None + ) -> str: """Returns the translation for the given message for this locale. If ``plural_message`` is given, you must also provide @@ -290,12 +318,19 @@ class Locale(object): """ raise NotImplementedError() - def pgettext(self, context: str, message: str, plural_message: str=None, - count: int=None) -> str: + def pgettext( + self, context: str, message: str, plural_message: str = None, count: int = None + ) -> str: raise NotImplementedError() - def format_date(self, date: Union[int, float, datetime.datetime], gmt_offset: int=0, - relative: bool=True, shorter: bool=False, full_format: bool=False) -> str: + def format_date( + self, + date: Union[int, float, datetime.datetime], + gmt_offset: int = 0, + relative: bool = True, + shorter: bool = False, + full_format: bool = False, + ) -> str: """Formats the given date (which should be GMT). By default, we return a relative time (e.g., "2 minutes ago"). You @@ -331,56 +366,66 @@ class Locale(object): if not full_format: if relative and days == 0: if seconds < 50: - return _("1 second ago", "%(seconds)d seconds ago", - seconds) % {"seconds": seconds} + return _("1 second ago", "%(seconds)d seconds ago", seconds) % { + "seconds": seconds + } if seconds < 50 * 60: minutes = round(seconds / 60.0) - return _("1 minute ago", "%(minutes)d minutes ago", - minutes) % {"minutes": minutes} + return _("1 minute ago", "%(minutes)d minutes ago", minutes) % { + "minutes": minutes + } hours = round(seconds / (60.0 * 60)) - return _("1 hour ago", "%(hours)d hours ago", - hours) % {"hours": hours} + return _("1 hour ago", "%(hours)d hours ago", hours) % {"hours": hours} if days == 0: format = _("%(time)s") - elif days == 1 and local_date.day == local_yesterday.day and \ - relative: - format = _("yesterday") if shorter else \ - _("yesterday at %(time)s") + elif days == 1 and local_date.day == local_yesterday.day and relative: + format = _("yesterday") if shorter else _("yesterday at %(time)s") elif days < 5: - format = _("%(weekday)s") if shorter else \ - _("%(weekday)s at %(time)s") + format = _("%(weekday)s") if shorter else _("%(weekday)s at %(time)s") elif days < 334: # 11mo, since confusing for same month last year - format = _("%(month_name)s %(day)s") if shorter else \ - _("%(month_name)s %(day)s at %(time)s") + format = ( + _("%(month_name)s %(day)s") + if shorter + else _("%(month_name)s %(day)s at %(time)s") + ) if format is None: - format = _("%(month_name)s %(day)s, %(year)s") if shorter else \ - _("%(month_name)s %(day)s, %(year)s at %(time)s") + format = ( + _("%(month_name)s %(day)s, %(year)s") + if shorter + else _("%(month_name)s %(day)s, %(year)s at %(time)s") + ) tfhour_clock = self.code not in ("en", "en_US", "zh_CN") if tfhour_clock: str_time = "%d:%02d" % (local_date.hour, local_date.minute) elif self.code == "zh_CN": str_time = "%s%d:%02d" % ( - (u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12], - local_date.hour % 12 or 12, local_date.minute) + (u"\u4e0a\u5348", u"\u4e0b\u5348")[local_date.hour >= 12], + local_date.hour % 12 or 12, + local_date.minute, + ) else: str_time = "%d:%02d %s" % ( - local_date.hour % 12 or 12, local_date.minute, - ("am", "pm")[local_date.hour >= 12]) + local_date.hour % 12 or 12, + local_date.minute, + ("am", "pm")[local_date.hour >= 12], + ) return format % { "month_name": self._months[local_date.month - 1], "weekday": self._weekdays[local_date.weekday()], "day": str(local_date.day), "year": str(local_date.year), - "time": str_time + "time": str_time, } - def format_day(self, date: datetime.datetime, gmt_offset: int=0, dow: bool=True) -> bool: + def format_day( + self, date: datetime.datetime, gmt_offset: int = 0, dow: bool = True + ) -> bool: """Formats the given date as a day of week. Example: "Monday, January 22". You can remove the day of week with @@ -411,7 +456,7 @@ class Locale(object): return "" if len(parts) == 1: return parts[0] - comma = u' \u0648 ' if self.code.startswith("fa") else u", " + comma = u" \u0648 " if self.code.startswith("fa") else u", " return _("%(commas)s and %(last)s") % { "commas": comma.join(parts[:-1]), "last": parts[len(parts) - 1], @@ -431,11 +476,14 @@ class Locale(object): class CSVLocale(Locale): """Locale implementation using tornado's CSV translation format.""" + def __init__(self, code: str, translations: Dict[str, Dict[str, str]]) -> None: self.translations = translations super(CSVLocale, self).__init__(code) - def translate(self, message: str, plural_message: str=None, count: int=None) -> str: + def translate( + self, message: str, plural_message: str = None, count: int = None + ) -> str: if plural_message is not None: assert count is not None if count != 1: @@ -447,15 +495,17 @@ class CSVLocale(Locale): message_dict = self.translations.get("unknown", {}) return message_dict.get(message, message) - def pgettext(self, context: str, message: str, plural_message: str=None, - count: int=None) -> str: + def pgettext( + self, context: str, message: str, plural_message: str = None, count: int = None + ) -> str: if self.translations: - gen_log.warning('pgettext is not supported by CSVLocale') + gen_log.warning("pgettext is not supported by CSVLocale") return self.translate(message, plural_message, count) class GettextLocale(Locale): """Locale implementation using the `gettext` module.""" + def __init__(self, code: str, translations: gettext.NullTranslations) -> None: self.ngettext = translations.ngettext self.gettext = translations.gettext @@ -463,15 +513,18 @@ class GettextLocale(Locale): # calls into self.translate super(GettextLocale, self).__init__(code) - def translate(self, message: str, plural_message: str=None, count: int=None) -> str: + def translate( + self, message: str, plural_message: str = None, count: int = None + ) -> str: if plural_message is not None: assert count is not None return self.ngettext(message, plural_message, count) else: return self.gettext(message) - def pgettext(self, context: str, message: str, plural_message: str=None, - count: int=None) -> str: + def pgettext( + self, context: str, message: str, plural_message: str = None, count: int = None + ) -> str: """Allows to set context for translation, accepts plural forms. Usage example:: @@ -493,9 +546,11 @@ class GettextLocale(Locale): """ if plural_message is not None: assert count is not None - msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message), - "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), - count) + msgs_with_ctxt = ( + "%s%s%s" % (context, CONTEXT_SEPARATOR, message), + "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message), + count, + ) result = self.ngettext(*msgs_with_ctxt) if CONTEXT_SEPARATOR in result: # Translation not found diff --git a/tornado/locks.py b/tornado/locks.py index e1b518aad..20d21b559 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -22,10 +22,11 @@ from tornado.concurrent import Future, future_set_result_unless_cancelled from typing import Union, Optional, Type, Any, Generator import typing + if typing.TYPE_CHECKING: from typing import Deque, Set # noqa: F401 -__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock'] +__all__ = ["Condition", "Event", "Semaphore", "BoundedSemaphore", "Lock"] class _TimeoutGarbageCollector(object): @@ -37,6 +38,7 @@ class _TimeoutGarbageCollector(object): yield condition.wait(short_timeout) print('looping....') """ + def __init__(self) -> None: self._waiters = collections.deque() # type: Deque[Future] self._timeouts = 0 @@ -46,8 +48,7 @@ class _TimeoutGarbageCollector(object): self._timeouts += 1 if self._timeouts > 100: self._timeouts = 0 - self._waiters = collections.deque( - w for w in self._waiters if not w.done()) + self._waiters = collections.deque(w for w in self._waiters if not w.done()) class Condition(_TimeoutGarbageCollector): @@ -115,12 +116,12 @@ class Condition(_TimeoutGarbageCollector): self.io_loop = ioloop.IOLoop.current() def __repr__(self) -> str: - result = '<%s' % (self.__class__.__name__, ) + result = "<%s" % (self.__class__.__name__,) if self._waiters: - result += ' waiters[%s]' % len(self._waiters) - return result + '>' + result += " waiters[%s]" % len(self._waiters) + return result + ">" - def wait(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[bool]': + def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[bool]": """Wait for `.notify`. Returns a `.Future` that resolves ``True`` if the condition is notified, @@ -129,17 +130,18 @@ class Condition(_TimeoutGarbageCollector): waiter = Future() # type: Future[bool] self._waiters.append(waiter) if timeout: + def on_timeout() -> None: if not waiter.done(): future_set_result_unless_cancelled(waiter, False) self._garbage_collect() + io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) - waiter.add_done_callback( - lambda _: io_loop.remove_timeout(timeout_handle)) + waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) return waiter - def notify(self, n: int=1) -> None: + def notify(self, n: int = 1) -> None: """Wake ``n`` waiters.""" waiters = [] # Waiters we plan to run right now. while n and self._waiters: @@ -195,13 +197,16 @@ class Event(object): Not waiting this time Done """ + def __init__(self) -> None: self._value = False self._waiters = set() # type: Set[Future[None]] def __repr__(self) -> str: - return '<%s %s>' % ( - self.__class__.__name__, 'set' if self.is_set() else 'clear') + return "<%s %s>" % ( + self.__class__.__name__, + "set" if self.is_set() else "clear", + ) def is_set(self) -> bool: """Return ``True`` if the internal flag is true.""" @@ -226,7 +231,7 @@ class Event(object): """ self._value = False - def wait(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]': + def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[None]": """Block until the internal flag is true. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -241,11 +246,15 @@ class Event(object): if timeout is None: return fut else: - timeout_fut = gen.with_timeout(timeout, fut, quiet_exceptions=(CancelledError,)) + timeout_fut = gen.with_timeout( + timeout, fut, quiet_exceptions=(CancelledError,) + ) # This is a slightly clumsy workaround for the fact that # gen.with_timeout doesn't cancel its futures. Cancelling # fut will remove it from the waiters list. - timeout_fut.add_done_callback(lambda tf: fut.cancel() if not fut.done() else None) + timeout_fut.add_done_callback( + lambda tf: fut.cancel() if not fut.done() else None + ) return timeout_fut @@ -257,15 +266,19 @@ class _ReleasingContextManager(object): # Now semaphore.release() has been called. """ + def __init__(self, obj: Any) -> None: self._obj = obj def __enter__(self) -> None: pass - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[types.TracebackType]) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> None: self._obj.release() @@ -364,20 +377,22 @@ class Semaphore(_TimeoutGarbageCollector): Added ``async with`` support in Python 3.5. """ - def __init__(self, value: int=1) -> None: + + def __init__(self, value: int = 1) -> None: super(Semaphore, self).__init__() if value < 0: - raise ValueError('semaphore initial value must be >= 0') + raise ValueError("semaphore initial value must be >= 0") self._value = value def __repr__(self) -> str: res = super(Semaphore, self).__repr__() - extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format( - self._value) + extra = ( + "locked" if self._value == 0 else "unlocked,value:{0}".format(self._value) + ) if self._waiters: - extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) - return '<{0} [{1}]>'.format(res[1:-1], extra) + extra = "{0},waiters:{1}".format(extra, len(self._waiters)) + return "<{0} [{1}]>".format(res[1:-1], extra) def release(self) -> None: """Increment the counter and wake one waiter.""" @@ -397,8 +412,8 @@ class Semaphore(_TimeoutGarbageCollector): break def acquire( - self, timeout: Union[float, datetime.timedelta]=None, - ) -> 'Future[_ReleasingContextManager]': + self, timeout: Union[float, datetime.timedelta] = None + ) -> "Future[_ReleasingContextManager]": """Decrement the counter. Returns a Future. Block if the counter is zero and wait for a `.release`. The Future @@ -411,33 +426,43 @@ class Semaphore(_TimeoutGarbageCollector): else: self._waiters.append(waiter) if timeout: + def on_timeout() -> None: if not waiter.done(): waiter.set_exception(gen.TimeoutError()) self._garbage_collect() + io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback( - lambda _: io_loop.remove_timeout(timeout_handle)) + lambda _: io_loop.remove_timeout(timeout_handle) + ) return waiter def __enter__(self) -> None: raise RuntimeError( "Use Semaphore like 'with (yield semaphore.acquire())', not like" - " 'with semaphore'") - - def __exit__(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - traceback: Optional[types.TracebackType]) -> None: + " 'with semaphore'" + ) + + def __exit__( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: self.__enter__() @gen.coroutine def __aenter__(self) -> Generator[Any, Any, None]: yield self.acquire() - async def __aexit__(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[types.TracebackType]) -> None: + async def __aexit__( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: self.release() @@ -449,7 +474,8 @@ class BoundedSemaphore(Semaphore): resources with limited capacity, so a semaphore released too many times is a sign of a bug. """ - def __init__(self, value: int=1) -> None: + + def __init__(self, value: int = 1) -> None: super(BoundedSemaphore, self).__init__(value=value) self._initial_value = value @@ -496,17 +522,16 @@ class Lock(object): Added ``async with`` support in Python 3.5. """ + def __init__(self) -> None: self._block = BoundedSemaphore(value=1) def __repr__(self) -> str: - return "<%s _block=%s>" % ( - self.__class__.__name__, - self._block) + return "<%s _block=%s>" % (self.__class__.__name__, self._block) def acquire( - self, timeout: Union[float, datetime.timedelta]=None, - ) -> 'Future[_ReleasingContextManager]': + self, timeout: Union[float, datetime.timedelta] = None + ) -> "Future[_ReleasingContextManager]": """Attempt to lock. Returns a Future. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -524,22 +549,27 @@ class Lock(object): try: self._block.release() except ValueError: - raise RuntimeError('release unlocked lock') + raise RuntimeError("release unlocked lock") def __enter__(self) -> None: - raise RuntimeError( - "Use Lock like 'with (yield lock)', not like 'with lock'") - - def __exit__(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[types.TracebackType]) -> None: + raise RuntimeError("Use Lock like 'with (yield lock)', not like 'with lock'") + + def __exit__( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: self.__enter__() @gen.coroutine def __aenter__(self) -> Generator[Any, Any, None]: yield self.acquire() - async def __aexit__(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[types.TracebackType]) -> None: + async def __aexit__( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[types.TracebackType], + ) -> None: self.release() diff --git a/tornado/log.py b/tornado/log.py index 5948d43ba..2ac2e4678 100644 --- a/tornado/log.py +++ b/tornado/log.py @@ -54,14 +54,15 @@ gen_log = logging.getLogger("tornado.general") def _stderr_supports_color() -> bool: try: - if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty(): + if hasattr(sys.stderr, "isatty") and sys.stderr.isatty(): if curses: curses.setupterm() if curses.tigetnum("colors") > 0: return True elif colorama: - if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr', - object()): + if sys.stderr is getattr( + colorama.initialise, "wrapped_stderr", object() + ): return True except Exception: # Very broad exception handling because it's always better to @@ -101,9 +102,9 @@ class LogFormatter(logging.Formatter): Added support for ``colorama``. Changed the constructor signature to be compatible with `logging.config.dictConfig`. """ - DEFAULT_FORMAT = \ - '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s' - DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S' + + DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # noqa: E501 + DEFAULT_DATE_FORMAT = "%y%m%d %H:%M:%S" DEFAULT_COLORS = { logging.DEBUG: 4, # Blue logging.INFO: 2, # Green @@ -111,8 +112,14 @@ class LogFormatter(logging.Formatter): logging.ERROR: 1, # Red } - def __init__(self, fmt: str=DEFAULT_FORMAT, datefmt: str=DEFAULT_DATE_FORMAT, - style: str='%', color: bool=True, colors: Dict[int, int]=DEFAULT_COLORS) -> None: + def __init__( + self, + fmt: str = DEFAULT_FORMAT, + datefmt: str = DEFAULT_DATE_FORMAT, + style: str = "%", + color: bool = True, + colors: Dict[int, int] = DEFAULT_COLORS, + ) -> None: r""" :arg bool color: Enables color support. :arg str fmt: Log message format. @@ -134,23 +141,24 @@ class LogFormatter(logging.Formatter): self._colors = {} # type: Dict[int, str] if color and _stderr_supports_color(): if curses is not None: - fg_color = (curses.tigetstr("setaf") or - curses.tigetstr("setf") or b"") + fg_color = curses.tigetstr("setaf") or curses.tigetstr("setf") or b"" for levelno, code in colors.items(): # Convert the terminal control characters from # bytes to unicode strings for easier use with the # logging module. - self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii") + self._colors[levelno] = unicode_type( + curses.tparm(fg_color, code), "ascii" + ) self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii") else: # If curses is not present (currently we'll only get here for # colorama on windows), assume hard-coded ANSI color codes. for levelno, code in colors.items(): - self._colors[levelno] = '\033[2;3%dm' % code - self._normal = '\033[0m' + self._colors[levelno] = "\033[2;3%dm" % code + self._normal = "\033[0m" else: - self._normal = '' + self._normal = "" def format(self, record: Any) -> str: try: @@ -182,7 +190,7 @@ class LogFormatter(logging.Formatter): record.color = self._colors[record.levelno] record.end_color = self._normal else: - record.color = record.end_color = '' + record.color = record.end_color = "" formatted = self._fmt % record.__dict__ @@ -194,13 +202,12 @@ class LogFormatter(logging.Formatter): # each line separately so that non-utf8 bytes don't cause # all the newlines to turn into '\n'. lines = [formatted.rstrip()] - lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n')) - formatted = '\n'.join(lines) + lines.extend(_safe_unicode(ln) for ln in record.exc_text.split("\n")) + formatted = "\n".join(lines) return formatted.replace("\n", "\n ") -def enable_pretty_logging(options: Any=None, - logger: logging.Logger=None) -> None: +def enable_pretty_logging(options: Any = None, logger: logging.Logger = None) -> None: """Turns on formatted logging output as configured. This is called automatically by `tornado.options.parse_command_line` @@ -208,41 +215,45 @@ def enable_pretty_logging(options: Any=None, """ if options is None: import tornado.options + options = tornado.options.options - if options.logging is None or options.logging.lower() == 'none': + if options.logging is None or options.logging.lower() == "none": return if logger is None: logger = logging.getLogger() logger.setLevel(getattr(logging, options.logging.upper())) if options.log_file_prefix: rotate_mode = options.log_rotate_mode - if rotate_mode == 'size': + if rotate_mode == "size": channel = logging.handlers.RotatingFileHandler( filename=options.log_file_prefix, maxBytes=options.log_file_max_size, - backupCount=options.log_file_num_backups) # type: logging.Handler - elif rotate_mode == 'time': + backupCount=options.log_file_num_backups, + ) # type: logging.Handler + elif rotate_mode == "time": channel = logging.handlers.TimedRotatingFileHandler( filename=options.log_file_prefix, when=options.log_rotate_when, interval=options.log_rotate_interval, - backupCount=options.log_file_num_backups) + backupCount=options.log_file_num_backups, + ) else: - error_message = 'The value of log_rotate_mode option should be ' +\ - '"size" or "time", not "%s".' % rotate_mode + error_message = ( + "The value of log_rotate_mode option should be " + + '"size" or "time", not "%s".' % rotate_mode + ) raise ValueError(error_message) channel.setFormatter(LogFormatter(color=False)) logger.addHandler(channel) - if (options.log_to_stderr or - (options.log_to_stderr is None and not logger.handlers)): + if options.log_to_stderr or (options.log_to_stderr is None and not logger.handlers): # Set up color if we are in a tty and curses is installed channel = logging.StreamHandler() channel.setFormatter(LogFormatter()) logger.addHandler(channel) -def define_logging_options(options: Any=None) -> None: +def define_logging_options(options: Any = None) -> None: """Add logging-related flags to ``options``. These options are present automatically on the default options instance; @@ -254,32 +265,70 @@ def define_logging_options(options: Any=None) -> None: if options is None: # late import to prevent cycle import tornado.options + options = tornado.options.options - options.define("logging", default="info", - help=("Set the Python log level. If 'none', tornado won't touch the " - "logging configuration."), - metavar="debug|info|warning|error|none") - options.define("log_to_stderr", type=bool, default=None, - help=("Send log output to stderr (colorized if possible). " - "By default use stderr if --log_file_prefix is not set and " - "no other logging is configured.")) - options.define("log_file_prefix", type=str, default=None, metavar="PATH", - help=("Path prefix for log files. " - "Note that if you are running multiple tornado processes, " - "log_file_prefix must be different for each of them (e.g. " - "include the port number)")) - options.define("log_file_max_size", type=int, default=100 * 1000 * 1000, - help="max size of log files before rollover") - options.define("log_file_num_backups", type=int, default=10, - help="number of log files to keep") - - options.define("log_rotate_when", type=str, default='midnight', - help=("specify the type of TimedRotatingFileHandler interval " - "other options:('S', 'M', 'H', 'D', 'W0'-'W6')")) - options.define("log_rotate_interval", type=int, default=1, - help="The interval value of timed rotating") - - options.define("log_rotate_mode", type=str, default='size', - help="The mode of rotating files(time or size)") + options.define( + "logging", + default="info", + help=( + "Set the Python log level. If 'none', tornado won't touch the " + "logging configuration." + ), + metavar="debug|info|warning|error|none", + ) + options.define( + "log_to_stderr", + type=bool, + default=None, + help=( + "Send log output to stderr (colorized if possible). " + "By default use stderr if --log_file_prefix is not set and " + "no other logging is configured." + ), + ) + options.define( + "log_file_prefix", + type=str, + default=None, + metavar="PATH", + help=( + "Path prefix for log files. " + "Note that if you are running multiple tornado processes, " + "log_file_prefix must be different for each of them (e.g. " + "include the port number)" + ), + ) + options.define( + "log_file_max_size", + type=int, + default=100 * 1000 * 1000, + help="max size of log files before rollover", + ) + options.define( + "log_file_num_backups", type=int, default=10, help="number of log files to keep" + ) + + options.define( + "log_rotate_when", + type=str, + default="midnight", + help=( + "specify the type of TimedRotatingFileHandler interval " + "other options:('S', 'M', 'H', 'D', 'W0'-'W6')" + ), + ) + options.define( + "log_rotate_interval", + type=int, + default=1, + help="The interval value of timed rotating", + ) + + options.define( + "log_rotate_mode", + type=str, + default="size", + help="The mode of rotating files(time or size)", + ) options.add_parse_callback(lambda: enable_pretty_logging(options)) diff --git a/tornado/netutil.py b/tornado/netutil.py index 57e2b2b81..5a92690fa 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -38,11 +38,9 @@ if typing.TYPE_CHECKING: # Note that the naming of ssl.Purpose is confusing; the purpose # of a context is to authentiate the opposite side of the connection. -_client_ssl_defaults = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH) -_server_ssl_defaults = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH) -if hasattr(ssl, 'OP_NO_COMPRESSION'): +_client_ssl_defaults = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) +_server_ssl_defaults = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +if hasattr(ssl, "OP_NO_COMPRESSION"): # See netutil.ssl_options_to_context _client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION _server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION @@ -52,10 +50,10 @@ if hasattr(ssl, 'OP_NO_COMPRESSION'): # module-import time, the import lock is already held by the main thread, # leading to deadlock. Avoid it by caching the idna encoder on the main # thread now. -u'foo'.encode('idna') +u"foo".encode("idna") # For undiagnosed reasons, 'latin1' codec may also need to be preloaded. -u'foo'.encode('latin1') +u"foo".encode("latin1") # These errnos indicate that a non-blocking operation must be retried # at a later time. On most platforms they're the same value, but on @@ -69,10 +67,14 @@ if hasattr(errno, "WSAEWOULDBLOCK"): _DEFAULT_BACKLOG = 128 -def bind_sockets(port: int, address: str=None, - family: socket.AddressFamily=socket.AF_UNSPEC, - backlog: int=_DEFAULT_BACKLOG, flags: int=None, - reuse_port: bool=False) -> List[socket.socket]: +def bind_sockets( + port: int, + address: str = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = _DEFAULT_BACKLOG, + flags: int = None, + reuse_port: bool = False, +) -> List[socket.socket]: """Creates listening sockets bound to the given port and address. Returns a list of socket objects (multiple sockets are returned if @@ -113,16 +115,22 @@ def bind_sockets(port: int, address: str=None, flags = socket.AI_PASSIVE bound_port = None unique_addresses = set() # type: set - for res in sorted(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, - 0, flags), key=lambda x: x[0]): + for res in sorted( + socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags), + key=lambda x: x[0], + ): if res in unique_addresses: continue unique_addresses.add(res) af, socktype, proto, canonname, sockaddr = res - if (sys.platform == 'darwin' and address == 'localhost' and - af == socket.AF_INET6 and sockaddr[3] != 0): + if ( + sys.platform == "darwin" + and address == "localhost" + and af == socket.AF_INET6 + and sockaddr[3] != 0 + ): # Mac OS X includes a link-local address fe80::1%lo0 in the # getaddrinfo results for 'localhost'. However, the firewall # doesn't understand that this is a local address and will @@ -137,7 +145,7 @@ def bind_sockets(port: int, address: str=None, continue raise set_close_exec(sock.fileno()) - if os.name != 'nt': + if os.name != "nt": try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) except socket.error as e: @@ -172,9 +180,11 @@ def bind_sockets(port: int, address: str=None, return sockets -if hasattr(socket, 'AF_UNIX'): - def bind_unix_socket(file: str, mode: int=0o600, - backlog: int=_DEFAULT_BACKLOG) -> socket.socket: +if hasattr(socket, "AF_UNIX"): + + def bind_unix_socket( + file: str, mode: int = 0o600, backlog: int = _DEFAULT_BACKLOG + ) -> socket.socket: """Creates a listening unix socket. If a socket with the given name already exists, it will be deleted. @@ -209,8 +219,9 @@ if hasattr(socket, 'AF_UNIX'): return sock -def add_accept_handler(sock: socket.socket, - callback: Callable[[socket.socket, Any], None]) -> Callable[[], None]: +def add_accept_handler( + sock: socket.socket, callback: Callable[[socket.socket, Any], None] +) -> Callable[[], None]: """Adds an `.IOLoop` event handler to accept new connections on ``sock``. When a connection is accepted, ``callback(connection, address)`` will @@ -276,14 +287,14 @@ def is_valid_ip(ip: str) -> bool: Supports IPv4 and IPv6. """ - if not ip or '\x00' in ip: + if not ip or "\x00" in ip: # getaddrinfo resolves empty strings to localhost, and truncates # on zero bytes. return False try: - res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC, - socket.SOCK_STREAM, - 0, socket.AI_NUMERICHOST) + res = socket.getaddrinfo( + ip, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST + ) return bool(res) except socket.gaierror as e: if e.args[0] == socket.EAI_NONAME: @@ -315,17 +326,18 @@ class Resolver(Configurable): The default implementation has changed from `BlockingResolver` to `DefaultExecutorResolver`. """ + @classmethod - def configurable_base(cls) -> Type['Resolver']: + def configurable_base(cls) -> Type["Resolver"]: return Resolver @classmethod - def configurable_default(cls) -> Type['Resolver']: + def configurable_default(cls) -> Type["Resolver"]: return DefaultExecutorResolver def resolve( - self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC, - ) -> 'Future[List[Tuple[int, Any]]]': + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> "Future[List[Tuple[int, Any]]]": """Resolves an address. The ``host`` argument is a string which may be a hostname or a @@ -359,7 +371,7 @@ class Resolver(Configurable): def _resolve_addr( - host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC, + host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC ) -> List[Tuple[int, Any]]: # On Solaris, getaddrinfo fails if the given port is not found # in /etc/services and no socket type is given, so we must pass @@ -378,12 +390,14 @@ class DefaultExecutorResolver(Resolver): .. versionadded:: 5.0 """ + @gen.coroutine def resolve( - self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC, + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC ) -> Generator[Any, Any, List[Tuple[int, Any]]]: result = yield IOLoop.current().run_in_executor( - None, _resolve_addr, host, port, family) + None, _resolve_addr, host, port, family + ) return result @@ -404,8 +418,10 @@ class ExecutorResolver(Resolver): The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead of this class. """ - def initialize(self, executor: concurrent.futures.Executor=None, - close_executor: bool=True) -> None: + + def initialize( + self, executor: concurrent.futures.Executor = None, close_executor: bool = True + ) -> None: self.io_loop = IOLoop.current() if executor is not None: self.executor = executor @@ -421,7 +437,7 @@ class ExecutorResolver(Resolver): @run_on_executor def resolve( - self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC, + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC ) -> List[Tuple[int, Any]]: return _resolve_addr(host, port, family) @@ -436,6 +452,7 @@ class BlockingResolver(ExecutorResolver): The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead of this class. """ + def initialize(self) -> None: # type: ignore super(BlockingResolver, self).initialize() @@ -460,16 +477,20 @@ class ThreadedResolver(ExecutorResolver): The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead of this class. """ + _threadpool = None # type: ignore _threadpool_pid = None # type: int - def initialize(self, num_threads: int=10) -> None: # type: ignore + def initialize(self, num_threads: int = 10) -> None: # type: ignore threadpool = ThreadedResolver._create_threadpool(num_threads) super(ThreadedResolver, self).initialize( - executor=threadpool, close_executor=False) + executor=threadpool, close_executor=False + ) @classmethod - def _create_threadpool(cls, num_threads: int) -> concurrent.futures.ThreadPoolExecutor: + def _create_threadpool( + cls, num_threads: int + ) -> concurrent.futures.ThreadPoolExecutor: pid = os.getpid() if cls._threadpool_pid != pid: # Threads cannot survive after a fork, so if our pid isn't what it @@ -503,6 +524,7 @@ class OverrideResolver(Resolver): .. versionchanged:: 5.0 Added support for host-port-family triplets. """ + def initialize(self, resolver: Resolver, mapping: dict) -> None: # type: ignore self.resolver = resolver self.mapping = mapping @@ -511,8 +533,8 @@ class OverrideResolver(Resolver): self.resolver.close() def resolve( - self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC, - ) -> 'Future[List[Tuple[int, Any]]]': + self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC + ) -> "Future[List[Tuple[int, Any]]]": if (host, port, family) in self.mapping: host, port = self.mapping[(host, port, family)] elif (host, port) in self.mapping: @@ -525,11 +547,14 @@ class OverrideResolver(Resolver): # These are the keyword arguments to ssl.wrap_socket that must be translated # to their SSLContext equivalents (the other arguments are still passed # to SSLContext.wrap_socket). -_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile', - 'cert_reqs', 'ca_certs', 'ciphers']) +_SSL_CONTEXT_KEYWORDS = frozenset( + ["ssl_version", "certfile", "keyfile", "cert_reqs", "ca_certs", "ciphers"] +) -def ssl_options_to_context(ssl_options: Union[Dict[str, Any], ssl.SSLContext]) -> ssl.SSLContext: +def ssl_options_to_context( + ssl_options: Union[Dict[str, Any], ssl.SSLContext] +) -> ssl.SSLContext: """Try to convert an ``ssl_options`` dictionary to an `~ssl.SSLContext` object. @@ -546,17 +571,18 @@ def ssl_options_to_context(ssl_options: Union[Dict[str, Any], ssl.SSLContext]) - assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options # Can't use create_default_context since this interface doesn't # tell us client vs server. - context = ssl.SSLContext( - ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23)) - if 'certfile' in ssl_options: - context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None)) - if 'cert_reqs' in ssl_options: - context.verify_mode = ssl_options['cert_reqs'] - if 'ca_certs' in ssl_options: - context.load_verify_locations(ssl_options['ca_certs']) - if 'ciphers' in ssl_options: - context.set_ciphers(ssl_options['ciphers']) - if hasattr(ssl, 'OP_NO_COMPRESSION'): + context = ssl.SSLContext(ssl_options.get("ssl_version", ssl.PROTOCOL_SSLv23)) + if "certfile" in ssl_options: + context.load_cert_chain( + ssl_options["certfile"], ssl_options.get("keyfile", None) + ) + if "cert_reqs" in ssl_options: + context.verify_mode = ssl_options["cert_reqs"] + if "ca_certs" in ssl_options: + context.load_verify_locations(ssl_options["ca_certs"]) + if "ciphers" in ssl_options: + context.set_ciphers(ssl_options["ciphers"]) + if hasattr(ssl, "OP_NO_COMPRESSION"): # Disable TLS compression to avoid CRIME and related attacks. # This constant depends on openssl version 1.0. # TODO: Do we need to do this ourselves or can we trust @@ -565,8 +591,12 @@ def ssl_options_to_context(ssl_options: Union[Dict[str, Any], ssl.SSLContext]) - return context -def ssl_wrap_socket(socket: socket.socket, ssl_options: Union[Dict[str, Any], ssl.SSLContext], - server_hostname: str=None, **kwargs: Any) -> ssl.SSLSocket: +def ssl_wrap_socket( + socket: socket.socket, + ssl_options: Union[Dict[str, Any], ssl.SSLContext], + server_hostname: str = None, + **kwargs: Any +) -> ssl.SSLSocket: """Returns an ``ssl.SSLSocket`` wrapping the given socket. ``ssl_options`` may be either an `ssl.SSLContext` object or a @@ -582,7 +612,6 @@ def ssl_wrap_socket(socket: socket.socket, ssl_options: Union[Dict[str, Any], ss # TODO: add a unittest (python added server-side SNI support in 3.4) # In the meantime it can be manually tested with # python3 -m tornado.httpclient https://sni.velox.ch - return context.wrap_socket(socket, server_hostname=server_hostname, - **kwargs) + return context.wrap_socket(socket, server_hostname=server_hostname, **kwargs) else: return context.wrap_socket(socket, **kwargs) diff --git a/tornado/options.py b/tornado/options.py index 5129fbf6f..4449ffa55 100644 --- a/tornado/options.py +++ b/tornado/options.py @@ -113,6 +113,7 @@ if typing.TYPE_CHECKING: class Error(Exception): """Exception raised by errors in the options module.""" + pass @@ -122,15 +123,20 @@ class OptionParser(object): Normally accessed via static functions in the `tornado.options` module, which reference a global instance. """ + def __init__(self) -> None: # we have to use self.__dict__ because we override setattr. - self.__dict__['_options'] = {} - self.__dict__['_parse_callbacks'] = [] - self.define("help", type=bool, help="show this help information", - callback=self._help_callback) + self.__dict__["_options"] = {} + self.__dict__["_parse_callbacks"] = [] + self.define( + "help", + type=bool, + help="show this help information", + callback=self._help_callback, + ) def _normalize_name(self, name: str) -> str: - return name.replace('_', '-') + return name.replace("_", "-") def __getattr__(self, name: str) -> Any: name = self._normalize_name(name) @@ -189,20 +195,29 @@ class OptionParser(object): .. versionadded:: 3.1 """ return dict( - (opt.name, opt.value()) for name, opt in self._options.items() - if not group or group == opt.group_name) + (opt.name, opt.value()) + for name, opt in self._options.items() + if not group or group == opt.group_name + ) def as_dict(self) -> Dict[str, Any]: """The names and values of all options. .. versionadded:: 3.1 """ - return dict( - (opt.name, opt.value()) for name, opt in self._options.items()) - - def define(self, name: str, default: Any=None, type: type=None, - help: str=None, metavar: str=None, - multiple: bool=False, group: str=None, callback: Callable[[Any], None]=None) -> None: + return dict((opt.name, opt.value()) for name, opt in self._options.items()) + + def define( + self, + name: str, + default: Any = None, + type: type = None, + help: str = None, + metavar: str = None, + multiple: bool = False, + group: str = None, + callback: Callable[[Any], None] = None, + ) -> None: """Defines a new command line option. ``type`` can be any of `str`, `int`, `float`, `bool`, @@ -239,15 +254,19 @@ class OptionParser(object): """ normalized = self._normalize_name(name) if normalized in self._options: - raise Error("Option %r already defined in %s" % - (normalized, self._options[normalized].file_name)) + raise Error( + "Option %r already defined in %s" + % (normalized, self._options[normalized].file_name) + ) frame = sys._getframe(0) options_file = frame.f_code.co_filename # Can be called directly, or through top level define() fn, in which # case, step up above that frame to look for real caller. - if (frame.f_back.f_code.co_filename == options_file and - frame.f_back.f_code.co_name == 'define'): + if ( + frame.f_back.f_code.co_filename == options_file + and frame.f_back.f_code.co_name == "define" + ): frame = frame.f_back file_name = frame.f_back.f_code.co_filename @@ -262,14 +281,22 @@ class OptionParser(object): group_name = group # type: Optional[str] else: group_name = file_name - option = _Option(name, file_name=file_name, - default=default, type=type, help=help, - metavar=metavar, multiple=multiple, - group_name=group_name, - callback=callback) + option = _Option( + name, + file_name=file_name, + default=default, + type=type, + help=help, + metavar=metavar, + multiple=multiple, + group_name=group_name, + callback=callback, + ) self._options[normalized] = option - def parse_command_line(self, args: List[str]=None, final: bool=True) -> List[str]: + def parse_command_line( + self, args: List[str] = None, final: bool = True + ) -> List[str]: """Parses all options given on the command line (defaults to `sys.argv`). @@ -300,20 +327,20 @@ class OptionParser(object): remaining = args[i:] break if args[i] == "--": - remaining = args[i + 1:] + remaining = args[i + 1 :] break arg = args[i].lstrip("-") name, equals, value = arg.partition("=") name = self._normalize_name(name) if name not in self._options: self.print_help() - raise Error('Unrecognized command line option: %r' % name) + raise Error("Unrecognized command line option: %r" % name) option = self._options[name] if not equals: if option.type == bool: value = "true" else: - raise Error('Option %r requires a value' % name) + raise Error("Option %r requires a value" % name) option.parse(value) if final: @@ -321,7 +348,7 @@ class OptionParser(object): return remaining - def parse_config_file(self, path: str, final: bool=True) -> None: + def parse_config_file(self, path: str, final: bool = True) -> None: """Parses and loads the config file at the given path. The config file contains Python code that will be executed (so @@ -367,8 +394,8 @@ class OptionParser(object): Added the ability to set options via strings in config files. """ - config = {'__file__': os.path.abspath(path)} - with open(path, 'rb') as f: + config = {"__file__": os.path.abspath(path)} + with open(path, "rb") as f: exec_in(native_str(f.read()), config, config) for name in config: normalized = self._normalize_name(name) @@ -376,9 +403,11 @@ class OptionParser(object): option = self._options[normalized] if option.multiple: if not isinstance(config[name], (list, str)): - raise Error("Option %r is required to be a list of %s " - "or a comma-separated string" % - (option.name, option.type.__name__)) + raise Error( + "Option %r is required to be a list of %s " + "or a comma-separated string" + % (option.name, option.type.__name__) + ) if type(config[name]) == str and option.type != str: option.parse(config[name]) @@ -388,7 +417,7 @@ class OptionParser(object): if final: self.run_parse_callbacks() - def print_help(self, file: TextIO=None) -> None: + def print_help(self, file: TextIO = None) -> None: """Prints all the command line options to stderr (or another file).""" if file is None: file = sys.stderr @@ -408,14 +437,14 @@ class OptionParser(object): if option.metavar: prefix += "=" + option.metavar description = option.help or "" - if option.default is not None and option.default != '': + if option.default is not None and option.default != "": description += " (default %s)" % option.default lines = textwrap.wrap(description, 79 - 35) if len(prefix) > 30 or len(lines) == 0: - lines.insert(0, '') + lines.insert(0, "") print(" --%-30s %s" % (prefix, lines[0]), file=file) for line in lines[1:]: - print("%-34s %s" % (' ', line), file=file) + print("%-34s %s" % (" ", line), file=file) print(file=file) def _help_callback(self, value: bool) -> None: @@ -431,7 +460,7 @@ class OptionParser(object): for callback in self._parse_callbacks: callback() - def mockable(self) -> '_Mockable': + def mockable(self) -> "_Mockable": """Returns a wrapper around self that is compatible with `mock.patch `. @@ -461,10 +490,11 @@ class _Mockable(object): _Mockable's getattr and setattr pass through to the underlying OptionParser, and delattr undoes the effect of a previous setattr. """ + def __init__(self, options: OptionParser) -> None: # Modify __dict__ directly to bypass __setattr__ - self.__dict__['_options'] = options - self.__dict__['_originals'] = {} + self.__dict__["_options"] = options + self.__dict__["_originals"] = {} def __getattr__(self, name: str) -> Any: return getattr(self._options, name) @@ -484,9 +514,18 @@ class _Option(object): # and the callback use List[T], but type is still Type[T]). UNSET = object() - def __init__(self, name: str, default: Any=None, type: type=None, - help: str=None, metavar: str=None, multiple: bool=False, file_name: str=None, - group_name: str=None, callback: Callable[[Any], None]=None) -> None: + def __init__( + self, + name: str, + default: Any = None, + type: type = None, + help: str = None, + metavar: str = None, + multiple: bool = False, + file_name: str = None, + group_name: str = None, + callback: Callable[[Any], None] = None, + ) -> None: if default is None and multiple: default = [] self.name = name @@ -511,7 +550,9 @@ class _Option(object): datetime.timedelta: self._parse_timedelta, bool: self._parse_bool, basestring_type: self._parse_string, - }.get(self.type, self.type) # type: Callable[[str], Any] + }.get( + self.type, self.type + ) # type: Callable[[str], Any] if self.multiple: self._value = [] for part in value.split(","): @@ -532,16 +573,22 @@ class _Option(object): def set(self, value: Any) -> None: if self.multiple: if not isinstance(value, list): - raise Error("Option %r is required to be a list of %s" % - (self.name, self.type.__name__)) + raise Error( + "Option %r is required to be a list of %s" + % (self.name, self.type.__name__) + ) for item in value: if item is not None and not isinstance(item, self.type): - raise Error("Option %r is required to be a list of %s" % - (self.name, self.type.__name__)) + raise Error( + "Option %r is required to be a list of %s" + % (self.name, self.type.__name__) + ) else: if value is not None and not isinstance(value, self.type): - raise Error("Option %r is required to be a %s (%s given)" % - (self.name, self.type.__name__, type(value))) + raise Error( + "Option %r is required to be a %s (%s given)" + % (self.name, self.type.__name__, type(value)) + ) self._value = value if self.callback is not None: self.callback(self._value) @@ -566,24 +613,25 @@ class _Option(object): return datetime.datetime.strptime(value, format) except ValueError: pass - raise Error('Unrecognized date/time format: %r' % value) + raise Error("Unrecognized date/time format: %r" % value) _TIMEDELTA_ABBREV_DICT = { - 'h': 'hours', - 'm': 'minutes', - 'min': 'minutes', - 's': 'seconds', - 'sec': 'seconds', - 'ms': 'milliseconds', - 'us': 'microseconds', - 'd': 'days', - 'w': 'weeks', + "h": "hours", + "m": "minutes", + "min": "minutes", + "s": "seconds", + "sec": "seconds", + "ms": "milliseconds", + "us": "microseconds", + "d": "days", + "w": "weeks", } - _FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?' + _FLOAT_PATTERN = r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?" _TIMEDELTA_PATTERN = re.compile( - r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE) + r"\s*(%s)\s*(\w*)\s*" % _FLOAT_PATTERN, re.IGNORECASE + ) def _parse_timedelta(self, value: str) -> datetime.timedelta: try: @@ -594,7 +642,7 @@ class _Option(object): if not m: raise Exception() num = float(m.group(1)) - units = m.group(2) or 'seconds' + units = m.group(2) or "seconds" units = self._TIMEDELTA_ABBREV_DICT.get(units, units) sum += datetime.timedelta(**{units: num}) start = m.end() @@ -616,19 +664,33 @@ All defined options are available as attributes on this object. """ -def define(name: str, default: Any=None, type: type=None, help: str=None, - metavar: str=None, multiple: bool=False, group: str=None, - callback: Callable[[Any], None]=None) -> None: +def define( + name: str, + default: Any = None, + type: type = None, + help: str = None, + metavar: str = None, + multiple: bool = False, + group: str = None, + callback: Callable[[Any], None] = None, +) -> None: """Defines an option in the global namespace. See `OptionParser.define`. """ - return options.define(name, default=default, type=type, help=help, - metavar=metavar, multiple=multiple, group=group, - callback=callback) - - -def parse_command_line(args: List[str]=None, final: bool=True) -> List[str]: + return options.define( + name, + default=default, + type=type, + help=help, + metavar=metavar, + multiple=multiple, + group=group, + callback=callback, + ) + + +def parse_command_line(args: List[str] = None, final: bool = True) -> List[str]: """Parses global options from the command line. See `OptionParser.parse_command_line`. @@ -636,7 +698,7 @@ def parse_command_line(args: List[str]=None, final: bool=True) -> List[str]: return options.parse_command_line(args, final=final) -def parse_config_file(path: str, final: bool=True) -> None: +def parse_config_file(path: str, final: bool = True) -> None: """Parses global options from a config file. See `OptionParser.parse_config_file`. @@ -644,7 +706,7 @@ def parse_config_file(path: str, final: bool=True) -> None: return options.parse_config_file(path, final=final) -def print_help(file: TextIO=None) -> None: +def print_help(file: TextIO = None) -> None: """Prints all the command line options to stderr (or another file). See `OptionParser.print_help`. diff --git a/tornado/platform/asyncio.py b/tornado/platform/asyncio.py index f1b072c7b..aa701d6b3 100644 --- a/tornado/platform/asyncio.py +++ b/tornado/platform/asyncio.py @@ -29,15 +29,17 @@ import asyncio import typing from typing import Any, TypeVar, Awaitable, Callable, Union, Optional + if typing.TYPE_CHECKING: from typing import Set, Dict, Tuple # noqa: F401 -_T = TypeVar('_T') +_T = TypeVar("_T") class BaseAsyncIOLoop(IOLoop): - def initialize(self, asyncio_loop: asyncio.AbstractEventLoop, # type: ignore - **kwargs: Any) -> None: + def initialize( # type: ignore + self, asyncio_loop: asyncio.AbstractEventLoop, **kwargs: Any + ) -> None: self.asyncio_loop = asyncio_loop # Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler) self.handlers = {} # type: Dict[int, Tuple[Union[int, _Selectable], Callable]] @@ -62,7 +64,7 @@ class BaseAsyncIOLoop(IOLoop): IOLoop._ioloop_for_asyncio[asyncio_loop] = self super(BaseAsyncIOLoop, self).initialize(**kwargs) - def close(self, all_fds: bool=False) -> None: + def close(self, all_fds: bool = False) -> None: self.closing = True for fd in list(self.handlers): fileobj, handler_func = self.handlers[fd] @@ -77,27 +79,25 @@ class BaseAsyncIOLoop(IOLoop): del IOLoop._ioloop_for_asyncio[self.asyncio_loop] self.asyncio_loop.close() - def add_handler(self, fd: Union[int, _Selectable], - handler: Callable[..., None], events: int) -> None: + def add_handler( + self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int + ) -> None: fd, fileobj = self.split_fd(fd) if fd in self.handlers: raise ValueError("fd %s added twice" % fd) self.handlers[fd] = (fileobj, handler) if events & IOLoop.READ: - self.asyncio_loop.add_reader( - fd, self._handle_events, fd, IOLoop.READ) + self.asyncio_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ) self.readers.add(fd) if events & IOLoop.WRITE: - self.asyncio_loop.add_writer( - fd, self._handle_events, fd, IOLoop.WRITE) + self.asyncio_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE) self.writers.add(fd) def update_handler(self, fd: Union[int, _Selectable], events: int) -> None: fd, fileobj = self.split_fd(fd) if events & IOLoop.READ: if fd not in self.readers: - self.asyncio_loop.add_reader( - fd, self._handle_events, fd, IOLoop.READ) + self.asyncio_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ) self.readers.add(fd) else: if fd in self.readers: @@ -105,8 +105,7 @@ class BaseAsyncIOLoop(IOLoop): self.readers.remove(fd) if events & IOLoop.WRITE: if fd not in self.writers: - self.asyncio_loop.add_writer( - fd, self._handle_events, fd, IOLoop.WRITE) + self.asyncio_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE) self.writers.add(fd) else: if fd in self.writers: @@ -144,14 +143,17 @@ class BaseAsyncIOLoop(IOLoop): def stop(self) -> None: self.asyncio_loop.stop() - def call_at(self, when: float, callback: Callable[..., None], - *args: Any, **kwargs: Any) -> object: + def call_at( + self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> object: # asyncio.call_at supports *args but not **kwargs, so bind them here. # We do not synchronize self.time and asyncio_loop.time, so # convert from absolute to relative. return self.asyncio_loop.call_later( - max(0, when - self.time()), self._run_callback, - functools.partial(callback, *args, **kwargs)) + max(0, when - self.time()), + self._run_callback, + functools.partial(callback, *args, **kwargs), + ) def remove_timeout(self, timeout: object) -> None: timeout.cancel() # type: ignore @@ -159,8 +161,8 @@ class BaseAsyncIOLoop(IOLoop): def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None: try: self.asyncio_loop.call_soon_threadsafe( - self._run_callback, - functools.partial(callback, *args, **kwargs)) + self._run_callback, functools.partial(callback, *args, **kwargs) + ) except RuntimeError: # "Event loop is closed". Swallow the exception for # consistency with PollIOLoop (and logical consistency @@ -171,8 +173,12 @@ class BaseAsyncIOLoop(IOLoop): add_callback_from_signal = add_callback - def run_in_executor(self, executor: Optional[concurrent.futures.Executor], - func: Callable[..., _T], *args: Any) -> Awaitable[_T]: + def run_in_executor( + self, + executor: Optional[concurrent.futures.Executor], + func: Callable[..., _T], + *args: Any + ) -> Awaitable[_T]: return self.asyncio_loop.run_in_executor(executor, func, *args) def set_default_executor(self, executor: concurrent.futures.Executor) -> None: @@ -193,6 +199,7 @@ class AsyncIOMainLoop(BaseAsyncIOLoop): Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop. """ + def initialize(self, **kwargs: Any) -> None: # type: ignore super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs) @@ -221,6 +228,7 @@ class AsyncIOLoop(BaseAsyncIOLoop): Now used automatically when appropriate; it is no longer necessary to refer to this class directly. """ + def initialize(self, **kwargs: Any) -> None: # type: ignore self.is_current = False loop = asyncio.new_event_loop() @@ -232,7 +240,7 @@ class AsyncIOLoop(BaseAsyncIOLoop): loop.close() raise - def close(self, all_fds: bool=False) -> None: + def close(self, all_fds: bool = False) -> None: if self.is_current: self.clear_current() super(AsyncIOLoop, self).close(all_fds=all_fds) @@ -297,6 +305,7 @@ class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore .. versionadded:: 5.0 """ + def get_event_loop(self) -> asyncio.AbstractEventLoop: try: return super().get_event_loop() diff --git a/tornado/platform/auto.py b/tornado/platform/auto.py index 6c3100170..4f1b6ac39 100644 --- a/tornado/platform/auto.py +++ b/tornado/platform/auto.py @@ -24,9 +24,9 @@ Most code that needs access to this functionality should do e.g.:: import os -if os.name == 'nt': +if os.name == "nt": from tornado.platform.windows import set_close_exec else: from tornado.platform.posix import set_close_exec -__all__ = ['set_close_exec'] +__all__ = ["set_close_exec"] diff --git a/tornado/platform/caresresolver.py b/tornado/platform/caresresolver.py index b23614b67..e2c5009ac 100644 --- a/tornado/platform/caresresolver.py +++ b/tornado/platform/caresresolver.py @@ -7,6 +7,7 @@ from tornado.ioloop import IOLoop from tornado.netutil import Resolver, is_valid_ip import typing + if typing.TYPE_CHECKING: from typing import Generator, Any, List, Tuple, Dict # noqa: F401 @@ -26,14 +27,14 @@ class CaresResolver(Resolver): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ + def initialize(self) -> None: self.io_loop = IOLoop.current() self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb) self.fds = {} # type: Dict[int, int] def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: - state = ((IOLoop.READ if readable else 0) | - (IOLoop.WRITE if writable else 0)) + state = (IOLoop.READ if readable else 0) | (IOLoop.WRITE if writable else 0) if not state: self.io_loop.remove_handler(fd) del self.fds[fd] @@ -55,30 +56,34 @@ class CaresResolver(Resolver): @gen.coroutine def resolve( - self, host: str, port: int, family: int=0, - ) -> 'Generator[Any, Any, List[Tuple[int, Any]]]': + self, host: str, port: int, family: int = 0 + ) -> "Generator[Any, Any, List[Tuple[int, Any]]]": if is_valid_ip(host): addresses = [host] else: # gethostbyname doesn't take callback as a kwarg fut = Future() # type: Future[Tuple[Any, Any]] - self.channel.gethostbyname(host, family, - lambda result, error: fut.set_result((result, error))) + self.channel.gethostbyname( + host, family, lambda result, error: fut.set_result((result, error)) + ) result, error = yield fut if error: - raise IOError('C-Ares returned error %s: %s while resolving %s' % - (error, pycares.errno.strerror(error), host)) + raise IOError( + "C-Ares returned error %s: %s while resolving %s" + % (error, pycares.errno.strerror(error), host) + ) addresses = result.addresses addrinfo = [] for address in addresses: - if '.' in address: + if "." in address: address_family = socket.AF_INET - elif ':' in address: + elif ":" in address: address_family = socket.AF_INET6 else: address_family = socket.AF_UNSPEC if family != socket.AF_UNSPEC and family != address_family: - raise IOError('Requested socket family %d but got %d' % - (family, address_family)) + raise IOError( + "Requested socket family %d but got %d" % (family, address_family) + ) addrinfo.append((typing.cast(int, address_family), (address, port))) return addrinfo diff --git a/tornado/platform/twisted.py b/tornado/platform/twisted.py index bf73e4aab..388163130 100644 --- a/tornado/platform/twisted.py +++ b/tornado/platform/twisted.py @@ -40,6 +40,7 @@ from tornado import gen from tornado.netutil import Resolver import typing + if typing.TYPE_CHECKING: from typing import Generator, Any, List, Tuple # noqa: F401 @@ -61,22 +62,25 @@ class TwistedResolver(Resolver): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ + def initialize(self) -> None: # partial copy of twisted.names.client.createResolver, which doesn't # allow for a reactor to be passed in. self.reactor = twisted.internet.asyncioreactor.AsyncioSelectorReactor() - host_resolver = twisted.names.hosts.Resolver('/etc/hosts') + host_resolver = twisted.names.hosts.Resolver("/etc/hosts") cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor) - real_resolver = twisted.names.client.Resolver('/etc/resolv.conf', - reactor=self.reactor) + real_resolver = twisted.names.client.Resolver( + "/etc/resolv.conf", reactor=self.reactor + ) self.resolver = twisted.names.resolve.ResolverChain( - [host_resolver, cache_resolver, real_resolver]) + [host_resolver, cache_resolver, real_resolver] + ) @gen.coroutine def resolve( - self, host: str, port: int, family: int=0, - ) -> 'Generator[Any, Any, List[Tuple[int, Any]]]': + self, host: str, port: int, family: int = 0 + ) -> "Generator[Any, Any, List[Tuple[int, Any]]]": # getHostByName doesn't accept IP addresses, so if the input # looks like an IP address just return it immediately. if twisted.internet.abstract.isIPAddress(host): @@ -102,15 +106,15 @@ class TwistedResolver(Resolver): else: resolved_family = socket.AF_UNSPEC if family != socket.AF_UNSPEC and family != resolved_family: - raise Exception('Requested socket family %d but got %d' % - (family, resolved_family)) - result = [ - (typing.cast(int, resolved_family), (resolved, port)), - ] + raise Exception( + "Requested socket family %d but got %d" % (family, resolved_family) + ) + result = [(typing.cast(int, resolved_family), (resolved, port))] return result -if hasattr(gen.convert_yielded, 'register'): +if hasattr(gen.convert_yielded, "register"): + @gen.convert_yielded.register(Deferred) # type: ignore def _(d: Deferred) -> Future: f = Future() # type: Future[Any] @@ -122,5 +126,6 @@ if hasattr(gen.convert_yielded, 'register'): raise Exception("errback called without error") except: future_set_exc_info(f, sys.exc_info()) + d.addCallbacks(f.set_result, errback) return f diff --git a/tornado/platform/windows.py b/tornado/platform/windows.py index 6d6ebaf04..82f0118cb 100644 --- a/tornado/platform/windows.py +++ b/tornado/platform/windows.py @@ -6,7 +6,11 @@ import ctypes.wintypes # See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation # type: ignore -SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) # noqa: E501 +SetHandleInformation.argtypes = ( + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, +) # noqa: E501 SetHandleInformation.restype = ctypes.wintypes.BOOL HANDLE_FLAG_INHERIT = 0x00000001 diff --git a/tornado/process.py b/tornado/process.py index 5df6aa7ad..9e1a64500 100644 --- a/tornado/process.py +++ b/tornado/process.py @@ -36,6 +36,7 @@ from tornado.util import errno_from_exception import typing from typing import Tuple, Optional, Any, Callable + if typing.TYPE_CHECKING: from typing import List # noqa: F401 @@ -60,9 +61,10 @@ def cpu_count() -> int: def _reseed_random() -> None: - if 'random' not in sys.modules: + if "random" not in sys.modules: return import random + # If os.urandom is available, this method does the same thing as # random.seed (at least as of python 2.6). If os.urandom is not # available, we mix in the pid in addition to a timestamp. @@ -83,7 +85,7 @@ def _pipe_cloexec() -> Tuple[int, int]: _task_id = None -def fork_processes(num_processes: Optional[int], max_restarts: int=100) -> int: +def fork_processes(num_processes: Optional[int], max_restarts: int = 100) -> int: """Starts multiple worker processes. If ``num_processes`` is None or <= 0, we detect the number of cores @@ -143,11 +145,19 @@ def fork_processes(num_processes: Optional[int], max_restarts: int=100) -> int: continue id = children.pop(pid) if os.WIFSIGNALED(status): - gen_log.warning("child %d (pid %d) killed by signal %d, restarting", - id, pid, os.WTERMSIG(status)) + gen_log.warning( + "child %d (pid %d) killed by signal %d, restarting", + id, + pid, + os.WTERMSIG(status), + ) elif os.WEXITSTATUS(status) != 0: - gen_log.warning("child %d (pid %d) exited with status %d, restarting", - id, pid, os.WEXITSTATUS(status)) + gen_log.warning( + "child %d (pid %d) exited with status %d, restarting", + id, + pid, + os.WEXITSTATUS(status), + ) else: gen_log.info("child %d (pid %d) exited normally", id, pid) continue @@ -194,6 +204,7 @@ class Subprocess(object): The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ + STREAM = object() _initialized = False @@ -206,21 +217,21 @@ class Subprocess(object): # should be closed in the parent process on success. pipe_fds = [] # type: List[int] to_close = [] # type: List[int] - if kwargs.get('stdin') is Subprocess.STREAM: + if kwargs.get("stdin") is Subprocess.STREAM: in_r, in_w = _pipe_cloexec() - kwargs['stdin'] = in_r + kwargs["stdin"] = in_r pipe_fds.extend((in_r, in_w)) to_close.append(in_r) self.stdin = PipeIOStream(in_w) - if kwargs.get('stdout') is Subprocess.STREAM: + if kwargs.get("stdout") is Subprocess.STREAM: out_r, out_w = _pipe_cloexec() - kwargs['stdout'] = out_w + kwargs["stdout"] = out_w pipe_fds.extend((out_r, out_w)) to_close.append(out_w) self.stdout = PipeIOStream(out_r) - if kwargs.get('stderr') is Subprocess.STREAM: + if kwargs.get("stderr") is Subprocess.STREAM: err_r, err_w = _pipe_cloexec() - kwargs['stderr'] = err_w + kwargs["stderr"] = err_w pipe_fds.extend((err_r, err_w)) to_close.append(err_w) self.stderr = PipeIOStream(err_r) @@ -233,7 +244,7 @@ class Subprocess(object): for fd in to_close: os.close(fd) self.pid = self.proc.pid - for attr in ['stdin', 'stdout', 'stderr']: + for attr in ["stdin", "stdout", "stderr"]: if not hasattr(self, attr): # don't clobber streams set above setattr(self, attr, getattr(self.proc, attr)) self._exit_callback = None # type: Optional[Callable[[int], None]] @@ -259,7 +270,7 @@ class Subprocess(object): Subprocess._waiting[self.pid] = self Subprocess._try_cleanup_process(self.pid) - def wait_for_exit(self, raise_error: bool=True) -> 'Future[int]': + def wait_for_exit(self, raise_error: bool = True) -> "Future[int]": """Returns a `.Future` which resolves when the process exits. Usage:: @@ -280,9 +291,10 @@ class Subprocess(object): def callback(ret: int) -> None: if ret != 0 and raise_error: # Unfortunately we don't have the original args any more. - future.set_exception(CalledProcessError(ret, 'unknown')) + future.set_exception(CalledProcessError(ret, "unknown")) else: future_set_result_unless_cancelled(future, ret) + self.set_exit_callback(callback) return future @@ -304,7 +316,8 @@ class Subprocess(object): io_loop = ioloop.IOLoop.current() cls._old_sigchld = signal.signal( signal.SIGCHLD, - lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup)) + lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup), + ) cls._initialized = True @classmethod @@ -331,8 +344,7 @@ class Subprocess(object): return assert ret_pid == pid subproc = cls._waiting.pop(pid) - subproc.io_loop.add_callback_from_signal( - subproc._set_returncode, status) + subproc.io_loop.add_callback_from_signal(subproc._set_returncode, status) def _set_returncode(self, status: int) -> None: if os.WIFSIGNALED(status): diff --git a/tornado/queues.py b/tornado/queues.py index 49fa9e602..6c4fbd755 100644 --- a/tornado/queues.py +++ b/tornado/queues.py @@ -35,37 +35,43 @@ from tornado.locks import Event from typing import Union, TypeVar, Generic, Awaitable import typing + if typing.TYPE_CHECKING: from typing import Deque, Tuple, List, Any # noqa: F401 -_T = TypeVar('_T') +_T = TypeVar("_T") -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] +__all__ = ["Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty"] class QueueEmpty(Exception): """Raised by `.Queue.get_nowait` when the queue has no items.""" + pass class QueueFull(Exception): """Raised by `.Queue.put_nowait` when a queue is at its maximum size.""" + pass -def _set_timeout(future: Future, timeout: Union[None, float, datetime.timedelta]) -> None: +def _set_timeout( + future: Future, timeout: Union[None, float, datetime.timedelta] +) -> None: if timeout: + def on_timeout() -> None: if not future.done(): future.set_exception(gen.TimeoutError()) + io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) - future.add_done_callback( - lambda _: io_loop.remove_timeout(timeout_handle)) + future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) class _QueueIterator(Generic[_T]): - def __init__(self, q: 'Queue[_T]') -> None: + def __init__(self, q: "Queue[_T]") -> None: self.q = q def __anext__(self) -> Awaitable[_T]: @@ -139,11 +145,12 @@ class Queue(Generic[_T]): Added ``async for`` support in Python 3.5. """ + # Exact type depends on subclass. Could be another generic # parameter and use protocols to be more precise here. _queue = None # type: Any - def __init__(self, maxsize: int=0) -> None: + def __init__(self, maxsize: int = 0) -> None: if maxsize is None: raise TypeError("maxsize can't be None") @@ -176,7 +183,9 @@ class Queue(Generic[_T]): else: return self.qsize() >= self.maxsize - def put(self, item: _T, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]': + def put( + self, item: _T, timeout: Union[float, datetime.timedelta] = None + ) -> "Future[None]": """Put an item into the queue, perhaps waiting until there is room. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -213,7 +222,7 @@ class Queue(Generic[_T]): else: self.__put_internal(item) - def get(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[_T]': + def get(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[_T]": """Remove and return an item from the queue. Returns a Future which resolves once an item is available, or raises @@ -263,12 +272,12 @@ class Queue(Generic[_T]): Raises `ValueError` if called more times than `.put`. """ if self._unfinished_tasks <= 0: - raise ValueError('task_done() called too many times') + raise ValueError("task_done() called too many times") self._unfinished_tasks -= 1 if self._unfinished_tasks == 0: self._finished.set() - def join(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]': + def join(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[None]": """Block until all items in the queue are processed. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -288,6 +297,7 @@ class Queue(Generic[_T]): def _put(self, item: _T) -> None: self._queue.append(item) + # End of the overridable methods. def __put_internal(self, item: _T) -> None: @@ -304,22 +314,21 @@ class Queue(Generic[_T]): self._getters.popleft() def __repr__(self) -> str: - return '<%s at %s %s>' % ( - type(self).__name__, hex(id(self)), self._format()) + return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format()) def __str__(self) -> str: - return '<%s %s>' % (type(self).__name__, self._format()) + return "<%s %s>" % (type(self).__name__, self._format()) def _format(self) -> str: - result = 'maxsize=%r' % (self.maxsize, ) - if getattr(self, '_queue', None): - result += ' queue=%r' % self._queue + result = "maxsize=%r" % (self.maxsize,) + if getattr(self, "_queue", None): + result += " queue=%r" % self._queue if self._getters: - result += ' getters[%s]' % len(self._getters) + result += " getters[%s]" % len(self._getters) if self._putters: - result += ' putters[%s]' % len(self._putters) + result += " putters[%s]" % len(self._putters) if self._unfinished_tasks: - result += ' tasks=%s' % self._unfinished_tasks + result += " tasks=%s" % self._unfinished_tasks return result @@ -347,6 +356,7 @@ class PriorityQueue(Queue): (1, 'medium-priority item') (10, 'low-priority item') """ + def _init(self) -> None: self._queue = [] @@ -379,6 +389,7 @@ class LifoQueue(Queue): 2 3 """ + def _init(self) -> None: self._queue = [] diff --git a/tornado/routing.py b/tornado/routing.py index 973b30473..ae7abcdcf 100644 --- a/tornado/routing.py +++ b/tornado/routing.py @@ -190,8 +190,9 @@ from typing import Any, Union, Optional, Awaitable, List, Dict, Pattern, Tuple, class Router(httputil.HTTPServerConnectionDelegate): """Abstract router interface.""" - def find_handler(self, request: httputil.HTTPServerRequest, - **kwargs: Any) -> Optional[httputil.HTTPMessageDelegate]: + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> Optional[httputil.HTTPMessageDelegate]: """Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate` that can serve the request. Routing implementations may pass additional kwargs to extend the routing logic. @@ -203,8 +204,9 @@ class Router(httputil.HTTPServerConnectionDelegate): """ raise NotImplementedError() - def start_request(self, server_conn: object, - request_conn: httputil.HTTPConnection) -> httputil.HTTPMessageDelegate: + def start_request( + self, server_conn: object, request_conn: httputil.HTTPConnection + ) -> httputil.HTTPMessageDelegate: return _RoutingDelegate(self, server_conn, request_conn) @@ -225,26 +227,34 @@ class ReversibleRouter(Router): class _RoutingDelegate(httputil.HTTPMessageDelegate): - def __init__(self, router: Router, server_conn: object, - request_conn: httputil.HTTPConnection) -> None: + def __init__( + self, router: Router, server_conn: object, request_conn: httputil.HTTPConnection + ) -> None: self.server_conn = server_conn self.request_conn = request_conn self.delegate = None # type: Optional[httputil.HTTPMessageDelegate] self.router = router # type: Router - def headers_received(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: assert isinstance(start_line, httputil.RequestStartLine) request = httputil.HTTPServerRequest( connection=self.request_conn, server_connection=self.server_conn, - start_line=start_line, headers=headers) + start_line=start_line, + headers=headers, + ) self.delegate = self.router.find_handler(request) if self.delegate is None: - app_log.debug("Delegate for %s %s request not found", - start_line.method, start_line.path) + app_log.debug( + "Delegate for %s %s request not found", + start_line.method, + start_line.path, + ) self.delegate = _DefaultMessageDelegate(self.request_conn) return self.delegate.headers_received(start_line, headers) @@ -268,23 +278,29 @@ class _DefaultMessageDelegate(httputil.HTTPMessageDelegate): def finish(self) -> None: self.connection.write_headers( - httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders()) + httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), + httputil.HTTPHeaders(), + ) self.connection.finish() # _RuleList can either contain pre-constructed Rules or a sequence of # arguments to be passed to the Rule constructor. -_RuleList = List[Union['Rule', - List[Any], # Can't do detailed typechecking of lists. - Tuple[Union[str, 'Matcher'], Any], - Tuple[Union[str, 'Matcher'], Any, Dict[str, Any]], - Tuple[Union[str, 'Matcher'], Any, Dict[str, Any], str]]] +_RuleList = List[ + Union[ + "Rule", + List[Any], # Can't do detailed typechecking of lists. + Tuple[Union[str, "Matcher"], Any], + Tuple[Union[str, "Matcher"], Any, Dict[str, Any]], + Tuple[Union[str, "Matcher"], Any, Dict[str, Any], str], + ] +] class RuleRouter(Router): """Rule-based router implementation.""" - def __init__(self, rules: _RuleList=None) -> None: + def __init__(self, rules: _RuleList = None) -> None: """Constructs a router from an ordered list of rules:: RuleRouter([ @@ -331,7 +347,7 @@ class RuleRouter(Router): self.rules.append(self.process_rule(rule)) - def process_rule(self, rule: 'Rule') -> 'Rule': + def process_rule(self, rule: "Rule") -> "Rule": """Override this method for additional preprocessing of each rule. :arg Rule rule: a rule to be processed. @@ -339,24 +355,27 @@ class RuleRouter(Router): """ return rule - def find_handler(self, request: httputil.HTTPServerRequest, - **kwargs: Any) -> Optional[httputil.HTTPMessageDelegate]: + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> Optional[httputil.HTTPMessageDelegate]: for rule in self.rules: target_params = rule.matcher.match(request) if target_params is not None: if rule.target_kwargs: - target_params['target_kwargs'] = rule.target_kwargs + target_params["target_kwargs"] = rule.target_kwargs delegate = self.get_target_delegate( - rule.target, request, **target_params) + rule.target, request, **target_params + ) if delegate is not None: return delegate return None - def get_target_delegate(self, target: Any, request: httputil.HTTPServerRequest, - **target_params: Any) -> Optional[httputil.HTTPMessageDelegate]: + def get_target_delegate( + self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any + ) -> Optional[httputil.HTTPMessageDelegate]: """Returns an instance of `~.httputil.HTTPMessageDelegate` for a Rule's target. This method is called by `~.find_handler` and can be extended to provide additional target types. @@ -390,18 +409,18 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter): in a rule's matcher (see `Matcher.reverse`). """ - def __init__(self, rules: _RuleList=None) -> None: + def __init__(self, rules: _RuleList = None) -> None: self.named_rules = {} # type: Dict[str, Any] super(ReversibleRuleRouter, self).__init__(rules) - def process_rule(self, rule: 'Rule') -> 'Rule': + def process_rule(self, rule: "Rule") -> "Rule": rule = super(ReversibleRuleRouter, self).process_rule(rule) if rule.name: if rule.name in self.named_rules: app_log.warning( - "Multiple handlers named %s; replacing previous value", - rule.name) + "Multiple handlers named %s; replacing previous value", rule.name + ) self.named_rules[rule.name] = rule return rule @@ -422,8 +441,13 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter): class Rule(object): """A routing rule.""" - def __init__(self, matcher: 'Matcher', target: Any, - target_kwargs: Dict[str, Any]=None, name: str=None) -> None: + def __init__( + self, + matcher: "Matcher", + target: Any, + target_kwargs: Dict[str, Any] = None, + name: str = None, + ) -> None: """Constructs a Rule instance. :arg Matcher matcher: a `Matcher` instance used for determining @@ -454,9 +478,13 @@ class Rule(object): return self.matcher.reverse(*args) def __repr__(self) -> str: - return '%s(%r, %s, kwargs=%r, name=%r)' % \ - (self.__class__.__name__, self.matcher, - self.target, self.target_kwargs, self.name) + return "%s(%r, %s, kwargs=%r, name=%r)" % ( + self.__class__.__name__, + self.matcher, + self.target, + self.target_kwargs, + self.name, + ) class Matcher(object): @@ -526,15 +554,16 @@ class PathMatches(Matcher): def __init__(self, path_pattern: Union[str, Pattern]) -> None: if isinstance(path_pattern, basestring_type): - if not path_pattern.endswith('$'): - path_pattern += '$' + if not path_pattern.endswith("$"): + path_pattern += "$" self.regex = re.compile(path_pattern) else: self.regex = path_pattern - assert len(self.regex.groupindex) in (0, self.regex.groups), \ - ("groups in url regexes must either be all named or all " - "positional: %r" % self.regex.pattern) + assert len(self.regex.groupindex) in (0, self.regex.groups), ( + "groups in url regexes must either be all named or all " + "positional: %r" % self.regex.pattern + ) self._path, self._group_count = self._find_groups() @@ -554,8 +583,8 @@ class PathMatches(Matcher): # or groupdict but not both. if self.regex.groupindex: path_kwargs = dict( - (str(k), _unquote_or_none(v)) - for (k, v) in match.groupdict().items()) + (str(k), _unquote_or_none(v)) for (k, v) in match.groupdict().items() + ) else: path_args = [_unquote_or_none(s) for s in match.groups()] @@ -564,8 +593,9 @@ class PathMatches(Matcher): def reverse(self, *args: Any) -> Optional[str]: if self._path is None: raise ValueError("Cannot reverse url regex " + self.regex.pattern) - assert len(args) == self._group_count, "required number of arguments " \ - "not found" + assert len(args) == self._group_count, ( + "required number of arguments " "not found" + ) if not len(args): return self._path converted_args = [] @@ -582,22 +612,22 @@ class PathMatches(Matcher): would return ('/%s/%s/', 2). """ pattern = self.regex.pattern - if pattern.startswith('^'): + if pattern.startswith("^"): pattern = pattern[1:] - if pattern.endswith('$'): + if pattern.endswith("$"): pattern = pattern[:-1] - if self.regex.groups != pattern.count('('): + if self.regex.groups != pattern.count("("): # The pattern is too complicated for our simplistic matching, # so we can't support reversing it. return None, None pieces = [] - for fragment in pattern.split('('): - if ')' in fragment: - paren_loc = fragment.index(')') + for fragment in pattern.split("("): + if ")" in fragment: + paren_loc = fragment.index(")") if paren_loc >= 0: - pieces.append('%s' + fragment[paren_loc + 1:]) + pieces.append("%s" + fragment[paren_loc + 1 :]) else: try: unescaped_fragment = re_unescape(fragment) @@ -607,7 +637,7 @@ class PathMatches(Matcher): return (None, None) pieces.append(unescaped_fragment) - return ''.join(pieces), self.regex.groups + return "".join(pieces), self.regex.groups class URLSpec(Rule): @@ -617,8 +647,14 @@ class URLSpec(Rule): `URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for backwards compatibility. """ - def __init__(self, pattern: Union[str, Pattern], handler: Any, - kwargs: Dict[str, Any]=None, name: str=None) -> None: + + def __init__( + self, + pattern: Union[str, Pattern], + handler: Any, + kwargs: Dict[str, Any] = None, + name: str = None, + ) -> None: """Parameters: * ``pattern``: Regular expression to be matched. Any capturing @@ -644,9 +680,13 @@ class URLSpec(Rule): self.kwargs = kwargs def __repr__(self) -> str: - return '%s(%r, %s, kwargs=%r, name=%r)' % \ - (self.__class__.__name__, self.regex.pattern, - self.handler_class, self.kwargs, self.name) + return "%s(%r, %s, kwargs=%r, name=%r)" % ( + self.__class__.__name__, + self.regex.pattern, + self.handler_class, + self.kwargs, + self.name, + ) @overload diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 15ab6e10c..e8474cb9d 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -1,7 +1,13 @@ from tornado.escape import _unicode from tornado import gen -from tornado.httpclient import (HTTPResponse, HTTPError, AsyncHTTPClient, main, - _RequestProxy, HTTPRequest) +from tornado.httpclient import ( + HTTPResponse, + HTTPError, + AsyncHTTPClient, + main, + _RequestProxy, + HTTPRequest, +) from tornado import httputil from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters from tornado.ioloop import IOLoop @@ -25,6 +31,7 @@ import urllib.parse from typing import Dict, Any, Generator, Callable, Optional, Type, Union from types import TracebackType import typing + if typing.TYPE_CHECKING: from typing import Deque, Tuple, List # noqa: F401 @@ -37,6 +44,7 @@ class HTTPTimeoutError(HTTPError): .. versionadded:: 5.1 """ + def __init__(self, message: str) -> None: super(HTTPTimeoutError, self).__init__(599, message=message) @@ -55,6 +63,7 @@ class HTTPStreamClosedError(HTTPError): .. versionadded:: 5.1 """ + def __init__(self, message: str) -> None: super(HTTPStreamClosedError, self).__init__(599, message=message) @@ -71,10 +80,17 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): are not reused, and callers cannot select the network interface to be used. """ - def initialize(self, max_clients: int=10, # type: ignore - hostname_mapping: Dict[str, str]=None, max_buffer_size: int=104857600, - resolver: Resolver=None, defaults: Dict[str, Any]=None, - max_header_size: int=None, max_body_size: int=None) -> None: + + def initialize( # type: ignore + self, + max_clients: int = 10, + hostname_mapping: Dict[str, str] = None, + max_buffer_size: int = 104857600, + resolver: Resolver = None, + defaults: Dict[str, Any] = None, + max_header_size: int = None, + max_body_size: int = None, + ) -> None: """Creates a AsyncHTTPClient. Only a single AsyncHTTPClient instance exists per IOLoop @@ -109,11 +125,15 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): """ super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults) self.max_clients = max_clients - self.queue = collections.deque() \ - # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]] - self.active = {} # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]] - self.waiting = {} \ - # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]] + self.queue = ( + collections.deque() + ) # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]] + self.active = ( + {} + ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]] + self.waiting = ( + {} + ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]] self.max_buffer_size = max_buffer_size self.max_header_size = max_header_size self.max_body_size = max_body_size @@ -126,8 +146,9 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.resolver = Resolver() self.own_resolver = True if hostname_mapping is not None: - self.resolver = OverrideResolver(resolver=self.resolver, - mapping=hostname_mapping) + self.resolver = OverrideResolver( + resolver=self.resolver, mapping=hostname_mapping + ) self.tcp_client = TCPClient(resolver=self.resolver) def close(self) -> None: @@ -136,24 +157,28 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.resolver.close() self.tcp_client.close() - def fetch_impl(self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]) -> None: + def fetch_impl( + self, request: HTTPRequest, callback: Callable[[HTTPResponse], None] + ) -> None: key = object() self.queue.append((key, request, callback)) if not len(self.active) < self.max_clients: assert request.connect_timeout is not None assert request.request_timeout is not None timeout_handle = self.io_loop.add_timeout( - self.io_loop.time() + min(request.connect_timeout, - request.request_timeout), - functools.partial(self._on_timeout, key, "in request queue")) + self.io_loop.time() + + min(request.connect_timeout, request.request_timeout), + functools.partial(self._on_timeout, key, "in request queue"), + ) else: timeout_handle = None self.waiting[key] = (request, callback, timeout_handle) self._process_queue() if self.queue: - gen_log.debug("max_clients limit reached, request queued. " - "%d active, %d queued requests." % ( - len(self.active), len(self.queue))) + gen_log.debug( + "max_clients limit reached, request queued. " + "%d active, %d queued requests." % (len(self.active), len(self.queue)) + ) def _process_queue(self) -> None: while self.queue and len(self.active) < self.max_clients: @@ -168,12 +193,22 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): def _connection_class(self) -> type: return _HTTPConnection - def _handle_request(self, request: HTTPRequest, release_callback: Callable[[], None], - final_callback: Callable[[HTTPResponse], None]) -> None: + def _handle_request( + self, + request: HTTPRequest, + release_callback: Callable[[], None], + final_callback: Callable[[HTTPResponse], None], + ) -> None: self._connection_class()( - self, request, release_callback, - final_callback, self.max_buffer_size, self.tcp_client, - self.max_header_size, self.max_body_size) + self, + request, + release_callback, + final_callback, + self.max_buffer_size, + self.tcp_client, + self.max_header_size, + self.max_body_size, + ) def _release_fetch(self, key: object) -> None: del self.active[key] @@ -186,7 +221,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.io_loop.remove_timeout(timeout_handle) del self.waiting[key] - def _on_timeout(self, key: object, info: str=None) -> None: + def _on_timeout(self, key: object, info: str = None) -> None: """Timeout callback of request. Construct a timeout HTTPResponse when a timeout occurs. @@ -199,19 +234,31 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): error_message = "Timeout {0}".format(info) if info else "Timeout" timeout_response = HTTPResponse( - request, 599, error=HTTPTimeoutError(error_message), - request_time=self.io_loop.time() - request.start_time) + request, + 599, + error=HTTPTimeoutError(error_message), + request_time=self.io_loop.time() - request.start_time, + ) self.io_loop.add_callback(callback, timeout_response) del self.waiting[key] class _HTTPConnection(httputil.HTTPMessageDelegate): - _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) - - def __init__(self, client: Optional[SimpleAsyncHTTPClient], request: HTTPRequest, - release_callback: Callable[[], None], - final_callback: Callable[[HTTPResponse], None], max_buffer_size: int, - tcp_client: TCPClient, max_header_size: int, max_body_size: int) -> None: + _SUPPORTED_METHODS = set( + ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"] + ) + + def __init__( + self, + client: Optional[SimpleAsyncHTTPClient], + request: HTTPRequest, + release_callback: Callable[[], None], + final_callback: Callable[[HTTPResponse], None], + max_buffer_size: int, + tcp_client: TCPClient, + max_header_size: int, + max_body_size: int, + ) -> None: self.io_loop = IOLoop.current() self.start_time = self.io_loop.time() self.start_wall_time = time.time() @@ -237,8 +284,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): try: self.parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): - raise ValueError("Unsupported url scheme: %s" % - self.request.url) + raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = self.parsed.netloc @@ -247,7 +293,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): host, port = httputil.split_host_and_port(netloc) if port is None: port = 443 if self.parsed.scheme == "https" else 80 - if re.match(r'^\[.*\]$', host): + if re.match(r"^\[.*\]$", host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect @@ -263,11 +309,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, - functools.partial(self._on_timeout, "while connecting")) + functools.partial(self._on_timeout, "while connecting"), + ) stream = yield self.tcp_client.connect( - host, port, af=af, + host, + port, + af=af, ssl_options=ssl_options, - max_buffer_size=self.max_buffer_size) + max_buffer_size=self.max_buffer_size, + ) if self.final_callback is None: # final_callback is cleared if we've hit our timeout. @@ -281,21 +331,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, - functools.partial(self._on_timeout, "during request")) - if (self.request.method not in self._SUPPORTED_METHODS and - not self.request.allow_nonstandard_methods): + functools.partial(self._on_timeout, "during request"), + ) + if ( + self.request.method not in self._SUPPORTED_METHODS + and not self.request.allow_nonstandard_methods + ): raise KeyError("unknown method %s" % self.request.method) - for key in ('network_interface', - 'proxy_host', 'proxy_port', - 'proxy_username', 'proxy_password', - 'proxy_auth_mode'): + for key in ( + "network_interface", + "proxy_host", + "proxy_port", + "proxy_username", + "proxy_password", + "proxy_auth_mode", + ): if getattr(self.request, key, None): - raise NotImplementedError('%s not supported' % key) + raise NotImplementedError("%s not supported" % key) if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: - if '@' in self.parsed.netloc: - self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1] + if "@" in self.parsed.netloc: + self.request.headers["Host"] = self.parsed.netloc.rpartition( + "@" + )[-1] else: self.request.headers["Host"] = self.parsed.netloc username, password = None, None @@ -303,15 +362,18 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): username, password = self.parsed.username, self.parsed.password elif self.request.auth_username is not None: username = self.request.auth_username - password = self.request.auth_password or '' + password = self.request.auth_password or "" if username is not None: assert password is not None if self.request.auth_mode not in (None, "basic"): - raise ValueError("unsupported auth_mode %s", - self.request.auth_mode) - self.request.headers["Authorization"] = ( - "Basic " + _unicode(base64.b64encode( - httputil.encode_username_password(username, password)))) + raise ValueError( + "unsupported auth_mode %s", self.request.auth_mode + ) + self.request.headers["Authorization"] = "Basic " + _unicode( + base64.b64encode( + httputil.encode_username_password(username, password) + ) + ) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: @@ -319,31 +381,40 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): # almost never do. Fail in this case unless the user has # opted out of sanity checks with allow_nonstandard_methods. body_expected = self.request.method in ("POST", "PATCH", "PUT") - body_present = (self.request.body is not None or - self.request.body_producer is not None) - if ((body_expected and not body_present) or - (body_present and not body_expected)): + body_present = ( + self.request.body is not None + or self.request.body_producer is not None + ) + if (body_expected and not body_present) or ( + body_present and not body_expected + ): raise ValueError( - 'Body must %sbe None for method %s (unless ' - 'allow_nonstandard_methods is true)' % - ('not ' if body_expected else '', self.request.method)) + "Body must %sbe None for method %s (unless " + "allow_nonstandard_methods is true)" + % ("not " if body_expected else "", self.request.method) + ) if self.request.expect_100_continue: self.request.headers["Expect"] = "100-continue" if self.request.body is not None: # When body_producer is used the caller is responsible for # setting Content-Length (or else chunked encoding will be used). - self.request.headers["Content-Length"] = str(len( - self.request.body)) - if (self.request.method == "POST" and - "Content-Type" not in self.request.headers): - self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" + self.request.headers["Content-Length"] = str(len(self.request.body)) + if ( + self.request.method == "POST" + and "Content-Type" not in self.request.headers + ): + self.request.headers[ + "Content-Type" + ] = "application/x-www-form-urlencoded" if self.request.decompress_response: self.request.headers["Accept-Encoding"] = "gzip" - req_path = ((self.parsed.path or '/') + - (('?' + self.parsed.query) if self.parsed.query else '')) + req_path = (self.parsed.path or "/") + ( + ("?" + self.parsed.query) if self.parsed.query else "" + ) self.connection = self._create_connection(stream) - start_line = httputil.RequestStartLine(self.request.method, - req_path, '') + start_line = httputil.RequestStartLine( + self.request.method, req_path, "" + ) self.connection.write_headers(start_line, self.request.headers) if self.request.expect_100_continue: yield self.connection.read_response(self) @@ -353,33 +424,38 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): if not self._handle_exception(*sys.exc_info()): raise - def _get_ssl_options(self, scheme: str) -> Union[None, Dict[str, Any], ssl.SSLContext]: + def _get_ssl_options( + self, scheme: str + ) -> Union[None, Dict[str, Any], ssl.SSLContext]: if scheme == "https": if self.request.ssl_options is not None: return self.request.ssl_options # If we are using the defaults, don't construct a # new SSLContext. - if (self.request.validate_cert and - self.request.ca_certs is None and - self.request.client_cert is None and - self.request.client_key is None): + if ( + self.request.validate_cert + and self.request.ca_certs is None + and self.request.client_cert is None + and self.request.client_key is None + ): return _client_ssl_defaults ssl_ctx = ssl.create_default_context( - ssl.Purpose.SERVER_AUTH, - cafile=self.request.ca_certs) + ssl.Purpose.SERVER_AUTH, cafile=self.request.ca_certs + ) if not self.request.validate_cert: ssl_ctx.check_hostname = False ssl_ctx.verify_mode = ssl.CERT_NONE if self.request.client_cert is not None: - ssl_ctx.load_cert_chain(self.request.client_cert, - self.request.client_key) - if hasattr(ssl, 'OP_NO_COMPRESSION'): + ssl_ctx.load_cert_chain( + self.request.client_cert, self.request.client_key + ) + if hasattr(ssl, "OP_NO_COMPRESSION"): # See netutil.ssl_options_to_context ssl_ctx.options |= ssl.OP_NO_COMPRESSION return ssl_ctx return None - def _on_timeout(self, info: str=None) -> None: + def _on_timeout(self, info: str = None) -> None: """Timeout callback of _HTTPConnection instance. Raise a `HTTPTimeoutError` when a timeout occurs. @@ -389,8 +465,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self._timeout = None error_message = "Timeout {0}".format(info) if info else "Timeout" if self.final_callback is not None: - self._handle_exception(HTTPTimeoutError, HTTPTimeoutError(error_message), - None) + self._handle_exception( + HTTPTimeoutError, HTTPTimeoutError(error_message), None + ) def _remove_timeout(self) -> None: if self._timeout is not None: @@ -400,13 +477,16 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): def _create_connection(self, stream: IOStream) -> HTTP1Connection: stream.set_nodelay(True) connection = HTTP1Connection( - stream, True, + stream, + True, HTTP1ConnectionParameters( no_keep_alive=True, max_header_size=self.max_header_size, max_body_size=self.max_body_size, - decompress=bool(self.request.decompress_response)), - self._sockaddr) + decompress=bool(self.request.decompress_response), + ), + self._sockaddr, + ) return connection @gen.coroutine @@ -438,9 +518,12 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.final_callback = None # type: ignore self.io_loop.add_callback(final_callback, response) - def _handle_exception(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[TracebackType]) -> bool: + def _handle_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> bool: if self.final_callback: self._remove_timeout() if isinstance(value, StreamClosedError): @@ -448,10 +531,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): value = HTTPStreamClosedError("Stream closed") else: value = value.real_error - self._run_callback(HTTPResponse(self.request, 599, error=value, - request_time=self.io_loop.time() - self.start_time, - start_time=self.start_wall_time, - )) + self._run_callback( + HTTPResponse( + self.request, + 599, + error=value, + request_time=self.io_loop.time() - self.start_time, + start_time=self.start_wall_time, + ) + ) if hasattr(self, "stream"): # TODO: this may cause a StreamClosedError to be raised @@ -476,9 +564,11 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): except HTTPStreamClosedError: self._handle_exception(*sys.exc_info()) - def headers_received(self, first_line: Union[httputil.ResponseStartLine, - httputil.RequestStartLine], - headers: httputil.HTTPHeaders) -> None: + def headers_received( + self, + first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine], + headers: httputil.HTTPHeaders, + ) -> None: assert isinstance(first_line, httputil.ResponseStartLine) if self.request.expect_100_continue and first_line.code == 100: self._write_body(False) @@ -492,29 +582,31 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): if self.request.header_callback is not None: # Reassemble the start line. - self.request.header_callback('%s %s %s\r\n' % first_line) + self.request.header_callback("%s %s %s\r\n" % first_line) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) - self.request.header_callback('\r\n') + self.request.header_callback("\r\n") def _should_follow_redirect(self) -> bool: if self.request.follow_redirects: assert self.request.max_redirects is not None - return (self.code in (301, 302, 303, 307, 308) and - self.request.max_redirects > 0) + return ( + self.code in (301, 302, 303, 307, 308) + and self.request.max_redirects > 0 + ) return False def finish(self) -> None: assert self.code is not None - data = b''.join(self.chunks) + data = b"".join(self.chunks) self._remove_timeout() - original_request = getattr(self.request, "original_request", - self.request) + original_request = getattr(self.request, "original_request", self.request) if self._should_follow_redirect(): assert isinstance(self.request, _RequestProxy) new_request = copy.copy(self.request.request) - new_request.url = urllib.parse.urljoin(self.request.url, - self.headers["Location"]) + new_request.url = urllib.parse.urljoin( + self.request.url, self.headers["Location"] + ) new_request.max_redirects = self.request.max_redirects - 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 @@ -527,8 +619,12 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): if self.code in (302, 303): new_request.method = "GET" new_request.body = None - for h in ["Content-Length", "Content-Type", - "Content-Encoding", "Transfer-Encoding"]: + for h in [ + "Content-Length", + "Content-Type", + "Content-Encoding", + "Transfer-Encoding", + ]: try: del self.request.headers[h] except KeyError: @@ -545,13 +641,16 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? - response = HTTPResponse(original_request, - self.code, reason=getattr(self, 'reason', None), - headers=self.headers, - request_time=self.io_loop.time() - self.start_time, - start_time=self.start_wall_time, - buffer=buffer, - effective_url=self.request.url) + response = HTTPResponse( + original_request, + self.code, + reason=getattr(self, "reason", None), + headers=self.headers, + request_time=self.io_loop.time() - self.start_time, + start_time=self.start_wall_time, + buffer=buffer, + effective_url=self.request.url, + ) self._run_callback(response) self._on_end_request() diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 902280ddf..5c2560d2c 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -32,6 +32,7 @@ from tornado.gen import TimeoutError import typing from typing import Generator, Any, Union, Dict, Tuple, List, Callable, Iterator + if typing.TYPE_CHECKING: from typing import Optional, Set # noqa: F401 @@ -55,13 +56,20 @@ class _Connector(object): http://tools.ietf.org/html/rfc6555 """ - def __init__(self, addrinfo: List[Tuple], - connect: Callable[[socket.AddressFamily, Tuple], - Tuple[IOStream, 'Future[IOStream]']]) -> None: + + def __init__( + self, + addrinfo: List[Tuple], + connect: Callable[ + [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"] + ], + ) -> None: self.io_loop = IOLoop.current() self.connect = connect - self.future = Future() # type: Future[Tuple[socket.AddressFamily, Any, IOStream]] + self.future = ( + Future() + ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]] self.timeout = None # type: Optional[object] self.connect_timeout = None # type: Optional[object] self.last_error = None # type: Optional[Exception] @@ -70,8 +78,12 @@ class _Connector(object): self.streams = set() # type: Set[IOStream] @staticmethod - def split(addrinfo: List[Tuple]) -> Tuple[List[Tuple[socket.AddressFamily, Tuple]], - List[Tuple[socket.AddressFamily, Tuple]]]: + def split( + addrinfo: List[Tuple] + ) -> Tuple[ + List[Tuple[socket.AddressFamily, Tuple]], + List[Tuple[socket.AddressFamily, Tuple]], + ]: """Partition the ``addrinfo`` list by address family. Returns two lists. The first list contains the first entry from @@ -91,9 +103,10 @@ class _Connector(object): return primary, secondary def start( - self, timeout: float=_INITIAL_CONNECT_TIMEOUT, - connect_timeout: Union[float, datetime.timedelta]=None, - ) -> 'Future[Tuple[socket.AddressFamily, Any, IOStream]]': + self, + timeout: float = _INITIAL_CONNECT_TIMEOUT, + connect_timeout: Union[float, datetime.timedelta] = None, + ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]": self.try_connect(iter(self.primary_addrs)) self.set_timeout(timeout) if connect_timeout is not None: @@ -108,17 +121,23 @@ class _Connector(object): # might still be working. Send a final error on the future # only when both queues are finished. if self.remaining == 0 and not self.future.done(): - self.future.set_exception(self.last_error or - IOError("connection failed")) + self.future.set_exception( + self.last_error or IOError("connection failed") + ) return stream, future = self.connect(af, addr) self.streams.add(stream) future_add_done_callback( - future, functools.partial(self.on_connect_done, addrs, af, addr)) + future, functools.partial(self.on_connect_done, addrs, af, addr) + ) - def on_connect_done(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]], - af: socket.AddressFamily, addr: Tuple, - future: 'Future[IOStream]') -> None: + def on_connect_done( + self, + addrs: Iterator[Tuple[socket.AddressFamily, Tuple]], + af: socket.AddressFamily, + addr: Tuple, + future: "Future[IOStream]", + ) -> None: self.remaining -= 1 try: stream = future.result() @@ -145,8 +164,9 @@ class _Connector(object): self.close_streams() def set_timeout(self, timeout: float) -> None: - self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, - self.on_timeout) + self.timeout = self.io_loop.add_timeout( + self.io_loop.time() + timeout, self.on_timeout + ) def on_timeout(self) -> None: self.timeout = None @@ -157,9 +177,12 @@ class _Connector(object): if self.timeout is not None: self.io_loop.remove_timeout(self.timeout) - def set_connect_timeout(self, connect_timeout: Union[float, datetime.timedelta]) -> None: + def set_connect_timeout( + self, connect_timeout: Union[float, datetime.timedelta] + ) -> None: self.connect_timeout = self.io_loop.add_timeout( - connect_timeout, self.on_connect_timeout) + connect_timeout, self.on_connect_timeout + ) def on_connect_timeout(self) -> None: if not self.future.done(): @@ -183,7 +206,8 @@ class TCPClient(object): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ - def __init__(self, resolver: Resolver=None) -> None: + + def __init__(self, resolver: Resolver = None) -> None: if resolver is not None: self.resolver = resolver self._own_resolver = False @@ -196,10 +220,17 @@ class TCPClient(object): self.resolver.close() @gen.coroutine - def connect(self, host: str, port: int, af: socket.AddressFamily=socket.AF_UNSPEC, - ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, - max_buffer_size: int=None, source_ip: str=None, source_port: int=None, - timeout: Union[float, datetime.timedelta]=None) -> Generator[Any, Any, IOStream]: + def connect( + self, + host: str, + port: int, + af: socket.AddressFamily = socket.AF_UNSPEC, + ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, + max_buffer_size: int = None, + source_ip: str = None, + source_port: int = None, + timeout: Union[float, datetime.timedelta] = None, + ) -> Generator[Any, Any, IOStream]: """Connect to the given host and port. Asynchronously returns an `.IOStream` (or `.SSLIOStream` if @@ -234,13 +265,18 @@ class TCPClient(object): raise TypeError("Unsupported timeout %r" % timeout) if timeout is not None: addrinfo = yield gen.with_timeout( - timeout, self.resolver.resolve(host, port, af)) + timeout, self.resolver.resolve(host, port, af) + ) else: addrinfo = yield self.resolver.resolve(host, port, af) connector = _Connector( addrinfo, - functools.partial(self._create_stream, max_buffer_size, - source_ip=source_ip, source_port=source_port) + functools.partial( + self._create_stream, + max_buffer_size, + source_ip=source_ip, + source_port=source_port, + ), ) af, addr, stream = yield connector.start(connect_timeout=timeout) # TODO: For better performance we could cache the (af, addr) @@ -248,16 +284,26 @@ class TCPClient(object): # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) if ssl_options is not None: if timeout is not None: - stream = yield gen.with_timeout(timeout, stream.start_tls( - False, ssl_options=ssl_options, server_hostname=host)) + stream = yield gen.with_timeout( + timeout, + stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host + ), + ) else: - stream = yield stream.start_tls(False, ssl_options=ssl_options, - server_hostname=host) + stream = yield stream.start_tls( + False, ssl_options=ssl_options, server_hostname=host + ) return stream - def _create_stream(self, max_buffer_size: int, af: socket.AddressFamily, - addr: Tuple, source_ip: str=None, - source_port: int=None) -> Tuple[IOStream, 'Future[IOStream]']: + def _create_stream( + self, + max_buffer_size: int, + af: socket.AddressFamily, + addr: Tuple, + source_ip: str = None, + source_port: int = None, + ) -> Tuple[IOStream, "Future[IOStream]"]: # Always connect in plaintext; we'll convert to ssl if necessary # after one connection has completed. source_port_bind = source_port if isinstance(source_port, int) else 0 @@ -265,7 +311,7 @@ class TCPClient(object): if source_port_bind and not source_ip: # User required a specific port, but did not specify # a certain source IP, will bind to the default loopback. - source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1' + source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1" # Trying to use the same address family as the requested af socket: # - 127.0.0.1 for IPv4 # - ::1 for IPv6 @@ -280,8 +326,7 @@ class TCPClient(object): # Fail loudly if unable to use the IP/port. raise try: - stream = IOStream(socket_obj, - max_buffer_size=max_buffer_size) + stream = IOStream(socket_obj, max_buffer_size=max_buffer_size) except socket.error as e: fu = Future() # type: Future[IOStream] fu.set_exception(e) diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 4d59f1ebe..26aec8ae9 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -30,6 +30,7 @@ from tornado.util import errno_from_exception import typing from typing import Union, Dict, Any, Iterable, Optional, Awaitable + if typing.TYPE_CHECKING: from typing import Callable, List # noqa: F401 @@ -103,10 +104,15 @@ class TCPServer(object): .. versionchanged:: 5.0 The ``io_loop`` argument has been removed. """ - def __init__(self, ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, - max_buffer_size: int=None, read_chunk_size: int=None) -> None: + + def __init__( + self, + ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None, + max_buffer_size: int = None, + read_chunk_size: int = None, + ) -> None: self.ssl_options = ssl_options - self._sockets = {} # type: Dict[int, socket.socket] + self._sockets = {} # type: Dict[int, socket.socket] self._handlers = {} # type: Dict[int, Callable[[], None]] self._pending_sockets = [] # type: List[socket.socket] self._started = False @@ -120,18 +126,21 @@ class TCPServer(object): # which seems like too much work if self.ssl_options is not None and isinstance(self.ssl_options, dict): # Only certfile is required: it can contain both keys - if 'certfile' not in self.ssl_options: + if "certfile" not in self.ssl_options: raise KeyError('missing key "certfile" in ssl_options') - if not os.path.exists(self.ssl_options['certfile']): - raise ValueError('certfile "%s" does not exist' % - self.ssl_options['certfile']) - if ('keyfile' in self.ssl_options and - not os.path.exists(self.ssl_options['keyfile'])): - raise ValueError('keyfile "%s" does not exist' % - self.ssl_options['keyfile']) - - def listen(self, port: int, address: str="") -> None: + if not os.path.exists(self.ssl_options["certfile"]): + raise ValueError( + 'certfile "%s" does not exist' % self.ssl_options["certfile"] + ) + if "keyfile" in self.ssl_options and not os.path.exists( + self.ssl_options["keyfile"] + ): + raise ValueError( + 'keyfile "%s" does not exist' % self.ssl_options["keyfile"] + ) + + def listen(self, port: int, address: str = "") -> None: """Starts accepting connections on the given port. This method may be called more than once to listen on multiple ports. @@ -154,15 +163,21 @@ class TCPServer(object): for sock in sockets: self._sockets[sock.fileno()] = sock self._handlers[sock.fileno()] = add_accept_handler( - sock, self._handle_connection) + sock, self._handle_connection + ) def add_socket(self, socket: socket.socket) -> None: """Singular version of `add_sockets`. Takes a single socket object.""" self.add_sockets([socket]) - def bind(self, port: int, address: str=None, - family: socket.AddressFamily=socket.AF_UNSPEC, - backlog: int=128, reuse_port: bool=False) -> None: + def bind( + self, + port: int, + address: str = None, + family: socket.AddressFamily = socket.AF_UNSPEC, + backlog: int = 128, + reuse_port: bool = False, + ) -> None: """Binds this server to the given port on the given address. To start the server, call `start`. If you want to run this server @@ -186,14 +201,15 @@ class TCPServer(object): .. versionchanged:: 4.4 Added the ``reuse_port`` argument. """ - sockets = bind_sockets(port, address=address, family=family, - backlog=backlog, reuse_port=reuse_port) + sockets = bind_sockets( + port, address=address, family=family, backlog=backlog, reuse_port=reuse_port + ) if self._started: self.add_sockets(sockets) else: self._pending_sockets.extend(sockets) - def start(self, num_processes: Optional[int]=1) -> None: + def start(self, num_processes: Optional[int] = 1) -> None: """Starts this server in the `.IOLoop`. By default, we run the server in this process and do not fork any @@ -236,7 +252,9 @@ class TCPServer(object): self._handlers.pop(fd)() sock.close() - def handle_stream(self, stream: IOStream, address: tuple) -> Optional[Awaitable[None]]: + def handle_stream( + self, stream: IOStream, address: tuple + ) -> Optional[Awaitable[None]]: """Override to handle a new `.IOStream` from an incoming connection. This method may be a coroutine; if so any exceptions it raises @@ -257,10 +275,12 @@ class TCPServer(object): if self.ssl_options is not None: assert ssl, "Python 2.6+ and OpenSSL required for SSL" try: - connection = ssl_wrap_socket(connection, - self.ssl_options, - server_side=True, - do_handshake_on_connect=False) + connection = ssl_wrap_socket( + connection, + self.ssl_options, + server_side=True, + do_handshake_on_connect=False, + ) except ssl.SSLError as err: if err.args[0] == ssl.SSL_ERROR_EOF: return connection.close() @@ -286,15 +306,19 @@ class TCPServer(object): stream = SSLIOStream( connection, max_buffer_size=self.max_buffer_size, - read_chunk_size=self.read_chunk_size) # type: IOStream + read_chunk_size=self.read_chunk_size, + ) # type: IOStream else: - stream = IOStream(connection, - max_buffer_size=self.max_buffer_size, - read_chunk_size=self.read_chunk_size) + stream = IOStream( + connection, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size, + ) future = self.handle_stream(stream, address) if future is not None: - IOLoop.current().add_future(gen.convert_yielded(future), - lambda f: f.result()) + IOLoop.current().add_future( + gen.convert_yielded(future), lambda f: f.result() + ) except Exception: app_log.error("Error in connection callback", exc_info=True) diff --git a/tornado/template.py b/tornado/template.py index 6e3c1552f..4ef65beac 100644 --- a/tornado/template.py +++ b/tornado/template.py @@ -207,8 +207,19 @@ from tornado import escape from tornado.log import app_log from tornado.util import ObjectDict, exec_in, unicode_type -from typing import Any, Union, Callable, List, Dict, Iterable, Optional, TextIO, ContextManager +from typing import ( + Any, + Union, + Callable, + List, + Dict, + Iterable, + Optional, + TextIO, + ContextManager, +) import typing + if typing.TYPE_CHECKING: from typing import Tuple # noqa: F401 @@ -235,13 +246,13 @@ def filter_whitespace(mode: str, text: str) -> str: .. versionadded:: 4.3 """ - if mode == 'all': + if mode == "all": return text - elif mode == 'single': + elif mode == "single": text = re.sub(r"([\t ]+)", " ", text) text = re.sub(r"(\s*\n\s*)", "\n", text) return text - elif mode == 'oneline': + elif mode == "oneline": return re.sub(r"(\s+)", " ", text) else: raise Exception("invalid whitespace mode %s" % mode) @@ -253,13 +264,19 @@ class Template(object): We compile into Python from the given template_string. You can generate the template from variables with generate(). """ + # note that the constructor's signature is not extracted with # autodoc because _UNSET looks like garbage. When changing # this signature update website/sphinx/template.rst too. - def __init__(self, template_string: Union[str, bytes], name: str="", - loader: 'BaseLoader'=None, compress_whitespace: Union[bool, _UnsetMarker]=_UNSET, - autoescape: Union[str, _UnsetMarker]=_UNSET, - whitespace: str=None) -> None: + def __init__( + self, + template_string: Union[str, bytes], + name: str = "", + loader: "BaseLoader" = None, + compress_whitespace: Union[bool, _UnsetMarker] = _UNSET, + autoescape: Union[str, _UnsetMarker] = _UNSET, + whitespace: str = None, + ) -> None: """Construct a Template. :arg str template_string: the contents of the template file. @@ -296,7 +313,7 @@ class Template(object): whitespace = "all" # Validate the whitespace setting. assert whitespace is not None - filter_whitespace(whitespace, '') + filter_whitespace(whitespace, "") if not isinstance(autoescape, _UnsetMarker): self.autoescape = autoescape # type: Optional[str] @@ -306,8 +323,7 @@ class Template(object): self.autoescape = _DEFAULT_AUTOESCAPE self.namespace = loader.namespace if loader else {} - reader = _TemplateReader(name, escape.native_str(template_string), - whitespace) + reader = _TemplateReader(name, escape.native_str(template_string), whitespace) self.file = _File(self, _parse(reader, self)) self.code = self._generate_python(loader) self.loader = loader @@ -318,8 +334,10 @@ class Template(object): # from being applied to the generated code. self.compiled = compile( escape.to_unicode(self.code), - "%s.generated.py" % self.name.replace('.', '_'), - "exec", dont_inherit=True) + "%s.generated.py" % self.name.replace(".", "_"), + "exec", + dont_inherit=True, + ) except Exception: formatted_code = _format_code(self.code).rstrip() app_log.error("%s code:\n%s", self.name, formatted_code) @@ -339,7 +357,7 @@ class Template(object): "_tt_string_types": (unicode_type, bytes), # __name__ and __loader__ allow the traceback mechanism to find # the generated source code. - "__name__": self.name.replace('.', '_'), + "__name__": self.name.replace(".", "_"), "__loader__": ObjectDict(get_source=lambda name: self.code), } namespace.update(self.namespace) @@ -352,7 +370,7 @@ class Template(object): linecache.clearcache() return execute() - def _generate_python(self, loader: Optional['BaseLoader']) -> str: + def _generate_python(self, loader: Optional["BaseLoader"]) -> str: buffer = StringIO() try: # named_blocks maps from names to _NamedBlock objects @@ -361,20 +379,20 @@ class Template(object): ancestors.reverse() for ancestor in ancestors: ancestor.find_named_blocks(loader, named_blocks) - writer = _CodeWriter(buffer, named_blocks, loader, - ancestors[0].template) + writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template) ancestors[0].generate(writer) return buffer.getvalue() finally: buffer.close() - def _get_ancestors(self, loader: Optional['BaseLoader']) -> List['_File']: + def _get_ancestors(self, loader: Optional["BaseLoader"]) -> List["_File"]: ancestors = [self.file] for chunk in self.file.body.chunks: if isinstance(chunk, _ExtendsBlock): if not loader: - raise ParseError("{% extends %} block found, but no " - "template loader") + raise ParseError( + "{% extends %} block found, but no " "template loader" + ) template = loader.load(chunk.name, self.name) ancestors.extend(template._get_ancestors(loader)) return ancestors @@ -387,9 +405,13 @@ class BaseLoader(object): ``{% extends %}`` and ``{% include %}``. The loader caches all templates after they are loaded the first time. """ - def __init__(self, autoescape: str=_DEFAULT_AUTOESCAPE, - namespace: Dict[str, Any]=None, - whitespace: str=None) -> None: + + def __init__( + self, + autoescape: str = _DEFAULT_AUTOESCAPE, + namespace: Dict[str, Any] = None, + whitespace: str = None, + ) -> None: """Construct a template loader. :arg str autoescape: The name of a function in the template @@ -421,11 +443,11 @@ class BaseLoader(object): with self.lock: self.templates = {} - def resolve_path(self, name: str, parent_path: str=None) -> str: + def resolve_path(self, name: str, parent_path: str = None) -> str: """Converts a possibly-relative path to absolute (used internally).""" raise NotImplementedError() - def load(self, name: str, parent_path: str=None) -> Template: + def load(self, name: str, parent_path: str = None) -> Template: """Loads a template.""" name = self.resolve_path(name, parent_path=parent_path) with self.lock: @@ -440,19 +462,23 @@ class BaseLoader(object): class Loader(BaseLoader): """A template loader that loads from a single root directory. """ + def __init__(self, root_directory: str, **kwargs: Any) -> None: super(Loader, self).__init__(**kwargs) self.root = os.path.abspath(root_directory) - def resolve_path(self, name: str, parent_path: str=None) -> str: - if parent_path and not parent_path.startswith("<") and \ - not parent_path.startswith("/") and \ - not name.startswith("/"): + def resolve_path(self, name: str, parent_path: str = None) -> str: + if ( + parent_path + and not parent_path.startswith("<") + and not parent_path.startswith("/") + and not name.startswith("/") + ): current_path = os.path.join(self.root, parent_path) file_dir = os.path.dirname(os.path.abspath(current_path)) relative_path = os.path.abspath(os.path.join(file_dir, name)) if relative_path.startswith(self.root): - name = relative_path[len(self.root) + 1:] + name = relative_path[len(self.root) + 1 :] return name def _create_template(self, name: str) -> Template: @@ -464,14 +490,18 @@ class Loader(BaseLoader): class DictLoader(BaseLoader): """A template loader that loads from a dictionary.""" + def __init__(self, dict: Dict[str, str], **kwargs: Any) -> None: super(DictLoader, self).__init__(**kwargs) self.dict = dict - def resolve_path(self, name: str, parent_path: str=None) -> str: - if parent_path and not parent_path.startswith("<") and \ - not parent_path.startswith("/") and \ - not name.startswith("/"): + def resolve_path(self, name: str, parent_path: str = None) -> str: + if ( + parent_path + and not parent_path.startswith("<") + and not parent_path.startswith("/") + and not name.startswith("/") + ): file_dir = posixpath.dirname(parent_path) name = posixpath.normpath(posixpath.join(file_dir, name)) return name @@ -481,25 +511,26 @@ class DictLoader(BaseLoader): class _Node(object): - def each_child(self) -> Iterable['_Node']: + def each_child(self) -> Iterable["_Node"]: return () - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: raise NotImplementedError() - def find_named_blocks(self, loader: Optional[BaseLoader], - named_blocks: Dict[str, '_NamedBlock']) -> None: + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"] + ) -> None: for child in self.each_child(): child.find_named_blocks(loader, named_blocks) class _File(_Node): - def __init__(self, template: Template, body: '_ChunkList') -> None: + def __init__(self, template: Template, body: "_ChunkList") -> None: self.template = template self.body = body self.line = 0 - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: writer.write_line("def _tt_execute():", self.line) with writer.indent(): writer.write_line("_tt_buffer = []", self.line) @@ -507,7 +538,7 @@ class _File(_Node): self.body.generate(writer) writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line) - def each_child(self) -> Iterable['_Node']: + def each_child(self) -> Iterable["_Node"]: return (self.body,) @@ -515,11 +546,11 @@ class _ChunkList(_Node): def __init__(self, chunks: List[_Node]) -> None: self.chunks = chunks - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: for chunk in self.chunks: chunk.generate(writer) - def each_child(self) -> Iterable['_Node']: + def each_child(self) -> Iterable["_Node"]: return self.chunks @@ -530,16 +561,17 @@ class _NamedBlock(_Node): self.template = template self.line = line - def each_child(self) -> Iterable['_Node']: + def each_child(self) -> Iterable["_Node"]: return (self.body,) - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: block = writer.named_blocks[self.name] with writer.include(block.template, self.line): block.body.generate(writer) - def find_named_blocks(self, loader: Optional[BaseLoader], - named_blocks: Dict[str, '_NamedBlock']) -> None: + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"] + ) -> None: named_blocks[self.name] = self _Node.find_named_blocks(self, loader, named_blocks) @@ -550,18 +582,19 @@ class _ExtendsBlock(_Node): class _IncludeBlock(_Node): - def __init__(self, name: str, reader: '_TemplateReader', line: int) -> None: + def __init__(self, name: str, reader: "_TemplateReader", line: int) -> None: self.name = name self.template_name = reader.name self.line = line - def find_named_blocks(self, loader: Optional[BaseLoader], - named_blocks: Dict[str, _NamedBlock]) -> None: + def find_named_blocks( + self, loader: Optional[BaseLoader], named_blocks: Dict[str, _NamedBlock] + ) -> None: assert loader is not None included = loader.load(self.name, self.template_name) included.file.find_named_blocks(loader, named_blocks) - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: assert writer.loader is not None included = writer.loader.load(self.name, self.template_name) with writer.include(included, self.line): @@ -574,10 +607,10 @@ class _ApplyBlock(_Node): self.line = line self.body = body - def each_child(self) -> Iterable['_Node']: + def each_child(self) -> Iterable["_Node"]: return (self.body,) - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: method_name = "_tt_apply%d" % writer.apply_counter writer.apply_counter += 1 writer.write_line("def %s():" % method_name, self.line) @@ -586,8 +619,9 @@ class _ApplyBlock(_Node): writer.write_line("_tt_append = _tt_buffer.append", self.line) self.body.generate(writer) writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line) - writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % ( - self.method, method_name), self.line) + writer.write_line( + "_tt_append(_tt_utf8(%s(%s())))" % (self.method, method_name), self.line + ) class _ControlBlock(_Node): @@ -599,7 +633,7 @@ class _ControlBlock(_Node): def each_child(self) -> Iterable[_Node]: return (self.body,) - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: writer.write_line("%s:" % self.statement, self.line) with writer.indent(): self.body.generate(writer) @@ -612,7 +646,7 @@ class _IntermediateControlBlock(_Node): self.statement = statement self.line = line - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: # In case the previous block was empty writer.write_line("pass", self.line) writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1) @@ -623,33 +657,36 @@ class _Statement(_Node): self.statement = statement self.line = line - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: writer.write_line(self.statement, self.line) class _Expression(_Node): - def __init__(self, expression: str, line: int, raw: bool=False) -> None: + def __init__(self, expression: str, line: int, raw: bool = False) -> None: self.expression = expression self.line = line self.raw = raw - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: writer.write_line("_tt_tmp = %s" % self.expression, self.line) - writer.write_line("if isinstance(_tt_tmp, _tt_string_types):" - " _tt_tmp = _tt_utf8(_tt_tmp)", self.line) + writer.write_line( + "if isinstance(_tt_tmp, _tt_string_types):" " _tt_tmp = _tt_utf8(_tt_tmp)", + self.line, + ) writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line) if not self.raw and writer.current_template.autoescape is not None: # In python3 functions like xhtml_escape return unicode, # so we have to convert to utf8 again. - writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" % - writer.current_template.autoescape, self.line) + writer.write_line( + "_tt_tmp = _tt_utf8(%s(_tt_tmp))" % writer.current_template.autoescape, + self.line, + ) writer.write_line("_tt_append(_tt_tmp)", self.line) class _Module(_Expression): def __init__(self, expression: str, line: int) -> None: - super(_Module, self).__init__("_tt_modules." + expression, line, - raw=True) + super(_Module, self).__init__("_tt_modules." + expression, line, raw=True) class _Text(_Node): @@ -658,7 +695,7 @@ class _Text(_Node): self.line = line self.whitespace = whitespace - def generate(self, writer: '_CodeWriter') -> None: + def generate(self, writer: "_CodeWriter") -> None: value = self.value # Compress whitespace if requested, with a crude heuristic to avoid @@ -667,7 +704,7 @@ class _Text(_Node): value = filter_whitespace(self.whitespace, value) if value: - writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line) + writer.write_line("_tt_append(%r)" % escape.utf8(value), self.line) class ParseError(Exception): @@ -679,7 +716,8 @@ class ParseError(Exception): .. versionchanged:: 4.3 Added ``filename`` and ``lineno`` attributes. """ - def __init__(self, message: str, filename: str=None, lineno: int=0) -> None: + + def __init__(self, message: str, filename: str = None, lineno: int = 0) -> None: self.message = message # The names "filename" and "lineno" are chosen for consistency # with python SyntaxError. @@ -687,12 +725,17 @@ class ParseError(Exception): self.lineno = lineno def __str__(self) -> str: - return '%s at %s:%d' % (self.message, self.filename, self.lineno) + return "%s at %s:%d" % (self.message, self.filename, self.lineno) class _CodeWriter(object): - def __init__(self, file: TextIO, named_blocks: Dict[str, _NamedBlock], - loader: Optional[BaseLoader], current_template: Template) -> None: + def __init__( + self, + file: TextIO, + named_blocks: Dict[str, _NamedBlock], + loader: Optional[BaseLoader], + current_template: Template, + ) -> None: self.file = file self.named_blocks = named_blocks self.loader = loader @@ -706,7 +749,7 @@ class _CodeWriter(object): def indent(self) -> ContextManager: class Indenter(object): - def __enter__(_) -> '_CodeWriter': + def __enter__(_) -> "_CodeWriter": self._indent += 1 return self @@ -721,7 +764,7 @@ class _CodeWriter(object): self.current_template = template class IncludeTemplate(object): - def __enter__(_) -> '_CodeWriter': + def __enter__(_) -> "_CodeWriter": return self def __exit__(_, *args: Any) -> None: @@ -729,14 +772,15 @@ class _CodeWriter(object): return IncludeTemplate() - def write_line(self, line: str, line_number: int, indent: int=None) -> None: + def write_line(self, line: str, line_number: int, indent: int = None) -> None: if indent is None: indent = self._indent - line_comment = ' # %s:%d' % (self.current_template.name, line_number) + line_comment = " # %s:%d" % (self.current_template.name, line_number) if self.include_stack: - ancestors = ["%s:%d" % (tmpl.name, lineno) - for (tmpl, lineno) in self.include_stack] - line_comment += ' (via %s)' % ', '.join(reversed(ancestors)) + ancestors = [ + "%s:%d" % (tmpl.name, lineno) for (tmpl, lineno) in self.include_stack + ] + line_comment += " (via %s)" % ", ".join(reversed(ancestors)) print(" " * indent + line + line_comment, file=self.file) @@ -748,7 +792,7 @@ class _TemplateReader(object): self.line = 1 self.pos = 0 - def find(self, needle: str, start: int=0, end: int=None) -> int: + def find(self, needle: str, start: int = 0, end: int = None) -> int: assert start >= 0, start pos = self.pos start += pos @@ -762,12 +806,12 @@ class _TemplateReader(object): index -= pos return index - def consume(self, count: int=None) -> str: + def consume(self, count: int = None) -> str: if count is None: count = len(self.text) - self.pos newpos = self.pos + count self.line += self.text.count("\n", self.pos, newpos) - s = self.text[self.pos:newpos] + s = self.text[self.pos : newpos] self.pos = newpos return s @@ -794,7 +838,7 @@ class _TemplateReader(object): return self.text[self.pos + key] def __str__(self) -> str: - return self.text[self.pos:] + return self.text[self.pos :] def raise_parse_error(self, msg: str) -> None: raise ParseError(msg, self.name, self.line) @@ -806,8 +850,12 @@ def _format_code(code: str) -> str: return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)]) -def _parse(reader: _TemplateReader, template: Template, - in_block: str=None, in_loop: str=None) -> _ChunkList: +def _parse( + reader: _TemplateReader, + template: Template, + in_block: str = None, + in_loop: str = None, +) -> _ChunkList: body = _ChunkList([]) while True: # Find next template directive @@ -818,9 +866,11 @@ def _parse(reader: _TemplateReader, template: Template, # EOF if in_block: reader.raise_parse_error( - "Missing {%% end %%} block for %s" % in_block) - body.chunks.append(_Text(reader.consume(), reader.line, - reader.whitespace)) + "Missing {%% end %%} block for %s" % in_block + ) + body.chunks.append( + _Text(reader.consume(), reader.line, reader.whitespace) + ) return body # If the first curly brace is not the start of a special token, # start searching from the character after it @@ -830,8 +880,11 @@ def _parse(reader: _TemplateReader, template: Template, # When there are more than 2 curlies in a row, use the # innermost ones. This is useful when generating languages # like latex where curlies are also meaningful - if (curly + 2 < reader.remaining() and - reader[curly + 1] == '{' and reader[curly + 2] == '{'): + if ( + curly + 2 < reader.remaining() + and reader[curly + 1] == "{" + and reader[curly + 2] == "{" + ): curly += 1 continue break @@ -839,8 +892,7 @@ def _parse(reader: _TemplateReader, template: Template, # Append any text before the special token if curly > 0: cons = reader.consume(curly) - body.chunks.append(_Text(cons, reader.line, - reader.whitespace)) + body.chunks.append(_Text(cons, reader.line, reader.whitespace)) start_brace = reader.consume(2) line = reader.line @@ -851,8 +903,7 @@ def _parse(reader: _TemplateReader, template: Template, # which also use double braces. if reader.remaining() and reader[0] == "!": reader.consume(1) - body.chunks.append(_Text(start_brace, line, - reader.whitespace)) + body.chunks.append(_Text(start_brace, line, reader.whitespace)) continue # Comment @@ -899,12 +950,13 @@ def _parse(reader: _TemplateReader, template: Template, allowed_parents = intermediate_blocks.get(operator) if allowed_parents is not None: if not in_block: - reader.raise_parse_error("%s outside %s block" % - (operator, allowed_parents)) + reader.raise_parse_error( + "%s outside %s block" % (operator, allowed_parents) + ) if in_block not in allowed_parents: reader.raise_parse_error( - "%s block cannot be attached to %s block" % - (operator, in_block)) + "%s block cannot be attached to %s block" % (operator, in_block) + ) body.chunks.append(_IntermediateControlBlock(contents, line)) continue @@ -914,9 +966,18 @@ def _parse(reader: _TemplateReader, template: Template, reader.raise_parse_error("Extra {% end %} block") return body - elif operator in ("extends", "include", "set", "import", "from", - "comment", "autoescape", "whitespace", "raw", - "module"): + elif operator in ( + "extends", + "include", + "set", + "import", + "from", + "comment", + "autoescape", + "whitespace", + "raw", + "module", + ): if operator == "comment": continue if operator == "extends": @@ -946,7 +1007,7 @@ def _parse(reader: _TemplateReader, template: Template, elif operator == "whitespace": mode = suffix.strip() # Validate the selected mode - filter_whitespace(mode, '') + filter_whitespace(mode, "") reader.whitespace = mode continue elif operator == "raw": @@ -982,8 +1043,9 @@ def _parse(reader: _TemplateReader, template: Template, elif operator in ("break", "continue"): if not in_loop: - reader.raise_parse_error("%s outside %s block" % - (operator, set(["for", "while"]))) + reader.raise_parse_error( + "%s outside %s block" % (operator, set(["for", "while"])) + ) body.chunks.append(_Statement(contents, line)) continue diff --git a/tornado/test/asyncio_test.py b/tornado/test/asyncio_test.py index 990ecc906..aa9c2f9f1 100644 --- a/tornado/test/asyncio_test.py +++ b/tornado/test/asyncio_test.py @@ -16,7 +16,11 @@ import unittest from concurrent.futures import ThreadPoolExecutor from tornado import gen from tornado.ioloop import IOLoop -from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy +from tornado.platform.asyncio import ( + AsyncIOLoop, + to_asyncio_future, + AnyThreadEventLoopPolicy, +) from tornado.testing import AsyncTestCase, gen_test @@ -35,14 +39,15 @@ class AsyncIOLoopTest(AsyncTestCase): # Test that we can yield an asyncio future from a tornado coroutine. # Without 'yield from', we must wrap coroutines in ensure_future, # which was introduced during Python 3.4, deprecating the prior "async". - if hasattr(asyncio, 'ensure_future'): + if hasattr(asyncio, "ensure_future"): ensure_future = asyncio.ensure_future else: # async is a reserved word in Python 3.7 - ensure_future = getattr(asyncio, 'async') + ensure_future = getattr(asyncio, "async") x = yield ensure_future( - asyncio.get_event_loop().run_in_executor(None, lambda: 42)) + asyncio.get_event_loop().run_in_executor(None, lambda: 42) + ) self.assertEqual(x, 42) @gen_test @@ -52,6 +57,7 @@ class AsyncIOLoopTest(AsyncTestCase): event_loop = asyncio.get_event_loop() x = yield from event_loop.run_in_executor(None, lambda: 42) return x + result = yield f() self.assertEqual(result, 42) @@ -76,30 +82,30 @@ class AsyncIOLoopTest(AsyncTestCase): return await to_asyncio_future(native_coroutine_without_adapter()) # Tornado supports native coroutines both with and without adapters - self.assertEqual( - self.io_loop.run_sync(native_coroutine_without_adapter), - 42) - self.assertEqual( - self.io_loop.run_sync(native_coroutine_with_adapter), - 42) - self.assertEqual( - self.io_loop.run_sync(native_coroutine_with_adapter2), - 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42) + self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42) # Asyncio only supports coroutines that yield asyncio-compatible # Futures (which our Future is since 5.0). self.assertEqual( asyncio.get_event_loop().run_until_complete( - native_coroutine_without_adapter()), - 42) + native_coroutine_without_adapter() + ), + 42, + ) self.assertEqual( asyncio.get_event_loop().run_until_complete( - native_coroutine_with_adapter()), - 42) + native_coroutine_with_adapter() + ), + 42, + ) self.assertEqual( asyncio.get_event_loop().run_until_complete( - native_coroutine_with_adapter2()), - 42) + native_coroutine_with_adapter2() + ), + 42, + ) class LeakTest(unittest.TestCase): @@ -160,19 +166,19 @@ class AnyThreadEventLoopPolicyTest(unittest.TestCase): loop = asyncio.get_event_loop() loop.close() return loop + future = self.executor.submit(get_and_close_event_loop) return future.result() def run_policy_test(self, accessor, expected_type): # With the default policy, non-main threads don't get an event # loop. - self.assertRaises((RuntimeError, AssertionError), - self.executor.submit(accessor).result) + self.assertRaises( + (RuntimeError, AssertionError), self.executor.submit(accessor).result + ) # Set the policy and we can get a loop. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - self.assertIsInstance( - self.executor.submit(accessor).result(), - expected_type) + self.assertIsInstance(self.executor.submit(accessor).result(), expected_type) # Clean up to silence leak warnings. Always use asyncio since # IOLoop doesn't (currently) close the underlying loop. self.executor.submit(lambda: asyncio.get_event_loop().close()).result() diff --git a/tornado/test/auth_test.py b/tornado/test/auth_test.py index 6fee29fcf..7807ae56e 100644 --- a/tornado/test/auth_test.py +++ b/tornado/test/auth_test.py @@ -6,8 +6,12 @@ import unittest from tornado.auth import ( - OpenIdMixin, OAuthMixin, OAuth2Mixin, - GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin, + OpenIdMixin, + OAuthMixin, + OAuth2Mixin, + GoogleOAuth2Mixin, + FacebookGraphMixin, + TwitterMixin, ) from tornado.escape import json_decode from tornado import gen @@ -25,12 +29,14 @@ except ImportError: class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin): def initialize(self, test): - self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate') + self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate") @gen.coroutine def get(self): - if self.get_argument('openid.mode', None): - user = yield self.get_authenticated_user(http_client=self.settings['http_client']) + if self.get_argument("openid.mode", None): + user = yield self.get_authenticated_user( + http_client=self.settings["http_client"] + ) if user is None: raise Exception("user is None") self.finish(user) @@ -41,45 +47,48 @@ class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin): class OpenIdServerAuthenticateHandler(RequestHandler): def post(self): - if self.get_argument('openid.mode') != 'check_authentication': + if self.get_argument("openid.mode") != "check_authentication": raise Exception("incorrect openid.mode %r") - self.write('is_valid:true') + self.write("is_valid:true") class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin): def initialize(self, test, version): self._OAUTH_VERSION = version - self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token') - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token') + self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token") + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token") def _oauth_consumer_token(self): - return dict(key='asdf', secret='qwer') + return dict(key="asdf", secret="qwer") @gen.coroutine def get(self): - if self.get_argument('oauth_token', None): - user = yield self.get_authenticated_user(http_client=self.settings['http_client']) + if self.get_argument("oauth_token", None): + user = yield self.get_authenticated_user( + http_client=self.settings["http_client"] + ) if user is None: raise Exception("user is None") self.finish(user) return - yield self.authorize_redirect(http_client=self.settings['http_client']) + yield self.authorize_redirect(http_client=self.settings["http_client"]) @gen.coroutine def _oauth_get_user_future(self, access_token): - if self.get_argument('fail_in_get_user', None): + if self.get_argument("fail_in_get_user", None): raise Exception("failing in get_user") - if access_token != dict(key='uiop', secret='5678'): + if access_token != dict(key="uiop", secret="5678"): raise Exception("incorrect access token %r" % access_token) - return dict(email='foo@example.com') + return dict(email="foo@example.com") class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler): """Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine.""" + @gen.coroutine def get(self): - if self.get_argument('oauth_token', None): + if self.get_argument("oauth_token", None): # Ensure that any exceptions are set on the returned Future, # not simply thrown into the surrounding StackContext. try: @@ -96,29 +105,30 @@ class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin): self._OAUTH_VERSION = version def _oauth_consumer_token(self): - return dict(key='asdf', secret='qwer') + return dict(key="asdf", secret="qwer") def get(self): params = self._oauth_request_parameters( - 'http://www.example.com/api/asdf', - dict(key='uiop', secret='5678'), - parameters=dict(foo='bar')) + "http://www.example.com/api/asdf", + dict(key="uiop", secret="5678"), + parameters=dict(foo="bar"), + ) self.write(params) class OAuth1ServerRequestTokenHandler(RequestHandler): def get(self): - self.write('oauth_token=zxcv&oauth_token_secret=1234') + self.write("oauth_token=zxcv&oauth_token_secret=1234") class OAuth1ServerAccessTokenHandler(RequestHandler): def get(self): - self.write('oauth_token=uiop&oauth_token_secret=5678') + self.write("oauth_token=uiop&oauth_token_secret=5678") class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin): def initialize(self, test): - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize') + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize") def get(self): res = self.authorize_redirect() @@ -127,9 +137,9 @@ class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin): class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin): def initialize(self, test): - self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token') - self._FACEBOOK_BASE_URL = test.get_url('/facebook/server') + self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token") + self._FACEBOOK_BASE_URL = test.get_url("/facebook/server") @gen.coroutine def get(self): @@ -138,13 +148,15 @@ class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin): redirect_uri=self.request.full_url(), client_id=self.settings["facebook_api_key"], client_secret=self.settings["facebook_secret"], - code=self.get_argument("code")) + code=self.get_argument("code"), + ) self.write(user) else: yield self.authorize_redirect( redirect_uri=self.request.full_url(), client_id=self.settings["facebook_api_key"], - extra_params={"scope": "read_stream,offline_access"}) + extra_params={"scope": "read_stream,offline_access"}, + ) class FacebookServerAccessTokenHandler(RequestHandler): @@ -154,19 +166,19 @@ class FacebookServerAccessTokenHandler(RequestHandler): class FacebookServerMeHandler(RequestHandler): def get(self): - self.write('{}') + self.write("{}") class TwitterClientHandler(RequestHandler, TwitterMixin): def initialize(self, test): - self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token') - self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize') - self._OAUTH_AUTHENTICATE_URL = test.get_url('/twitter/server/authenticate') - self._TWITTER_BASE_URL = test.get_url('/twitter/api') + self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token") + self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize") + self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate") + self._TWITTER_BASE_URL = test.get_url("/twitter/api") def get_auth_http_client(self): - return self.settings['http_client'] + return self.settings["http_client"] class TwitterClientLoginHandler(TwitterClientHandler): @@ -214,46 +226,47 @@ class TwitterClientShowUserHandler(TwitterClientHandler): # cheating with a hard-coded access token. try: response = yield self.twitter_request( - '/users/show/%s' % self.get_argument('name'), - access_token=dict(key='hjkl', secret='vbnm')) + "/users/show/%s" % self.get_argument("name"), + access_token=dict(key="hjkl", secret="vbnm"), + ) except HTTPClientError: # TODO(bdarnell): Should we catch HTTP errors and # transform some of them (like 403s) into AuthError? self.set_status(500) - self.finish('error from twitter request') + self.finish("error from twitter request") else: self.finish(response) class TwitterServerAccessTokenHandler(RequestHandler): def get(self): - self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo') + self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo") class TwitterServerShowUserHandler(RequestHandler): def get(self, screen_name): - if screen_name == 'error': + if screen_name == "error": raise HTTPError(500) - assert 'oauth_nonce' in self.request.arguments - assert 'oauth_timestamp' in self.request.arguments - assert 'oauth_signature' in self.request.arguments - assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key' - assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1' - assert self.get_argument('oauth_version') == '1.0' - assert self.get_argument('oauth_token') == 'hjkl' + assert "oauth_nonce" in self.request.arguments + assert "oauth_timestamp" in self.request.arguments + assert "oauth_signature" in self.request.arguments + assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key" + assert self.get_argument("oauth_signature_method") == "HMAC-SHA1" + assert self.get_argument("oauth_version") == "1.0" + assert self.get_argument("oauth_token") == "hjkl" self.write(dict(screen_name=screen_name, name=screen_name.capitalize())) class TwitterServerVerifyCredentialsHandler(RequestHandler): def get(self): - assert 'oauth_nonce' in self.request.arguments - assert 'oauth_timestamp' in self.request.arguments - assert 'oauth_signature' in self.request.arguments - assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key' - assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1' - assert self.get_argument('oauth_version') == '1.0' - assert self.get_argument('oauth_token') == 'hjkl' - self.write(dict(screen_name='foo', name='Foo')) + assert "oauth_nonce" in self.request.arguments + assert "oauth_timestamp" in self.request.arguments + assert "oauth_signature" in self.request.arguments + assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key" + assert self.get_argument("oauth_signature_method") == "HMAC-SHA1" + assert self.get_argument("oauth_version") == "1.0" + assert self.get_argument("oauth_token") == "hjkl" + self.write(dict(screen_name="foo", name="Foo")) class AuthTest(AsyncHTTPTestCase): @@ -261,258 +274,310 @@ class AuthTest(AsyncHTTPTestCase): return Application( [ # test endpoints - ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)), - ('/oauth10/client/login', OAuth1ClientLoginHandler, - dict(test=self, version='1.0')), - ('/oauth10/client/request_params', - OAuth1ClientRequestParametersHandler, - dict(version='1.0')), - ('/oauth10a/client/login', OAuth1ClientLoginHandler, - dict(test=self, version='1.0a')), - ('/oauth10a/client/login_coroutine', - OAuth1ClientLoginCoroutineHandler, - dict(test=self, version='1.0a')), - ('/oauth10a/client/request_params', - OAuth1ClientRequestParametersHandler, - dict(version='1.0a')), - ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)), - - ('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)), - - ('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)), - ('/twitter/client/authenticate', TwitterClientAuthenticateHandler, dict(test=self)), - ('/twitter/client/login_gen_coroutine', - TwitterClientLoginGenCoroutineHandler, dict(test=self)), - ('/twitter/client/show_user', - TwitterClientShowUserHandler, dict(test=self)), - + ("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)), + ( + "/oauth10/client/login", + OAuth1ClientLoginHandler, + dict(test=self, version="1.0"), + ), + ( + "/oauth10/client/request_params", + OAuth1ClientRequestParametersHandler, + dict(version="1.0"), + ), + ( + "/oauth10a/client/login", + OAuth1ClientLoginHandler, + dict(test=self, version="1.0a"), + ), + ( + "/oauth10a/client/login_coroutine", + OAuth1ClientLoginCoroutineHandler, + dict(test=self, version="1.0a"), + ), + ( + "/oauth10a/client/request_params", + OAuth1ClientRequestParametersHandler, + dict(version="1.0a"), + ), + ("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)), + ("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)), + ("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)), + ( + "/twitter/client/authenticate", + TwitterClientAuthenticateHandler, + dict(test=self), + ), + ( + "/twitter/client/login_gen_coroutine", + TwitterClientLoginGenCoroutineHandler, + dict(test=self), + ), + ( + "/twitter/client/show_user", + TwitterClientShowUserHandler, + dict(test=self), + ), # simulated servers - ('/openid/server/authenticate', OpenIdServerAuthenticateHandler), - ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler), - ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler), - - ('/facebook/server/access_token', FacebookServerAccessTokenHandler), - ('/facebook/server/me', FacebookServerMeHandler), - ('/twitter/server/access_token', TwitterServerAccessTokenHandler), - (r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler), - (r'/twitter/api/account/verify_credentials\.json', - TwitterServerVerifyCredentialsHandler), + ("/openid/server/authenticate", OpenIdServerAuthenticateHandler), + ("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler), + ("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler), + ("/facebook/server/access_token", FacebookServerAccessTokenHandler), + ("/facebook/server/me", FacebookServerMeHandler), + ("/twitter/server/access_token", TwitterServerAccessTokenHandler), + (r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler), + ( + r"/twitter/api/account/verify_credentials\.json", + TwitterServerVerifyCredentialsHandler, + ), ], http_client=self.http_client, - twitter_consumer_key='test_twitter_consumer_key', - twitter_consumer_secret='test_twitter_consumer_secret', - facebook_api_key='test_facebook_api_key', - facebook_secret='test_facebook_secret') + twitter_consumer_key="test_twitter_consumer_key", + twitter_consumer_secret="test_twitter_consumer_secret", + facebook_api_key="test_facebook_api_key", + facebook_secret="test_facebook_secret", + ) def test_openid_redirect(self): - response = self.fetch('/openid/client/login', follow_redirects=False) + response = self.fetch("/openid/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue( - '/openid/server/authenticate?' in response.headers['Location']) + self.assertTrue("/openid/server/authenticate?" in response.headers["Location"]) def test_openid_get_user(self): - response = self.fetch('/openid/client/login?openid.mode=blah' - '&openid.ns.ax=http://openid.net/srv/ax/1.0' - '&openid.ax.type.email=http://axschema.org/contact/email' - '&openid.ax.value.email=foo@example.com') + response = self.fetch( + "/openid/client/login?openid.mode=blah" + "&openid.ns.ax=http://openid.net/srv/ax/1.0" + "&openid.ax.type.email=http://axschema.org/contact/email" + "&openid.ax.value.email=foo@example.com" + ) response.rethrow() parsed = json_decode(response.body) self.assertEqual(parsed["email"], "foo@example.com") def test_oauth10_redirect(self): - response = self.fetch('/oauth10/client/login', follow_redirects=False) + response = self.fetch("/oauth10/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) def test_oauth10_get_user(self): response = self.fetch( - '/oauth10/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/oauth10/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['email'], 'foo@example.com') - self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + self.assertEqual(parsed["email"], "foo@example.com") + self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678")) def test_oauth10_request_parameters(self): - response = self.fetch('/oauth10/client/request_params') + response = self.fetch("/oauth10/client/request_params") response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['oauth_consumer_key'], 'asdf') - self.assertEqual(parsed['oauth_token'], 'uiop') - self.assertTrue('oauth_nonce' in parsed) - self.assertTrue('oauth_signature' in parsed) + self.assertEqual(parsed["oauth_consumer_key"], "asdf") + self.assertEqual(parsed["oauth_token"], "uiop") + self.assertTrue("oauth_nonce" in parsed) + self.assertTrue("oauth_signature" in parsed) def test_oauth10a_redirect(self): - response = self.fetch('/oauth10a/client/login', follow_redirects=False) + response = self.fetch("/oauth10a/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) - @unittest.skipIf(mock is None, 'mock package not present') + @unittest.skipIf(mock is None, "mock package not present") def test_oauth10a_redirect_error(self): - with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get: + with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get: get.side_effect = Exception("boom") with ExpectLog(app_log, "Uncaught exception"): - response = self.fetch('/oauth10a/client/login', follow_redirects=False) + response = self.fetch("/oauth10a/client/login", follow_redirects=False) self.assertEqual(response.code, 500) def test_oauth10a_get_user(self): response = self.fetch( - '/oauth10a/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/oauth10a/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['email'], 'foo@example.com') - self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678')) + self.assertEqual(parsed["email"], "foo@example.com") + self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678")) def test_oauth10a_request_parameters(self): - response = self.fetch('/oauth10a/client/request_params') + response = self.fetch("/oauth10a/client/request_params") response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed['oauth_consumer_key'], 'asdf') - self.assertEqual(parsed['oauth_token'], 'uiop') - self.assertTrue('oauth_nonce' in parsed) - self.assertTrue('oauth_signature' in parsed) + self.assertEqual(parsed["oauth_consumer_key"], "asdf") + self.assertEqual(parsed["oauth_token"], "uiop") + self.assertTrue("oauth_nonce" in parsed) + self.assertTrue("oauth_signature" in parsed) def test_oauth10a_get_user_coroutine_exception(self): response = self.fetch( - '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) self.assertEqual(response.code, 503) def test_oauth2_redirect(self): - response = self.fetch('/oauth2/client/login', follow_redirects=False) + response = self.fetch("/oauth2/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue('/oauth2/server/authorize?' in response.headers['Location']) + self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"]) def test_facebook_login(self): - response = self.fetch('/facebook/client/login', follow_redirects=False) + response = self.fetch("/facebook/client/login", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue('/facebook/server/authorize?' in response.headers['Location']) - response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False) + self.assertTrue("/facebook/server/authorize?" in response.headers["Location"]) + response = self.fetch( + "/facebook/client/login?code=1234", follow_redirects=False + ) self.assertEqual(response.code, 200) user = json_decode(response.body) - self.assertEqual(user['access_token'], 'asdf') - self.assertEqual(user['session_expires'], '3600') + self.assertEqual(user["access_token"], "asdf") + self.assertEqual(user["session_expires"], "3600") def base_twitter_redirect(self, url): # Same as test_oauth10a_redirect response = self.fetch(url, follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/oauth1/server/authorize?oauth_token=zxcv')) + self.assertTrue( + response.headers["Location"].endswith( + "/oauth1/server/authorize?oauth_token=zxcv" + ) + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) def test_twitter_redirect(self): - self.base_twitter_redirect('/twitter/client/login') + self.base_twitter_redirect("/twitter/client/login") def test_twitter_redirect_gen_coroutine(self): - self.base_twitter_redirect('/twitter/client/login_gen_coroutine') + self.base_twitter_redirect("/twitter/client/login_gen_coroutine") def test_twitter_authenticate_redirect(self): - response = self.fetch('/twitter/client/authenticate', follow_redirects=False) + response = self.fetch("/twitter/client/authenticate", follow_redirects=False) self.assertEqual(response.code, 302) - self.assertTrue(response.headers['Location'].endswith( - '/twitter/server/authenticate?oauth_token=zxcv'), response.headers['Location']) + self.assertTrue( + response.headers["Location"].endswith( + "/twitter/server/authenticate?oauth_token=zxcv" + ), + response.headers["Location"], + ) # the cookie is base64('zxcv')|base64('1234') self.assertTrue( - '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'], - response.headers['Set-Cookie']) + '_oauth_request_token="enhjdg==|MTIzNA=="' + in response.headers["Set-Cookie"], + response.headers["Set-Cookie"], + ) def test_twitter_get_user(self): response = self.fetch( - '/twitter/client/login?oauth_token=zxcv', - headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='}) + "/twitter/client/login?oauth_token=zxcv", + headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="}, + ) response.rethrow() parsed = json_decode(response.body) - self.assertEqual(parsed, - {u'access_token': {u'key': u'hjkl', - u'screen_name': u'foo', - u'secret': u'vbnm'}, - u'name': u'Foo', - u'screen_name': u'foo', - u'username': u'foo'}) + self.assertEqual( + parsed, + { + u"access_token": { + u"key": u"hjkl", + u"screen_name": u"foo", + u"secret": u"vbnm", + }, + u"name": u"Foo", + u"screen_name": u"foo", + u"username": u"foo", + }, + ) def test_twitter_show_user(self): - response = self.fetch('/twitter/client/show_user?name=somebody') + response = self.fetch("/twitter/client/show_user?name=somebody") response.rethrow() - self.assertEqual(json_decode(response.body), - {'name': 'Somebody', 'screen_name': 'somebody'}) + self.assertEqual( + json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"} + ) def test_twitter_show_user_error(self): - response = self.fetch('/twitter/client/show_user?name=error') + response = self.fetch("/twitter/client/show_user?name=error") self.assertEqual(response.code, 500) - self.assertEqual(response.body, b'error from twitter request') + self.assertEqual(response.body, b"error from twitter request") class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin): def initialize(self, test): self.test = test - self._OAUTH_REDIRECT_URI = test.get_url('/client/login') - self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize') - self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token') + self._OAUTH_REDIRECT_URI = test.get_url("/client/login") + self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize") + self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token") @gen.coroutine def get(self): - code = self.get_argument('code', None) + code = self.get_argument("code", None) if code is not None: # retrieve authenticate google user - access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, - code) + access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code) user = yield self.oauth2_request( self.test.get_url("/google/oauth2/userinfo"), - access_token=access["access_token"]) + access_token=access["access_token"], + ) # return the user and access token as json user["access_token"] = access["access_token"] self.write(user) else: yield self.authorize_redirect( redirect_uri=self._OAUTH_REDIRECT_URI, - client_id=self.settings['google_oauth']['key'], - client_secret=self.settings['google_oauth']['secret'], - scope=['profile', 'email'], - response_type='code', - extra_params={'prompt': 'select_account'}) + client_id=self.settings["google_oauth"]["key"], + client_secret=self.settings["google_oauth"]["secret"], + scope=["profile", "email"], + response_type="code", + extra_params={"prompt": "select_account"}, + ) class GoogleOAuth2AuthorizeHandler(RequestHandler): def get(self): # issue a fake auth code and redirect to redirect_uri - code = 'fake-authorization-code' - self.redirect(url_concat(self.get_argument('redirect_uri'), - dict(code=code))) + code = "fake-authorization-code" + self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code))) class GoogleOAuth2TokenHandler(RequestHandler): def post(self): - assert self.get_argument('code') == 'fake-authorization-code' + assert self.get_argument("code") == "fake-authorization-code" # issue a fake token - self.finish({ - 'access_token': 'fake-access-token', - 'expires_in': 'never-expires' - }) + self.finish( + {"access_token": "fake-access-token", "expires_in": "never-expires"} + ) class GoogleOAuth2UserinfoHandler(RequestHandler): def get(self): - assert self.get_argument('access_token') == 'fake-access-token' + assert self.get_argument("access_token") == "fake-access-token" # return a fake user - self.finish({ - 'name': 'Foo', - 'email': 'foo@example.com' - }) + self.finish({"name": "Foo", "email": "foo@example.com"}) class GoogleOAuth2Test(AsyncHTTPTestCase): @@ -520,22 +585,25 @@ class GoogleOAuth2Test(AsyncHTTPTestCase): return Application( [ # test endpoints - ('/client/login', GoogleLoginHandler, dict(test=self)), - + ("/client/login", GoogleLoginHandler, dict(test=self)), # simulated google authorization server endpoints - ('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler), - ('/google/oauth2/token', GoogleOAuth2TokenHandler), - ('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler), + ("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler), + ("/google/oauth2/token", GoogleOAuth2TokenHandler), + ("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler), ], google_oauth={ - "key": 'fake_google_client_id', - "secret": 'fake_google_client_secret' - }) + "key": "fake_google_client_id", + "secret": "fake_google_client_secret", + }, + ) def test_google_login(self): - response = self.fetch('/client/login') - self.assertDictEqual({ - u'name': u'Foo', - u'email': u'foo@example.com', - u'access_token': u'fake-access-token', - }, json_decode(response.body)) + response = self.fetch("/client/login") + self.assertDictEqual( + { + u"name": u"Foo", + u"email": u"foo@example.com", + u"access_token": u"fake-access-token", + }, + json_decode(response.body), + ) diff --git a/tornado/test/autoreload_test.py b/tornado/test/autoreload_test.py index 1c8ffbb88..be481e106 100644 --- a/tornado/test/autoreload_test.py +++ b/tornado/test/autoreload_test.py @@ -42,23 +42,26 @@ if 'TESTAPP_STARTED' not in os.environ: """ # Create temporary test application - os.mkdir(os.path.join(self.path, 'testapp')) - open(os.path.join(self.path, 'testapp/__init__.py'), 'w').close() - with open(os.path.join(self.path, 'testapp/__main__.py'), 'w') as f: + os.mkdir(os.path.join(self.path, "testapp")) + open(os.path.join(self.path, "testapp/__init__.py"), "w").close() + with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f: f.write(main) # Make sure the tornado module under test is available to the test # application pythonpath = os.getcwd() - if 'PYTHONPATH' in os.environ: - pythonpath += os.pathsep + os.environ['PYTHONPATH'] + if "PYTHONPATH" in os.environ: + pythonpath += os.pathsep + os.environ["PYTHONPATH"] p = Popen( - [sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE, - cwd=self.path, env=dict(os.environ, PYTHONPATH=pythonpath), - universal_newlines=True) + [sys.executable, "-m", "testapp"], + stdout=subprocess.PIPE, + cwd=self.path, + env=dict(os.environ, PYTHONPATH=pythonpath), + universal_newlines=True, + ) out = p.communicate()[0] - self.assertEqual(out, 'Starting\nStarting\n') + self.assertEqual(out, "Starting\nStarting\n") def test_reload_wrapper_preservation(self): # This test verifies that when `python -m tornado.autoreload` @@ -89,24 +92,26 @@ else: """ # Create temporary test application - os.mkdir(os.path.join(self.path, 'testapp')) - init_file = os.path.join(self.path, 'testapp', '__init__.py') - open(init_file, 'w').close() - main_file = os.path.join(self.path, 'testapp', '__main__.py') - with open(main_file, 'w') as f: + os.mkdir(os.path.join(self.path, "testapp")) + init_file = os.path.join(self.path, "testapp", "__init__.py") + open(init_file, "w").close() + main_file = os.path.join(self.path, "testapp", "__main__.py") + with open(main_file, "w") as f: f.write(main) # Make sure the tornado module under test is available to the test # application pythonpath = os.getcwd() - if 'PYTHONPATH' in os.environ: - pythonpath += os.pathsep + os.environ['PYTHONPATH'] + if "PYTHONPATH" in os.environ: + pythonpath += os.pathsep + os.environ["PYTHONPATH"] autoreload_proc = Popen( - [sys.executable, '-m', 'tornado.autoreload', '-m', 'testapp'], - stdout=subprocess.PIPE, cwd=self.path, + [sys.executable, "-m", "tornado.autoreload", "-m", "testapp"], + stdout=subprocess.PIPE, + cwd=self.path, env=dict(os.environ, PYTHONPATH=pythonpath), - universal_newlines=True) + universal_newlines=True, + ) # This timeout needs to be fairly generous for pypy due to jit # warmup costs. @@ -119,4 +124,4 @@ else: raise Exception("subprocess failed to terminate") out = autoreload_proc.communicate()[0] - self.assertEqual(out, 'Starting\n' * 2) + self.assertEqual(out, "Starting\n" * 2) diff --git a/tornado/test/concurrent_test.py b/tornado/test/concurrent_test.py index db263d621..488b7ddd6 100644 --- a/tornado/test/concurrent_test.py +++ b/tornado/test/concurrent_test.py @@ -18,7 +18,11 @@ import re import socket import unittest -from tornado.concurrent import Future, run_on_executor, future_set_result_unless_cancelled +from tornado.concurrent import ( + Future, + run_on_executor, + future_set_result_unless_cancelled, +) from tornado.escape import utf8, to_unicode from tornado import gen from tornado.iostream import IOStream @@ -27,7 +31,6 @@ from tornado.testing import AsyncTestCase, bind_unused_port, gen_test class MiscFutureTest(AsyncTestCase): - def test_future_set_result_unless_cancelled(self): fut = Future() # type: Future[int] future_set_result_unless_cancelled(fut, 42) @@ -69,11 +72,11 @@ class BaseCapClient(object): self.port = port def process_response(self, data): - m = re.match('(.*)\t(.*)\n', to_unicode(data)) + m = re.match("(.*)\t(.*)\n", to_unicode(data)) if m is None: raise Exception("did not match") status, message = m.groups() - if status == 'ok': + if status == "ok": return message else: raise CapError(message) @@ -82,14 +85,14 @@ class BaseCapClient(object): class GeneratorCapClient(BaseCapClient): @gen.coroutine def capitalize(self, request_data): - logging.debug('capitalize') + logging.debug("capitalize") stream = IOStream(socket.socket()) - logging.debug('connecting') - yield stream.connect(('127.0.0.1', self.port)) - stream.write(utf8(request_data + '\n')) - logging.debug('reading') - data = yield stream.read_until(b'\n') - logging.debug('returning') + logging.debug("connecting") + yield stream.connect(("127.0.0.1", self.port)) + stream.write(utf8(request_data + "\n")) + logging.debug("reading") + data = yield stream.read_until(b"\n") + logging.debug("returning") stream.close() raise gen.Return(self.process_response(data)) @@ -123,6 +126,7 @@ class ClientTestMixin(object): def f(): result = yield self.client.capitalize("hello") self.assertEqual(result, "HELLO") + self.io_loop.run_sync(f) def test_generator_error(self): @@ -130,6 +134,7 @@ class ClientTestMixin(object): def f(): with self.assertRaisesRegexp(CapError, "already capitalized"): yield self.client.capitalize("HELLO") + self.io_loop.run_sync(f) @@ -172,7 +177,7 @@ class RunOnExecutorTest(AsyncTestCase): def __init__(self): self.__executor = futures.thread.ThreadPoolExecutor(1) - @run_on_executor(executor='_Object__executor') + @run_on_executor(executor="_Object__executor") def f(self): return 42 @@ -195,9 +200,10 @@ class RunOnExecutorTest(AsyncTestCase): async def f(): answer = await o.f() return answer + result = yield f() self.assertEqual(result, 42) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py index b93e97123..ccb19fecd 100644 --- a/tornado/test/curl_httpclient_test.py +++ b/tornado/test/curl_httpclient_test.py @@ -32,39 +32,43 @@ class DigestAuthHandler(RequestHandler): self.password = password def get(self): - realm = 'test' - opaque = 'asdf' + realm = "test" + opaque = "asdf" # Real implementations would use a random nonce. nonce = "1234" - auth_header = self.request.headers.get('Authorization', None) + auth_header = self.request.headers.get("Authorization", None) if auth_header is not None: - auth_mode, params = auth_header.split(' ', 1) - assert auth_mode == 'Digest' + auth_mode, params = auth_header.split(" ", 1) + assert auth_mode == "Digest" param_dict = {} - for pair in params.split(','): - k, v = pair.strip().split('=', 1) + for pair in params.split(","): + k, v = pair.strip().split("=", 1) if v[0] == '"' and v[-1] == '"': v = v[1:-1] param_dict[k] = v - assert param_dict['realm'] == realm - assert param_dict['opaque'] == opaque - assert param_dict['nonce'] == nonce - assert param_dict['username'] == self.username - assert param_dict['uri'] == self.request.path - h1 = md5(utf8('%s:%s:%s' % (self.username, realm, self.password))).hexdigest() - h2 = md5(utf8('%s:%s' % (self.request.method, - self.request.path))).hexdigest() - digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest() - if digest == param_dict['response']: - self.write('ok') + assert param_dict["realm"] == realm + assert param_dict["opaque"] == opaque + assert param_dict["nonce"] == nonce + assert param_dict["username"] == self.username + assert param_dict["uri"] == self.request.path + h1 = md5( + utf8("%s:%s:%s" % (self.username, realm, self.password)) + ).hexdigest() + h2 = md5( + utf8("%s:%s" % (self.request.method, self.request.path)) + ).hexdigest() + digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest() + if digest == param_dict["response"]: + self.write("ok") else: - self.write('fail') + self.write("fail") else: self.set_status(401) - self.set_header('WWW-Authenticate', - 'Digest realm="%s", nonce="%s", opaque="%s"' % - (realm, nonce, opaque)) + self.set_header( + "WWW-Authenticate", + 'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque), + ) class CustomReasonHandler(RequestHandler): @@ -84,32 +88,43 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase): self.http_client = self.create_client() def get_app(self): - return Application([ - ('/digest', DigestAuthHandler, {'username': 'foo', 'password': 'bar'}), - ('/digest_non_ascii', DigestAuthHandler, {'username': 'foo', 'password': 'barユ£'}), - ('/custom_reason', CustomReasonHandler), - ('/custom_fail_reason', CustomFailReasonHandler), - ]) + return Application( + [ + ("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}), + ( + "/digest_non_ascii", + DigestAuthHandler, + {"username": "foo", "password": "barユ£"}, + ), + ("/custom_reason", CustomReasonHandler), + ("/custom_fail_reason", CustomFailReasonHandler), + ] + ) def create_client(self, **kwargs): - return CurlAsyncHTTPClient(force_instance=True, - defaults=dict(allow_ipv6=False), - **kwargs) + return CurlAsyncHTTPClient( + force_instance=True, defaults=dict(allow_ipv6=False), **kwargs + ) def test_digest_auth(self): - response = self.fetch('/digest', auth_mode='digest', - auth_username='foo', auth_password='bar') - self.assertEqual(response.body, b'ok') + response = self.fetch( + "/digest", auth_mode="digest", auth_username="foo", auth_password="bar" + ) + self.assertEqual(response.body, b"ok") def test_custom_reason(self): - response = self.fetch('/custom_reason') + response = self.fetch("/custom_reason") self.assertEqual(response.reason, "Custom reason") def test_fail_custom_reason(self): - response = self.fetch('/custom_fail_reason') + response = self.fetch("/custom_fail_reason") self.assertEqual(str(response.error), "HTTP 400: Custom reason") def test_digest_auth_non_ascii(self): - response = self.fetch('/digest_non_ascii', auth_mode='digest', - auth_username='foo', auth_password='barユ£') - self.assertEqual(response.body, b'ok') + response = self.fetch( + "/digest_non_ascii", + auth_mode="digest", + auth_username="foo", + auth_password="barユ£", + ) + self.assertEqual(response.body, b"ok") diff --git a/tornado/test/escape_test.py b/tornado/test/escape_test.py index 5c458f08b..06c043aac 100644 --- a/tornado/test/escape_test.py +++ b/tornado/test/escape_test.py @@ -2,8 +2,16 @@ import unittest import tornado.escape from tornado.escape import ( - utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape, - to_unicode, json_decode, json_encode, squeeze, recursive_unicode, + utf8, + xhtml_escape, + xhtml_unescape, + url_escape, + url_unescape, + to_unicode, + json_decode, + json_encode, + squeeze, + recursive_unicode, ) from tornado.util import unicode_type @@ -11,129 +19,193 @@ from typing import List, Tuple, Union, Dict, Any # noqa linkify_tests = [ # (input, linkify_kwargs, expected_output) - - ("hello http://world.com/!", {}, - u'hello http://world.com/!'), - - ("hello http://world.com/with?param=true&stuff=yes", {}, - u'hello http://world.com/with?param=true&stuff=yes'), # noqa: E501 - + ( + "hello http://world.com/!", + {}, + u'hello http://world.com/!', + ), + ( + "hello http://world.com/with?param=true&stuff=yes", + {}, + u'hello http://world.com/with?param=true&stuff=yes', # noqa: E501 + ), # an opened paren followed by many chars killed Gruber's regex - ("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {}, - u'http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501 - + ( + "http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + {}, + u'http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501 + ), # as did too many dots at the end - ("http://url.com/withmany.......................................", {}, - u'http://url.com/withmany.......................................'), # noqa: E501 - - ("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {}, - u'http://url.com/withmany((((((((((((((((((((((((((((((((((a)'), # noqa: E501 - + ( + "http://url.com/withmany.......................................", + {}, + u'http://url.com/withmany.......................................', # noqa: E501 + ), + ( + "http://url.com/withmany((((((((((((((((((((((((((((((((((a)", + {}, + u'http://url.com/withmany((((((((((((((((((((((((((((((((((a)', # noqa: E501 + ), # some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls # plus a fex extras (such as multiple parentheses). - ("http://foo.com/blah_blah", {}, - u'http://foo.com/blah_blah'), - - ("http://foo.com/blah_blah/", {}, - u'http://foo.com/blah_blah/'), - - ("(Something like http://foo.com/blah_blah)", {}, - u'(Something like http://foo.com/blah_blah)'), - - ("http://foo.com/blah_blah_(wikipedia)", {}, - u'http://foo.com/blah_blah_(wikipedia)'), - - ("http://foo.com/blah_(blah)_(wikipedia)_blah", {}, - u'http://foo.com/blah_(blah)_(wikipedia)_blah'), # noqa: E501 - - ("(Something like http://foo.com/blah_blah_(wikipedia))", {}, - u'(Something like http://foo.com/blah_blah_(wikipedia))'), # noqa: E501 - - ("http://foo.com/blah_blah.", {}, - u'http://foo.com/blah_blah.'), - - ("http://foo.com/blah_blah/.", {}, - u'http://foo.com/blah_blah/.'), - - ("", {}, - u'<http://foo.com/blah_blah>'), - - ("", {}, - u'<http://foo.com/blah_blah/>'), - - ("http://foo.com/blah_blah,", {}, - u'http://foo.com/blah_blah,'), - - ("http://www.example.com/wpstyle/?p=364.", {}, - u'http://www.example.com/wpstyle/?p=364.'), - - ("rdar://1234", - {"permitted_protocols": ["http", "rdar"]}, - u'rdar://1234'), - - ("rdar:/1234", - {"permitted_protocols": ["rdar"]}, - u'rdar:/1234'), - - ("http://userid:password@example.com:8080", {}, - u'http://userid:password@example.com:8080'), # noqa: E501 - - ("http://userid@example.com", {}, - u'http://userid@example.com'), - - ("http://userid@example.com:8080", {}, - u'http://userid@example.com:8080'), - - ("http://userid:password@example.com", {}, - u'http://userid:password@example.com'), - - ("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", - {"permitted_protocols": ["http", "message"]}, - u'' - u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e'), - - (u"http://\u27a1.ws/\u4a39", {}, - u'http://\u27a1.ws/\u4a39'), - - ("http://example.com", {}, - u'<tag>http://example.com</tag>'), - - ("Just a www.example.com link.", {}, - u'Just a www.example.com link.'), - - ("Just a www.example.com link.", - {"require_protocol": True}, - u'Just a www.example.com link.'), - - ("A http://reallylong.com/link/that/exceedsthelenglimit.html", - {"require_protocol": True, "shorten": True}, - u'A http://reallylong.com/link...'), # noqa: E501 - - ("A http://reallylongdomainnamethatwillbetoolong.com/hi!", - {"shorten": True}, - u'A http://reallylongdomainnametha...!'), # noqa: E501 - - ("A file:///passwords.txt and http://web.com link", {}, - u'A file:///passwords.txt and http://web.com link'), - - ("A file:///passwords.txt and http://web.com link", - {"permitted_protocols": ["file"]}, - u'A file:///passwords.txt and http://web.com link'), - - ("www.external-link.com", - {"extra_params": 'rel="nofollow" class="external"'}, - u'www.external-link.com'), # noqa: E501 - - ("www.external-link.com and www.internal-link.com/blogs extra", - {"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501 - u'www.external-link.com' # noqa: E501 - u' and www.internal-link.com/blogs extra'), # noqa: E501 - - ("www.external-link.com", - {"extra_params": lambda href: ' rel="nofollow" class="external" '}, - u'www.external-link.com'), # noqa: E501 + ( + "http://foo.com/blah_blah", + {}, + u'http://foo.com/blah_blah', + ), + ( + "http://foo.com/blah_blah/", + {}, + u'http://foo.com/blah_blah/', + ), + ( + "(Something like http://foo.com/blah_blah)", + {}, + u'(Something like http://foo.com/blah_blah)', + ), + ( + "http://foo.com/blah_blah_(wikipedia)", + {}, + u'http://foo.com/blah_blah_(wikipedia)', + ), + ( + "http://foo.com/blah_(blah)_(wikipedia)_blah", + {}, + u'http://foo.com/blah_(blah)_(wikipedia)_blah', # noqa: E501 + ), + ( + "(Something like http://foo.com/blah_blah_(wikipedia))", + {}, + u'(Something like http://foo.com/blah_blah_(wikipedia))', # noqa: E501 + ), + ( + "http://foo.com/blah_blah.", + {}, + u'http://foo.com/blah_blah.', + ), + ( + "http://foo.com/blah_blah/.", + {}, + u'http://foo.com/blah_blah/.', + ), + ( + "", + {}, + u'<http://foo.com/blah_blah>', + ), + ( + "", + {}, + u'<http://foo.com/blah_blah/>', + ), + ( + "http://foo.com/blah_blah,", + {}, + u'http://foo.com/blah_blah,', + ), + ( + "http://www.example.com/wpstyle/?p=364.", + {}, + u'http://www.example.com/wpstyle/?p=364.', # noqa: E501 + ), + ( + "rdar://1234", + {"permitted_protocols": ["http", "rdar"]}, + u'rdar://1234', + ), + ( + "rdar:/1234", + {"permitted_protocols": ["rdar"]}, + u'rdar:/1234', + ), + ( + "http://userid:password@example.com:8080", + {}, + u'http://userid:password@example.com:8080', # noqa: E501 + ), + ( + "http://userid@example.com", + {}, + u'http://userid@example.com', + ), + ( + "http://userid@example.com:8080", + {}, + u'http://userid@example.com:8080', + ), + ( + "http://userid:password@example.com", + {}, + u'http://userid:password@example.com', + ), + ( + "message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", + {"permitted_protocols": ["http", "message"]}, + u'' + u"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e", + ), + ( + u"http://\u27a1.ws/\u4a39", + {}, + u'http://\u27a1.ws/\u4a39', + ), + ( + "http://example.com", + {}, + u'<tag>http://example.com</tag>', + ), + ( + "Just a www.example.com link.", + {}, + u'Just a www.example.com link.', + ), + ( + "Just a www.example.com link.", + {"require_protocol": True}, + u"Just a www.example.com link.", + ), + ( + "A http://reallylong.com/link/that/exceedsthelenglimit.html", + {"require_protocol": True, "shorten": True}, + u'A http://reallylong.com/link...', # noqa: E501 + ), + ( + "A http://reallylongdomainnamethatwillbetoolong.com/hi!", + {"shorten": True}, + u'A http://reallylongdomainnametha...!', # noqa: E501 + ), + ( + "A file:///passwords.txt and http://web.com link", + {}, + u'A file:///passwords.txt and http://web.com link', + ), + ( + "A file:///passwords.txt and http://web.com link", + {"permitted_protocols": ["file"]}, + u'A file:///passwords.txt and http://web.com link', + ), + ( + "www.external-link.com", + {"extra_params": 'rel="nofollow" class="external"'}, + u'www.external-link.com', # noqa: E501 + ), + ( + "www.external-link.com and www.internal-link.com/blogs extra", + { + "extra_params": lambda href: 'class="internal"' + if href.startswith("http://www.internal-link.com") + else 'rel="nofollow" class="external"' + }, + u'www.external-link.com' # noqa: E501 + u' and www.internal-link.com/blogs extra', # noqa: E501 + ), + ( + "www.external-link.com", + {"extra_params": lambda href: ' rel="nofollow" class="external" '}, + u'www.external-link.com', # noqa: E501 + ), ] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]] @@ -148,10 +220,8 @@ class EscapeTestCase(unittest.TestCase): ("", "<foo>"), (u"", u"<foo>"), (b"", b"<foo>"), - ("<>&\"'", "<>&"'"), ("&", "&amp;"), - (u"<\u00e9>", u"<\u00e9>"), (b"<\xc3\xa9>", b"<\xc3\xa9>"), ] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]] @@ -161,13 +231,13 @@ class EscapeTestCase(unittest.TestCase): def test_xhtml_unescape_numeric(self): tests = [ - ('foo bar', 'foo bar'), - ('foo bar', 'foo bar'), - ('foo bar', 'foo bar'), - ('foo઼bar', u'foo\u0abcbar'), - ('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding - ('foo&#;bar', 'foo&#;bar'), # invalid encoding - ('foo&#x;bar', 'foo&#x;bar'), # invalid encoding + ("foo bar", "foo bar"), + ("foo bar", "foo bar"), + ("foo bar", "foo bar"), + ("foo઼bar", u"foo\u0abcbar"), + ("foo&#xyz;bar", "foo&#xyz;bar"), # invalid encoding + ("foo&#;bar", "foo&#;bar"), # invalid encoding + ("foo&#x;bar", "foo&#x;bar"), # invalid encoding ] for escaped, unescaped in tests: self.assertEqual(unescaped, xhtml_unescape(escaped)) @@ -175,20 +245,19 @@ class EscapeTestCase(unittest.TestCase): def test_url_escape_unicode(self): tests = [ # byte strings are passed through as-is - (u'\u00e9'.encode('utf8'), '%C3%A9'), - (u'\u00e9'.encode('latin1'), '%E9'), - + (u"\u00e9".encode("utf8"), "%C3%A9"), + (u"\u00e9".encode("latin1"), "%E9"), # unicode strings become utf8 - (u'\u00e9', '%C3%A9'), + (u"\u00e9", "%C3%A9"), ] # type: List[Tuple[Union[str, bytes], str]] for unescaped, escaped in tests: self.assertEqual(url_escape(unescaped), escaped) def test_url_unescape_unicode(self): tests = [ - ('%C3%A9', u'\u00e9', 'utf8'), - ('%C3%A9', u'\u00c3\u00a9', 'latin1'), - ('%C3%A9', utf8(u'\u00e9'), None), + ("%C3%A9", u"\u00e9", "utf8"), + ("%C3%A9", u"\u00c3\u00a9", "latin1"), + ("%C3%A9", utf8(u"\u00e9"), None), ] for escaped, unescaped, encoding in tests: # input strings to url_unescape should only contain ascii @@ -198,17 +267,17 @@ class EscapeTestCase(unittest.TestCase): self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped) def test_url_escape_quote_plus(self): - unescaped = '+ #%' - plus_escaped = '%2B+%23%25' - escaped = '%2B%20%23%25' + unescaped = "+ #%" + plus_escaped = "%2B+%23%25" + escaped = "%2B%20%23%25" self.assertEqual(url_escape(unescaped), plus_escaped) self.assertEqual(url_escape(unescaped, plus=False), escaped) self.assertEqual(url_unescape(plus_escaped), unescaped) self.assertEqual(url_unescape(escaped, plus=False), unescaped) - self.assertEqual(url_unescape(plus_escaped, encoding=None), - utf8(unescaped)) - self.assertEqual(url_unescape(escaped, encoding=None, plus=False), - utf8(unescaped)) + self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped)) + self.assertEqual( + url_unescape(escaped, encoding=None, plus=False), utf8(unescaped) + ) def test_escape_return_types(self): # On python2 the escape methods should generally return the same @@ -235,17 +304,19 @@ class EscapeTestCase(unittest.TestCase): self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9") def test_squeeze(self): - self.assertEqual(squeeze(u'sequences of whitespace chars'), - u'sequences of whitespace chars') + self.assertEqual( + squeeze(u"sequences of whitespace chars"), + u"sequences of whitespace chars", + ) def test_recursive_unicode(self): tests = { - 'dict': {b"foo": b"bar"}, - 'list': [b"foo", b"bar"], - 'tuple': (b"foo", b"bar"), - 'bytes': b"foo" + "dict": {b"foo": b"bar"}, + "list": [b"foo", b"bar"], + "tuple": (b"foo", b"bar"), + "bytes": b"foo", } - self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"}) - self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"]) - self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar")) - self.assertEqual(recursive_unicode(tests['bytes']), u"foo") + self.assertEqual(recursive_unicode(tests["dict"]), {u"foo": u"bar"}) + self.assertEqual(recursive_unicode(tests["list"]), [u"foo", u"bar"]) + self.assertEqual(recursive_unicode(tests["tuple"]), (u"foo", u"bar")) + self.assertEqual(recursive_unicode(tests["bytes"]), u"foo") diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index dcc50583b..8a84e4384 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -49,12 +49,14 @@ class GenBasicTest(AsyncTestCase): @gen.coroutine def f(): pass + self.io_loop.run_sync(f) def test_exception_phase1(self): @gen.coroutine def f(): 1 / 0 + self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f) def test_exception_phase2(self): @@ -62,24 +64,28 @@ class GenBasicTest(AsyncTestCase): def f(): yield gen.moment 1 / 0 + self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f) def test_bogus_yield(self): @gen.coroutine def f(): yield 42 + self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f) def test_bogus_yield_tuple(self): @gen.coroutine def f(): yield (1, 2) + self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f) def test_reuse(self): @gen.coroutine def f(): yield gen.moment + self.io_loop.run_sync(f) self.io_loop.run_sync(f) @@ -87,6 +93,7 @@ class GenBasicTest(AsyncTestCase): @gen.coroutine def f(): yield None + self.io_loop.run_sync(f) def test_multi(self): @@ -94,6 +101,7 @@ class GenBasicTest(AsyncTestCase): def f(): results = yield [self.add_one_async(1), self.add_one_async(2)] self.assertEqual(results, [2, 3]) + self.io_loop.run_sync(f) def test_multi_dict(self): @@ -101,28 +109,29 @@ class GenBasicTest(AsyncTestCase): def f(): results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2)) self.assertEqual(results, dict(foo=2, bar=3)) + self.io_loop.run_sync(f) def test_multi_delayed(self): @gen.coroutine def f(): # callbacks run at different times - responses = yield gen.multi_future([ - self.delay(3, "v1"), - self.delay(1, "v2"), - ]) + responses = yield gen.multi_future( + [self.delay(3, "v1"), self.delay(1, "v2")] + ) self.assertEqual(responses, ["v1", "v2"]) + self.io_loop.run_sync(f) def test_multi_dict_delayed(self): @gen.coroutine def f(): # callbacks run at different times - responses = yield gen.multi_future(dict( - foo=self.delay(3, "v1"), - bar=self.delay(1, "v2"), - )) + responses = yield gen.multi_future( + dict(foo=self.delay(3, "v1"), bar=self.delay(1, "v2")) + ) self.assertEqual(responses, dict(foo="v1", bar="v2")) + self.io_loop.run_sync(f) @skipOnTravis @@ -171,40 +180,53 @@ class GenBasicTest(AsyncTestCase): def test_multi_exceptions(self): with ExpectLog(app_log, "Multiple exceptions in yield list"): with self.assertRaises(RuntimeError) as cm: - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))]) + yield gen.Multi( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ] + ) self.assertEqual(str(cm.exception), "error 1") # With only one exception, no error is logged. with self.assertRaises(RuntimeError): - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_future(2)]) + yield gen.Multi( + [self.async_exception(RuntimeError("error 1")), self.async_future(2)] + ) # Exception logging may be explicitly quieted. with self.assertRaises(RuntimeError): - yield gen.Multi([self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))], - quiet_exceptions=RuntimeError) + yield gen.Multi( + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ], + quiet_exceptions=RuntimeError, + ) @gen_test def test_multi_future_exceptions(self): with ExpectLog(app_log, "Multiple exceptions in yield list"): with self.assertRaises(RuntimeError) as cm: - yield [self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))] + yield [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ] self.assertEqual(str(cm.exception), "error 1") # With only one exception, no error is logged. with self.assertRaises(RuntimeError): - yield [self.async_exception(RuntimeError("error 1")), - self.async_future(2)] + yield [self.async_exception(RuntimeError("error 1")), self.async_future(2)] # Exception logging may be explicitly quieted. with self.assertRaises(RuntimeError): yield gen.multi_future( - [self.async_exception(RuntimeError("error 1")), - self.async_exception(RuntimeError("error 2"))], - quiet_exceptions=RuntimeError) + [ + self.async_exception(RuntimeError("error 1")), + self.async_exception(RuntimeError("error 2")), + ], + quiet_exceptions=RuntimeError, + ) def test_sync_raise_return(self): @gen.coroutine @@ -291,6 +313,7 @@ class GenCoroutineTest(AsyncTestCase): @gen.coroutine def f(): raise gen.Return(42) + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -301,6 +324,7 @@ class GenCoroutineTest(AsyncTestCase): def f(): yield gen.moment raise gen.Return(42) + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -310,6 +334,7 @@ class GenCoroutineTest(AsyncTestCase): @gen.coroutine def f(): return 42 + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -320,6 +345,7 @@ class GenCoroutineTest(AsyncTestCase): def f(): yield gen.moment return 42 + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -334,6 +360,7 @@ class GenCoroutineTest(AsyncTestCase): if True: return 42 yield gen.Task(self.io_loop.add_callback) + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -351,6 +378,7 @@ class GenCoroutineTest(AsyncTestCase): async def f2(): result = await f1() return result + result = yield f2() self.assertEqual(result, 42) self.finished = True @@ -361,8 +389,10 @@ class GenCoroutineTest(AsyncTestCase): # `yield None`) async def f(): import asyncio + await asyncio.sleep(0) return 42 + result = yield f() self.assertEqual(result, 42) self.finished = True @@ -400,6 +430,7 @@ class GenCoroutineTest(AsyncTestCase): @gen.coroutine def f(): return + result = yield f() self.assertEqual(result, None) self.finished = True @@ -411,6 +442,7 @@ class GenCoroutineTest(AsyncTestCase): def f(): yield gen.moment return + result = yield f() self.assertEqual(result, None) self.finished = True @@ -420,6 +452,7 @@ class GenCoroutineTest(AsyncTestCase): @gen.coroutine def f(): 1 / 0 + # The exception is raised when the future is yielded # (or equivalently when its result method is called), # not when the function itself is called). @@ -434,6 +467,7 @@ class GenCoroutineTest(AsyncTestCase): def f(): yield gen.moment 1 / 0 + future = f() with self.assertRaises(ZeroDivisionError): yield future @@ -487,22 +521,23 @@ class GenCoroutineTest(AsyncTestCase): for i in range(5): calls.append(name) yield yieldable + # First, confirm the behavior without moment: each coroutine # monopolizes the event loop until it finishes. immediate = Future() # type: Future[None] immediate.set_result(None) - yield [f('a', immediate), f('b', immediate)] - self.assertEqual(''.join(calls), 'aaaaabbbbb') + yield [f("a", immediate), f("b", immediate)] + self.assertEqual("".join(calls), "aaaaabbbbb") # With moment, they take turns. calls = [] - yield [f('a', gen.moment), f('b', gen.moment)] - self.assertEqual(''.join(calls), 'ababababab') + yield [f("a", gen.moment), f("b", gen.moment)] + self.assertEqual("".join(calls), "ababababab") self.finished = True calls = [] - yield [f('a', gen.moment), f('b', immediate)] - self.assertEqual(''.join(calls), 'abbbbbaaaa') + yield [f("a", gen.moment), f("b", immediate)] + self.assertEqual("".join(calls), "abbbbbaaaa") @gen_test def test_sleep(self): @@ -533,8 +568,9 @@ class GenCoroutineTest(AsyncTestCase): self.finished = True @skipNotCPython - @unittest.skipIf((3,) < sys.version_info < (3, 6), - "asyncio.Future has reference cycles") + @unittest.skipIf( + (3,) < sys.version_info < (3, 6), "asyncio.Future has reference cycles" + ) def test_coroutine_refcounting(self): # On CPython, tasks and their arguments should be released immediately # without waiting for garbage collection. @@ -542,13 +578,15 @@ class GenCoroutineTest(AsyncTestCase): def inner(): class Foo(object): pass + local_var = Foo() self.local_ref = weakref.ref(local_var) def dummy(): pass + yield gen.coroutine(dummy)() - raise ValueError('Some error') + raise ValueError("Some error") @gen.coroutine def inner2(): @@ -576,8 +614,7 @@ class GenCoroutineTest(AsyncTestCase): self.assertIsInstance(coro, asyncio.Future) # We expect the coroutine repr() to show the place where # it was instantiated - expected = ("created at %s:%d" - % (__file__, f.__code__.co_firstlineno + 3)) + expected = "created at %s:%d" % (__file__, f.__code__.co_firstlineno + 3) actual = repr(coro) self.assertIn(expected, actual) @@ -624,15 +661,15 @@ class UndecoratedCoroutinesHandler(RequestHandler): def prepare(self): self.chunks = [] # type: List[str] yield gen.moment - self.chunks.append('1') + self.chunks.append("1") @gen.coroutine def get(self): - self.chunks.append('2') + self.chunks.append("2") yield gen.moment - self.chunks.append('3') + self.chunks.append("3") yield gen.moment - self.write(''.join(self.chunks)) + self.write("".join(self.chunks)) class AsyncPrepareErrorHandler(RequestHandler): @@ -642,7 +679,7 @@ class AsyncPrepareErrorHandler(RequestHandler): raise HTTPError(403) def get(self): - self.finish('ok') + self.finish("ok") class NativeCoroutineHandler(RequestHandler): @@ -653,86 +690,91 @@ class NativeCoroutineHandler(RequestHandler): class GenWebTest(AsyncHTTPTestCase): def get_app(self): - return Application([ - ('/coroutine_sequence', GenCoroutineSequenceHandler), - ('/coroutine_unfinished_sequence', - GenCoroutineUnfinishedSequenceHandler), - ('/undecorated_coroutine', UndecoratedCoroutinesHandler), - ('/async_prepare_error', AsyncPrepareErrorHandler), - ('/native_coroutine', NativeCoroutineHandler), - ]) + return Application( + [ + ("/coroutine_sequence", GenCoroutineSequenceHandler), + ( + "/coroutine_unfinished_sequence", + GenCoroutineUnfinishedSequenceHandler, + ), + ("/undecorated_coroutine", UndecoratedCoroutinesHandler), + ("/async_prepare_error", AsyncPrepareErrorHandler), + ("/native_coroutine", NativeCoroutineHandler), + ] + ) def test_coroutine_sequence_handler(self): - response = self.fetch('/coroutine_sequence') + response = self.fetch("/coroutine_sequence") self.assertEqual(response.body, b"123") def test_coroutine_unfinished_sequence_handler(self): - response = self.fetch('/coroutine_unfinished_sequence') + response = self.fetch("/coroutine_unfinished_sequence") self.assertEqual(response.body, b"123") def test_undecorated_coroutines(self): - response = self.fetch('/undecorated_coroutine') - self.assertEqual(response.body, b'123') + response = self.fetch("/undecorated_coroutine") + self.assertEqual(response.body, b"123") def test_async_prepare_error_handler(self): - response = self.fetch('/async_prepare_error') + response = self.fetch("/async_prepare_error") self.assertEqual(response.code, 403) def test_native_coroutine_handler(self): - response = self.fetch('/native_coroutine') + response = self.fetch("/native_coroutine") self.assertEqual(response.code, 200) - self.assertEqual(response.body, b'ok') + self.assertEqual(response.body, b"ok") class WithTimeoutTest(AsyncTestCase): @gen_test def test_timeout(self): with self.assertRaises(gen.TimeoutError): - yield gen.with_timeout(datetime.timedelta(seconds=0.1), - Future()) + yield gen.with_timeout(datetime.timedelta(seconds=0.1), Future()) @gen_test def test_completes_before_timeout(self): future = Future() # type: Future[str] - self.io_loop.add_timeout(datetime.timedelta(seconds=0.1), - lambda: future.set_result('asdf')) - result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) - self.assertEqual(result, 'asdf') + self.io_loop.add_timeout( + datetime.timedelta(seconds=0.1), lambda: future.set_result("asdf") + ) + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + self.assertEqual(result, "asdf") @gen_test def test_fails_before_timeout(self): future = Future() # type: Future[str] self.io_loop.add_timeout( datetime.timedelta(seconds=0.1), - lambda: future.set_exception(ZeroDivisionError())) + lambda: future.set_exception(ZeroDivisionError()), + ) with self.assertRaises(ZeroDivisionError): - yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) + yield gen.with_timeout(datetime.timedelta(seconds=3600), future) @gen_test def test_already_resolved(self): future = Future() # type: Future[str] - future.set_result('asdf') - result = yield gen.with_timeout(datetime.timedelta(seconds=3600), - future) - self.assertEqual(result, 'asdf') + future.set_result("asdf") + result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future) + self.assertEqual(result, "asdf") @gen_test def test_timeout_concurrent_future(self): # A concurrent future that does not resolve before the timeout. with futures.ThreadPoolExecutor(1) as executor: with self.assertRaises(gen.TimeoutError): - yield gen.with_timeout(self.io_loop.time(), - executor.submit(time.sleep, 0.1)) + yield gen.with_timeout( + self.io_loop.time(), executor.submit(time.sleep, 0.1) + ) @gen_test def test_completed_concurrent_future(self): # A concurrent future that is resolved before we even submit it # to with_timeout. with futures.ThreadPoolExecutor(1) as executor: + def dummy(): pass + f = executor.submit(dummy) f.result() # wait for completion yield gen.with_timeout(datetime.timedelta(seconds=3600), f) @@ -741,15 +783,17 @@ class WithTimeoutTest(AsyncTestCase): def test_normal_concurrent_future(self): # A conccurrent future that resolves while waiting for the timeout. with futures.ThreadPoolExecutor(1) as executor: - yield gen.with_timeout(datetime.timedelta(seconds=3600), - executor.submit(lambda: time.sleep(0.01))) + yield gen.with_timeout( + datetime.timedelta(seconds=3600), + executor.submit(lambda: time.sleep(0.01)), + ) class WaitIteratorTest(AsyncTestCase): @gen_test def test_empty_iterator(self): g = gen.WaitIterator() - self.assertTrue(g.done(), 'empty generator iterated') + self.assertTrue(g.done(), "empty generator iterated") with self.assertRaises(ValueError): g = gen.WaitIterator(Future(), bar=Future()) @@ -794,14 +838,17 @@ class WaitIteratorTest(AsyncTestCase): while not dg.done(): dr = yield dg.next() if dg.current_index == "f1": - self.assertTrue(dg.current_future == f1 and dr == 24, - "WaitIterator dict status incorrect") + self.assertTrue( + dg.current_future == f1 and dr == 24, + "WaitIterator dict status incorrect", + ) elif dg.current_index == "f2": - self.assertTrue(dg.current_future == f2 and dr == 42, - "WaitIterator dict status incorrect") + self.assertTrue( + dg.current_future == f2 and dr == 42, + "WaitIterator dict status incorrect", + ) else: - self.fail("got bad WaitIterator index {}".format( - dg.current_index)) + self.fail("got bad WaitIterator index {}".format(dg.current_index)) i += 1 @@ -833,18 +880,17 @@ class WaitIteratorTest(AsyncTestCase): try: r = yield g.next() except ZeroDivisionError: - self.assertIs(g.current_future, futures[0], - 'exception future invalid') + self.assertIs(g.current_future, futures[0], "exception future invalid") else: if i == 0: - self.assertEqual(r, 24, 'iterator value incorrect') - self.assertEqual(g.current_index, 2, 'wrong index') + self.assertEqual(r, 24, "iterator value incorrect") + self.assertEqual(g.current_index, 2, "wrong index") elif i == 2: - self.assertEqual(r, 42, 'iterator value incorrect') - self.assertEqual(g.current_index, 1, 'wrong index') + self.assertEqual(r, 42, "iterator value incorrect") + self.assertEqual(g.current_index, 1, "wrong index") elif i == 3: - self.assertEqual(r, 84, 'iterator value incorrect') - self.assertEqual(g.current_index, 3, 'wrong index') + self.assertEqual(r, 84, "iterator value incorrect") + self.assertEqual(g.current_index, 3, "wrong index") i += 1 @gen_test @@ -862,8 +908,8 @@ class WaitIteratorTest(AsyncTestCase): try: async for r in g: if i == 0: - self.assertEqual(r, 24, 'iterator value incorrect') - self.assertEqual(g.current_index, 2, 'wrong index') + self.assertEqual(r, 24, "iterator value incorrect") + self.assertEqual(g.current_index, 2, "wrong index") else: raise Exception("expected exception on iteration 1") i += 1 @@ -871,15 +917,16 @@ class WaitIteratorTest(AsyncTestCase): i += 1 async for r in g: if i == 2: - self.assertEqual(r, 42, 'iterator value incorrect') - self.assertEqual(g.current_index, 1, 'wrong index') + self.assertEqual(r, 42, "iterator value incorrect") + self.assertEqual(g.current_index, 1, "wrong index") elif i == 3: - self.assertEqual(r, 84, 'iterator value incorrect') - self.assertEqual(g.current_index, 3, 'wrong index') + self.assertEqual(r, 84, "iterator value incorrect") + self.assertEqual(g.current_index, 3, "wrong index") else: raise Exception("didn't expect iteration %d" % i) i += 1 self.finished = True + yield f() self.assertTrue(self.finished) @@ -889,14 +936,14 @@ class WaitIteratorTest(AsyncTestCase): # WaitIterator itself, only the Future it returns. Since # WaitIterator uses weak references internally to improve GC # performance, this used to cause problems. - yield gen.with_timeout(datetime.timedelta(seconds=0.1), - gen.WaitIterator(gen.sleep(0)).next()) + yield gen.with_timeout( + datetime.timedelta(seconds=0.1), gen.WaitIterator(gen.sleep(0)).next() + ) class RunnerGCTest(AsyncTestCase): def is_pypy3(self): - return (platform.python_implementation() == 'PyPy' and - sys.version_info > (3,)) + return platform.python_implementation() == "PyPy" and sys.version_info > (3,) @gen_test def test_gc(self): @@ -915,10 +962,7 @@ class RunnerGCTest(AsyncTestCase): self.io_loop.add_callback(callback) yield fut - yield gen.with_timeout( - datetime.timedelta(seconds=0.2), - tester() - ) + yield gen.with_timeout(datetime.timedelta(seconds=0.2), tester()) def test_gc_infinite_coro(self): # Github issue 2229: suspended coroutines should be GCed when @@ -981,7 +1025,7 @@ class RunnerGCTest(AsyncTestCase): yield gen.sleep(0.2) loop.run_sync(do_something) - with ExpectLog('asyncio', "Task was destroyed but it is pending"): + with ExpectLog("asyncio", "Task was destroyed but it is pending"): loop.close() gc.collect() # Future was collected @@ -1005,5 +1049,5 @@ class RunnerGCTest(AsyncTestCase): self.assertEqual(result, [None, None]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tornado/test/http1connection_test.py b/tornado/test/http1connection_test.py index 4d8c04bc5..023c02a30 100644 --- a/tornado/test/http1connection_test.py +++ b/tornado/test/http1connection_test.py @@ -26,8 +26,7 @@ class HTTP1ConnectionTest(AsyncTestCase): add_accept_handler(listener, accept_callback) self.client_stream = IOStream(socket.socket()) self.addCleanup(self.client_stream.close) - yield [self.client_stream.connect(('127.0.0.1', port)), - event.wait()] + yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()] self.io_loop.remove_handler(listener) listener.close() @@ -56,4 +55,4 @@ class HTTP1ConnectionTest(AsyncTestCase): yield conn.read_response(Delegate()) yield event.wait() self.assertEqual(self.code, 200) - self.assertEqual(b''.join(body), b'hello') + self.assertEqual(b"".join(body), b"hello") diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index d6a519d33..d0b46ebec 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -13,7 +13,13 @@ import unittest from tornado.escape import utf8, native_str from tornado import gen -from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient +from tornado.httpclient import ( + HTTPRequest, + HTTPResponse, + _RequestProxy, + HTTPError, + HTTPClient, +) from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from tornado.iostream import IOStream @@ -34,8 +40,10 @@ class HelloWorldHandler(RequestHandler): class PostHandler(RequestHandler): def post(self): - self.finish("Post arg1: %s, arg2: %s" % ( - self.get_argument("arg1"), self.get_argument("arg2"))) + self.finish( + "Post arg1: %s, arg2: %s" + % (self.get_argument("arg1"), self.get_argument("arg2")) + ) class PutHandler(RequestHandler): @@ -46,9 +54,10 @@ class PutHandler(RequestHandler): class RedirectHandler(RequestHandler): def prepare(self): - self.write('redirects can have bodies too') - self.redirect(self.get_argument("url"), - status=int(self.get_argument("status", "302"))) + self.write("redirects can have bodies too") + self.redirect( + self.get_argument("url"), status=int(self.get_argument("status", "302")) + ) class ChunkHandler(RequestHandler): @@ -82,13 +91,13 @@ class EchoPostHandler(RequestHandler): class UserAgentHandler(RequestHandler): def get(self): - self.write(self.request.headers.get('User-Agent', 'User agent not set')) + self.write(self.request.headers.get("User-Agent", "User agent not set")) class ContentLength304Handler(RequestHandler): def get(self): self.set_status(304) - self.set_header('Content-Length', 42) + self.set_header("Content-Length", 42) def _clear_headers_for_304(self): # Tornado strips content-length from 304 responses, but here we @@ -97,14 +106,13 @@ class ContentLength304Handler(RequestHandler): class PatchHandler(RequestHandler): - def patch(self): "Return the request payload - so we can check it is being kept" self.write(self.request.body) class AllMethodsHandler(RequestHandler): - SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',) # type: ignore + SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore def method(self): self.write(self.request.method) @@ -116,10 +124,10 @@ class SetHeaderHandler(RequestHandler): def get(self): # Use get_arguments for keys to get strings, but # request.arguments for values to get bytes. - for k, v in zip(self.get_arguments('k'), - self.request.arguments['v']): + for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]): self.set_header(k, v) + # These tests end up getting run redundantly: once here with the default # HTTPClient implementation, and then again in each implementation's own # test suite. @@ -127,25 +135,28 @@ class SetHeaderHandler(RequestHandler): class HTTPClientCommonTestCase(AsyncHTTPTestCase): def get_app(self): - return Application([ - url("/hello", HelloWorldHandler), - url("/post", PostHandler), - url("/put", PutHandler), - url("/redirect", RedirectHandler), - url("/chunk", ChunkHandler), - url("/auth", AuthHandler), - url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), - url("/echopost", EchoPostHandler), - url("/user_agent", UserAgentHandler), - url("/304_with_content_length", ContentLength304Handler), - url("/all_methods", AllMethodsHandler), - url('/patch', PatchHandler), - url('/set_header', SetHeaderHandler), - ], gzip=True) + return Application( + [ + url("/hello", HelloWorldHandler), + url("/post", PostHandler), + url("/put", PutHandler), + url("/redirect", RedirectHandler), + url("/chunk", ChunkHandler), + url("/auth", AuthHandler), + url("/countdown/([0-9]+)", CountdownHandler, name="countdown"), + url("/echopost", EchoPostHandler), + url("/user_agent", UserAgentHandler), + url("/304_with_content_length", ContentLength304Handler), + url("/all_methods", AllMethodsHandler), + url("/patch", PatchHandler), + url("/set_header", SetHeaderHandler), + ], + gzip=True, + ) def test_patch_receives_payload(self): body = b"some patch data" - response = self.fetch("/patch", method='PATCH', body=body) + response = self.fetch("/patch", method="PATCH", body=body) self.assertEqual(response.code, 200) self.assertEqual(response.body, body) @@ -163,15 +174,13 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): def test_streaming_callback(self): # streaming_callback is also tested in test_chunked chunks = [] # type: typing.List[bytes] - response = self.fetch("/hello", - streaming_callback=chunks.append) + response = self.fetch("/hello", streaming_callback=chunks.append) # with streaming_callback, data goes to the callback and not response.body self.assertEqual(chunks, [b"Hello world!"]) self.assertFalse(response.body) def test_post(self): - response = self.fetch("/post", method="POST", - body="arg1=foo&arg2=bar") + response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar") self.assertEqual(response.code, 200) self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") @@ -180,8 +189,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): self.assertEqual(response.body, b"asdfqwer") chunks = [] # type: typing.List[bytes] - response = self.fetch("/chunk", - streaming_callback=chunks.append) + response = self.fetch("/chunk", streaming_callback=chunks.append) self.assertEqual(chunks, [b"asdf", b"qwer"]) self.assertFalse(response.body) @@ -190,6 +198,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): # over several ioloop iterations, but the connection is already closed. sock, port = bind_unused_port() with closing(sock): + @gen.coroutine def accept_callback(conn, address): # fake an HTTP server using chunked encoding where the final chunks @@ -198,7 +207,8 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase): request_data = yield stream.read_until(b"\r\n\r\n") if b"HTTP/1." not in request_data: self.skipTest("requires HTTP/1.x") - yield stream.write(b"""\ + yield stream.write( + b"""\ HTTP/1.1 200 OK Transfer-Encoding: chunked @@ -208,8 +218,12 @@ Transfer-Encoding: chunked 2 0 -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) stream.close() + netutil.add_accept_handler(sock, accept_callback) # type: ignore resp = self.fetch("http://127.0.0.1:%d/" % port) resp.rethrow() @@ -218,29 +232,38 @@ Transfer-Encoding: chunked def test_basic_auth(self): # This test data appears in section 2 of RFC 7617. - self.assertEqual(self.fetch("/auth", auth_username="Aladdin", - auth_password="open sesame").body, - b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + self.assertEqual( + self.fetch( + "/auth", auth_username="Aladdin", auth_password="open sesame" + ).body, + b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + ) def test_basic_auth_explicit_mode(self): - self.assertEqual(self.fetch("/auth", auth_username="Aladdin", - auth_password="open sesame", - auth_mode="basic").body, - b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + self.assertEqual( + self.fetch( + "/auth", + auth_username="Aladdin", + auth_password="open sesame", + auth_mode="basic", + ).body, + b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + ) def test_basic_auth_unicode(self): # This test data appears in section 2.1 of RFC 7617. - self.assertEqual(self.fetch("/auth", auth_username="test", - auth_password="123£").body, - b"Basic dGVzdDoxMjPCow==") + self.assertEqual( + self.fetch("/auth", auth_username="test", auth_password="123£").body, + b"Basic dGVzdDoxMjPCow==", + ) # The standard mandates NFC. Give it a decomposed username # and ensure it is normalized to composed form. username = unicodedata.normalize("NFD", u"josé") - self.assertEqual(self.fetch("/auth", - auth_username=username, - auth_password="səcrət").body, - b"Basic am9zw6k6c8mZY3LJmXQ=") + self.assertEqual( + self.fetch("/auth", auth_username=username, auth_password="səcrət").body, + b"Basic am9zw6k6c8mZY3LJmXQ=", + ) def test_unsupported_auth_mode(self): # curl and simple clients handle errors a bit differently; the @@ -248,10 +271,13 @@ Transfer-Encoding: chunked # on an unknown mode. with ExpectLog(gen_log, "uncaught exception", required=False): with self.assertRaises((ValueError, HTTPError)): - self.fetch("/auth", auth_username="Aladdin", - auth_password="open sesame", - auth_mode="asdf", - raise_error=True) + self.fetch( + "/auth", + auth_username="Aladdin", + auth_password="open sesame", + auth_mode="asdf", + raise_error=True, + ) def test_follow_redirect(self): response = self.fetch("/countdown/2", follow_redirects=False) @@ -266,31 +292,41 @@ Transfer-Encoding: chunked def test_credentials_in_url(self): url = self.get_url("/auth").replace("http://", "http://me:secret@") response = self.fetch(url) - self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), - response.body) + self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body) def test_body_encoding(self): unicode_body = u"\xe9" byte_body = binascii.a2b_hex(b"e9") # unicode string in body gets converted to utf8 - response = self.fetch("/echopost", method="POST", body=unicode_body, - headers={"Content-Type": "application/blah"}) + response = self.fetch( + "/echopost", + method="POST", + body=unicode_body, + headers={"Content-Type": "application/blah"}, + ) self.assertEqual(response.headers["Content-Length"], "2") self.assertEqual(response.body, utf8(unicode_body)) # byte strings pass through directly - response = self.fetch("/echopost", method="POST", - body=byte_body, - headers={"Content-Type": "application/blah"}) + response = self.fetch( + "/echopost", + method="POST", + body=byte_body, + headers={"Content-Type": "application/blah"}, + ) self.assertEqual(response.headers["Content-Length"], "1") self.assertEqual(response.body, byte_body) # Mixing unicode in headers and byte string bodies shouldn't # break anything - response = self.fetch("/echopost", method="POST", body=byte_body, - headers={"Content-Type": "application/blah"}, - user_agent=u"foo") + response = self.fetch( + "/echopost", + method="POST", + body=byte_body, + headers={"Content-Type": "application/blah"}, + user_agent=u"foo", + ) self.assertEqual(response.headers["Content-Length"], "1") self.assertEqual(response.body, byte_body) @@ -307,37 +343,39 @@ Transfer-Encoding: chunked chunks = [] def header_callback(header_line): - if header_line.startswith('HTTP/1.1 101'): + if header_line.startswith("HTTP/1.1 101"): # Upgrading to HTTP/2 pass - elif header_line.startswith('HTTP/'): + elif header_line.startswith("HTTP/"): first_line.append(header_line) - elif header_line != '\r\n': - k, v = header_line.split(':', 1) + elif header_line != "\r\n": + k, v = header_line.split(":", 1) headers[k.lower()] = v.strip() def streaming_callback(chunk): # All header callbacks are run before any streaming callbacks, # so the header data is available to process the data as it # comes in. - self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8') + self.assertEqual(headers["content-type"], "text/html; charset=UTF-8") chunks.append(chunk) - self.fetch('/chunk', header_callback=header_callback, - streaming_callback=streaming_callback) + self.fetch( + "/chunk", + header_callback=header_callback, + streaming_callback=streaming_callback, + ) self.assertEqual(len(first_line), 1, first_line) - self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n') - self.assertEqual(chunks, [b'asdf', b'qwer']) + self.assertRegexpMatches(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n") + self.assertEqual(chunks, [b"asdf", b"qwer"]) @gen_test def test_configure_defaults(self): - defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False) + defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False) # Construct a new instance of the configured client class - client = self.http_client.__class__(force_instance=True, - defaults=defaults) + client = self.http_client.__class__(force_instance=True, defaults=defaults) try: - response = yield client.fetch(self.get_url('/user_agent')) - self.assertEqual(response.body, b'TestDefaultUserAgent') + response = yield client.fetch(self.get_url("/user_agent")) + self.assertEqual(response.body, b"TestDefaultUserAgent") finally: client.close() @@ -349,36 +387,43 @@ Transfer-Encoding: chunked for value in [u"MyUserAgent", b"MyUserAgent"]: for container in [dict, HTTPHeaders]: headers = container() - headers['User-Agent'] = value - resp = self.fetch('/user_agent', headers=headers) + headers["User-Agent"] = value + resp = self.fetch("/user_agent", headers=headers) self.assertEqual( - resp.body, b"MyUserAgent", - "response=%r, value=%r, container=%r" % - (resp.body, value, container)) + resp.body, + b"MyUserAgent", + "response=%r, value=%r, container=%r" + % (resp.body, value, container), + ) def test_multi_line_headers(self): # Multi-line http headers are rare but rfc-allowed # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 sock, port = bind_unused_port() with closing(sock): + @gen.coroutine def accept_callback(conn, address): stream = IOStream(conn) request_data = yield stream.read_until(b"\r\n\r\n") if b"HTTP/1." not in request_data: self.skipTest("requires HTTP/1.x") - yield stream.write(b"""\ + yield stream.write( + b"""\ HTTP/1.1 200 OK X-XSS-Protection: 1; \tmode=block -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) stream.close() netutil.add_accept_handler(sock, accept_callback) # type: ignore resp = self.fetch("http://127.0.0.1:%d/" % port) resp.rethrow() - self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block") + self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block") self.io_loop.remove_handler(sock.fileno()) def test_304_with_content_length(self): @@ -386,25 +431,27 @@ X-XSS-Protection: 1; # Content-Length or other entity headers, but some servers do it # anyway. # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5 - response = self.fetch('/304_with_content_length') + response = self.fetch("/304_with_content_length") self.assertEqual(response.code, 304) - self.assertEqual(response.headers['Content-Length'], '42') + self.assertEqual(response.headers["Content-Length"], "42") @gen_test def test_future_interface(self): - response = yield self.http_client.fetch(self.get_url('/hello')) - self.assertEqual(response.body, b'Hello world!') + response = yield self.http_client.fetch(self.get_url("/hello")) + self.assertEqual(response.body, b"Hello world!") @gen_test def test_future_http_error(self): with self.assertRaises(HTTPError) as context: - yield self.http_client.fetch(self.get_url('/notfound')) + yield self.http_client.fetch(self.get_url("/notfound")) self.assertEqual(context.exception.code, 404) self.assertEqual(context.exception.response.code, 404) @gen_test def test_future_http_error_no_raise(self): - response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False) + response = yield self.http_client.fetch( + self.get_url("/notfound"), raise_error=False + ) self.assertEqual(response.code, 404) @gen_test @@ -413,48 +460,57 @@ X-XSS-Protection: 1; # a _RequestProxy. # This test uses self.http_client.fetch because self.fetch calls # self.get_url on the input unconditionally. - url = self.get_url('/hello') + url = self.get_url("/hello") response = yield self.http_client.fetch(url) self.assertEqual(response.request.url, url) self.assertTrue(isinstance(response.request, HTTPRequest)) response2 = yield self.http_client.fetch(response.request) - self.assertEqual(response2.body, b'Hello world!') + self.assertEqual(response2.body, b"Hello world!") def test_all_methods(self): - for method in ['GET', 'DELETE', 'OPTIONS']: - response = self.fetch('/all_methods', method=method) + for method in ["GET", "DELETE", "OPTIONS"]: + response = self.fetch("/all_methods", method=method) self.assertEqual(response.body, utf8(method)) - for method in ['POST', 'PUT', 'PATCH']: - response = self.fetch('/all_methods', method=method, body=b'') + for method in ["POST", "PUT", "PATCH"]: + response = self.fetch("/all_methods", method=method, body=b"") self.assertEqual(response.body, utf8(method)) - response = self.fetch('/all_methods', method='HEAD') - self.assertEqual(response.body, b'') - response = self.fetch('/all_methods', method='OTHER', - allow_nonstandard_methods=True) - self.assertEqual(response.body, b'OTHER') + response = self.fetch("/all_methods", method="HEAD") + self.assertEqual(response.body, b"") + response = self.fetch( + "/all_methods", method="OTHER", allow_nonstandard_methods=True + ) + self.assertEqual(response.body, b"OTHER") def test_body_sanity_checks(self): # These methods require a body. - for method in ('POST', 'PUT', 'PATCH'): + for method in ("POST", "PUT", "PATCH"): with self.assertRaises(ValueError) as context: - self.fetch('/all_methods', method=method, raise_error=True) - self.assertIn('must not be None', str(context.exception)) + self.fetch("/all_methods", method=method, raise_error=True) + self.assertIn("must not be None", str(context.exception)) - resp = self.fetch('/all_methods', method=method, - allow_nonstandard_methods=True) + resp = self.fetch( + "/all_methods", method=method, allow_nonstandard_methods=True + ) self.assertEqual(resp.code, 200) # These methods don't allow a body. - for method in ('GET', 'DELETE', 'OPTIONS'): + for method in ("GET", "DELETE", "OPTIONS"): with self.assertRaises(ValueError) as context: - self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True) - self.assertIn('must be None', str(context.exception)) + self.fetch( + "/all_methods", method=method, body=b"asdf", raise_error=True + ) + self.assertIn("must be None", str(context.exception)) # In most cases this can be overridden, but curl_httpclient # does not allow body with a GET at all. - if method != 'GET': - self.fetch('/all_methods', method=method, body=b'asdf', - allow_nonstandard_methods=True, raise_error=True) + if method != "GET": + self.fetch( + "/all_methods", + method=method, + body=b"asdf", + allow_nonstandard_methods=True, + raise_error=True, + ) self.assertEqual(resp.code, 200) # This test causes odd failures with the combination of @@ -474,8 +530,9 @@ X-XSS-Protection: 1; # self.assertEqual(response.body, b"Post arg1: foo, arg2: bar") def test_put_307(self): - response = self.fetch("/redirect?status=307&url=/put", - method="PUT", body=b"hello") + response = self.fetch( + "/redirect?status=307&url=/put", method="PUT", body=b"hello" + ) response.rethrow() self.assertEqual(response.body, b"Put body: hello") @@ -504,45 +561,45 @@ X-XSS-Protection: 1; class RequestProxyTest(unittest.TestCase): def test_request_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/', - user_agent='foo'), - dict()) - self.assertEqual(proxy.user_agent, 'foo') + proxy = _RequestProxy( + HTTPRequest("http://example.com/", user_agent="foo"), dict() + ) + self.assertEqual(proxy.user_agent, "foo") def test_default_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), - dict(network_interface='foo')) - self.assertEqual(proxy.network_interface, 'foo') + proxy = _RequestProxy( + HTTPRequest("http://example.com/"), dict(network_interface="foo") + ) + self.assertEqual(proxy.network_interface, "foo") def test_both_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/', - proxy_host='foo'), - dict(proxy_host='bar')) - self.assertEqual(proxy.proxy_host, 'foo') + proxy = _RequestProxy( + HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar") + ) + self.assertEqual(proxy.proxy_host, "foo") def test_neither_set(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), - dict()) + proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict()) self.assertIs(proxy.auth_username, None) def test_bad_attribute(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), - dict()) + proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict()) with self.assertRaises(AttributeError): proxy.foo def test_defaults_none(self): - proxy = _RequestProxy(HTTPRequest('http://example.com/'), None) + proxy = _RequestProxy(HTTPRequest("http://example.com/"), None) self.assertIs(proxy.auth_username, None) class HTTPResponseTestCase(unittest.TestCase): def test_str(self): - response = HTTPResponse(HTTPRequest('http://example.com'), # type: ignore - 200, headers={}, buffer=BytesIO()) + response = HTTPResponse( # type: ignore + HTTPRequest("http://example.com"), 200, headers={}, buffer=BytesIO() + ) s = str(response) - self.assertTrue(s.startswith('HTTPResponse(')) - self.assertIn('code=200', s) + self.assertTrue(s.startswith("HTTPResponse(")) + self.assertIn("code=200", s) class SyncHTTPClientTest(unittest.TestCase): @@ -552,9 +609,10 @@ class SyncHTTPClientTest(unittest.TestCase): @gen.coroutine def init_server(): sock, self.port = bind_unused_port() - app = Application([('/', HelloWorldHandler)]) + app = Application([("/", HelloWorldHandler)]) self.server = HTTPServer(app) self.server.add_socket(sock) + self.server_ioloop.run_sync(init_server) self.server_thread = threading.Thread(target=self.server_ioloop.start) @@ -578,56 +636,59 @@ class SyncHTTPClientTest(unittest.TestCase): for i in range(5): yield self.server_ioloop.stop() + self.server_ioloop.add_callback(slow_stop) + self.server_ioloop.add_callback(stop_server) self.server_thread.join() self.http_client.close() self.server_ioloop.close(all_fds=True) def get_url(self, path): - return 'http://127.0.0.1:%d%s' % (self.port, path) + return "http://127.0.0.1:%d%s" % (self.port, path) def test_sync_client(self): - response = self.http_client.fetch(self.get_url('/')) - self.assertEqual(b'Hello world!', response.body) + response = self.http_client.fetch(self.get_url("/")) + self.assertEqual(b"Hello world!", response.body) def test_sync_client_error(self): # Synchronous HTTPClient raises errors directly; no need for # response.rethrow() with self.assertRaises(HTTPError) as assertion: - self.http_client.fetch(self.get_url('/notfound')) + self.http_client.fetch(self.get_url("/notfound")) self.assertEqual(assertion.exception.code, 404) class HTTPRequestTestCase(unittest.TestCase): def test_headers(self): - request = HTTPRequest('http://example.com', headers={'foo': 'bar'}) - self.assertEqual(request.headers, {'foo': 'bar'}) + request = HTTPRequest("http://example.com", headers={"foo": "bar"}) + self.assertEqual(request.headers, {"foo": "bar"}) def test_headers_setter(self): - request = HTTPRequest('http://example.com') - request.headers = {'bar': 'baz'} # type: ignore - self.assertEqual(request.headers, {'bar': 'baz'}) + request = HTTPRequest("http://example.com") + request.headers = {"bar": "baz"} # type: ignore + self.assertEqual(request.headers, {"bar": "baz"}) def test_null_headers_setter(self): - request = HTTPRequest('http://example.com') + request = HTTPRequest("http://example.com") request.headers = None # type: ignore self.assertEqual(request.headers, {}) def test_body(self): - request = HTTPRequest('http://example.com', body='foo') - self.assertEqual(request.body, utf8('foo')) + request = HTTPRequest("http://example.com", body="foo") + self.assertEqual(request.body, utf8("foo")) def test_body_setter(self): - request = HTTPRequest('http://example.com') - request.body = 'foo' # type: ignore - self.assertEqual(request.body, utf8('foo')) + request = HTTPRequest("http://example.com") + request.body = "foo" # type: ignore + self.assertEqual(request.body, utf8("foo")) def test_if_modified_since(self): http_date = datetime.datetime.utcnow() - request = HTTPRequest('http://example.com', if_modified_since=http_date) - self.assertEqual(request.headers, - {'If-Modified-Since': format_timestamp(http_date)}) + request = HTTPRequest("http://example.com", if_modified_since=http_date) + self.assertEqual( + request.headers, {"If-Modified-Since": format_timestamp(http_date)} + ) class HTTPErrorTestCase(unittest.TestCase): @@ -643,7 +704,7 @@ class HTTPErrorTestCase(unittest.TestCase): self.assertEqual(repr(e), "HTTP 403: Forbidden") def test_error_with_response(self): - resp = HTTPResponse(HTTPRequest('http://example.com/'), 403) + resp = HTTPResponse(HTTPRequest("http://example.com/"), 403) with self.assertRaises(HTTPError) as cm: resp.rethrow() e = cm.exception diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 675631cce..6f7f2a500 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -1,16 +1,34 @@ from tornado import gen, netutil from tornado.concurrent import Future -from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str +from tornado.escape import ( + json_decode, + json_encode, + utf8, + _unicode, + recursive_unicode, + native_str, +) from tornado.http1connection import HTTP1Connection from tornado.httpclient import HTTPError from tornado.httpserver import HTTPServer -from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501 +from tornado.httputil import ( + HTTPHeaders, + HTTPMessageDelegate, + HTTPServerConnectionDelegate, + ResponseStartLine, +) # noqa: E501 from tornado.iostream import IOStream from tornado.locks import Event from tornado.log import gen_log from tornado.netutil import ssl_options_to_context from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test # noqa: E501 +from tornado.testing import ( + AsyncHTTPTestCase, + AsyncHTTPSTestCase, + AsyncTestCase, + ExpectLog, + gen_test, +) # noqa: E501 from tornado.test.util import skipOnTravis from tornado.web import Application, RequestHandler, stream_request_body @@ -27,6 +45,7 @@ import unittest from io import BytesIO import typing + if typing.TYPE_CHECKING: from typing import Dict, List # noqa: F401 @@ -46,14 +65,15 @@ def read_stream_body(stream, callback): def finish(self): conn.detach() # type: ignore - callback((self.start_line, self.headers, b''.join(chunks))) + callback((self.start_line, self.headers, b"".join(chunks))) + conn = HTTP1Connection(stream, True) conn.read_response(Delegate()) class HandlerBaseTestCase(AsyncHTTPTestCase): def get_app(self): - return Application([('/', self.__class__.Handler)]) + return Application([("/", self.__class__.Handler)]) def fetch_json(self, *args, **kwargs): response = self.fetch(*args, **kwargs) @@ -80,55 +100,58 @@ class HelloWorldRequestHandler(RequestHandler): # introduced in python3.2, it was present but undocumented in # python 2.7 skipIfOldSSL = unittest.skipIf( - getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0), - "old version of ssl module and/or openssl") + getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0), + "old version of ssl module and/or openssl", +) class BaseSSLTest(AsyncHTTPSTestCase): def get_app(self): - return Application([('/', HelloWorldRequestHandler, - dict(protocol="https"))]) + return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))]) class SSLTestMixin(object): def get_ssl_options(self): - return dict(ssl_version=self.get_ssl_version(), - **AsyncHTTPSTestCase.default_ssl_options()) + return dict( + ssl_version=self.get_ssl_version(), + **AsyncHTTPSTestCase.default_ssl_options() + ) def get_ssl_version(self): raise NotImplementedError() def test_ssl(self): - response = self.fetch('/') + response = self.fetch("/") self.assertEqual(response.body, b"Hello world") def test_large_post(self): - response = self.fetch('/', - method='POST', - body='A' * 5000) + response = self.fetch("/", method="POST", body="A" * 5000) self.assertEqual(response.body, b"Got 5000 bytes in POST") def test_non_ssl_request(self): # Make sure the server closes the connection when it gets a non-ssl # connection, rather than waiting for a timeout or otherwise # misbehaving. - with ExpectLog(gen_log, '(SSL Error|uncaught exception)'): - with ExpectLog(gen_log, 'Uncaught exception', required=False): + with ExpectLog(gen_log, "(SSL Error|uncaught exception)"): + with ExpectLog(gen_log, "Uncaught exception", required=False): with self.assertRaises((IOError, HTTPError)): self.fetch( - self.get_url("/").replace('https:', 'http:'), + self.get_url("/").replace("https:", "http:"), request_timeout=3600, connect_timeout=3600, - raise_error=True) + raise_error=True, + ) def test_error_logging(self): # No stack traces are logged for SSL errors. - with ExpectLog(gen_log, 'SSL Error') as expect_log: + with ExpectLog(gen_log, "SSL Error") as expect_log: with self.assertRaises((IOError, HTTPError)): - self.fetch(self.get_url("/").replace("https:", "http:"), - raise_error=True) + self.fetch( + self.get_url("/").replace("https:", "http:"), raise_error=True + ) self.assertFalse(expect_log.logged_stack) + # Python's SSL implementation differs significantly between versions. # For example, SSLv3 and TLSv1 throw an exception if you try to read # from the socket before the handshake is complete, but the default @@ -154,8 +177,7 @@ class TLSv1Test(BaseSSLTest, SSLTestMixin): class SSLContextTest(BaseSSLTest, SSLTestMixin): def get_ssl_options(self): - context = ssl_options_to_context( - AsyncHTTPSTestCase.get_ssl_options(self)) + context = ssl_options_to_context(AsyncHTTPSTestCase.get_ssl_options(self)) assert isinstance(context, ssl.SSLContext) return context @@ -163,60 +185,78 @@ class SSLContextTest(BaseSSLTest, SSLTestMixin): class BadSSLOptionsTest(unittest.TestCase): def test_missing_arguments(self): application = Application() - self.assertRaises(KeyError, HTTPServer, application, ssl_options={ - "keyfile": "/__missing__.crt", - }) + self.assertRaises( + KeyError, + HTTPServer, + application, + ssl_options={"keyfile": "/__missing__.crt"}, + ) def test_missing_key(self): """A missing SSL key should cause an immediate exception.""" application = Application() module_dir = os.path.dirname(__file__) - existing_certificate = os.path.join(module_dir, 'test.crt') - existing_key = os.path.join(module_dir, 'test.key') - - self.assertRaises((ValueError, IOError), - HTTPServer, application, ssl_options={ - "certfile": "/__mising__.crt", - }) - self.assertRaises((ValueError, IOError), - HTTPServer, application, ssl_options={ - "certfile": existing_certificate, - "keyfile": "/__missing__.key" - }) + existing_certificate = os.path.join(module_dir, "test.crt") + existing_key = os.path.join(module_dir, "test.key") + + self.assertRaises( + (ValueError, IOError), + HTTPServer, + application, + ssl_options={"certfile": "/__mising__.crt"}, + ) + self.assertRaises( + (ValueError, IOError), + HTTPServer, + application, + ssl_options={ + "certfile": existing_certificate, + "keyfile": "/__missing__.key", + }, + ) # This actually works because both files exist - HTTPServer(application, ssl_options={ - "certfile": existing_certificate, - "keyfile": existing_key, - }) + HTTPServer( + application, + ssl_options={"certfile": existing_certificate, "keyfile": existing_key}, + ) class MultipartTestHandler(RequestHandler): def post(self): - self.finish({"header": self.request.headers["X-Header-Encoding-Test"], - "argument": self.get_argument("argument"), - "filename": self.request.files["files"][0].filename, - "filebody": _unicode(self.request.files["files"][0]["body"]), - }) + self.finish( + { + "header": self.request.headers["X-Header-Encoding-Test"], + "argument": self.get_argument("argument"), + "filename": self.request.files["files"][0].filename, + "filebody": _unicode(self.request.files["files"][0]["body"]), + } + ) # This test is also called from wsgi_test class HTTPConnectionTest(AsyncHTTPTestCase): def get_handlers(self): - return [("/multipart", MultipartTestHandler), - ("/hello", HelloWorldRequestHandler)] + return [ + ("/multipart", MultipartTestHandler), + ("/hello", HelloWorldRequestHandler), + ] def get_app(self): return Application(self.get_handlers()) def raw_fetch(self, headers, body, newline=b"\r\n"): with closing(IOStream(socket.socket())) as stream: - self.io_loop.run_sync(lambda: stream.connect(('127.0.0.1', self.get_http_port()))) + self.io_loop.run_sync( + lambda: stream.connect(("127.0.0.1", self.get_http_port())) + ) stream.write( - newline.join(headers + - [utf8("Content-Length: %d" % len(body))]) + - newline + newline + body) + newline.join(headers + [utf8("Content-Length: %d" % len(body))]) + + newline + + newline + + body + ) read_stream_body(stream, self.stop) start_line, headers, body = self.wait() return body @@ -224,22 +264,28 @@ class HTTPConnectionTest(AsyncHTTPTestCase): def test_multipart_form(self): # Encodings here are tricky: Headers are latin1, bodies can be # anything (we use utf8 by default). - response = self.raw_fetch([ - b"POST /multipart HTTP/1.0", - b"Content-Type: multipart/form-data; boundary=1234567890", - b"X-Header-encoding-test: \xe9", - ], - b"\r\n".join([ - b"Content-Disposition: form-data; name=argument", - b"", - u"\u00e1".encode("utf-8"), - b"--1234567890", - u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"), - b"", - u"\u00fa".encode("utf-8"), - b"--1234567890--", - b"", - ])) + response = self.raw_fetch( + [ + b"POST /multipart HTTP/1.0", + b"Content-Type: multipart/form-data; boundary=1234567890", + b"X-Header-encoding-test: \xe9", + ], + b"\r\n".join( + [ + b"Content-Disposition: form-data; name=argument", + b"", + u"\u00e1".encode("utf-8"), + b"--1234567890", + u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode( + "utf8" + ), + b"", + u"\u00fa".encode("utf-8"), + b"--1234567890--", + b"", + ] + ), + ) data = json_decode(response) self.assertEqual(u"\u00e9", data["header"]) self.assertEqual(u"\u00e1", data["argument"]) @@ -249,9 +295,8 @@ class HTTPConnectionTest(AsyncHTTPTestCase): def test_newlines(self): # We support both CRLF and bare LF as line separators. for newline in (b"\r\n", b"\n"): - response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", - newline=newline) - self.assertEqual(response, b'Hello world') + response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline) + self.assertEqual(response, b"Hello world") @gen_test def test_100_continue(self): @@ -260,19 +305,24 @@ class HTTPConnectionTest(AsyncHTTPTestCase): # headers, and then the real response after the body. stream = IOStream(socket.socket()) yield stream.connect(("127.0.0.1", self.get_http_port())) - yield stream.write(b"\r\n".join([ - b"POST /hello HTTP/1.1", - b"Content-Length: 1024", - b"Expect: 100-continue", - b"Connection: close", - b"\r\n"])) + yield stream.write( + b"\r\n".join( + [ + b"POST /hello HTTP/1.1", + b"Content-Length: 1024", + b"Expect: 100-continue", + b"Connection: close", + b"\r\n", + ] + ) + ) data = yield stream.read_until(b"\r\n\r\n") self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) stream.write(b"a" * 1024) first_line = yield stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_data = yield stream.read_until(b"\r\n\r\n") - headers = HTTPHeaders.parse(native_str(header_data.decode('latin1'))) + headers = HTTPHeaders.parse(native_str(header_data.decode("latin1"))) body = yield stream.read_bytes(int(headers["Content-Length"])) self.assertEqual(body, b"Got 1024 bytes in POST") stream.close() @@ -290,30 +340,32 @@ class TypeCheckHandler(RequestHandler): def prepare(self): self.errors = {} # type: Dict[str, str] fields = [ - ('method', str), - ('uri', str), - ('version', str), - ('remote_ip', str), - ('protocol', str), - ('host', str), - ('path', str), - ('query', str), + ("method", str), + ("uri", str), + ("version", str), + ("remote_ip", str), + ("protocol", str), + ("host", str), + ("path", str), + ("query", str), ] for field, expected_type in fields: self.check_type(field, getattr(self.request, field), expected_type) - self.check_type('header_key', list(self.request.headers.keys())[0], str) - self.check_type('header_value', list(self.request.headers.values())[0], str) + self.check_type("header_key", list(self.request.headers.keys())[0], str) + self.check_type("header_value", list(self.request.headers.values())[0], str) - self.check_type('cookie_key', list(self.request.cookies.keys())[0], str) - self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str) + self.check_type("cookie_key", list(self.request.cookies.keys())[0], str) + self.check_type( + "cookie_value", list(self.request.cookies.values())[0].value, str + ) # secure cookies - self.check_type('arg_key', list(self.request.arguments.keys())[0], str) - self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes) + self.check_type("arg_key", list(self.request.arguments.keys())[0], str) + self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes) def post(self): - self.check_type('body', self.request.body, bytes) + self.check_type("body", self.request.body, bytes) self.write(self.errors) def get(self): @@ -322,16 +374,18 @@ class TypeCheckHandler(RequestHandler): def check_type(self, name, obj, expected_type): actual_type = type(obj) if expected_type != actual_type: - self.errors[name] = "expected %s, got %s" % (expected_type, - actual_type) + self.errors[name] = "expected %s, got %s" % (expected_type, actual_type) class HTTPServerTest(AsyncHTTPTestCase): def get_app(self): - return Application([("/echo", EchoHandler), - ("/typecheck", TypeCheckHandler), - ("//doubleslash", EchoHandler), - ]) + return Application( + [ + ("/echo", EchoHandler), + ("/typecheck", TypeCheckHandler), + ("//doubleslash", EchoHandler), + ] + ) def test_query_string_encoding(self): response = self.fetch("/echo?foo=%C3%A9") @@ -354,7 +408,9 @@ class HTTPServerTest(AsyncHTTPTestCase): data = json_decode(response.body) self.assertEqual(data, {}) - response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers) + response = self.fetch( + "/typecheck", method="POST", body="foo=bar", headers=headers + ) data = json_decode(response.body) self.assertEqual(data, {}) @@ -369,25 +425,27 @@ class HTTPServerTest(AsyncHTTPTestCase): def test_malformed_body(self): # parse_qs is pretty forgiving, but it will fail on python 3 # if the data is not utf8. - with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'): + with ExpectLog(gen_log, "Invalid x-www-form-urlencoded body"): response = self.fetch( - '/echo', method="POST", - headers={'Content-Type': 'application/x-www-form-urlencoded'}, - body=b'\xe9') + "/echo", + method="POST", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + body=b"\xe9", + ) self.assertEqual(200, response.code) - self.assertEqual(b'{}', response.body) + self.assertEqual(b"{}", response.body) class HTTPServerRawTest(AsyncHTTPTestCase): def get_app(self): - return Application([ - ('/echo', EchoHandler), - ]) + return Application([("/echo", EchoHandler)]) def setUp(self): super(HTTPServerRawTest, self).setUp() self.stream = IOStream(socket.socket()) - self.io_loop.run_sync(lambda: self.stream.connect(('127.0.0.1', self.get_http_port()))) + self.io_loop.run_sync( + lambda: self.stream.connect(("127.0.0.1", self.get_http_port())) + ) def tearDown(self): self.stream.close() @@ -399,34 +457,33 @@ class HTTPServerRawTest(AsyncHTTPTestCase): self.wait() def test_malformed_first_line_response(self): - with ExpectLog(gen_log, '.*Malformed HTTP request line'): - self.stream.write(b'asdf\r\n\r\n') + with ExpectLog(gen_log, ".*Malformed HTTP request line"): + self.stream.write(b"asdf\r\n\r\n") read_stream_body(self.stream, self.stop) start_line, headers, response = self.wait() - self.assertEqual('HTTP/1.1', start_line.version) + self.assertEqual("HTTP/1.1", start_line.version) self.assertEqual(400, start_line.code) - self.assertEqual('Bad Request', start_line.reason) + self.assertEqual("Bad Request", start_line.reason) def test_malformed_first_line_log(self): - with ExpectLog(gen_log, '.*Malformed HTTP request line'): - self.stream.write(b'asdf\r\n\r\n') + with ExpectLog(gen_log, ".*Malformed HTTP request line"): + self.stream.write(b"asdf\r\n\r\n") # TODO: need an async version of ExpectLog so we don't need # hard-coded timeouts here. - self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), - self.stop) + self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_malformed_headers(self): - with ExpectLog(gen_log, '.*Malformed HTTP message.*no colon in header line'): - self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n') - self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), - self.stop) + with ExpectLog(gen_log, ".*Malformed HTTP message.*no colon in header line"): + self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n") + self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_chunked_request_body(self): # Chunked requests are not widely supported and we don't have a way # to generate them in AsyncHTTPClient, but HTTPServer will read them. - self.stream.write(b"""\ + self.stream.write( + b"""\ POST /echo HTTP/1.1 Transfer-Encoding: chunked Content-Type: application/x-www-form-urlencoded @@ -437,15 +494,19 @@ foo= bar 0 -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) read_stream_body(self.stream, self.stop) start_line, headers, response = self.wait() - self.assertEqual(json_decode(response), {u'foo': [u'bar']}) + self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) def test_chunked_request_uppercase(self): # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is # case-insensitive. - self.stream.write(b"""\ + self.stream.write( + b"""\ POST /echo HTTP/1.1 Transfer-Encoding: Chunked Content-Type: application/x-www-form-urlencoded @@ -456,118 +517,133 @@ foo= bar 0 -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) read_stream_body(self.stream, self.stop) start_line, headers, response = self.wait() - self.assertEqual(json_decode(response), {u'foo': [u'bar']}) + self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) @gen_test def test_invalid_content_length(self): - with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'): - self.stream.write(b"""\ + with ExpectLog(gen_log, ".*Only integer Content-Length is allowed"): + self.stream.write( + b"""\ POST /echo HTTP/1.1 Content-Length: foo bar -""".replace(b"\n", b"\r\n")) +""".replace( + b"\n", b"\r\n" + ) + ) yield self.stream.read_until_close() class XHeaderTest(HandlerBaseTestCase): class Handler(RequestHandler): def get(self): - self.set_header('request-version', self.request.version) - self.write(dict(remote_ip=self.request.remote_ip, - remote_protocol=self.request.protocol)) + self.set_header("request-version", self.request.version) + self.write( + dict( + remote_ip=self.request.remote_ip, + remote_protocol=self.request.protocol, + ) + ) def get_httpserver_options(self): - return dict(xheaders=True, trusted_downstream=['5.5.5.5']) + return dict(xheaders=True, trusted_downstream=["5.5.5.5"]) def test_ip_headers(self): self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1") valid_ipv4 = {"X-Real-IP": "4.4.4.4"} self.assertEqual( - self.fetch_json("/", headers=valid_ipv4)["remote_ip"], - "4.4.4.4") + self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4" + ) valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"} self.assertEqual( - self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], - "4.4.4.4") + self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4" + ) valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"} self.assertEqual( self.fetch_json("/", headers=valid_ipv6)["remote_ip"], - "2620:0:1cfe:face:b00c::3") + "2620:0:1cfe:face:b00c::3", + ) valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"} self.assertEqual( self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"], - "2620:0:1cfe:face:b00c::3") + "2620:0:1cfe:face:b00c::3", + ) invalid_chars = {"X-Real-IP": "4.4.4.4 ' - for p in paths) + return "".join( + '' + for p in paths + ) def render_embed_js(self, js_embed: Iterable[bytes]) -> bytes: """Default method used to render the final embedded js for the @@ -877,8 +952,11 @@ class RequestHandler(object): Override this method in a sub-classed controller to change the output. """ - return b'' + return ( + b'" + ) def render_linked_css(self, css_files: Iterable[str]) -> str: """Default method used to render the final css links for the @@ -896,9 +974,11 @@ class RequestHandler(object): paths.append(path) unique_paths.add(path) - return ''.join('' - for p in paths) + return "".join( + '' + for p in paths + ) def render_embed_css(self, css_embed: Iterable[bytes]) -> bytes: """Default method used to render the final embedded css for the @@ -906,8 +986,7 @@ class RequestHandler(object): Override this method in a sub-classed controller to change the output. """ - return b'' + return b'" def render_string(self, template_name: str, **kwargs: Any) -> bytes: """Generate the given template with the given arguments. @@ -953,7 +1032,7 @@ class RequestHandler(object): pgettext=self.locale.pgettext, static_url=self.static_url, xsrf_form_html=self.xsrf_form_html, - reverse_url=self.reverse_url + reverse_url=self.reverse_url, ) namespace.update(self.ui) return namespace @@ -979,7 +1058,7 @@ class RequestHandler(object): kwargs["whitespace"] = settings["template_whitespace"] return template.Loader(template_path, **kwargs) - def flush(self, include_footers: bool=False) -> 'Future[None]': + def flush(self, include_footers: bool = False) -> "Future[None]": """Flushes the current output buffer to the network. The ``callback`` argument, if given, can be used for flow control: @@ -1002,13 +1081,12 @@ class RequestHandler(object): self._headers_written = True for transform in self._transforms: assert chunk is not None - self._status_code, self._headers, chunk = \ - transform.transform_first_chunk( - self._status_code, self._headers, - chunk, include_footers) + self._status_code, self._headers, chunk = transform.transform_first_chunk( + self._status_code, self._headers, chunk, include_footers + ) # Ignore the chunk and only write the headers for HEAD requests if self.request.method == "HEAD": - chunk = b'' + chunk = b"" # Finalize the cookie headers (which have been stored in a side # object so an outgoing cookie could be overwritten before it @@ -1017,11 +1095,10 @@ class RequestHandler(object): for cookie in self._new_cookie.values(): self.add_header("Set-Cookie", cookie.OutputString(None)) - start_line = httputil.ResponseStartLine('', - self._status_code, - self._reason) + start_line = httputil.ResponseStartLine("", self._status_code, self._reason) return self.request.connection.write_headers( - start_line, self._headers, chunk) + start_line, self._headers, chunk + ) else: for transform in self._transforms: chunk = transform.transform_chunk(chunk, include_footers) @@ -1033,7 +1110,7 @@ class RequestHandler(object): future.set_result(None) return future - def finish(self, chunk: Union[str, bytes, dict]=None) -> 'Future[None]': + def finish(self, chunk: Union[str, bytes, dict] = None) -> "Future[None]": """Finishes this response, ending the HTTP request. Passing a ``chunk`` to ``finish()`` is equivalent to passing that @@ -1057,16 +1134,21 @@ class RequestHandler(object): # Automatically support ETags and add the Content-Length header if # we have not flushed any content yet. if not self._headers_written: - if (self._status_code == 200 and - self.request.method in ("GET", "HEAD") and - "Etag" not in self._headers): + if ( + self._status_code == 200 + and self.request.method in ("GET", "HEAD") + and "Etag" not in self._headers + ): self.set_etag_header() if self.check_etag_header(): self._write_buffer = [] self.set_status(304) - if (self._status_code in (204, 304) or - (self._status_code >= 100 and self._status_code < 200)): - assert not self._write_buffer, "Cannot send body with %s" % self._status_code + if self._status_code in (204, 304) or ( + self._status_code >= 100 and self._status_code < 200 + ): + assert not self._write_buffer, ( + "Cannot send body with %s" % self._status_code + ) self._clear_headers_for_304() elif "Content-Length" not in self._headers: content_length = sum(len(part) for part in self._write_buffer) @@ -1107,7 +1189,7 @@ class RequestHandler(object): # _ui_module closures to allow for faster GC on CPython. self.ui = None # type: ignore - def send_error(self, status_code: int=500, **kwargs: Any) -> None: + def send_error(self, status_code: int = 500, **kwargs: Any) -> None: """Sends the given HTTP error code to the browser. If `flush()` has already been called, it is not possible to send @@ -1128,14 +1210,13 @@ class RequestHandler(object): try: self.finish() except Exception: - gen_log.error("Failed to flush partial response", - exc_info=True) + gen_log.error("Failed to flush partial response", exc_info=True) return self.clear() - reason = kwargs.get('reason') - if 'exc_info' in kwargs: - exception = kwargs['exc_info'][1] + reason = kwargs.get("reason") + if "exc_info" in kwargs: + exception = kwargs["exc_info"][1] if isinstance(exception, HTTPError) and exception.reason: reason = exception.reason self.set_status(status_code, reason=reason) @@ -1160,16 +1241,16 @@ class RequestHandler(object): """ if self.settings.get("serve_traceback") and "exc_info" in kwargs: # in debug mode, try to send a traceback - self.set_header('Content-Type', 'text/plain') + self.set_header("Content-Type", "text/plain") for line in traceback.format_exception(*kwargs["exc_info"]): self.write(line) self.finish() else: - self.finish("%(code)d: %(message)s" - "%(code)d: %(message)s" % { - "code": status_code, - "message": self._reason, - }) + self.finish( + "%(code)d: %(message)s" + "%(code)d: %(message)s" + % {"code": status_code, "message": self._reason} + ) @property def locale(self) -> tornado.locale.Locale: @@ -1206,7 +1287,7 @@ class RequestHandler(object): """ return None - def get_browser_locale(self, default: str="en_US") -> tornado.locale.Locale: + def get_browser_locale(self, default: str = "en_US") -> tornado.locale.Locale: """Determines the user's locale from ``Accept-Language`` header. See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4 @@ -1334,19 +1415,24 @@ class RequestHandler(object): self._xsrf_token = binascii.b2a_hex(token) elif output_version == 2: mask = os.urandom(4) - self._xsrf_token = b"|".join([ - b"2", - binascii.b2a_hex(mask), - binascii.b2a_hex(_websocket_mask(mask, token)), - utf8(str(int(timestamp)))]) + self._xsrf_token = b"|".join( + [ + b"2", + binascii.b2a_hex(mask), + binascii.b2a_hex(_websocket_mask(mask, token)), + utf8(str(int(timestamp))), + ] + ) else: - raise ValueError("unknown xsrf cookie version %d", - output_version) + raise ValueError("unknown xsrf cookie version %d", output_version) if version is None: expires_days = 30 if self.current_user else None - self.set_cookie("_xsrf", self._xsrf_token, - expires_days=expires_days, - **cookie_kwargs) + self.set_cookie( + "_xsrf", + self._xsrf_token, + expires_days=expires_days, + **cookie_kwargs + ) return self._xsrf_token def _get_raw_xsrf_token(self) -> Tuple[Optional[int], bytes, float]: @@ -1360,7 +1446,7 @@ class RequestHandler(object): * timestamp: the time this token was generated (will not be accurate for version 1 cookies) """ - if not hasattr(self, '_raw_xsrf_token'): + if not hasattr(self, "_raw_xsrf_token"): cookie = self.get_cookie("_xsrf") if cookie: version, token, timestamp = self._decode_xsrf_token(cookie) @@ -1375,8 +1461,9 @@ class RequestHandler(object): self._raw_xsrf_token = (version, token, timestamp) return self._raw_xsrf_token - def _decode_xsrf_token(self, cookie: str) -> Tuple[ - Optional[int], Optional[bytes], Optional[float]]: + def _decode_xsrf_token( + self, cookie: str + ) -> Tuple[Optional[int], Optional[bytes], Optional[float]]: """Convert a cookie string into a the tuple form returned by _get_raw_xsrf_token. """ @@ -1390,8 +1477,7 @@ class RequestHandler(object): _, mask_str, masked_token, timestamp_str = cookie.split("|") mask = binascii.a2b_hex(utf8(mask_str)) - token = _websocket_mask( - mask, binascii.a2b_hex(utf8(masked_token))) + token = _websocket_mask(mask, binascii.a2b_hex(utf8(masked_token))) timestamp = int(timestamp_str) return version, token, timestamp else: @@ -1408,8 +1494,7 @@ class RequestHandler(object): return (version, token, timestamp) except Exception: # Catch exceptions and return nothing instead of failing. - gen_log.debug("Uncaught exception in _decode_xsrf_token", - exc_info=True) + gen_log.debug("Uncaught exception in _decode_xsrf_token", exc_info=True) return None, None, None def check_xsrf_cookie(self) -> None: @@ -1437,9 +1522,11 @@ class RequestHandler(object): Added support for cookie version 2. Both versions 1 and 2 are supported. """ - token = (self.get_argument("_xsrf", None) or - self.request.headers.get("X-Xsrftoken") or - self.request.headers.get("X-Csrftoken")) + token = ( + self.get_argument("_xsrf", None) + or self.request.headers.get("X-Xsrftoken") + or self.request.headers.get("X-Csrftoken") + ) if not token: raise HTTPError(403, "'_xsrf' argument missing from POST") _, token, _ = self._decode_xsrf_token(token) @@ -1462,10 +1549,13 @@ class RequestHandler(object): See `check_xsrf_cookie()` above for more information. """ - return '' + return ( + '' + ) - def static_url(self, path: str, include_host: bool=None, **kwargs: Any) -> str: + def static_url(self, path: str, include_host: bool = None, **kwargs: Any) -> str: """Returns a static URL for the given relative static file path. This method requires you set the ``static_path`` setting in your @@ -1487,8 +1577,9 @@ class RequestHandler(object): """ self.require_setting("static_path", "static_url") - get_url = self.settings.get("static_handler_class", - StaticFileHandler).make_static_url + get_url = self.settings.get( + "static_handler_class", StaticFileHandler + ).make_static_url if include_host is None: include_host = getattr(self, "include_host", False) @@ -1500,11 +1591,13 @@ class RequestHandler(object): return base + get_url(self.settings, path, **kwargs) - def require_setting(self, name: str, feature: str="this feature") -> None: + def require_setting(self, name: str, feature: str = "this feature") -> None: """Raises an exception if the given app setting is not defined.""" if not self.application.settings.get(name): - raise Exception("You must define the '%s' setting in your " - "application to use %s" % (name, feature)) + raise Exception( + "You must define the '%s' setting in your " + "application to use %s" % (name, feature) + ) def reverse_url(self, name: str, *args: Any) -> str: """Alias for `Application.reverse_url`.""" @@ -1555,19 +1648,18 @@ class RequestHandler(object): # Find all weak and strong etag values from If-None-Match header # because RFC 7232 allows multiple etag values in a single header. etags = re.findall( - br'\*|(?:W/)?"[^"]*"', - utf8(self.request.headers.get("If-None-Match", "")) + br'\*|(?:W/)?"[^"]*"', utf8(self.request.headers.get("If-None-Match", "")) ) if not computed_etag or not etags: return False match = False - if etags[0] == b'*': + if etags[0] == b"*": match = True else: # Use a weak comparison when comparing entity-tags. def val(x: bytes) -> bytes: - return x[2:] if x.startswith(b'W/') else x + return x[2:] if x.startswith(b"W/") else x for etag in etags: if val(etag) == val(computed_etag): @@ -1576,20 +1668,25 @@ class RequestHandler(object): return match @gen.coroutine - def _execute(self, transforms: List['OutputTransform'], *args: bytes, - **kwargs: bytes) -> Generator[Any, Any, None]: + def _execute( + self, transforms: List["OutputTransform"], *args: bytes, **kwargs: bytes + ) -> Generator[Any, Any, None]: """Executes this request with the given output transforms.""" self._transforms = transforms try: if self.request.method not in self.SUPPORTED_METHODS: raise HTTPError(405) self.path_args = [self.decode_argument(arg) for arg in args] - self.path_kwargs = dict((k, self.decode_argument(v, name=k)) - for (k, v) in kwargs.items()) + self.path_kwargs = dict( + (k, self.decode_argument(v, name=k)) for (k, v) in kwargs.items() + ) # If XSRF cookies are turned on, reject form submissions without # the proper cookie - if self.request.method not in ("GET", "HEAD", "OPTIONS") and \ - self.application.settings.get("xsrf_cookies"): + if self.request.method not in ( + "GET", + "HEAD", + "OPTIONS", + ) and self.application.settings.get("xsrf_cookies"): self.check_xsrf_cookie() result = self.prepare() @@ -1626,8 +1723,7 @@ class RequestHandler(object): finally: # Unset result to avoid circular references result = None - if (self._prepared_future is not None and - not self._prepared_future.done()): + if self._prepared_future is not None and not self._prepared_future.done(): # In case we failed before setting _prepared_future, do it # now (to unblock the HTTP server). Note that this is not # in a finally block to avoid GC issues prior to Python 3.4. @@ -1650,8 +1746,11 @@ class RequestHandler(object): self.application.log_request(self) def _request_summary(self) -> str: - return "%s %s (%s)" % (self.request.method, self.request.uri, - self.request.remote_ip) + return "%s %s (%s)" % ( + self.request.method, + self.request.uri, + self.request.remote_ip, + ) def _handle_request_exception(self, e: BaseException) -> None: if isinstance(e, Finish): @@ -1675,9 +1774,12 @@ class RequestHandler(object): else: self.send_error(500, exc_info=sys.exc_info()) - def log_exception(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[TracebackType]) -> None: + def log_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: """Override to customize logging of uncaught exceptions. By default logs instances of `HTTPError` as warnings without @@ -1690,14 +1792,17 @@ class RequestHandler(object): if isinstance(value, HTTPError): if value.log_message: format = "%d %s: " + value.log_message - args = ([value.status_code, self._request_summary()] + - list(value.args)) + args = [value.status_code, self._request_summary()] + list(value.args) gen_log.warning(format, *args) else: - app_log.error("Uncaught exception %s\n%r", self._request_summary(), # type: ignore - self.request, exc_info=(typ, value, tb)) - - def _ui_module(self, name: str, module: Type['UIModule']) -> Callable[..., str]: + app_log.error( # type: ignore + "Uncaught exception %s\n%r", + self._request_summary(), + self.request, + exc_info=(typ, value, tb), + ) + + def _ui_module(self, name: str, module: Type["UIModule"]) -> Callable[..., str]: def render(*args: Any, **kwargs: Any) -> str: if not hasattr(self, "_active_modules"): self._active_modules = {} # type: Dict[str, UIModule] @@ -1705,6 +1810,7 @@ class RequestHandler(object): self._active_modules[name] = module(self) rendered = self._active_modules[name].render(*args, **kwargs) return rendered + return render def _ui_method(self, method: Callable[..., str]) -> Callable[..., str]: @@ -1715,9 +1821,16 @@ class RequestHandler(object): # http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.1) # not explicitly allowed by # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5 - headers = ["Allow", "Content-Encoding", "Content-Language", - "Content-Length", "Content-MD5", "Content-Range", - "Content-Type", "Last-Modified"] + headers = [ + "Allow", + "Content-Encoding", + "Content-Language", + "Content-Length", + "Content-MD5", + "Content-Range", + "Content-Type", + "Last-Modified", + ] for h in headers: self.clear_header(h) @@ -1756,7 +1869,7 @@ def _has_stream_request_body(cls: Type[RequestHandler]) -> bool: def removeslash( - method: Callable[..., Optional[Awaitable[None]]] + method: Callable[..., Optional[Awaitable[None]]] ) -> Callable[..., Optional[Awaitable[None]]]: """Use this decorator to remove trailing slashes from the request path. @@ -1764,8 +1877,11 @@ def removeslash( decorator. Your request handler mapping should use a regular expression like ``r'/foo/*'`` in conjunction with using the decorator. """ + @functools.wraps(method) - def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]: + def wrapper( + self: RequestHandler, *args: Any, **kwargs: Any + ) -> Optional[Awaitable[None]]: if self.request.path.endswith("/"): if self.request.method in ("GET", "HEAD"): uri = self.request.path.rstrip("/") @@ -1777,11 +1893,12 @@ def removeslash( else: raise HTTPError(404) return method(self, *args, **kwargs) + return wrapper def addslash( - method: Callable[..., Optional[Awaitable[None]]] + method: Callable[..., Optional[Awaitable[None]]] ) -> Callable[..., Optional[Awaitable[None]]]: """Use this decorator to add a missing trailing slash to the request path. @@ -1789,8 +1906,11 @@ def addslash( decorator. Your request handler mapping should use a regular expression like ``r'/foo/?'`` in conjunction with using the decorator. """ + @functools.wraps(method) - def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]: + def wrapper( + self: RequestHandler, *args: Any, **kwargs: Any + ) -> Optional[Awaitable[None]]: if not self.request.path.endswith("/"): if self.request.method in ("GET", "HEAD"): uri = self.request.path + "/" @@ -1800,6 +1920,7 @@ def addslash( return None raise HTTPError(404) return method(self, *args, **kwargs) + return wrapper @@ -1814,7 +1935,7 @@ class _ApplicationRouter(ReversibleRuleRouter): `_ApplicationRouter` instance. """ - def __init__(self, application: 'Application', rules: _RuleList=None) -> None: + def __init__(self, application: "Application", rules: _RuleList = None) -> None: assert isinstance(application, Application) self.application = application super(_ApplicationRouter, self).__init__(rules) @@ -1823,16 +1944,23 @@ class _ApplicationRouter(ReversibleRuleRouter): rule = super(_ApplicationRouter, self).process_rule(rule) if isinstance(rule.target, (list, tuple)): - rule.target = _ApplicationRouter(self.application, rule.target) # type: ignore + rule.target = _ApplicationRouter( # type: ignore + self.application, rule.target + ) return rule - def get_target_delegate(self, target: Any, request: httputil.HTTPServerRequest, - **target_params: Any) -> Optional[httputil.HTTPMessageDelegate]: + def get_target_delegate( + self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any + ) -> Optional[httputil.HTTPMessageDelegate]: if isclass(target) and issubclass(target, RequestHandler): - return self.application.get_handler_delegate(request, target, **target_params) + return self.application.get_handler_delegate( + request, target, **target_params + ) - return super(_ApplicationRouter, self).get_target_delegate(target, request, **target_params) + return super(_ApplicationRouter, self).get_target_delegate( + target, request, **target_params + ) class Application(ReversibleRouter): @@ -1918,8 +2046,14 @@ class Application(ReversibleRouter): Integration with the new `tornado.routing` module. """ - def __init__(self, handlers: _RuleList=None, default_host: str=None, - transforms: List[Type['OutputTransform']]=None, **settings: Any) -> None: + + def __init__( + self, + handlers: _RuleList = None, + default_host: str = None, + transforms: List[Type["OutputTransform"]] = None, + **settings: Any + ) -> None: if transforms is None: self.transforms = [] # type: List[Type[OutputTransform]] if settings.get("compress_response") or settings.get("gzip"): @@ -1928,44 +2062,48 @@ class Application(ReversibleRouter): self.transforms = transforms self.default_host = default_host self.settings = settings - self.ui_modules = {'linkify': _linkify, - 'xsrf_form_html': _xsrf_form_html, - 'Template': TemplateModule, - } + self.ui_modules = { + "linkify": _linkify, + "xsrf_form_html": _xsrf_form_html, + "Template": TemplateModule, + } self.ui_methods = {} # type: Dict[str, Callable[..., str]] self._load_ui_modules(settings.get("ui_modules", {})) self._load_ui_methods(settings.get("ui_methods", {})) if self.settings.get("static_path"): path = self.settings["static_path"] handlers = list(handlers or []) - static_url_prefix = settings.get("static_url_prefix", - "/static/") - static_handler_class = settings.get("static_handler_class", - StaticFileHandler) + static_url_prefix = settings.get("static_url_prefix", "/static/") + static_handler_class = settings.get( + "static_handler_class", StaticFileHandler + ) static_handler_args = settings.get("static_handler_args", {}) - static_handler_args['path'] = path - for pattern in [re.escape(static_url_prefix) + r"(.*)", - r"/(favicon\.ico)", r"/(robots\.txt)"]: - handlers.insert(0, (pattern, static_handler_class, - static_handler_args)) - - if self.settings.get('debug'): - self.settings.setdefault('autoreload', True) - self.settings.setdefault('compiled_template_cache', False) - self.settings.setdefault('static_hash_cache', False) - self.settings.setdefault('serve_traceback', True) + static_handler_args["path"] = path + for pattern in [ + re.escape(static_url_prefix) + r"(.*)", + r"/(favicon\.ico)", + r"/(robots\.txt)", + ]: + handlers.insert(0, (pattern, static_handler_class, static_handler_args)) + + if self.settings.get("debug"): + self.settings.setdefault("autoreload", True) + self.settings.setdefault("compiled_template_cache", False) + self.settings.setdefault("static_hash_cache", False) + self.settings.setdefault("serve_traceback", True) self.wildcard_router = _ApplicationRouter(self, handlers) - self.default_router = _ApplicationRouter(self, [ - Rule(AnyMatches(), self.wildcard_router) - ]) + self.default_router = _ApplicationRouter( + self, [Rule(AnyMatches(), self.wildcard_router)] + ) # Automatically reload modified modules - if self.settings.get('autoreload'): + if self.settings.get("autoreload"): from tornado import autoreload + autoreload.start() - def listen(self, port: int, address: str="", **kwargs: Any) -> HTTPServer: + def listen(self, port: int, address: str = "", **kwargs: Any) -> HTTPServer: """Starts an HTTP server for this application on the given port. This is a convenience alias for creating an `.HTTPServer` @@ -2000,31 +2138,31 @@ class Application(ReversibleRouter): self.default_router.rules.insert(-1, rule) if self.default_host is not None: - self.wildcard_router.add_rules([( - DefaultHostMatches(self, host_matcher.host_pattern), - host_handlers - )]) + self.wildcard_router.add_rules( + [(DefaultHostMatches(self, host_matcher.host_pattern), host_handlers)] + ) - def add_transform(self, transform_class: Type['OutputTransform']) -> None: + def add_transform(self, transform_class: Type["OutputTransform"]) -> None: self.transforms.append(transform_class) def _load_ui_methods(self, methods: Any) -> None: if isinstance(methods, types.ModuleType): - self._load_ui_methods(dict((n, getattr(methods, n)) - for n in dir(methods))) + self._load_ui_methods(dict((n, getattr(methods, n)) for n in dir(methods))) elif isinstance(methods, list): for m in methods: self._load_ui_methods(m) else: for name, fn in methods.items(): - if not name.startswith("_") and hasattr(fn, "__call__") \ - and name[0].lower() == name[0]: + if ( + not name.startswith("_") + and hasattr(fn, "__call__") + and name[0].lower() == name[0] + ): self.ui_methods[name] = fn def _load_ui_modules(self, modules: Any) -> None: if isinstance(modules, types.ModuleType): - self._load_ui_modules(dict((n, getattr(modules, n)) - for n in dir(modules))) + self._load_ui_modules(dict((n, getattr(modules, n)) for n in dir(modules))) elif isinstance(modules, list): for m in modules: self._load_ui_modules(m) @@ -2037,31 +2175,37 @@ class Application(ReversibleRouter): except TypeError: pass - def __call__(self, request: httputil.HTTPServerRequest) -> Optional[Awaitable[None]]: + def __call__( + self, request: httputil.HTTPServerRequest + ) -> Optional[Awaitable[None]]: # Legacy HTTPServer interface dispatcher = self.find_handler(request) return dispatcher.execute() - def find_handler(self, request: httputil.HTTPServerRequest, - **kwargs: Any) -> '_HandlerDelegate': + def find_handler( + self, request: httputil.HTTPServerRequest, **kwargs: Any + ) -> "_HandlerDelegate": route = self.default_router.find_handler(request) if route is not None: - return cast('_HandlerDelegate', route) + return cast("_HandlerDelegate", route) - if self.settings.get('default_handler_class'): + if self.settings.get("default_handler_class"): return self.get_handler_delegate( request, - self.settings['default_handler_class'], - self.settings.get('default_handler_args', {})) - - return self.get_handler_delegate( - request, ErrorHandler, {'status_code': 404}) - - def get_handler_delegate(self, request: httputil.HTTPServerRequest, - target_class: Type[RequestHandler], - target_kwargs: Dict[str, Any]=None, - path_args: List[bytes]=None, - path_kwargs: Dict[str, bytes]=None) -> '_HandlerDelegate': + self.settings["default_handler_class"], + self.settings.get("default_handler_args", {}), + ) + + return self.get_handler_delegate(request, ErrorHandler, {"status_code": 404}) + + def get_handler_delegate( + self, + request: httputil.HTTPServerRequest, + target_class: Type[RequestHandler], + target_kwargs: Dict[str, Any] = None, + path_args: List[bytes] = None, + path_kwargs: Dict[str, bytes] = None, + ) -> "_HandlerDelegate": """Returns `~.httputil.HTTPMessageDelegate` that can serve a request for application and `RequestHandler` subclass. @@ -2073,7 +2217,8 @@ class Application(ReversibleRouter): :arg dict path_kwargs: keyword arguments for ``target_class`` HTTP method. """ return _HandlerDelegate( - self, request, target_class, target_kwargs, path_args, path_kwargs) + self, request, target_class, target_kwargs, path_args, path_kwargs + ) def reverse_url(self, name: str, *args: Any) -> str: """Returns a URL path for handler named ``name`` @@ -2108,14 +2253,24 @@ class Application(ReversibleRouter): else: log_method = access_log.error request_time = 1000.0 * handler.request.request_time() - log_method("%d %s %.2fms", handler.get_status(), - handler._request_summary(), request_time) + log_method( + "%d %s %.2fms", + handler.get_status(), + handler._request_summary(), + request_time, + ) class _HandlerDelegate(httputil.HTTPMessageDelegate): - def __init__(self, application: Application, request: httputil.HTTPServerRequest, - handler_class: Type[RequestHandler], handler_kwargs: Optional[Dict[str, Any]], - path_args: Optional[List[bytes]], path_kwargs: Optional[Dict[str, bytes]]) -> None: + def __init__( + self, + application: Application, + request: httputil.HTTPServerRequest, + handler_class: Type[RequestHandler], + handler_kwargs: Optional[Dict[str, Any]], + path_args: Optional[List[bytes]], + path_kwargs: Optional[Dict[str, bytes]], + ) -> None: self.application = application self.connection = request.connection self.request = request @@ -2126,9 +2281,11 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate): self.chunks = [] # type: List[bytes] self.stream_request_body = _has_stream_request_body(self.handler_class) - def headers_received(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> Optional[Awaitable[None]]: if self.stream_request_body: self.request._body_future = Future() return self.execute() @@ -2145,7 +2302,7 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate): if self.stream_request_body: future_set_result_unless_cancelled(self.request._body_future, None) else: - self.request.body = b''.join(self.chunks) + self.request.body = b"".join(self.chunks) self.request._parse_body() self.execute() @@ -2163,11 +2320,12 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate): with RequestHandler._template_loader_lock: for loader in RequestHandler._template_loaders.values(): loader.reset() - if not self.application.settings.get('static_hash_cache', True): + if not self.application.settings.get("static_hash_cache", True): StaticFileHandler.reset() - self.handler = self.handler_class(self.application, self.request, - **self.handler_kwargs) + self.handler = self.handler_class( + self.application, self.request, **self.handler_kwargs + ) transforms = [t(self.request) for t in self.application.transforms] if self.stream_request_body: @@ -2179,8 +2337,7 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate): # except handler, and we cannot easily access the IOLoop here to # call add_future (because of the requirement to remain compatible # with WSGI) - self.handler._execute(transforms, *self.path_args, - **self.path_kwargs) + self.handler._execute(transforms, *self.path_args, **self.path_kwargs) # If we are streaming the request body, then execute() is finished # when the handler has prepared to receive the body. If not, # it doesn't matter when execute() finishes (so we return None) @@ -2209,19 +2366,22 @@ class HTTPError(Exception): determined automatically from ``status_code``, but can be used to use a non-standard numeric code. """ - def __init__(self, status_code: int=500, log_message: str=None, - *args: Any, **kwargs: Any) -> None: + + def __init__( + self, status_code: int = 500, log_message: str = None, *args: Any, **kwargs: Any + ) -> None: self.status_code = status_code self.log_message = log_message self.args = args - self.reason = kwargs.get('reason', None) + self.reason = kwargs.get("reason", None) if log_message and not args: - self.log_message = log_message.replace('%', '%%') + self.log_message = log_message.replace("%", "%%") def __str__(self) -> str: message = "HTTP %d: %s" % ( self.status_code, - self.reason or httputil.responses.get(self.status_code, 'Unknown')) + self.reason or httputil.responses.get(self.status_code, "Unknown"), + ) if self.log_message: return message + " (" + (self.log_message % self.args) + ")" else: @@ -2252,6 +2412,7 @@ class Finish(Exception): Arguments passed to ``Finish()`` will be passed on to `RequestHandler.finish`. """ + pass @@ -2263,14 +2424,17 @@ class MissingArgumentError(HTTPError): .. versionadded:: 3.1 """ + def __init__(self, arg_name: str) -> None: super(MissingArgumentError, self).__init__( - 400, 'Missing argument %s' % arg_name) + 400, "Missing argument %s" % arg_name + ) self.arg_name = arg_name class ErrorHandler(RequestHandler): """Generates an error response with ``status_code`` for all requests.""" + def initialize(self, status_code: int) -> None: # type: ignore self.set_status(status_code) @@ -2316,7 +2480,8 @@ class RedirectHandler(RequestHandler): If any query arguments are present, they will be copied to the destination URL. """ - def initialize(self, url: str, permanent: bool=True) -> None: # type: ignore + + def initialize(self, url: str, permanent: bool = True) -> None: # type: ignore self._url = url self._permanent = permanent @@ -2325,7 +2490,9 @@ class RedirectHandler(RequestHandler): if self.request.query_arguments: # TODO: figure out typing for the next line. to_url = httputil.url_concat( - to_url, list(httputil.qs_to_qsl(self.request.query_arguments))) # type: ignore + to_url, + list(httputil.qs_to_qsl(self.request.query_arguments)), # type: ignore + ) self.redirect(to_url, permanent=self._permanent) @@ -2395,12 +2562,15 @@ class StaticFileHandler(RequestHandler): .. versionchanged:: 3.1 Many of the methods for subclasses were added in Tornado 3.1. """ + CACHE_MAX_AGE = 86400 * 365 * 10 # 10 years _static_hashes = {} # type: Dict[str, Optional[str]] _lock = threading.Lock() # protects _static_hashes - def initialize(self, path: str, default_filename: str=None) -> None: # type: ignore + def initialize( # type: ignore + self, path: str, default_filename: str = None + ) -> None: self.root = path self.default_filename = default_filename @@ -2409,17 +2579,16 @@ class StaticFileHandler(RequestHandler): with cls._lock: cls._static_hashes = {} - def head(self, path: str) -> 'Future[None]': # type: ignore + def head(self, path: str) -> "Future[None]": # type: ignore return self.get(path, include_body=False) @gen.coroutine - def get(self, path: str, include_body: bool=True) -> Generator[Any, Any, None]: + def get(self, path: str, include_body: bool = True) -> Generator[Any, Any, None]: # Set up our path instance variables. self.path = self.parse_url_path(path) del path # make sure we don't refer to path instead of self.path again absolute_path = self.get_absolute_path(self.root, self.path) - self.absolute_path = self.validate_absolute_path( - self.root, absolute_path) + self.absolute_path = self.validate_absolute_path(self.root, absolute_path) if self.absolute_path is None: return @@ -2446,7 +2615,7 @@ class StaticFileHandler(RequestHandler): # content, or when a suffix with length 0 is specified self.set_status(416) # Range Not Satisfiable self.set_header("Content-Type", "text/plain") - self.set_header("Content-Range", "bytes */%s" % (size, )) + self.set_header("Content-Range", "bytes */%s" % (size,)) return if start is not None and start < 0: start += size @@ -2460,8 +2629,9 @@ class StaticFileHandler(RequestHandler): # ``Range: bytes=0-``. if size != (end or size) - (start or 0): self.set_status(206) # Partial Content - self.set_header("Content-Range", - httputil._get_content_range(start, end, size)) + self.set_header( + "Content-Range", httputil._get_content_range(start, end, size) + ) else: start = end = None @@ -2501,7 +2671,7 @@ class StaticFileHandler(RequestHandler): version_hash = self._get_cached_version(self.absolute_path) if not version_hash: return None - return '"%s"' % (version_hash, ) + return '"%s"' % (version_hash,) def set_headers(self) -> None: """Sets the content and caching headers on the response. @@ -2518,11 +2688,12 @@ class StaticFileHandler(RequestHandler): if content_type: self.set_header("Content-Type", content_type) - cache_time = self.get_cache_time(self.path, self.modified, - content_type) + cache_time = self.get_cache_time(self.path, self.modified, content_type) if cache_time > 0: - self.set_header("Expires", datetime.datetime.utcnow() + - datetime.timedelta(seconds=cache_time)) + self.set_header( + "Expires", + datetime.datetime.utcnow() + datetime.timedelta(seconds=cache_time), + ) self.set_header("Cache-Control", "max-age=" + str(cache_time)) self.set_extra_headers(self.path) @@ -2533,7 +2704,7 @@ class StaticFileHandler(RequestHandler): .. versionadded:: 3.1 """ # If client sent If-None-Match, use it, ignore If-Modified-Since - if self.request.headers.get('If-None-Match'): + if self.request.headers.get("If-None-Match"): return self.check_etag_header() # Check the If-Modified-Since, and don't send the result if the @@ -2601,10 +2772,8 @@ class StaticFileHandler(RequestHandler): # The trailing slash also needs to be temporarily added back # the requested path so a request to root/ will match. if not (absolute_path + os.path.sep).startswith(root): - raise HTTPError(403, "%s is not in root static directory", - self.path) - if (os.path.isdir(absolute_path) and - self.default_filename is not None): + raise HTTPError(403, "%s is not in root static directory", self.path) + if os.path.isdir(absolute_path) and self.default_filename is not None: # need to look at the request.path here for when path is empty # but there is some prefix to the path that was already # trimmed by the routing @@ -2619,8 +2788,9 @@ class StaticFileHandler(RequestHandler): return absolute_path @classmethod - def get_content(cls, abspath: str, - start: int=None, end: int=None) -> Generator[bytes, None, None]: + def get_content( + cls, abspath: str, start: int = None, end: int = None + ) -> Generator[bytes, None, None]: """Retrieve the content of the requested resource which is located at the given absolute path. @@ -2676,7 +2846,7 @@ class StaticFileHandler(RequestHandler): def _stat(self) -> os.stat_result: assert self.absolute_path is not None - if not hasattr(self, '_stat_result'): + if not hasattr(self, "_stat_result"): self._stat_result = os.stat(self.absolute_path) return self._stat_result @@ -2711,8 +2881,7 @@ class StaticFileHandler(RequestHandler): # consistency with the past (and because we have a unit test # that relies on this), we truncate the float here, although # I'm not sure that's the right thing to do. - modified = datetime.datetime.utcfromtimestamp( - int(stat_result.st_mtime)) + modified = datetime.datetime.utcfromtimestamp(int(stat_result.st_mtime)) return modified def get_content_type(self) -> str: @@ -2740,8 +2909,9 @@ class StaticFileHandler(RequestHandler): """For subclass to add extra headers to the response""" pass - def get_cache_time(self, path: str, modified: Optional[datetime.datetime], - mime_type: str) -> int: + def get_cache_time( + self, path: str, modified: Optional[datetime.datetime], mime_type: str + ) -> int: """Override to customize cache control behavior. Return a positive number of seconds to make the result @@ -2755,8 +2925,9 @@ class StaticFileHandler(RequestHandler): return self.CACHE_MAX_AGE if "v" in self.request.arguments else 0 @classmethod - def make_static_url(cls, settings: Dict[str, Any], path: str, - include_version: bool=True) -> str: + def make_static_url( + cls, settings: Dict[str, Any], path: str, include_version: bool = True + ) -> str: """Constructs a versioned url for the given path. This method may be overridden in subclasses (but note that it @@ -2775,7 +2946,7 @@ class StaticFileHandler(RequestHandler): file corresponding to the given ``path``. """ - url = settings.get('static_url_prefix', '/static/') + path + url = settings.get("static_url_prefix", "/static/") + path if not include_version: return url @@ -2783,7 +2954,7 @@ class StaticFileHandler(RequestHandler): if not version_hash: return url - return '%s?v=%s' % (url, version_hash) + return "%s?v=%s" % (url, version_hash) def parse_url_path(self, url_path: str) -> str: """Converts a static URL path into a filesystem path. @@ -2812,7 +2983,7 @@ class StaticFileHandler(RequestHandler): `get_content_version` is now preferred as it allows the base class to handle caching of the result. """ - abs_path = cls.get_absolute_path(settings['static_path'], path) + abs_path = cls.get_absolute_path(settings["static_path"], path) return cls._get_cached_version(abs_path) @classmethod @@ -2847,8 +3018,10 @@ class FallbackHandler(RequestHandler): (r".*", FallbackHandler, dict(fallback=wsgi_app), ]) """ - def initialize(self, # type: ignore - fallback: Callable[[httputil.HTTPServerRequest], None]) -> None: + + def initialize( # type: ignore + self, fallback: Callable[[httputil.HTTPServerRequest], None] + ) -> None: self.fallback = fallback def prepare(self) -> None: @@ -2864,12 +3037,16 @@ class OutputTransform(object): or interact with them directly; the framework chooses which transforms (if any) to apply. """ + def __init__(self, request: httputil.HTTPServerRequest) -> None: pass def transform_first_chunk( - self, status_code: int, headers: httputil.HTTPHeaders, - chunk: bytes, finishing: bool + self, + status_code: int, + headers: httputil.HTTPHeaders, + chunk: bytes, + finishing: bool, ) -> Tuple[int, httputil.HTTPHeaders, bytes]: return status_code, headers, chunk @@ -2887,12 +3064,20 @@ class GZipContentEncoding(OutputTransform): of just a whitelist. (the whitelist is still used for certain non-text mime types). """ + # Whitelist of compressible mime types (in addition to any types # beginning with "text/"). - CONTENT_TYPES = set(["application/javascript", "application/x-javascript", - "application/xml", "application/atom+xml", - "application/json", "application/xhtml+xml", - "image/svg+xml"]) + CONTENT_TYPES = set( + [ + "application/javascript", + "application/x-javascript", + "application/xml", + "application/atom+xml", + "application/json", + "application/xhtml+xml", + "image/svg+xml", + ] + ) # Python's GzipFile defaults to level 9, while most other gzip # tools (including gzip itself) default to 6, which is probably a # better CPU/size tradeoff. @@ -2908,27 +3093,33 @@ class GZipContentEncoding(OutputTransform): self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "") def _compressible_type(self, ctype: str) -> bool: - return ctype.startswith('text/') or ctype in self.CONTENT_TYPES + return ctype.startswith("text/") or ctype in self.CONTENT_TYPES def transform_first_chunk( - self, status_code: int, headers: httputil.HTTPHeaders, - chunk: bytes, finishing: bool + self, + status_code: int, + headers: httputil.HTTPHeaders, + chunk: bytes, + finishing: bool, ) -> Tuple[int, httputil.HTTPHeaders, bytes]: # TODO: can/should this type be inherited from the superclass? - if 'Vary' in headers: - headers['Vary'] += ', Accept-Encoding' + if "Vary" in headers: + headers["Vary"] += ", Accept-Encoding" else: - headers['Vary'] = 'Accept-Encoding' + headers["Vary"] = "Accept-Encoding" if self._gzipping: ctype = _unicode(headers.get("Content-Type", "")).split(";")[0] - self._gzipping = self._compressible_type(ctype) and \ - (not finishing or len(chunk) >= self.MIN_LENGTH) and \ - ("Content-Encoding" not in headers) + self._gzipping = ( + self._compressible_type(ctype) + and (not finishing or len(chunk) >= self.MIN_LENGTH) + and ("Content-Encoding" not in headers) + ) if self._gzipping: headers["Content-Encoding"] = "gzip" self._gzip_value = BytesIO() - self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value, - compresslevel=self.GZIP_LEVEL) + self._gzip_file = gzip.GzipFile( + mode="w", fileobj=self._gzip_value, compresslevel=self.GZIP_LEVEL + ) chunk = self.transform_chunk(chunk, finishing) if "Content-Length" in headers: # The original content length is no longer correct. @@ -2955,7 +3146,7 @@ class GZipContentEncoding(OutputTransform): def authenticated( - method: Callable[..., Optional[Awaitable[None]]] + method: Callable[..., Optional[Awaitable[None]]] ) -> Callable[..., Optional[Awaitable[None]]]: """Decorate methods with this to require that the user be logged in. @@ -2967,8 +3158,11 @@ def authenticated( will add a `next` parameter so the login page knows where to send you once you're logged in. """ + @functools.wraps(method) - def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]: + def wrapper( + self: RequestHandler, *args: Any, **kwargs: Any + ) -> Optional[Awaitable[None]]: if not self.current_user: if self.request.method in ("GET", "HEAD"): url = self.get_login_url() @@ -2984,6 +3178,7 @@ def authenticated( return None raise HTTPError(403) return method(self, *args, **kwargs) + return wrapper @@ -2996,6 +3191,7 @@ class UIModule(object): Subclasses of UIModule must override the `render` method. """ + def __init__(self, handler: RequestHandler) -> None: self.handler = handler self.request = handler.request @@ -3078,6 +3274,7 @@ class TemplateModule(UIModule): per instantiation of the template, so they must not depend on any arguments to the template. """ + def __init__(self, handler: RequestHandler) -> None: super(TemplateModule, self).__init__(handler) # keep resources in both a list and a dict to preserve order @@ -3091,11 +3288,13 @@ class TemplateModule(UIModule): self._resource_dict[path] = kwargs else: if self._resource_dict[path] != kwargs: - raise ValueError("set_resources called with different " - "resources for the same template") + raise ValueError( + "set_resources called with different " + "resources for the same template" + ) return "" - return self.render_string(path, set_resources=set_resources, - **kwargs) + + return self.render_string(path, set_resources=set_resources, **kwargs) def _get_resources(self, key: str) -> Iterable[str]: return (r[key] for r in self._resource_list if key in r) @@ -3133,7 +3332,10 @@ class TemplateModule(UIModule): class _UIModuleNamespace(object): """Lazy namespace which creates UIModule proxies bound to a handler.""" - def __init__(self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]]) -> None: + + def __init__( + self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]] + ) -> None: self.handler = handler self.ui_modules = ui_modules @@ -3147,10 +3349,14 @@ class _UIModuleNamespace(object): raise AttributeError(str(e)) -def create_signed_value(secret: _CookieSecretTypes, - name: str, value: Union[str, bytes], - version: int=None, clock: Callable[[], float]=None, - key_version: int=None) -> bytes: +def create_signed_value( + secret: _CookieSecretTypes, + name: str, + value: Union[str, bytes], + version: int = None, + clock: Callable[[], float] = None, + key_version: int = None, +) -> bytes: if version is None: version = DEFAULT_SIGNED_VALUE_VERSION if clock is None: @@ -3180,17 +3386,23 @@ def create_signed_value(secret: _CookieSecretTypes, # - signature (hex-encoded; no length prefix) def format_field(s: Union[str, bytes]) -> bytes: return utf8("%d:" % len(s)) + utf8(s) - to_sign = b"|".join([ - b"2", - format_field(str(key_version or 0)), - format_field(timestamp), - format_field(name), - format_field(value), - b'']) + + to_sign = b"|".join( + [ + b"2", + format_field(str(key_version or 0)), + format_field(timestamp), + format_field(name), + format_field(value), + b"", + ] + ) if isinstance(secret, dict): - assert key_version is not None, 'Key version must be set when sign key dict is used' - assert version >= 2, 'Version must be at least 2 for key version support' + assert ( + key_version is not None + ), "Key version must be set when sign key dict is used" + assert version >= 2, "Version must be at least 2 for key version support" secret = secret[key_version] signature = _create_signature_v2(secret, to_sign) @@ -3227,9 +3439,14 @@ def _get_version(value: bytes) -> int: return version -def decode_signed_value(secret: _CookieSecretTypes, - name: str, value: Union[None, str, bytes], max_age_days: int=31, - clock: Callable[[], float]=None, min_version: int=None) -> Optional[bytes]: +def decode_signed_value( + secret: _CookieSecretTypes, + name: str, + value: Union[None, str, bytes], + max_age_days: int = 31, + clock: Callable[[], float] = None, + min_version: int = None, +) -> Optional[bytes]: if clock is None: clock = time.time if min_version is None: @@ -3246,17 +3463,20 @@ def decode_signed_value(secret: _CookieSecretTypes, return None if version == 1: assert not isinstance(secret, dict) - return _decode_signed_value_v1(secret, name, value, - max_age_days, clock) + return _decode_signed_value_v1(secret, name, value, max_age_days, clock) elif version == 2: - return _decode_signed_value_v2(secret, name, value, - max_age_days, clock) + return _decode_signed_value_v2(secret, name, value, max_age_days, clock) else: return None -def _decode_signed_value_v1(secret: Union[str, bytes], name: str, value: bytes, max_age_days: int, - clock: Callable[[], float]) -> Optional[bytes]: +def _decode_signed_value_v1( + secret: Union[str, bytes], + name: str, + value: bytes, + max_age_days: int, + clock: Callable[[], float], +) -> Optional[bytes]: parts = utf8(value).split(b"|") if len(parts) != 3: return None @@ -3274,8 +3494,7 @@ def _decode_signed_value_v1(secret: Union[str, bytes], name: str, value: bytes, # digits from the payload to the timestamp without altering the # signature. For backwards compatibility, sanity-check timestamp # here instead of modifying _cookie_signature. - gen_log.warning("Cookie timestamp in future; possible tampering %r", - value) + gen_log.warning("Cookie timestamp in future; possible tampering %r", value) return None if parts[1].startswith(b"0"): gen_log.warning("Tampered cookie %r", value) @@ -3288,14 +3507,14 @@ def _decode_signed_value_v1(secret: Union[str, bytes], name: str, value: bytes, def _decode_fields_v2(value: bytes) -> Tuple[int, bytes, bytes, bytes, bytes]: def _consume_field(s: bytes) -> Tuple[bytes, bytes]: - length, _, rest = s.partition(b':') + length, _, rest = s.partition(b":") n = int(length) field_value = rest[:n] # In python 3, indexing bytes returns small integers; we must # use a slice to get a byte string as in python 2. - if rest[n:n + 1] != b'|': + if rest[n : n + 1] != b"|": raise ValueError("malformed v2 signed value field") - rest = rest[n + 1:] + rest = rest[n + 1 :] return field_value, rest rest = value[2:] # remove version number @@ -3306,14 +3525,20 @@ def _decode_fields_v2(value: bytes) -> Tuple[int, bytes, bytes, bytes, bytes]: return int(key_version), timestamp, name_field, value_field, passed_sig -def _decode_signed_value_v2(secret: _CookieSecretTypes, - name: str, value: bytes, max_age_days: int, - clock: Callable[[], float]) -> Optional[bytes]: +def _decode_signed_value_v2( + secret: _CookieSecretTypes, + name: str, + value: bytes, + max_age_days: int, + clock: Callable[[], float], +) -> Optional[bytes]: try: - key_version, timestamp_bytes, name_field, value_field, passed_sig = _decode_fields_v2(value) + key_version, timestamp_bytes, name_field, value_field, passed_sig = _decode_fields_v2( + value + ) except ValueError: return None - signed_string = value[:-len(passed_sig)] + signed_string = value[: -len(passed_sig)] if isinstance(secret, dict): try: diff --git a/tornado/websocket.py b/tornado/websocket.py index 6600d8ca4..d100b9805 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -38,9 +38,22 @@ from tornado.queues import Queue from tornado.tcpclient import TCPClient from tornado.util import _websocket_mask -from typing import (TYPE_CHECKING, cast, Any, Optional, Dict, Union, List, Awaitable, - Callable, Generator, Tuple, Type) +from typing import ( + TYPE_CHECKING, + cast, + Any, + Optional, + Dict, + Union, + List, + Awaitable, + Callable, + Generator, + Tuple, + Type, +) from types import TracebackType + if TYPE_CHECKING: from tornado.iostream import IOStream # noqa: F401 from typing_extensions import Protocol @@ -55,7 +68,7 @@ if TYPE_CHECKING: pass class _Decompressor(Protocol): - unconsumed_tail = b'' # type: bytes + unconsumed_tail = b"" # type: bytes def decompress(self, data: bytes, max_length: int) -> bytes: pass @@ -96,7 +109,7 @@ if TYPE_CHECKING: def close_reason(self, value: Optional[str]) -> None: pass - def on_message(self, message: Union[str, bytes]) -> Optional['Awaitable[None]']: + def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: pass def on_ping(self, data: bytes) -> None: @@ -105,9 +118,12 @@ if TYPE_CHECKING: def on_pong(self, data: bytes) -> None: pass - def log_exception(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[TracebackType]) -> None: + def log_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: pass @@ -123,6 +139,7 @@ class WebSocketClosedError(WebSocketError): .. versionadded:: 3.2 """ + pass @@ -211,8 +228,13 @@ class WebSocketHandler(tornado.web.RequestHandler): Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and ``websocket_max_message_size``. """ - def __init__(self, application: tornado.web.Application, request: httputil.HTTPServerRequest, - **kwargs: Any) -> None: + + def __init__( + self, + application: tornado.web.Application, + request: httputil.HTTPServerRequest, + **kwargs: Any + ) -> None: super(WebSocketHandler, self).__init__(application, request, **kwargs) self.ws_connection = None # type: Optional[WebSocketProtocol] self.close_code = None # type: Optional[int] @@ -225,9 +247,9 @@ class WebSocketHandler(tornado.web.RequestHandler): self.open_kwargs = kwargs # Upgrade header should be present and should be equal to WebSocket - if self.request.headers.get("Upgrade", "").lower() != 'websocket': + if self.request.headers.get("Upgrade", "").lower() != "websocket": self.set_status(400) - log_msg = "Can \"Upgrade\" only to \"WebSocket\"." + log_msg = 'Can "Upgrade" only to "WebSocket".' self.finish(log_msg) gen_log.debug(log_msg) return @@ -236,11 +258,12 @@ class WebSocketHandler(tornado.web.RequestHandler): # Some proxy servers/load balancers # might mess with it. headers = self.request.headers - connection = map(lambda s: s.strip().lower(), - headers.get("Connection", "").split(",")) - if 'upgrade' not in connection: + connection = map( + lambda s: s.strip().lower(), headers.get("Connection", "").split(",") + ) + if "upgrade" not in connection: self.set_status(400) - log_msg = "\"Connection\" must be \"Upgrade\"." + log_msg = '"Connection" must be "Upgrade".' self.finish(log_msg) gen_log.debug(log_msg) return @@ -280,7 +303,7 @@ class WebSocketHandler(tornado.web.RequestHandler): Set websocket_ping_interval = 0 to disable pings. """ - return self.settings.get('websocket_ping_interval', None) + return self.settings.get("websocket_ping_interval", None) @property def ping_timeout(self) -> Optional[float]: @@ -288,7 +311,7 @@ class WebSocketHandler(tornado.web.RequestHandler): close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). Default is max of 3 pings or 30 seconds. """ - return self.settings.get('websocket_ping_timeout', None) + return self.settings.get("websocket_ping_timeout", None) @property def max_message_size(self) -> int: @@ -299,10 +322,13 @@ class WebSocketHandler(tornado.web.RequestHandler): Default is 10MiB. """ - return self.settings.get('websocket_max_message_size', _default_max_message_size) + return self.settings.get( + "websocket_max_message_size", _default_max_message_size + ) - def write_message(self, message: Union[bytes, str, Dict[str, Any]], - binary: bool=False) -> 'Future[None]': + def write_message( + self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False + ) -> "Future[None]": """Sends the given message to the client of this Web Socket. The message may be either a string or a dict (which will be @@ -415,7 +441,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ raise NotImplementedError - def ping(self, data: Union[str, bytes]=b'') -> None: + def ping(self, data: Union[str, bytes] = b"") -> None: """Send ping frame to the remote end. The data argument allows a small amount of data (up to 125 @@ -457,7 +483,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ pass - def close(self, code: int=None, reason: str=None) -> None: + def close(self, code: int = None, reason: str = None) -> None: """Closes this Web Socket. Once the close handshake is successful the socket will be closed. @@ -578,19 +604,27 @@ class WebSocketHandler(tornado.web.RequestHandler): # we can close the connection more gracefully. self.stream.close() - def get_websocket_protocol(self) -> Optional['WebSocketProtocol']: + def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: websocket_version = self.request.headers.get("Sec-WebSocket-Version") if websocket_version in ("7", "8", "13"): return WebSocketProtocol13( - self, compression_options=self.get_compression_options()) + self, compression_options=self.get_compression_options() + ) return None def _attach_stream(self) -> None: self.stream = self.detach() self.stream.set_close_callback(self.on_connection_close) # disable non-WS methods - for method in ["write", "redirect", "set_header", "set_cookie", - "set_status", "flush", "finish"]: + for method in [ + "write", + "redirect", + "set_header", + "set_cookie", + "set_status", + "flush", + "finish", + ]: setattr(self, method, _raise_not_supported_for_websockets) @@ -601,14 +635,16 @@ def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: class WebSocketProtocol(abc.ABC): """Base class for WebSocket protocol versions. """ - def __init__(self, handler: '_WebSocketConnection') -> None: + + def __init__(self, handler: "_WebSocketConnection") -> None: self.handler = handler self.stream = handler.stream self.client_terminated = False self.server_terminated = False - def _run_callback(self, callback: Callable, - *args: Any, **kwargs: Any) -> Optional['Future[Any]']: + def _run_callback( + self, callback: Callable, *args: Any, **kwargs: Any + ) -> Optional["Future[Any]"]: """Runs the given callback with exception handling. If the callback is a coroutine, returns its Future. On error, aborts the @@ -639,7 +675,7 @@ class WebSocketProtocol(abc.ABC): self.close() # let the subclass cleanup @abc.abstractmethod - def close(self, code: int=None, reason: str=None) -> None: + def close(self, code: int = None, reason: str = None) -> None: raise NotImplementedError() @abc.abstractmethod @@ -651,7 +687,9 @@ class WebSocketProtocol(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]': + def write_message( + self, message: Union[str, bytes], binary: bool = False + ) -> "Future[None]": raise NotImplementedError() @property @@ -668,8 +706,9 @@ class WebSocketProtocol(abc.ABC): # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13 # boundary is currently pretty ad-hoc. @abc.abstractmethod - def _process_server_headers(self, key: Union[str, bytes], - headers: httputil.HTTPHeaders) -> None: + def _process_server_headers( + self, key: Union[str, bytes], headers: httputil.HTTPHeaders + ) -> None: raise NotImplementedError() @abc.abstractmethod @@ -677,69 +716,91 @@ class WebSocketProtocol(abc.ABC): raise NotImplementedError() @abc.abstractmethod - def _receive_frame_loop(self) -> 'Future[None]': + def _receive_frame_loop(self) -> "Future[None]": raise NotImplementedError() class _PerMessageDeflateCompressor(object): - def __init__(self, persistent: bool, max_wbits: Optional[int], - compression_options: Dict[str, Any]=None) -> None: + def __init__( + self, + persistent: bool, + max_wbits: Optional[int], + compression_options: Dict[str, Any] = None, + ) -> None: if max_wbits is None: max_wbits = zlib.MAX_WBITS # There is no symbolic constant for the minimum wbits value. if not (8 <= max_wbits <= zlib.MAX_WBITS): - raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", - max_wbits, zlib.MAX_WBITS) + raise ValueError( + "Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, + zlib.MAX_WBITS, + ) self._max_wbits = max_wbits - if compression_options is None or 'compression_level' not in compression_options: + if ( + compression_options is None + or "compression_level" not in compression_options + ): self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL else: - self._compression_level = compression_options['compression_level'] + self._compression_level = compression_options["compression_level"] - if compression_options is None or 'mem_level' not in compression_options: + if compression_options is None or "mem_level" not in compression_options: self._mem_level = 8 else: - self._mem_level = compression_options['mem_level'] + self._mem_level = compression_options["mem_level"] if persistent: self._compressor = self._create_compressor() # type: Optional[_Compressor] else: self._compressor = None - def _create_compressor(self) -> '_Compressor': - return zlib.compressobj(self._compression_level, - zlib.DEFLATED, -self._max_wbits, self._mem_level) + def _create_compressor(self) -> "_Compressor": + return zlib.compressobj( + self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level + ) def compress(self, data: bytes) -> bytes: compressor = self._compressor or self._create_compressor() - data = (compressor.compress(data) + - compressor.flush(zlib.Z_SYNC_FLUSH)) - assert data.endswith(b'\x00\x00\xff\xff') + data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH) + assert data.endswith(b"\x00\x00\xff\xff") return data[:-4] class _PerMessageDeflateDecompressor(object): - def __init__(self, persistent: bool, max_wbits: Optional[int], max_message_size: int, - compression_options: Dict[str, Any]=None) -> None: + def __init__( + self, + persistent: bool, + max_wbits: Optional[int], + max_message_size: int, + compression_options: Dict[str, Any] = None, + ) -> None: self._max_message_size = max_message_size if max_wbits is None: max_wbits = zlib.MAX_WBITS if not (8 <= max_wbits <= zlib.MAX_WBITS): - raise ValueError("Invalid max_wbits value %r; allowed range 8-%d", - max_wbits, zlib.MAX_WBITS) + raise ValueError( + "Invalid max_wbits value %r; allowed range 8-%d", + max_wbits, + zlib.MAX_WBITS, + ) self._max_wbits = max_wbits if persistent: - self._decompressor = self._create_decompressor() # type: Optional[_Decompressor] + self._decompressor = ( + self._create_decompressor() + ) # type: Optional[_Decompressor] else: self._decompressor = None - def _create_decompressor(self) -> '_Decompressor': + def _create_decompressor(self) -> "_Decompressor": return zlib.decompressobj(-self._max_wbits) def decompress(self, data: bytes) -> bytes: decompressor = self._decompressor or self._create_decompressor() - result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size) + result = decompressor.decompress( + data + b"\x00\x00\xff\xff", self._max_message_size + ) if decompressor.unconsumed_tail: raise _DecompressTooLargeError() return result @@ -751,18 +812,23 @@ class WebSocketProtocol13(WebSocketProtocol): This class supports versions 7 and 8 of the protocol in addition to the final version 13. """ + # Bit masks for the first byte of a frame. FIN = 0x80 RSV1 = 0x40 RSV2 = 0x20 RSV3 = 0x10 RSV_MASK = RSV1 | RSV2 | RSV3 - OPCODE_MASK = 0x0f + OPCODE_MASK = 0x0F stream = None # type: IOStream - def __init__(self, handler: '_WebSocketConnection', mask_outgoing: bool=False, - compression_options: Dict[str, Any]=None) -> None: + def __init__( + self, + handler: "_WebSocketConnection", + mask_outgoing: bool = False, + compression_options: Dict[str, Any] = None, + ) -> None: WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing self._final_frame = False @@ -812,8 +878,7 @@ class WebSocketProtocol13(WebSocketProtocol): try: self._accept_connection(handler) except ValueError: - gen_log.debug("Malformed WebSocket request received", - exc_info=True) + gen_log.debug("Malformed WebSocket request received", exc_info=True) self._abort() return @@ -839,13 +904,16 @@ class WebSocketProtocol13(WebSocketProtocol): def _challenge_response(self, handler: WebSocketHandler) -> str: return WebSocketProtocol13.compute_accept_value( - cast(str, handler.request.headers.get("Sec-Websocket-Key"))) + cast(str, handler.request.headers.get("Sec-Websocket-Key")) + ) @gen.coroutine - def _accept_connection(self, handler: WebSocketHandler) -> Generator[Any, Any, None]: + def _accept_connection( + self, handler: WebSocketHandler + ) -> Generator[Any, Any, None]: subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") if subprotocol_header: - subprotocols = [s.strip() for s in subprotocol_header.split(',')] + subprotocols = [s.strip() for s in subprotocol_header.split(",")] else: subprotocols = [] self.selected_subprotocol = handler.select_subprotocol(subprotocols) @@ -855,19 +923,21 @@ class WebSocketProtocol13(WebSocketProtocol): extensions = self._parse_extensions_header(handler.request.headers) for ext in extensions: - if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): + if ext[0] == "permessage-deflate" and self._compression_options is not None: # TODO: negotiate parameters if compression_options # specifies limits. - self._create_compressors('server', ext[1], self._compression_options) - if ('client_max_window_bits' in ext[1] and - ext[1]['client_max_window_bits'] is None): + self._create_compressors("server", ext[1], self._compression_options) + if ( + "client_max_window_bits" in ext[1] + and ext[1]["client_max_window_bits"] is None + ): # Don't echo an offered client_max_window_bits # parameter with no value. - del ext[1]['client_max_window_bits'] - handler.set_header("Sec-WebSocket-Extensions", - httputil._encode_header( - 'permessage-deflate', ext[1])) + del ext[1]["client_max_window_bits"] + handler.set_header( + "Sec-WebSocket-Extensions", + httputil._encode_header("permessage-deflate", ext[1]), + ) break handler.clear_header("Content-Type") @@ -882,75 +952,94 @@ class WebSocketProtocol13(WebSocketProtocol): self.stream = handler.stream self.start_pinging() - open_result = self._run_callback(handler.open, *handler.open_args, - **handler.open_kwargs) + open_result = self._run_callback( + handler.open, *handler.open_args, **handler.open_kwargs + ) if open_result is not None: yield open_result yield self._receive_frame_loop() def _parse_extensions_header( - self, headers: httputil.HTTPHeaders + self, headers: httputil.HTTPHeaders ) -> List[Tuple[str, Dict[str, str]]]: - extensions = headers.get("Sec-WebSocket-Extensions", '') + extensions = headers.get("Sec-WebSocket-Extensions", "") if extensions: - return [httputil._parse_header(e.strip()) - for e in extensions.split(',')] + return [httputil._parse_header(e.strip()) for e in extensions.split(",")] return [] - def _process_server_headers(self, key: Union[str, bytes], - headers: httputil.HTTPHeaders) -> None: + def _process_server_headers( + self, key: Union[str, bytes], headers: httputil.HTTPHeaders + ) -> None: """Process the headers sent by the server to this client connection. 'key' is the websocket handshake challenge/response key. """ - assert headers['Upgrade'].lower() == 'websocket' - assert headers['Connection'].lower() == 'upgrade' + assert headers["Upgrade"].lower() == "websocket" + assert headers["Connection"].lower() == "upgrade" accept = self.compute_accept_value(key) - assert headers['Sec-Websocket-Accept'] == accept + assert headers["Sec-Websocket-Accept"] == accept extensions = self._parse_extensions_header(headers) for ext in extensions: - if (ext[0] == 'permessage-deflate' and - self._compression_options is not None): - self._create_compressors('client', ext[1]) + if ext[0] == "permessage-deflate" and self._compression_options is not None: + self._create_compressors("client", ext[1]) else: raise ValueError("unsupported extension %r", ext) - self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None) + self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None) - def _get_compressor_options(self, side: str, agreed_parameters: Dict[str, Any], - compression_options: Dict[str, Any]=None) -> Dict[str, Any]: + def _get_compressor_options( + self, + side: str, + agreed_parameters: Dict[str, Any], + compression_options: Dict[str, Any] = None, + ) -> Dict[str, Any]: """Converts a websocket agreed_parameters set to keyword arguments for our compressor objects. """ - options = dict(persistent=(side + '_no_context_takeover') not in agreed_parameters) \ - # type: Dict[str, Any] - wbits_header = agreed_parameters.get(side + '_max_window_bits', None) + options = dict( + persistent=(side + "_no_context_takeover") not in agreed_parameters + ) # type: Dict[str, Any] + wbits_header = agreed_parameters.get(side + "_max_window_bits", None) if wbits_header is None: - options['max_wbits'] = zlib.MAX_WBITS + options["max_wbits"] = zlib.MAX_WBITS else: - options['max_wbits'] = int(wbits_header) - options['compression_options'] = compression_options + options["max_wbits"] = int(wbits_header) + options["compression_options"] = compression_options return options - def _create_compressors(self, side: str, agreed_parameters: Dict[str, Any], - compression_options: Dict[str, Any]=None) -> None: + def _create_compressors( + self, + side: str, + agreed_parameters: Dict[str, Any], + compression_options: Dict[str, Any] = None, + ) -> None: # TODO: handle invalid parameters gracefully - allowed_keys = set(['server_no_context_takeover', - 'client_no_context_takeover', - 'server_max_window_bits', - 'client_max_window_bits']) + allowed_keys = set( + [ + "server_no_context_takeover", + "client_no_context_takeover", + "server_max_window_bits", + "client_max_window_bits", + ] + ) for key in agreed_parameters: if key not in allowed_keys: raise ValueError("unsupported compression parameter %r" % key) - other_side = 'client' if (side == 'server') else 'server' + other_side = "client" if (side == "server") else "server" self._compressor = _PerMessageDeflateCompressor( - **self._get_compressor_options(side, agreed_parameters, compression_options)) + **self._get_compressor_options(side, agreed_parameters, compression_options) + ) self._decompressor = _PerMessageDeflateDecompressor( max_message_size=self.handler.max_message_size, - **self._get_compressor_options(other_side, agreed_parameters, compression_options)) - - def _write_frame(self, fin: bool, opcode: int, data: bytes, flags: int=0) -> 'Future[None]': + **self._get_compressor_options( + other_side, agreed_parameters, compression_options + ) + ) + + def _write_frame( + self, fin: bool, opcode: int, data: bytes, flags: int = 0 + ) -> "Future[None]": data_len = len(data) if opcode & 0x8: # All control frames MUST have a payload length of 125 @@ -981,7 +1070,9 @@ class WebSocketProtocol13(WebSocketProtocol): self._wire_bytes_out += len(frame) return self.stream.write(frame) - def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]': + def write_message( + self, message: Union[str, bytes], binary: bool = False + ) -> "Future[None]": """Sends the given message to the client of this Web Socket.""" if binary: opcode = 0x2 @@ -1010,6 +1101,7 @@ class WebSocketProtocol13(WebSocketProtocol): yield fut except StreamClosedError: raise WebSocketClosedError() + return wrapper() def write_ping(self, data: bytes) -> None: @@ -1049,7 +1141,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() return is_masked = bool(mask_payloadlen & 0x80) - payloadlen = mask_payloadlen & 0x7f + payloadlen = mask_payloadlen & 0x7F # Parse and validate the length. if opcode_is_control and payloadlen >= 126: @@ -1113,7 +1205,7 @@ class WebSocketProtocol13(WebSocketProtocol): if handled_future is not None: yield handled_future - def _handle_message(self, opcode: int, data: bytes) -> Optional['Future[None]']: + def _handle_message(self, opcode: int, data: bytes) -> Optional["Future[None]"]: """Execute on_message, returning its Future if it is a coroutine.""" if self.client_terminated: return None @@ -1144,7 +1236,7 @@ class WebSocketProtocol13(WebSocketProtocol): # Close self.client_terminated = True if len(data) >= 2: - self.handler.close_code = struct.unpack('>H', data[:2])[0] + self.handler.close_code = struct.unpack(">H", data[:2])[0] if len(data) > 2: self.handler.close_reason = to_unicode(data[2:]) # Echo the received close code, if any (RFC 6455 section 5.5.1). @@ -1164,16 +1256,16 @@ class WebSocketProtocol13(WebSocketProtocol): self._abort() return None - def close(self, code: int=None, reason: str=None) -> None: + def close(self, code: int = None, reason: str = None) -> None: """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): if code is None and reason is not None: code = 1000 # "normal closure" status code if code is None: - close_data = b'' + close_data = b"" else: - close_data = struct.pack('>H', code) + close_data = struct.pack(">H", code) if reason is not None: close_data += utf8(reason) try: @@ -1190,7 +1282,8 @@ class WebSocketProtocol13(WebSocketProtocol): # Give the client a few seconds to complete a clean shutdown, # otherwise just close the connection. self._waiting = self.stream.io_loop.add_timeout( - self.stream.io_loop.time() + 5, self._abort) + self.stream.io_loop.time() + 5, self._abort + ) def is_closing(self) -> bool: """Return true if this connection is closing. @@ -1199,9 +1292,7 @@ class WebSocketProtocol13(WebSocketProtocol): initiated its closing handshake or if the stream has been shut down uncleanly. """ - return (self.stream.closed() or - self.client_terminated or - self.server_terminated) + return self.stream.closed() or self.client_terminated or self.server_terminated @property def ping_interval(self) -> Optional[float]: @@ -1224,7 +1315,8 @@ class WebSocketProtocol13(WebSocketProtocol): if self.ping_interval > 0: self.last_ping = self.last_pong = IOLoop.current().time() self.ping_callback = PeriodicCallback( - self.periodic_ping, self.ping_interval * 1000) + self.periodic_ping, self.ping_interval * 1000 + ) self.ping_callback.start() def periodic_ping(self) -> None: @@ -1244,12 +1336,14 @@ class WebSocketProtocol13(WebSocketProtocol): since_last_ping = now - self.last_ping assert self.ping_interval is not None assert self.ping_timeout is not None - if (since_last_ping < 2 * self.ping_interval and - since_last_pong > self.ping_timeout): + if ( + since_last_ping < 2 * self.ping_interval + and since_last_pong > self.ping_timeout + ): self.close() return - self.write_ping(b'') + self.write_ping(b"") self.last_ping = now @@ -1259,14 +1353,19 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): This class should not be instantiated directly; use the `websocket_connect` function instead. """ + protocol = None # type: WebSocketProtocol - def __init__(self, request: httpclient.HTTPRequest, - on_message_callback: Callable[[Union[None, str, bytes]], None]=None, - compression_options: Dict[str, Any]=None, - ping_interval: float=None, ping_timeout: float=None, - max_message_size: int=_default_max_message_size, - subprotocols: Optional[List[str]]=[]) -> None: + def __init__( + self, + request: httpclient.HTTPRequest, + on_message_callback: Callable[[Union[None, str, bytes]], None] = None, + compression_options: Dict[str, Any] = None, + ping_interval: float = None, + ping_timeout: float = None, + max_message_size: int = _default_max_message_size, + subprotocols: Optional[List[str]] = [], + ) -> None: self.compression_options = compression_options self.connect_future = Future() # type: Future[WebSocketClientConnection] self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] @@ -1278,32 +1377,42 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.ping_timeout = ping_timeout self.max_message_size = max_message_size - scheme, sep, rest = request.url.partition(':') - scheme = {'ws': 'http', 'wss': 'https'}[scheme] + scheme, sep, rest = request.url.partition(":") + scheme = {"ws": "http", "wss": "https"}[scheme] request.url = scheme + sep + rest - request.headers.update({ - 'Upgrade': 'websocket', - 'Connection': 'Upgrade', - 'Sec-WebSocket-Key': self.key, - 'Sec-WebSocket-Version': '13', - }) + request.headers.update( + { + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": self.key, + "Sec-WebSocket-Version": "13", + } + ) if subprotocols is not None: - request.headers['Sec-WebSocket-Protocol'] = ','.join(subprotocols) + request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) if self.compression_options is not None: # Always offer to let the server set our max_wbits (and even though # we don't offer it, we will accept a client_no_context_takeover # from the server). # TODO: set server parameters for deflate extension # if requested in self.compression_options. - request.headers['Sec-WebSocket-Extensions'] = ( - 'permessage-deflate; client_max_window_bits') + request.headers[ + "Sec-WebSocket-Extensions" + ] = "permessage-deflate; client_max_window_bits" self.tcp_client = TCPClient() super(WebSocketClientConnection, self).__init__( - None, request, lambda: None, self._on_http_response, - 104857600, self.tcp_client, 65536, 104857600) - - def close(self, code: int=None, reason: str=None) -> None: + None, + request, + lambda: None, + self._on_http_response, + 104857600, + self.tcp_client, + 65536, + 104857600, + ) + + def close(self, code: int = None, reason: str = None) -> None: """Closes the websocket connection. ``code`` and ``reason`` are documented under @@ -1331,16 +1440,20 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): if response.error: self.connect_future.set_exception(response.error) else: - self.connect_future.set_exception(WebSocketError( - "Non-websocket response")) - - def headers_received(self, start_line: Union[httputil.RequestStartLine, - httputil.ResponseStartLine], - headers: httputil.HTTPHeaders) -> None: + self.connect_future.set_exception( + WebSocketError("Non-websocket response") + ) + + def headers_received( + self, + start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], + headers: httputil.HTTPHeaders, + ) -> None: assert isinstance(start_line, httputil.ResponseStartLine) if start_line.code != 101: return super(WebSocketClientConnection, self).headers_received( - start_line, headers) + start_line, headers + ) self.headers = headers self.protocol = self.get_websocket_protocol() @@ -1362,7 +1475,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): future_set_result_unless_cancelled(self.connect_future, self) - def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]': + def write_message( + self, message: Union[str, bytes], binary: bool = False + ) -> "Future[None]": """Sends a message to the WebSocket server. If the stream is closed, raises `WebSocketClosedError`. @@ -1375,8 +1490,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): return self.protocol.write_message(message, binary=binary) def read_message( - self, callback: Callable[['Future[Union[None, str, bytes]]'], None]=None - ) -> 'Future[Union[None, str, bytes]]': + self, callback: Callable[["Future[Union[None, str, bytes]]"], None] = None + ) -> "Future[Union[None, str, bytes]]": """Reads a message from the WebSocket server. If on_message_callback was specified at WebSocket @@ -1393,17 +1508,17 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.io_loop.add_future(future, callback) return future - def on_message(self, message: Union[str, bytes]) -> Optional['Future[None]']: + def on_message(self, message: Union[str, bytes]) -> Optional["Future[None]"]: return self._on_message(message) - def _on_message(self, message: Union[None, str, bytes]) -> Optional['Future[None]']: + def _on_message(self, message: Union[None, str, bytes]) -> Optional["Future[None]"]: if self._on_message_callback: self._on_message_callback(message) return None else: return self.read_queue.put(message) - def ping(self, data: bytes=b'') -> None: + def ping(self, data: bytes = b"") -> None: """Send ping frame to the remote end. The data argument allows a small amount of data (up to 125 @@ -1429,8 +1544,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): pass def get_websocket_protocol(self) -> WebSocketProtocol: - return WebSocketProtocol13(self, mask_outgoing=True, - compression_options=self.compression_options) + return WebSocketProtocol13( + self, mask_outgoing=True, compression_options=self.compression_options + ) @property def selected_subprotocol(self) -> Optional[str]: @@ -1440,23 +1556,28 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): """ return self.protocol.selected_subprotocol - def log_exception(self, typ: Optional[Type[BaseException]], - value: Optional[BaseException], - tb: Optional[TracebackType]) -> None: + def log_exception( + self, + typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: assert typ is not None assert value is not None app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) def websocket_connect( - url: Union[str, httpclient.HTTPRequest], - callback: Callable[['Future[WebSocketClientConnection]'], None]=None, - connect_timeout: float=None, - on_message_callback: Callable[[Union[None, str, bytes]], None]=None, - compression_options: Dict[str, Any]=None, - ping_interval: float=None, ping_timeout: float=None, - max_message_size: int=_default_max_message_size, subprotocols: List[str]=None -) -> 'Future[WebSocketClientConnection]': + url: Union[str, httpclient.HTTPRequest], + callback: Callable[["Future[WebSocketClientConnection]"], None] = None, + connect_timeout: float = None, + on_message_callback: Callable[[Union[None, str, bytes]], None] = None, + compression_options: Dict[str, Any] = None, + ping_interval: float = None, + ping_timeout: float = None, + max_message_size: int = _default_max_message_size, + subprotocols: List[str] = None, +) -> "Future[WebSocketClientConnection]": """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -1508,15 +1629,19 @@ def websocket_connect( request.headers = httputil.HTTPHeaders(request.headers) else: request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) - request = cast(httpclient.HTTPRequest, httpclient._RequestProxy( - request, httpclient.HTTPRequest._DEFAULTS)) - conn = WebSocketClientConnection(request, - on_message_callback=on_message_callback, - compression_options=compression_options, - ping_interval=ping_interval, - ping_timeout=ping_timeout, - max_message_size=max_message_size, - subprotocols=subprotocols) + request = cast( + httpclient.HTTPRequest, + httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), + ) + conn = WebSocketClientConnection( + request, + on_message_callback=on_message_callback, + compression_options=compression_options, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + max_message_size=max_message_size, + subprotocols=subprotocols, + ) if callback is not None: IOLoop.current().add_future(conn.connect_future, callback) return conn.connect_future diff --git a/tornado/wsgi.py b/tornado/wsgi.py index 053864877..55e0da201 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -38,6 +38,7 @@ from tornado.log import access_log from typing import List, Tuple, Optional, Type, Callable, Any, Dict, Text from types import TracebackType import typing + if typing.TYPE_CHECKING: from wsgiref.types import WSGIApplication as WSGIAppType # noqa: F401 @@ -48,7 +49,7 @@ if typing.TYPE_CHECKING: # here to minimize the temptation to use it in non-wsgi contexts. def to_wsgi_str(s: bytes) -> str: assert isinstance(s, bytes) - return s.decode('latin1') + return s.decode("latin1") class WSGIContainer(object): @@ -85,7 +86,8 @@ class WSGIContainer(object): Tornado and WSGI apps in the same server. See https://github.com/bdarnell/django-tornado-demo for a complete example. """ - def __init__(self, wsgi_application: 'WSGIAppType') -> None: + + def __init__(self, wsgi_application: "WSGIAppType") -> None: self.wsgi_application = wsgi_application def __call__(self, request: httputil.HTTPServerRequest) -> None: @@ -93,16 +95,23 @@ class WSGIContainer(object): response = [] # type: List[bytes] def start_response( - status: str, headers: List[Tuple[str, str]], - exc_info: Optional[Tuple[Optional[Type[BaseException]], - Optional[BaseException], - Optional[TracebackType]]]=None + status: str, + headers: List[Tuple[str, str]], + exc_info: Optional[ + Tuple[ + Optional[Type[BaseException]], + Optional[BaseException], + Optional[TracebackType], + ] + ] = None, ) -> Callable[[bytes], Any]: data["status"] = status data["headers"] = headers return response.append + app_response = self.wsgi_application( - WSGIContainer.environ(request), start_response) + WSGIContainer.environ(request), start_response + ) try: response.extend(app_response) body = b"".join(response) @@ -112,7 +121,7 @@ class WSGIContainer(object): if not data: raise Exception("WSGI app did not call start_response") - status_code_str, reason = data["status"].split(' ', 1) + status_code_str, reason = data["status"].split(" ", 1) status_code = int(status_code_str) headers = data["headers"] # type: List[Tuple[str, str]] header_set = set(k.lower() for (k, v) in headers) @@ -148,8 +157,9 @@ class WSGIContainer(object): environ = { "REQUEST_METHOD": request.method, "SCRIPT_NAME": "", - "PATH_INFO": to_wsgi_str(escape.url_unescape( - request.path, encoding=None, plus=False)), + "PATH_INFO": to_wsgi_str( + escape.url_unescape(request.path, encoding=None, plus=False) + ), "QUERY_STRING": request.query, "REMOTE_ADDR": request.remote_ip, "SERVER_NAME": host, @@ -181,8 +191,7 @@ class WSGIContainer(object): request_time = 1000.0 * request.request_time() assert request.method is not None assert request.uri is not None - summary = request.method + " " + request.uri + " (" + \ - request.remote_ip + ")" + summary = request.method + " " + request.uri + " (" + request.remote_ip + ")" log_method("%d %s %.2fms", status_code, summary, request_time)