E402,
# E722 do not use bare except
E722,
+ # flake8 and black disagree about
+ # W503 line break before binary operator
+ # E203 whitespace before ':'
+ W503,E203
doctests = true
* ``_OPENID_ENDPOINT``: the identity provider's URI.
"""
+
def authenticate_redirect(
- self, callback_uri: str=None,
- ax_attrs: List[str]=["name", "email", "language", "username"]) -> None:
+ self,
+ callback_uri: str = None,
+ ax_attrs: List[str] = ["name", "email", "language", "username"],
+ ) -> None:
"""Redirects to the authentication URL for this service.
After authentication, the service will redirect back to the given
handler.redirect(endpoint + "?" + urllib.parse.urlencode(args))
async def get_authenticated_user(
- self, http_client: httpclient.AsyncHTTPClient=None
+ self, http_client: httpclient.AsyncHTTPClient = None
) -> Dict[str, Any]:
"""Fetches the authenticated user data upon redirect.
"""
handler = cast(RequestHandler, self)
# Verify the OpenID response via direct request to the OP
- args = dict((k, v[-1]) for k, v in handler.request.arguments.items()) \
- # type: Dict[str, Union[str, bytes]]
+ args = dict(
+ (k, v[-1]) for k, v in handler.request.arguments.items()
+ ) # type: Dict[str, Union[str, bytes]]
args["openid.mode"] = u"check_authentication"
url = self._OPENID_ENDPOINT # type: ignore
if http_client is None:
http_client = self.get_auth_http_client()
- resp = await http_client.fetch(url, method="POST", body=urllib.parse.urlencode(args))
+ resp = await http_client.fetch(
+ url, method="POST", body=urllib.parse.urlencode(args)
+ )
return self._on_authentication_verified(resp)
- def _openid_args(self, callback_uri: str, ax_attrs: Iterable[str]=[],
- oauth_scope: str=None) -> Dict[str, str]:
+ def _openid_args(
+ self, callback_uri: str, ax_attrs: Iterable[str] = [], oauth_scope: str = None
+ ) -> Dict[str, str]:
handler = cast(RequestHandler, self)
url = urllib.parse.urljoin(handler.request.full_url(), callback_uri)
args = {
"openid.ns": "http://specs.openid.net/auth/2.0",
- "openid.claimed_id":
- "http://specs.openid.net/auth/2.0/identifier_select",
- "openid.identity":
- "http://specs.openid.net/auth/2.0/identifier_select",
+ "openid.claimed_id": "http://specs.openid.net/auth/2.0/identifier_select",
+ "openid.identity": "http://specs.openid.net/auth/2.0/identifier_select",
"openid.return_to": url,
- "openid.realm": urllib.parse.urljoin(url, '/'),
+ "openid.realm": urllib.parse.urljoin(url, "/"),
"openid.mode": "checkid_setup",
}
if ax_attrs:
- args.update({
- "openid.ns.ax": "http://openid.net/srv/ax/1.0",
- "openid.ax.mode": "fetch_request",
- })
+ args.update(
+ {
+ "openid.ns.ax": "http://openid.net/srv/ax/1.0",
+ "openid.ax.mode": "fetch_request",
+ }
+ )
ax_attrs = set(ax_attrs)
required = [] # type: List[str]
if "name" in ax_attrs:
ax_attrs -= set(["name", "firstname", "fullname", "lastname"])
required += ["firstname", "fullname", "lastname"]
- args.update({
- "openid.ax.type.firstname":
- "http://axschema.org/namePerson/first",
- "openid.ax.type.fullname":
- "http://axschema.org/namePerson",
- "openid.ax.type.lastname":
- "http://axschema.org/namePerson/last",
- })
+ args.update(
+ {
+ "openid.ax.type.firstname": "http://axschema.org/namePerson/first",
+ "openid.ax.type.fullname": "http://axschema.org/namePerson",
+ "openid.ax.type.lastname": "http://axschema.org/namePerson/last",
+ }
+ )
known_attrs = {
"email": "http://axschema.org/contact/email",
"language": "http://axschema.org/pref/language",
required.append(name)
args["openid.ax.required"] = ",".join(required)
if oauth_scope:
- args.update({
- "openid.ns.oauth":
- "http://specs.openid.net/extensions/oauth/1.0",
- "openid.oauth.consumer": handler.request.host.split(":")[0],
- "openid.oauth.scope": oauth_scope,
- })
+ args.update(
+ {
+ "openid.ns.oauth": "http://specs.openid.net/extensions/oauth/1.0",
+ "openid.oauth.consumer": handler.request.host.split(":")[0],
+ "openid.oauth.scope": oauth_scope,
+ }
+ )
return args
- def _on_authentication_verified(self, response: httpclient.HTTPResponse) -> Dict[str, Any]:
+ def _on_authentication_verified(
+ self, response: httpclient.HTTPResponse
+ ) -> Dict[str, Any]:
handler = cast(RequestHandler, self)
if b"is_valid:true" not in response.body:
raise AuthError("Invalid OpenID response: %s" % response.body)
# Make sure we got back at least an email from attribute exchange
ax_ns = None
for key in handler.request.arguments:
- if key.startswith("openid.ns.") and \
- handler.get_argument(key) == u"http://openid.net/srv/ax/1.0":
+ if (
+ key.startswith("openid.ns.")
+ and handler.get_argument(key) == u"http://openid.net/srv/ax/1.0"
+ ):
ax_ns = key[10:]
break
ax_name = None
for name in handler.request.arguments.keys():
if handler.get_argument(name) == uri and name.startswith(prefix):
- part = name[len(prefix):]
+ part = name[len(prefix) :]
ax_name = "openid." + ax_ns + ".value." + part
break
if not ax_name:
Subclasses must also override the `_oauth_get_user_future` and
`_oauth_consumer_token` methods.
"""
- async def authorize_redirect(self, callback_uri: str=None, extra_params: Dict[str, Any]=None,
- http_client: httpclient.AsyncHTTPClient=None) -> None:
+
+ async def authorize_redirect(
+ self,
+ callback_uri: str = None,
+ extra_params: Dict[str, Any] = None,
+ http_client: httpclient.AsyncHTTPClient = None,
+ ) -> None:
"""Redirects the user to obtain OAuth authorization for this service.
The ``callback_uri`` may be omitted if you have previously
assert http_client is not None
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
response = await http_client.fetch(
- self._oauth_request_token_url(callback_uri=callback_uri,
- extra_params=extra_params))
+ self._oauth_request_token_url(
+ callback_uri=callback_uri, extra_params=extra_params
+ )
+ )
else:
response = await http_client.fetch(self._oauth_request_token_url())
url = self._OAUTH_AUTHORIZE_URL # type: ignore
self._on_request_token(url, callback_uri, response)
async def get_authenticated_user(
- self, http_client: httpclient.AsyncHTTPClient=None
+ self, http_client: httpclient.AsyncHTTPClient = None
) -> Dict[str, Any]:
"""Gets the OAuth authorized user and access token.
raise AuthError("Missing OAuth request token cookie")
handler.clear_cookie("_oauth_request_token")
cookie_key, cookie_secret = [
- base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")]
+ base64.b64decode(escape.utf8(i)) for i in request_cookie.split("|")
+ ]
if cookie_key != request_key:
raise AuthError("Request token does not match cookie")
- token = dict(key=cookie_key, secret=cookie_secret) # type: Dict[str, Union[str, bytes]]
+ token = dict(
+ key=cookie_key, secret=cookie_secret
+ ) # type: Dict[str, Union[str, bytes]]
if oauth_verifier:
token["verifier"] = oauth_verifier
if http_client is None:
user["access_token"] = access_token
return user
- def _oauth_request_token_url(self, callback_uri: str=None,
- extra_params: Dict[str, Any]=None) -> str:
+ def _oauth_request_token_url(
+ self, callback_uri: str = None, extra_params: Dict[str, Any] = None
+ ) -> str:
handler = cast(RequestHandler, self)
consumer_token = self._oauth_consumer_token()
url = self._OAUTH_REQUEST_TOKEN_URL # type: ignore
args["oauth_callback"] = "oob"
elif callback_uri:
args["oauth_callback"] = urllib.parse.urljoin(
- handler.request.full_url(), callback_uri)
+ handler.request.full_url(), callback_uri
+ )
if extra_params:
args.update(extra_params)
signature = _oauth10a_signature(consumer_token, "GET", url, args)
args["oauth_signature"] = signature
return url + "?" + urllib.parse.urlencode(args)
- def _on_request_token(self, authorize_url: str, callback_uri: Optional[str],
- response: httpclient.HTTPResponse) -> None:
+ def _on_request_token(
+ self,
+ authorize_url: str,
+ callback_uri: Optional[str],
+ response: httpclient.HTTPResponse,
+ ) -> None:
handler = cast(RequestHandler, self)
request_token = _oauth_parse_response(response.body)
- data = (base64.b64encode(escape.utf8(request_token["key"])) + b"|" +
- base64.b64encode(escape.utf8(request_token["secret"])))
+ data = (
+ base64.b64encode(escape.utf8(request_token["key"]))
+ + b"|"
+ + base64.b64encode(escape.utf8(request_token["secret"]))
+ )
handler.set_cookie("_oauth_request_token", data)
args = dict(oauth_token=request_token["key"])
if callback_uri == "oob":
return
elif callback_uri:
args["oauth_callback"] = urllib.parse.urljoin(
- handler.request.full_url(), callback_uri)
+ handler.request.full_url(), callback_uri
+ )
handler.redirect(authorize_url + "?" + urllib.parse.urlencode(args))
def _oauth_access_token_url(self, request_token: Dict[str, Any]) -> str:
args["oauth_verifier"] = request_token["verifier"]
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
- signature = _oauth10a_signature(consumer_token, "GET", url, args,
- request_token)
+ signature = _oauth10a_signature(
+ consumer_token, "GET", url, args, request_token
+ )
else:
- signature = _oauth_signature(consumer_token, "GET", url, args,
- request_token)
+ signature = _oauth_signature(
+ consumer_token, "GET", url, args, request_token
+ )
args["oauth_signature"] = signature
return url + "?" + urllib.parse.urlencode(args)
"""
raise NotImplementedError()
- async def _oauth_get_user_future(self, access_token: Dict[str, Any]) -> Dict[str, Any]:
+ async def _oauth_get_user_future(
+ self, access_token: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""Subclasses must override this to get basic information about the
user.
"""
raise NotImplementedError()
- def _oauth_request_parameters(self, url: str, access_token: Dict[str, Any],
- parameters: Dict[str, Any]={},
- method: str="GET") -> Dict[str, Any]:
+ def _oauth_request_parameters(
+ self,
+ url: str,
+ access_token: Dict[str, Any],
+ parameters: Dict[str, Any] = {},
+ method: str = "GET",
+ ) -> Dict[str, Any]:
"""Returns the OAuth parameters as a dict for the given request.
parameters should include all POST arguments and query string arguments
args.update(base_args)
args.update(parameters)
if getattr(self, "_OAUTH_VERSION", "1.0a") == "1.0a":
- signature = _oauth10a_signature(consumer_token, method, url, args,
- access_token)
+ signature = _oauth10a_signature(
+ consumer_token, method, url, args, access_token
+ )
else:
- signature = _oauth_signature(consumer_token, method, url, args,
- access_token)
+ signature = _oauth_signature(
+ consumer_token, method, url, args, access_token
+ )
base_args["oauth_signature"] = escape.to_basestring(signature)
return base_args
* ``_OAUTH_AUTHORIZE_URL``: The service's authorization url.
* ``_OAUTH_ACCESS_TOKEN_URL``: The service's access token url.
"""
- def authorize_redirect(self, redirect_uri: str=None, client_id: str=None,
- client_secret: str=None, extra_params: Dict[str, Any]=None,
- scope: str=None, response_type: str="code") -> None:
+
+ def authorize_redirect(
+ self,
+ redirect_uri: str = None,
+ client_id: str = None,
+ client_secret: str = None,
+ extra_params: Dict[str, Any] = None,
+ scope: str = None,
+ response_type: str = "code",
+ ) -> None:
"""Redirects the user to obtain OAuth authorization for this service.
Some providers require that you register a redirect URL with
this is now an ordinary synchronous function.
"""
handler = cast(RequestHandler, self)
- args = {
- "response_type": response_type
- }
+ args = {"response_type": response_type}
if redirect_uri is not None:
args["redirect_uri"] = redirect_uri
if client_id is not None:
if extra_params:
args.update(extra_params)
if scope:
- args['scope'] = ' '.join(scope)
+ args["scope"] = " ".join(scope)
url = self._OAUTH_AUTHORIZE_URL # type: ignore
handler.redirect(url_concat(url, args))
- def _oauth_request_token_url(self, redirect_uri: str=None, client_id: str=None,
- client_secret: str=None, code: str=None,
- extra_params: Dict[str, Any]=None) -> str:
+ def _oauth_request_token_url(
+ self,
+ redirect_uri: str = None,
+ client_id: str = None,
+ client_secret: str = None,
+ code: str = None,
+ extra_params: Dict[str, Any] = None,
+ ) -> str:
url = self._OAUTH_ACCESS_TOKEN_URL # type: ignore
args = {} # type: Dict[str, str]
if redirect_uri is not None:
args.update(extra_params)
return url_concat(url, args)
- async def oauth2_request(self, url: str, access_token: str=None,
- post_args: Dict[str, Any]=None, **args: Any) -> Any:
+ async def oauth2_request(
+ self,
+ url: str,
+ access_token: str = None,
+ post_args: Dict[str, Any] = None,
+ **args: Any
+ ) -> Any:
"""Fetches the given URL auth an OAuth2 access token.
If the request is a POST, ``post_args`` should be provided. Query
url += "?" + urllib.parse.urlencode(all_args)
http = self.get_auth_http_client()
if post_args is not None:
- response = await http.fetch(url, method="POST", body=urllib.parse.urlencode(post_args))
+ response = await http.fetch(
+ url, method="POST", body=urllib.parse.urlencode(post_args)
+ )
else:
response = await http.fetch(url)
return escape.json_decode(response.body)
and all of the custom Twitter user attributes described at
https://dev.twitter.com/docs/api/1.1/get/users/show
"""
+
_OAUTH_REQUEST_TOKEN_URL = "https://api.twitter.com/oauth/request_token"
_OAUTH_ACCESS_TOKEN_URL = "https://api.twitter.com/oauth/access_token"
_OAUTH_AUTHORIZE_URL = "https://api.twitter.com/oauth/authorize"
_OAUTH_NO_CALLBACKS = False
_TWITTER_BASE_URL = "https://api.twitter.com/1.1"
- async def authenticate_redirect(self, callback_uri: str=None) -> None:
+ async def authenticate_redirect(self, callback_uri: str = None) -> None:
"""Just like `~OAuthMixin.authorize_redirect`, but
auto-redirects if authorized.
awaitable object instead.
"""
http = self.get_auth_http_client()
- response = await http.fetch(self._oauth_request_token_url(callback_uri=callback_uri))
+ response = await http.fetch(
+ self._oauth_request_token_url(callback_uri=callback_uri)
+ )
self._on_request_token(self._OAUTH_AUTHENTICATE_URL, None, response)
- async def twitter_request(self, path: str, access_token: Dict[str, Any],
- post_args: Dict[str, Any]=None, **args: Any) -> Any:
+ async def twitter_request(
+ self,
+ path: str,
+ access_token: Dict[str, Any],
+ post_args: Dict[str, Any] = None,
+ **args: Any
+ ) -> Any:
"""Fetches the given API path, e.g., ``statuses/user_timeline/btaylor``
The path should not include the format or API version number.
The ``callback`` argument was removed. Use the returned
awaitable object instead.
"""
- if path.startswith('http:') or path.startswith('https:'):
+ if path.startswith("http:") or path.startswith("https:"):
# Raw urls are useful for e.g. search which doesn't follow the
# usual pattern: http://search.twitter.com/search.json
url = path
all_args.update(post_args or {})
method = "POST" if post_args is not None else "GET"
oauth = self._oauth_request_parameters(
- url, access_token, all_args, method=method)
+ url, access_token, all_args, method=method
+ )
args.update(oauth)
if args:
url += "?" + urllib.parse.urlencode(args)
http = self.get_auth_http_client()
if post_args is not None:
- response = await http.fetch(url, method="POST", body=urllib.parse.urlencode(post_args))
+ response = await http.fetch(
+ url, method="POST", body=urllib.parse.urlencode(post_args)
+ )
else:
response = await http.fetch(url)
return escape.json_decode(response.body)
handler.require_setting("twitter_consumer_secret", "Twitter OAuth")
return dict(
key=handler.settings["twitter_consumer_key"],
- secret=handler.settings["twitter_consumer_secret"])
+ secret=handler.settings["twitter_consumer_secret"],
+ )
- async def _oauth_get_user_future(self, access_token: Dict[str, Any]) -> Dict[str, Any]:
+ async def _oauth_get_user_future(
+ self, access_token: Dict[str, Any]
+ ) -> Dict[str, Any]:
user = await self.twitter_request(
- "/account/verify_credentials",
- access_token=access_token)
+ "/account/verify_credentials", access_token=access_token
+ )
if user:
user["username"] = user["screen_name"]
return user
.. versionadded:: 3.2
"""
+
_OAUTH_AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_OAUTH_ACCESS_TOKEN_URL = "https://www.googleapis.com/oauth2/v4/token"
_OAUTH_USERINFO_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
_OAUTH_NO_CALLBACKS = False
- _OAUTH_SETTINGS_KEY = 'google_oauth'
+ _OAUTH_SETTINGS_KEY = "google_oauth"
- async def get_authenticated_user(self, redirect_uri: str, code: str) -> Dict[str, Any]:
+ async def get_authenticated_user(
+ self, redirect_uri: str, code: str
+ ) -> Dict[str, Any]:
"""Handles the login for the Google user, returning an access token.
The result is a dictionary containing an ``access_token`` field
""" # noqa: E501
handler = cast(RequestHandler, self)
http = self.get_auth_http_client()
- body = urllib.parse.urlencode({
- "redirect_uri": redirect_uri,
- "code": code,
- "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]['key'],
- "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]['secret'],
- "grant_type": "authorization_code",
- })
-
- response = await http.fetch(self._OAUTH_ACCESS_TOKEN_URL,
- method="POST",
- headers={'Content-Type': 'application/x-www-form-urlencoded'},
- body=body)
+ body = urllib.parse.urlencode(
+ {
+ "redirect_uri": redirect_uri,
+ "code": code,
+ "client_id": handler.settings[self._OAUTH_SETTINGS_KEY]["key"],
+ "client_secret": handler.settings[self._OAUTH_SETTINGS_KEY]["secret"],
+ "grant_type": "authorization_code",
+ }
+ )
+
+ response = await http.fetch(
+ self._OAUTH_ACCESS_TOKEN_URL,
+ method="POST",
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ body=body,
+ )
return escape.json_decode(response.body)
class FacebookGraphMixin(OAuth2Mixin):
"""Facebook authentication using the new Graph API and OAuth2."""
+
_OAUTH_ACCESS_TOKEN_URL = "https://graph.facebook.com/oauth/access_token?"
_OAUTH_AUTHORIZE_URL = "https://www.facebook.com/dialog/oauth?"
_OAUTH_NO_CALLBACKS = False
_FACEBOOK_BASE_URL = "https://graph.facebook.com"
async def get_authenticated_user(
- self, redirect_uri: str, client_id: str, client_secret: str,
- code: str, extra_fields: Dict[str, Any]=None
+ self,
+ redirect_uri: str,
+ client_id: str,
+ client_secret: str,
+ code: str,
+ extra_fields: Dict[str, Any] = None,
) -> Optional[Dict[str, Any]]:
"""Handles the login for the Facebook user, returning a user object.
"client_secret": client_secret,
}
- fields = set(['id', 'name', 'first_name', 'last_name',
- 'locale', 'picture', 'link'])
+ fields = set(
+ ["id", "name", "first_name", "last_name", "locale", "picture", "link"]
+ )
if extra_fields:
fields.update(extra_fields)
- response = await http.fetch(self._oauth_request_token_url(**args)) # type: ignore
+ response = await http.fetch(
+ self._oauth_request_token_url(**args) # type: ignore
+ )
args = escape.json_decode(response.body)
session = {
"access_token": args.get("access_token"),
- "expires_in": args.get("expires_in")
+ "expires_in": args.get("expires_in"),
}
assert session["access_token"] is not None
user = await self.facebook_request(
path="/me",
access_token=session["access_token"],
- appsecret_proof=hmac.new(key=client_secret.encode('utf8'),
- msg=session["access_token"].encode('utf8'),
- digestmod=hashlib.sha256).hexdigest(),
- fields=",".join(fields)
+ appsecret_proof=hmac.new(
+ key=client_secret.encode("utf8"),
+ msg=session["access_token"].encode("utf8"),
+ digestmod=hashlib.sha256,
+ ).hexdigest(),
+ fields=",".join(fields),
)
if user is None:
# older versions in which the server used url-encoding and
# this code simply returned the string verbatim.
# This should change in Tornado 5.0.
- fieldmap.update({"access_token": session["access_token"],
- "session_expires": str(session.get("expires_in"))})
+ fieldmap.update(
+ {
+ "access_token": session["access_token"],
+ "session_expires": str(session.get("expires_in")),
+ }
+ )
return fieldmap
- async def facebook_request(self, path: str, access_token: str=None,
- post_args: Dict[str, Any]=None, **args: Any) -> Any:
+ async def facebook_request(
+ self,
+ path: str,
+ access_token: str = None,
+ post_args: Dict[str, Any] = None,
+ **args: Any
+ ) -> Any:
"""Fetches the given relative API path, e.g., "/btaylor/picture"
If the request is a POST, ``post_args`` should be provided. Query
The ``callback`` argument was removed. Use the returned awaitable object instead.
"""
url = self._FACEBOOK_BASE_URL + path
- return await self.oauth2_request(url, access_token=access_token,
- post_args=post_args, **args)
+ return await self.oauth2_request(
+ url, access_token=access_token, post_args=post_args, **args
+ )
-def _oauth_signature(consumer_token: Dict[str, Any], method: str, url: str,
- parameters: Dict[str, Any]={}, token: Dict[str, Any]=None) -> bytes:
+def _oauth_signature(
+ consumer_token: Dict[str, Any],
+ method: str,
+ url: str,
+ parameters: Dict[str, Any] = {},
+ token: Dict[str, Any] = None,
+) -> bytes:
"""Calculates the HMAC-SHA1 OAuth signature for the given request.
See http://oauth.net/core/1.0/#signing_process
base_elems = []
base_elems.append(method.upper())
base_elems.append(normalized_url)
- base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v)))
- for k, v in sorted(parameters.items())))
+ base_elems.append(
+ "&".join(
+ "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items())
+ )
+ )
base_string = "&".join(_oauth_escape(e) for e in base_elems)
key_elems = [escape.utf8(consumer_token["secret"])]
return binascii.b2a_base64(hash.digest())[:-1]
-def _oauth10a_signature(consumer_token: Dict[str, Any], method: str, url: str,
- parameters: Dict[str, Any]={}, token: Dict[str, Any]=None) -> bytes:
+def _oauth10a_signature(
+ consumer_token: Dict[str, Any],
+ method: str,
+ url: str,
+ parameters: Dict[str, Any] = {},
+ token: Dict[str, Any] = None,
+) -> bytes:
"""Calculates the HMAC-SHA1 OAuth 1.0a signature for the given request.
See http://oauth.net/core/1.0a/#signing_process
base_elems = []
base_elems.append(method.upper())
base_elems.append(normalized_url)
- base_elems.append("&".join("%s=%s" % (k, _oauth_escape(str(v)))
- for k, v in sorted(parameters.items())))
+ base_elems.append(
+ "&".join(
+ "%s=%s" % (k, _oauth_escape(str(v))) for k, v in sorted(parameters.items())
+ )
+ )
base_string = "&".join(_oauth_escape(e) for e in base_elems)
- key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe='~'))]
- key_elems.append(escape.utf8(urllib.parse.quote(token["secret"], safe='~') if token else ""))
+ key_elems = [escape.utf8(urllib.parse.quote(consumer_token["secret"], safe="~"))]
+ key_elems.append(
+ escape.utf8(urllib.parse.quote(token["secret"], safe="~") if token else "")
+ )
key = b"&".join(key_elems)
hash = hmac.new(key, escape.utf8(base_string), hashlib.sha1)
import typing
from typing import Callable, Dict
+
if typing.TYPE_CHECKING:
from typing import List, Optional, Union # noqa: F401
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
-_has_execv = sys.platform != 'win32'
+_has_execv = sys.platform != "win32"
_watched_files = set()
_reload_hooks = []
_original_spec = None
-def start(check_time: int=500) -> None:
+def start(check_time: int = 500) -> None:
"""Begins watching source files for changes.
.. versionchanged:: 5.0
spec = _original_spec
argv = _original_argv
else:
- spec = getattr(sys.modules['__main__'], '__spec__', None)
+ spec = getattr(sys.modules["__main__"], "__spec__", None)
argv = sys.argv
if spec:
- argv = ['-m', spec.name] + argv[1:]
+ argv = ["-m", spec.name] + argv[1:]
else:
- path_prefix = '.' + os.pathsep
- if (sys.path[0] == '' and
- not os.environ.get("PYTHONPATH", "").startswith(path_prefix)):
- os.environ["PYTHONPATH"] = (path_prefix +
- os.environ.get("PYTHONPATH", ""))
+ path_prefix = "." + os.pathsep
+ if sys.path[0] == "" and not os.environ.get("PYTHONPATH", "").startswith(
+ path_prefix
+ ):
+ os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "")
if not _has_execv:
subprocess.Popen([sys.executable] + argv)
os._exit(0)
# Unfortunately the errno returned in this case does not
# appear to be consistent, so we can't easily check for
# this error specifically.
- os.spawnv(os.P_NOWAIT, sys.executable, [sys.executable] + argv) # type: ignore
+ os.spawnv( # type: ignore
+ os.P_NOWAIT, sys.executable, [sys.executable] + argv
+ )
# At this point the IOLoop has been closed and finally
# blocks will experience errors if we allow the stack to
# unwind, so just exit uncleanly.
# The main module can be tricky; set the variables both in our globals
# (which may be __main__) and the real importable version.
import tornado.autoreload
+
global _autoreload_is_main
global _original_argv, _original_spec
tornado.autoreload._autoreload_is_main = _autoreload_is_main = True
original_argv = sys.argv
tornado.autoreload._original_argv = _original_argv = original_argv
- original_spec = getattr(sys.modules['__main__'], '__spec__', None)
+ original_spec = getattr(sys.modules["__main__"], "__spec__", None)
tornado.autoreload._original_spec = _original_spec = original_spec
sys.argv = sys.argv[:]
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
try:
if mode == "module":
import runpy
+
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
# restore sys.argv so subsequent executions will include autoreload
sys.argv = original_argv
- if mode == 'module':
+ if mode == "module":
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
loader = pkgutil.get_loader(module)
import typing
from typing import Any, Callable, Optional, Tuple, Union
-_T = typing.TypeVar('_T')
+_T = typing.TypeVar("_T")
class ReturnValueIgnoredError(Exception):
class DummyExecutor(futures.Executor):
- def submit(self, fn: Callable[..., _T], *args: Any, **kwargs: Any) -> 'futures.Future[_T]':
+ def submit(
+ self, fn: Callable[..., _T], *args: Any, **kwargs: Any
+ ) -> "futures.Future[_T]":
future = futures.Future() # type: futures.Future[_T]
try:
future_set_result_unless_cancelled(future, fn(*args, **kwargs))
future_set_exc_info(future, sys.exc_info())
return future
- def shutdown(self, wait: bool=True) -> None:
+ def shutdown(self, wait: bool = True) -> None:
pass
conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs)
chain_future(conc_future, async_future)
return async_future
+
return wrapper
+
if args and kwargs:
raise ValueError("cannot combine positional and keyword args")
if len(args) == 1:
_NO_RESULT = object()
-def chain_future(a: 'Future[_T]', b: 'Future[_T]') -> None:
+def chain_future(a: "Future[_T]", b: "Future[_T]") -> None:
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``, unless
`concurrent.futures.Future`.
"""
- def copy(future: 'Future[_T]') -> None:
+
+ def copy(future: "Future[_T]") -> None:
assert future is a
if b.done():
return
- if (hasattr(a, 'exc_info') and
- a.exc_info() is not None): # type: ignore
+ if hasattr(a, "exc_info") and a.exc_info() is not None: # type: ignore
future_set_exc_info(b, a.exc_info()) # type: ignore
elif a.exception() is not None:
b.set_exception(a.exception())
else:
b.set_result(a.result())
+
if isinstance(a, Future):
future_add_done_callback(a, copy)
else:
# concurrent.futures.Future
from tornado.ioloop import IOLoop
+
IOLoop.current().add_future(a, copy)
-def future_set_result_unless_cancelled(future: Union['futures.Future[_T]', 'Future[_T]'],
- value: _T) -> None:
+def future_set_result_unless_cancelled(
+ future: Union["futures.Future[_T]", "Future[_T]"], value: _T
+) -> None:
"""Set the given ``value`` as the `Future`'s result, if not cancelled.
Avoids asyncio.InvalidStateError when calling set_result() on
future.set_result(value)
-def future_set_exc_info(future: Union['futures.Future[_T]', 'Future[_T]'],
- exc_info: Tuple[Optional[type], Optional[BaseException],
- Optional[types.TracebackType]]) -> None:
+def future_set_exc_info(
+ future: Union["futures.Future[_T]", "Future[_T]"],
+ exc_info: Tuple[
+ Optional[type], Optional[BaseException], Optional[types.TracebackType]
+ ],
+) -> None:
"""Set the given ``exc_info`` as the `Future`'s exception.
Understands both `asyncio.Future` and Tornado's extensions to
.. versionadded:: 5.0
"""
- if hasattr(future, 'set_exc_info'):
+ if hasattr(future, "set_exc_info"):
# Tornado's Future
future.set_exc_info(exc_info) # type: ignore
else:
@typing.overload
-def future_add_done_callback(future: 'futures.Future[_T]',
- callback: Callable[['futures.Future[_T]'], None]) -> None:
+def future_add_done_callback(
+ future: "futures.Future[_T]", callback: Callable[["futures.Future[_T]"], None]
+) -> None:
pass
@typing.overload # noqa: F811
-def future_add_done_callback(future: 'Future[_T]',
- callback: Callable[['Future[_T]'], None]) -> None:
+def future_add_done_callback(
+ future: "Future[_T]", callback: Callable[["Future[_T]"], None]
+) -> None:
pass
-def future_add_done_callback(future: Union['futures.Future[_T]', 'Future[_T]'], # noqa: F811
- callback: Callable[..., None]) -> None:
+def future_add_done_callback( # noqa: F811
+ future: Union["futures.Future[_T]", "Future[_T]"], callback: Callable[..., None]
+) -> None:
"""Arrange to call ``callback`` when ``future`` is complete.
``callback`` is invoked with one argument, the ``future``.
from tornado import ioloop
from tornado.escape import utf8, native_str
-from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main
+from tornado.httpclient import (
+ HTTPRequest,
+ HTTPResponse,
+ HTTPError,
+ AsyncHTTPClient,
+ main,
+)
from tornado.log import app_log
from typing import Dict, Any, Callable, Union
import typing
+
if typing.TYPE_CHECKING:
from typing import Deque, Tuple, Optional # noqa: F401
-curl_log = logging.getLogger('tornado.curl_httpclient')
+curl_log = logging.getLogger("tornado.curl_httpclient")
class CurlAsyncHTTPClient(AsyncHTTPClient):
- def initialize(self, max_clients: int=10, # type: ignore
- defaults: Dict[str, Any]=None) -> None:
+ def initialize( # type: ignore
+ self, max_clients: int = 10, defaults: Dict[str, Any] = None
+ ) -> None:
super(CurlAsyncHTTPClient, self).initialize(defaults=defaults)
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
- self._requests = collections.deque() \
- # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]]
+ self._requests = (
+ collections.deque()
+ ) # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]]
self._fds = {} # type: Dict[int, int]
self._timeout = None # type: Optional[object]
# SOCKETFUNCTION. Mitigate the effects of such bugs by
# forcing a periodic scan of all active requests.
self._force_timeout_callback = ioloop.PeriodicCallback(
- self._handle_force_timeout, 1000)
+ self._handle_force_timeout, 1000
+ )
self._force_timeout_callback.start()
# Work around a bug in libcurl 7.29.0: Some fields in the curl
self._force_timeout_callback = None # type: ignore
self._multi = None
- def fetch_impl(self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]) -> None:
+ def fetch_impl(
+ self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]
+ ) -> None:
self._requests.append((request, callback, self.io_loop.time()))
self._process_queue()
self._set_timeout(0)
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
pycurl.POLL_IN: ioloop.IOLoop.READ,
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
- pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE
+ pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE,
}
if event == pycurl.POLL_REMOVE:
if fd in self._fds:
# instead of update.
if fd in self._fds:
self.io_loop.remove_handler(fd)
- self.io_loop.add_handler(fd, self._handle_events,
- ioloop_event)
+ self.io_loop.add_handler(fd, self._handle_events, ioloop_event)
self._fds[fd] = ioloop_event
def _set_timeout(self, msecs: int) -> None:
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
- self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
+ self.io_loop.time() + msecs / 1000.0, self._handle_timeout
+ )
def _handle_events(self, fd: int, events: int) -> None:
"""Called by IOLoop when there is activity on one of our
self._timeout = None
while True:
try:
- ret, num_handles = self._multi.socket_action(
- pycurl.SOCKET_TIMEOUT, 0)
+ ret, num_handles = self._multi.socket_action(pycurl.SOCKET_TIMEOUT, 0)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
}
try:
self._curl_setup_request(
- curl, request, curl.info["buffer"],
- curl.info["headers"])
+ curl, request, curl.info["buffer"], curl.info["headers"]
+ )
except Exception as e:
# If there was an error in setup, pass it on
# to the callback. Note that allowing the
# _finish_pending_requests the exceptions have
# nowhere to go.
self._free_list.append(curl)
- callback(HTTPResponse(
- request=request,
- code=599,
- error=e))
+ callback(HTTPResponse(request=request, code=599, error=e))
else:
self._multi.add_handle(curl)
if not started:
break
- def _finish(self, curl: pycurl.Curl, curl_error: int=None, curl_message: str=None) -> None:
+ def _finish(
+ self, curl: pycurl.Curl, curl_error: int = None, curl_message: str = None
+ ) -> None:
info = curl.info
curl.info = None
self._multi.remove_handle(curl)
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
)
try:
- info["callback"](HTTPResponse(
- request=info["request"], code=code, headers=info["headers"],
- buffer=buffer, effective_url=effective_url, error=error,
- reason=info['headers'].get("X-Http-Reason", None),
- request_time=self.io_loop.time() - info["curl_start_ioloop_time"],
- start_time=info["curl_start_time"],
- time_info=time_info))
+ info["callback"](
+ HTTPResponse(
+ request=info["request"],
+ code=code,
+ headers=info["headers"],
+ buffer=buffer,
+ effective_url=effective_url,
+ error=error,
+ reason=info["headers"].get("X-Http-Reason", None),
+ request_time=self.io_loop.time() - info["curl_start_ioloop_time"],
+ start_time=info["curl_start_time"],
+ time_info=time_info,
+ )
+ )
except Exception:
self.handle_callback_exception(info["callback"])
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
- if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
+ if hasattr(
+ pycurl, "PROTOCOLS"
+ ): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
return curl
- def _curl_setup_request(self, curl: pycurl.Curl, request: HTTPRequest,
- buffer: BytesIO, headers: httputil.HTTPHeaders) -> None:
+ def _curl_setup_request(
+ self,
+ curl: pycurl.Curl,
+ request: HTTPRequest,
+ buffer: BytesIO,
+ headers: httputil.HTTPHeaders,
+ ) -> None:
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
- curl.setopt(pycurl.HTTPHEADER,
- ["%s: %s" % (native_str(k), native_str(v))
- for k, v in request.headers.get_all()])
+ curl.setopt(
+ pycurl.HTTPHEADER,
+ [
+ "%s: %s" % (native_str(k), native_str(v))
+ for k, v in request.headers.get_all()
+ ],
+ )
- curl.setopt(pycurl.HEADERFUNCTION,
- functools.partial(self._curl_header_callback,
- headers, request.header_callback))
+ curl.setopt(
+ pycurl.HEADERFUNCTION,
+ functools.partial(
+ self._curl_header_callback, headers, request.header_callback
+ ),
+ )
if request.streaming_callback:
+
def write_function(b: Union[bytes, bytearray]) -> int:
assert request.streaming_callback is not None
self.io_loop.add_callback(request.streaming_callback, b)
return len(b)
+
else:
write_function = buffer.write
curl.setopt(pycurl.WRITEFUNCTION, write_function)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
assert request.proxy_password is not None
- credentials = httputil.encode_username_password(request.proxy_username,
- request.proxy_password)
+ credentials = httputil.encode_username_password(
+ request.proxy_username, request.proxy_password
+ )
curl.setopt(pycurl.PROXYUSERPWD, credentials)
- if (request.proxy_auth_mode is None or
- request.proxy_auth_mode == "basic"):
+ if request.proxy_auth_mode is None or request.proxy_auth_mode == "basic":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
elif request.proxy_auth_mode == "digest":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError(
- "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
+ "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode
+ )
else:
- curl.setopt(pycurl.PROXY, '')
+ curl.setopt(pycurl.PROXY, "")
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
- raise KeyError('unknown method ' + request.method)
+ raise KeyError("unknown method " + request.method)
body_expected = request.method in ("POST", "PATCH", "PUT")
body_present = request.body is not None
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
- if ((body_expected and not body_present) or
- (body_present and not body_expected)):
+ if (body_expected and not body_present) or (
+ body_present and not body_expected
+ ):
raise ValueError(
- 'Body must %sbe None for method %s (unless '
- 'allow_nonstandard_methods is true)' %
- ('not ' if body_expected else '', request.method))
+ "Body must %sbe None for method %s (unless "
+ "allow_nonstandard_methods is true)"
+ % ("not " if body_expected else "", request.method)
+ )
if body_expected or body_present:
if request.method == "GET":
# unless we use CUSTOMREQUEST). While the spec doesn't
# forbid clients from sending a body, it arguably
# disallows the server from doing anything with them.
- raise ValueError('Body must be None for GET request')
- request_buffer = BytesIO(utf8(request.body or ''))
+ raise ValueError("Body must be None for GET request")
+ request_buffer = BytesIO(utf8(request.body or ""))
def ioctl(cmd: int) -> None:
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
+
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
- curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
+ curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ""))
else:
curl.setopt(pycurl.UPLOAD, True)
- curl.setopt(pycurl.INFILESIZE, len(request.body or ''))
+ curl.setopt(pycurl.INFILESIZE, len(request.body or ""))
if request.auth_username is not None:
assert request.auth_password is not None
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
- userpwd = httputil.encode_username_password(request.auth_username,
- request.auth_password)
+ userpwd = httputil.encode_username_password(
+ request.auth_username, request.auth_password
+ )
curl.setopt(pycurl.USERPWD, userpwd)
- curl_log.debug("%s %s (username: %r)", request.method, request.url,
- request.auth_username)
+ curl_log.debug(
+ "%s %s (username: %r)",
+ request.method,
+ request.url,
+ request.auth_username,
+ )
else:
curl.unsetopt(pycurl.USERPWD)
curl_log.debug("%s %s", request.method, request.url)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
- def _curl_header_callback(self, headers: httputil.HTTPHeaders,
- header_callback: Callable[[str], None],
- header_line_bytes: bytes) -> None:
- header_line = native_str(header_line_bytes.decode('latin1'))
+ def _curl_header_callback(
+ self,
+ headers: httputil.HTTPHeaders,
+ header_callback: Callable[[str], None],
+ header_line_bytes: bytes,
+ ) -> None:
+ header_line = native_str(header_line_bytes.decode("latin1"))
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
headers.parse_line(header_line)
def _curl_debug(self, debug_type: int, debug_msg: str) -> None:
- debug_types = ('I', '<', '>', '<', '>')
+ debug_types = ("I", "<", ">", "<", ">")
if debug_type == 0:
debug_msg = native_str(debug_msg)
- curl_log.debug('%s', debug_msg.strip())
+ curl_log.debug("%s", debug_msg.strip())
elif debug_type in (1, 2):
debug_msg = native_str(debug_msg)
for line in debug_msg.splitlines():
- curl_log.debug('%s %s', debug_types[debug_type], line)
+ curl_log.debug("%s %s", debug_types[debug_type], line)
elif debug_type == 4:
- curl_log.debug('%s %r', debug_types[debug_type], debug_msg)
+ curl_log.debug("%s %r", debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
from typing import Union, Any, Optional, Dict, List, Callable
-_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
-_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"',
- '\'': '''}
+_XHTML_ESCAPE_RE = re.compile("[&<>\"']")
+_XHTML_ESCAPE_DICT = {
+ "&": "&",
+ "<": "<",
+ ">": ">",
+ '"': """,
+ "'": "'",
+}
def xhtml_escape(value: Union[str, bytes]) -> str:
Added the single quote to the list of escaped characters.
"""
- return _XHTML_ESCAPE_RE.sub(lambda match: _XHTML_ESCAPE_DICT[match.group(0)],
- to_basestring(value))
+ return _XHTML_ESCAPE_RE.sub(
+ lambda match: _XHTML_ESCAPE_DICT[match.group(0)], to_basestring(value)
+ )
def xhtml_unescape(value: Union[str, bytes]) -> str:
return re.sub(r"[\x00-\x20]+", " ", value).strip()
-def url_escape(value: Union[str, bytes], plus: bool=True) -> str:
+def url_escape(value: Union[str, bytes], plus: bool = True) -> str:
"""Returns a URL-encoded version of the given value.
If ``plus`` is true (the default), spaces will be represented
@typing.overload
-def url_unescape(value: Union[str, bytes], encoding: None, plus: bool=True) -> bytes:
+def url_unescape(value: Union[str, bytes], encoding: None, plus: bool = True) -> bytes:
pass
@typing.overload # noqa: F811
-def url_unescape(value: Union[str, bytes], encoding: str='utf-8', plus: bool=True) -> str:
+def url_unescape(
+ value: Union[str, bytes], encoding: str = "utf-8", plus: bool = True
+) -> str:
pass
-def url_unescape(value: Union[str, bytes], encoding: Optional[str]='utf-8', # noqa: F811
- plus: bool=True) -> Union[str, bytes]:
+def url_unescape( # noqa: F811
+ value: Union[str, bytes], encoding: Optional[str] = "utf-8", plus: bool = True
+) -> Union[str, bytes]:
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
if encoding is None:
if plus:
# unquote_to_bytes doesn't have a _plus variant
- value = to_basestring(value).replace('+', ' ')
+ value = to_basestring(value).replace("+", " ")
return urllib.parse.unquote_to_bytes(value)
else:
- unquote = (urllib.parse.unquote_plus if plus
- else urllib.parse.unquote)
+ unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote
return unquote(to_basestring(value), encoding=encoding)
-def parse_qs_bytes(qs: str, keep_blank_values: bool=False,
- strict_parsing: bool=False) -> Dict[str, List[bytes]]:
+def parse_qs_bytes(
+ qs: str, keep_blank_values: bool = False, strict_parsing: bool = False
+) -> Dict[str, List[bytes]]:
"""Parses a query string like urlparse.parse_qs, but returns the
values as byte strings.
"""
# This is gross, but python3 doesn't give us another way.
# Latin1 is the universal donor of character encodings.
- result = urllib.parse.parse_qs(qs, keep_blank_values, strict_parsing,
- encoding='latin1', errors='strict')
+ result = urllib.parse.parse_qs(
+ qs, keep_blank_values, strict_parsing, encoding="latin1", errors="strict"
+ )
encoded = {}
for k, v in result.items():
- encoded[k] = [i.encode('latin1') for i in v]
+ encoded[k] = [i.encode("latin1") for i in v]
return encoded
if isinstance(value, _UTF8_TYPES):
return value
if not isinstance(value, unicode_type):
- raise TypeError(
- "Expected bytes, unicode, or None; got %r" % type(value)
- )
+ raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.encode("utf-8")
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes):
- raise TypeError(
- "Expected bytes, unicode, or None; got %r" % type(value)
- )
+ raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.decode("utf-8")
if isinstance(value, _BASESTRING_TYPES):
return value
if not isinstance(value, bytes):
- raise TypeError(
- "Expected bytes, unicode, or None; got %r" % type(value)
- )
+ raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.decode("utf-8")
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
- return dict((recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items())
+ return dict(
+ (recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items()
+ )
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
# This regex should avoid those problems.
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
# processed as escapes.
-_URL_RE = re.compile(to_unicode(
- r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501
-))
-
-
-def linkify(text: Union[str, bytes], shorten: bool=False,
- extra_params: Union[str, Callable[[str], str]]="",
- require_protocol: bool=False, permitted_protocols: List[str]=["http", "https"]) -> str:
+_URL_RE = re.compile(
+ to_unicode(
+ r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&|")*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&|")*\)))+)""" # noqa: E501
+ )
+)
+
+
+def linkify(
+ text: Union[str, bytes],
+ shorten: bool = False,
+ extra_params: Union[str, Callable[[str], str]] = "",
+ require_protocol: bool = False,
+ permitted_protocols: List[str] = ["http", "https"],
+) -> str:
"""Converts plain text into HTML with links.
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
href = m.group(1)
if not proto:
- href = "http://" + href # no proto specified, use http
+ href = "http://" + href # no proto specified, use http
if callable(extra_params):
params = " " + extra_params(href).strip()
# The path is usually not that interesting once shortened
# (no more slug, etc), so it really just provides a little
# extra indication of shortening.
- url = url[:proto_len] + parts[0] + "/" + \
- parts[1][:8].split('?')[0].split('.')[0]
+ url = (
+ url[:proto_len]
+ + parts[0]
+ + "/"
+ + parts[1][:8].split("?")[0].split(".")[0]
+ )
if len(url) > max_len * 1.5: # still too long
url = url[:max_len]
if url != before_clip:
- amp = url.rfind('&')
+ amp = url.rfind("&")
# avoid splitting html char entities
if amp > max_len - 5:
url = url[:amp]
def _convert_entity(m: typing.Match) -> str:
if m.group(1) == "#":
try:
- if m.group(2)[:1].lower() == 'x':
+ if m.group(2)[:1].lower() == "x":
return chr(int(m.group(2)[1:], 16))
else:
return chr(int(m.group(2)))
import sys
import types
-from tornado.concurrent import (Future, is_future, chain_future, future_set_exc_info,
- future_add_done_callback, future_set_result_unless_cancelled)
+from tornado.concurrent import (
+ Future,
+ is_future,
+ chain_future,
+ future_set_exc_info,
+ future_add_done_callback,
+ future_set_result_unless_cancelled,
+)
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado.util import TimeoutError
if typing.TYPE_CHECKING:
from typing import Sequence, Deque, Optional, Set, Iterable # noqa: F401
-_T = typing.TypeVar('_T')
+_T = typing.TypeVar("_T")
-_Yieldable = Union[None, Awaitable, List[Awaitable], Dict[Any, Awaitable],
- concurrent.futures.Future]
+_Yieldable = Union[
+ None, Awaitable, List[Awaitable], Dict[Any, Awaitable], concurrent.futures.Future
+]
class KeyReuseError(Exception):
pass
-def _value_from_stopiteration(e: Union[StopIteration, 'Return']) -> Any:
+def _value_from_stopiteration(e: Union[StopIteration, "Return"]) -> Any:
try:
# StopIteration has a value attribute beginning in py33.
# So does our Return class.
return future
-def coroutine(func: Callable[..., 'Generator[Any, Any, _T]']) -> Callable[..., 'Future[_T]']:
+def coroutine(
+ func: Callable[..., "Generator[Any, Any, _T]"]
+) -> Callable[..., "Future[_T]"]:
"""Decorator for asynchronous generators.
Any generator that yields objects from this module must be wrapped
awaitable object instead.
"""
+
@functools.wraps(func)
def wrapper(*args, **kwargs):
# type: (*Any, **Any) -> Future[_T]
try:
yielded = next(result)
except (StopIteration, Return) as e:
- future_set_result_unless_cancelled(future, _value_from_stopiteration(e))
+ future_set_result_unless_cancelled(
+ future, _value_from_stopiteration(e)
+ )
except Exception:
future_set_exc_info(future, sys.exc_info())
else:
.. versionadded:: 4.5
"""
- return getattr(func, '__tornado_coroutine__', False)
+ return getattr(func, "__tornado_coroutine__", False)
class Return(Exception):
but it is never necessary to ``raise gen.Return()``. The ``return``
statement can be used with no arguments instead.
"""
- def __init__(self, value: Any=None) -> None:
+
+ def __init__(self, value: Any = None) -> None:
super(Return, self).__init__()
self.value = value
# Cython recognizes subclasses of StopIteration with a .args tuple.
def __init__(self, *args: Future, **kwargs: Future) -> None:
if args and kwargs:
- raise ValueError(
- "You must provide args or kwargs, not both")
+ raise ValueError("You must provide args or kwargs, not both")
if kwargs:
self._unfinished = dict((f, k) for (k, f) in kwargs.items())
def __anext__(self) -> Future:
if self.done():
# Lookup by name to silence pyflakes on older versions.
- raise getattr(builtins, 'StopAsyncIteration')()
+ raise getattr(builtins, "StopAsyncIteration")()
return self.next()
def multi(
- children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
- quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]]=(),
-) -> Union['Future[List]', 'Future[Dict]']:
+ children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
+ quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (),
+) -> Union["Future[List]", "Future[Dict]"]:
"""Runs multiple asynchronous operations in parallel.
``children`` may either be a list or a dict whose values are
def multi_future(
- children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
- quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]]=(),
-) -> Union['Future[List]', 'Future[Dict]']:
+ children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
+ quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (),
+) -> Union["Future[List]", "Future[Dict]"]:
"""Wait for multiple asynchronous futures in parallel.
Since Tornado 6.0, this function is exactly the same as `multi`.
future = _create_future()
if not children_futs:
- future_set_result_unless_cancelled(future,
- {} if keys is not None else [])
+ future_set_result_unless_cancelled(future, {} if keys is not None else [])
def callback(fut: Future) -> None:
unfinished_children.remove(fut)
except Exception as e:
if future.done():
if not isinstance(e, quiet_exceptions):
- app_log.error("Multiple exceptions in yield list",
- exc_info=True)
+ app_log.error(
+ "Multiple exceptions in yield list", exc_info=True
+ )
else:
future_set_exc_info(future, sys.exc_info())
if not future.done():
if keys is not None:
- future_set_result_unless_cancelled(future,
- dict(zip(keys, result_list)))
+ future_set_result_unless_cancelled(
+ future, dict(zip(keys, result_list))
+ )
else:
future_set_result_unless_cancelled(future, result_list)
def with_timeout(
- timeout: Union[float, datetime.timedelta], future: _Yieldable,
- quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]]=(),
+ timeout: Union[float, datetime.timedelta],
+ future: _Yieldable,
+ quiet_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (),
) -> Future:
"""Wraps a `.Future` (or other yieldable object) in a timeout.
future.result()
except Exception as e:
if not isinstance(e, quiet_exceptions):
- app_log.error("Exception in Future %r after timeout",
- future, exc_info=True)
+ app_log.error(
+ "Exception in Future %r after timeout", future, exc_info=True
+ )
def timeout_callback() -> None:
if not result.done():
result.set_exception(TimeoutError("Timeout"))
# In case the wrapped future goes on to fail, log it.
future_add_done_callback(future_converted, error_callback)
- timeout_handle = io_loop.add_timeout(
- timeout, timeout_callback)
+
+ timeout_handle = io_loop.add_timeout(timeout, timeout_callback)
if isinstance(future_converted, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
# don't care about StackContext here.
future_add_done_callback(
- future_converted, lambda future: io_loop.remove_timeout(timeout_handle))
+ future_converted, lambda future: io_loop.remove_timeout(timeout_handle)
+ )
else:
# concurrent.futures.Futures may resolve on any thread, so we
# need to route them back to the IOLoop.
io_loop.add_future(
- future_converted, lambda future: io_loop.remove_timeout(timeout_handle))
+ future_converted, lambda future: io_loop.remove_timeout(timeout_handle)
+ )
return result
-def sleep(duration: float) -> 'Future[None]':
+def sleep(duration: float) -> "Future[None]":
"""Return a `.Future` that resolves after the given number of seconds.
When used with ``yield`` in a coroutine, this is a non-blocking
.. versionadded:: 4.1
"""
f = _create_future()
- IOLoop.current().call_later(duration,
- lambda: future_set_result_unless_cancelled(f, None))
+ IOLoop.current().call_later(
+ duration, lambda: future_set_result_unless_cancelled(f, None)
+ )
return f
a _NullFuture into a code path that doesn't understand what to do
with it.
"""
+
def result(self) -> None:
return None
_null_future = typing.cast(Future, _NullFuture())
moment = typing.cast(Future, _NullFuture())
-moment.__doc__ = \
- """A special object which may be yielded to allow the IOLoop to run for
+moment.__doc__ = """A special object which may be yielded to allow the IOLoop to run for
one iteration.
This is not needed in normal use but it can be helpful in long-running
The results of the generator are stored in ``result_future`` (a
`.Future`)
"""
- def __init__(self, gen: 'Generator[_Yieldable, Any, _T]', result_future: 'Future[_T]',
- first_yielded: _Yieldable) -> None:
+
+ def __init__(
+ self,
+ gen: "Generator[_Yieldable, Any, _T]",
+ result_future: "Future[_T]",
+ first_yielded: _Yieldable,
+ ) -> None:
self.gen = gen
self.result_future = result_future
self.future = _null_future # type: Union[None, Future]
except (StopIteration, Return) as e:
self.finished = True
self.future = _null_future
- future_set_result_unless_cancelled(self.result_future,
- _value_from_stopiteration(e))
+ future_set_result_unless_cancelled(
+ self.result_future, _value_from_stopiteration(e)
+ )
self.result_future = None # type: ignore
return
except Exception:
elif self.future is None:
raise Exception("no pending future")
elif not self.future.done():
+
def inner(f: Any) -> None:
# Break a reference cycle to speed GC.
f = None # noqa
self.run()
- self.io_loop.add_future(
- self.future, inner)
+
+ self.io_loop.add_future(self.future, inner)
return False
return True
- def handle_exception(self, typ: Type[Exception], value: Exception,
- tb: types.TracebackType) -> bool:
+ def handle_exception(
+ self, typ: Type[Exception], value: Exception, tb: types.TracebackType
+ ) -> bool:
if not self.running and not self.finished:
self.future = Future()
future_set_exc_info(self.future, (typ, value, tb))
except AttributeError:
# asyncio.ensure_future was introduced in Python 3.4.4, but
# Debian jessie still ships with 3.4.2 so try the old name.
- _wrap_awaitable = getattr(asyncio, 'async')
+ _wrap_awaitable = getattr(asyncio, "async")
def convert_yielded(yielded: _Yieldable) -> Future:
import re
import types
-from tornado.concurrent import (Future, future_add_done_callback,
- future_set_result_unless_cancelled)
+from tornado.concurrent import (
+ Future,
+ future_add_done_callback,
+ future_set_result_unless_cancelled,
+)
from tornado.escape import native_str, utf8
from tornado import gen
from tornado import httputil
from tornado.util import GzipDecompressor
-from typing import cast, Optional, Type, Awaitable, Generator, Any, Callable, Union, Tuple
+from typing import (
+ cast,
+ Optional,
+ Type,
+ Awaitable,
+ Generator,
+ Any,
+ Callable,
+ Union,
+ Tuple,
+)
class _QuietException(Exception):
log any exceptions with the given logger. Any exceptions caught are
converted to _QuietException
"""
+
def __init__(self, logger: logging.Logger) -> None:
self.logger = logger
def __enter__(self) -> None:
pass
- def __exit__(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: types.TracebackType) -> None:
+ def __exit__(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: types.TracebackType,
+ ) -> None:
if value is not None:
assert typ is not None
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
- def __init__(self, no_keep_alive: bool=False, chunk_size: int=None,
- max_header_size: int=None, header_timeout: float=None,
- max_body_size: int=None, body_timeout: float=None,
- decompress: bool=False) -> None:
+
+ def __init__(
+ self,
+ no_keep_alive: bool = False,
+ chunk_size: int = None,
+ max_header_size: int = None,
+ header_timeout: float = None,
+ max_body_size: int = None,
+ body_timeout: float = None,
+ decompress: bool = False,
+ ) -> None:
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
This class can be on its own for clients, or via `HTTP1ServerConnection`
for servers.
"""
- def __init__(self, stream: iostream.IOStream, is_client: bool,
- params: HTTP1ConnectionParameters=None, context: object=None) -> None:
+
+ def __init__(
+ self,
+ stream: iostream.IOStream,
+ is_client: bool,
+ params: HTTP1ConnectionParameters = None,
+ context: object = None,
+ ) -> None:
"""
:arg stream: an `.IOStream`
:arg bool is_client: client or server
self.no_keep_alive = params.no_keep_alive
# The body limits can be altered by the delegate, so save them
# here instead of just referencing self.params later.
- self._max_body_size = (self.params.max_body_size or
- self.stream.max_buffer_size)
+ self._max_body_size = self.params.max_body_size or self.stream.max_buffer_size
self._body_timeout = self.params.body_timeout
# _write_finished is set to True when finish() has been called,
# i.e. there will be no more data sent. Data may still be in the
return self._read_message(delegate)
@gen.coroutine
- def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, bool]:
+ def _read_message(
+ self, delegate: httputil.HTTPMessageDelegate
+ ) -> Generator[Any, Any, bool]:
need_delegate_close = False
try:
header_future = self.stream.read_until_regex(
- b"\r?\n\r?\n",
- max_bytes=self.params.max_header_size)
+ b"\r?\n\r?\n", max_bytes=self.params.max_header_size
+ )
if self.params.header_timeout is None:
header_data = yield header_future
else:
header_data = yield gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
- quiet_exceptions=iostream.StreamClosedError)
+ quiet_exceptions=iostream.StreamClosedError,
+ )
except gen.TimeoutError:
self.close()
return False
if self.is_client:
resp_start_line = httputil.parse_response_start_line(start_line_str)
self._response_start_line = resp_start_line
- start_line = resp_start_line \
- # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine]
+ start_line = (
+ resp_start_line
+ ) # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine]
# TODO: this will need to change to support client-side keepalive
self._disconnect_on_finish = False
else:
self._request_headers = headers
start_line = req_start_line
self._disconnect_on_finish = not self._can_keep_alive(
- req_start_line, headers)
+ req_start_line, headers
+ )
need_delegate_close = True
with _ExceptionLoggingContext(app_log):
header_recv_future = delegate.headers_received(start_line, headers)
skip_body = False
if self.is_client:
assert isinstance(start_line, httputil.ResponseStartLine)
- if (self._request_start_line is not None and
- self._request_start_line.method == 'HEAD'):
+ if (
+ self._request_start_line is not None
+ and self._request_start_line.method == "HEAD"
+ ):
skip_body = True
code = start_line.code
if code == 304:
if code >= 100 and code < 200:
# 1xx responses should never indicate the presence of
# a body.
- if ('Content-Length' in headers or
- 'Transfer-Encoding' in headers):
+ if "Content-Length" in headers or "Transfer-Encoding" in headers:
raise httputil.HTTPInputError(
- "Response code %d cannot have body" % code)
+ "Response code %d cannot have body" % code
+ )
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(delegate)
else:
- if (headers.get("Expect") == "100-continue" and
- not self._write_finished):
+ if headers.get("Expect") == "100-continue" and not self._write_finished:
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(
- resp_start_line.code if self.is_client else 0, headers, delegate)
+ resp_start_line.code if self.is_client else 0, headers, delegate
+ )
if body_future is not None:
if self._body_timeout is None:
yield body_future
yield gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future,
- quiet_exceptions=iostream.StreamClosedError)
+ quiet_exceptions=iostream.StreamClosedError,
+ )
except gen.TimeoutError:
- gen_log.info("Timeout reading body from %s",
- self.context)
+ gen_log.info("Timeout reading body from %s", self.context)
self.stream.close()
return False
self._read_finished = True
# If we're waiting for the application to produce an asynchronous
# response, and we're not detached, register a close callback
# on the stream (we didn't need one while we were reading)
- if (not self._finish_future.done() and
- self.stream is not None and
- not self.stream.closed()):
+ if (
+ not self._finish_future.done()
+ and self.stream is not None
+ and not self.stream.closed()
+ ):
self.stream.set_close_callback(self._on_connection_close)
yield self._finish_future
if self.is_client and self._disconnect_on_finish:
if self.stream is None:
return False
except httputil.HTTPInputError as e:
- gen_log.info("Malformed HTTP message from %s: %s",
- self.context, e)
+ gen_log.info("Malformed HTTP message from %s: %s", self.context, e)
if not self.is_client:
- yield self.stream.write(b'HTTP/1.1 400 Bad Request\r\n\r\n')
+ yield self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
self.close()
return False
finally:
"""
self._max_body_size = max_body_size
- def write_headers(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders, chunk: bytes=None) -> 'Future[None]':
+ def write_headers(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ chunk: bytes = None,
+ ) -> "Future[None]":
"""Implements `.HTTPConnection.write_headers`."""
lines = []
if self.is_client:
assert isinstance(start_line, httputil.RequestStartLine)
self._request_start_line = start_line
- lines.append(utf8('%s %s HTTP/1.1' % (start_line[0], start_line[1])))
+ lines.append(utf8("%s %s HTTP/1.1" % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
- start_line.method in ('POST', 'PUT', 'PATCH') and
- 'Content-Length' not in headers and
- 'Transfer-Encoding' not in headers)
+ start_line.method in ("POST", "PUT", "PATCH")
+ and "Content-Length" not in headers
+ and "Transfer-Encoding" not in headers
+ )
else:
assert isinstance(start_line, httputil.ResponseStartLine)
assert self._request_start_line is not None
assert self._request_headers is not None
self._response_start_line = start_line
- lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
+ lines.append(utf8("HTTP/1.1 %d %s" % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
# start_line.version?
- self._request_start_line.version == 'HTTP/1.1' and
+ self._request_start_line.version == "HTTP/1.1"
+ and
# 1xx, 204 and 304 responses have no body (not even a zero-length
# body), and so should not have either Content-Length or
# Transfer-Encoding headers.
- start_line.code not in (204, 304) and
- (start_line.code < 100 or start_line.code >= 200) and
+ start_line.code not in (204, 304)
+ and (start_line.code < 100 or start_line.code >= 200)
+ and
# No need to chunk the output if a Content-Length is specified.
- 'Content-Length' not in headers and
+ "Content-Length" not in headers
+ and
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
- 'Transfer-Encoding' not in headers)
+ "Transfer-Encoding" not in headers
+ )
# If connection to a 1.1 client will be closed, inform client
- if (self._request_start_line.version == 'HTTP/1.1' and self._disconnect_on_finish):
- headers['Connection'] = 'close'
+ if (
+ self._request_start_line.version == "HTTP/1.1"
+ and self._disconnect_on_finish
+ ):
+ headers["Connection"] = "close"
# If a 1.0 client asked for keep-alive, add the header.
- if (self._request_start_line.version == 'HTTP/1.0' and
- self._request_headers.get('Connection', '').lower() == 'keep-alive'):
- headers['Connection'] = 'Keep-Alive'
+ if (
+ self._request_start_line.version == "HTTP/1.0"
+ and self._request_headers.get("Connection", "").lower() == "keep-alive"
+ ):
+ headers["Connection"] = "Keep-Alive"
if self._chunking_output:
- headers['Transfer-Encoding'] = 'chunked'
- if (not self.is_client and
- (self._request_start_line.method == 'HEAD' or
- cast(httputil.ResponseStartLine, start_line).code == 304)):
+ headers["Transfer-Encoding"] = "chunked"
+ if not self.is_client and (
+ self._request_start_line.method == "HEAD"
+ or cast(httputil.ResponseStartLine, start_line).code == 304
+ ):
self._expected_content_remaining = 0
- elif 'Content-Length' in headers:
- self._expected_content_remaining = int(headers['Content-Length'])
+ elif "Content-Length" in headers:
+ self._expected_content_remaining = int(headers["Content-Length"])
else:
self._expected_content_remaining = None
# TODO: headers are supposed to be of type str, but we still have some
# cases that let bytes slip through. Remove these native_str calls when those
# are fixed.
- header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
- lines.extend(l.encode('latin1') for l in header_lines)
+ header_lines = (
+ native_str(n) + ": " + native_str(v) for n, v in headers.get_all()
+ )
+ lines.extend(l.encode("latin1") for l in header_lines)
for line in lines:
- if b'\n' in line:
- raise ValueError('Newline in header: ' + repr(line))
+ if b"\n" in line:
+ raise ValueError("Newline in header: " + repr(line))
future = None
if self.stream.closed():
future = self._write_future = Future()
# Close the stream now to stop further framing errors.
self.stream.close()
raise httputil.HTTPOutputError(
- "Tried to write more data than Content-Length")
+ "Tried to write more data than Content-Length"
+ )
if self._chunking_output and chunk:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
else:
return chunk
- def write(self, chunk: bytes) -> 'Future[None]':
+ def write(self, chunk: bytes) -> "Future[None]":
"""Implements `.HTTPConnection.write`.
For backwards compatibility it is allowed but deprecated to
def finish(self) -> None:
"""Implements `.HTTPConnection.finish`."""
- if (self._expected_content_remaining is not None and
- self._expected_content_remaining != 0 and
- not self.stream.closed()):
+ if (
+ self._expected_content_remaining is not None
+ and self._expected_content_remaining != 0
+ and not self.stream.closed()
+ ):
self.stream.close()
raise httputil.HTTPOutputError(
- "Tried to write %d bytes less than Content-Length" %
- self._expected_content_remaining)
+ "Tried to write %d bytes less than Content-Length"
+ % self._expected_content_remaining
+ )
if self._chunking_output:
if not self.stream.closed():
self._pending_write = self.stream.write(b"0\r\n\r\n")
else:
future_add_done_callback(self._pending_write, self._finish_request)
- def _on_write_complete(self, future: 'Future[None]') -> None:
+ def _on_write_complete(self, future: "Future[None]") -> None:
exc = future.exception()
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
future.result()
self._write_future = None
future_set_result_unless_cancelled(future, None)
- def _can_keep_alive(self, start_line: httputil.RequestStartLine,
- headers: httputil.HTTPHeaders) -> bool:
+ def _can_keep_alive(
+ self, start_line: httputil.RequestStartLine, headers: httputil.HTTPHeaders
+ ) -> bool:
if self.params.no_keep_alive:
return False
connection_header = headers.get("Connection")
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
- elif ("Content-Length" in headers or
- headers.get("Transfer-Encoding", "").lower() == "chunked" or
- getattr(start_line, 'method', None) in ("HEAD", "GET")):
+ elif (
+ "Content-Length" in headers
+ or headers.get("Transfer-Encoding", "").lower() == "chunked"
+ or getattr(start_line, "method", None) in ("HEAD", "GET")
+ ):
# start_line may be a request or response start line; only
# the former has a method attribute.
return connection_header == "keep-alive"
return False
- def _finish_request(self, future: Optional['Future[None]']) -> None:
+ def _finish_request(self, future: Optional["Future[None]"]) -> None:
self._clear_callbacks()
if not self.is_client and self._disconnect_on_finish:
self.close()
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
- data_str = native_str(data.decode('latin1')).lstrip("\r\n")
+ data_str = native_str(data.decode("latin1")).lstrip("\r\n")
# RFC 7230 section allows for both CRLF and bare LF.
eol = data_str.find("\n")
start_line = data_str[:eol].rstrip("\r")
headers = httputil.HTTPHeaders.parse(data_str[eol:])
return start_line, headers
- def _read_body(self, code: int, headers: httputil.HTTPHeaders,
- delegate: httputil.HTTPMessageDelegate) -> Optional[Awaitable[None]]:
+ def _read_body(
+ self,
+ code: int,
+ headers: httputil.HTTPHeaders,
+ delegate: httputil.HTTPMessageDelegate,
+ ) -> Optional[Awaitable[None]]:
if "Content-Length" in headers:
if "Transfer-Encoding" in headers:
# Response cannot contain both Content-Length and
# Transfer-Encoding headers.
# http://tools.ietf.org/html/rfc7230#section-3.3.3
raise httputil.HTTPInputError(
- "Response with both Transfer-Encoding and Content-Length")
+ "Response with both Transfer-Encoding and Content-Length"
+ )
if "," in headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
- pieces = re.split(r',\s*', headers["Content-Length"])
+ pieces = re.split(r",\s*", headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
- "Multiple unequal Content-Lengths: %r" %
- headers["Content-Length"])
+ "Multiple unequal Content-Lengths: %r"
+ % headers["Content-Length"]
+ )
headers["Content-Length"] = pieces[0]
try:
except ValueError:
# Handles non-integer Content-Length value.
raise httputil.HTTPInputError(
- "Only integer Content-Length is allowed: %s" % headers["Content-Length"])
+ "Only integer Content-Length is allowed: %s"
+ % headers["Content-Length"]
+ )
if cast(int, content_length) > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
- if ("Transfer-Encoding" in headers or
- content_length not in (None, 0)):
+ if "Transfer-Encoding" in headers or content_length not in (None, 0):
raise httputil.HTTPInputError(
- "Response with code %d should not have body" % code)
+ "Response with code %d should not have body" % code
+ )
content_length = 0
if content_length is not None:
return None
@gen.coroutine
- def _read_fixed_body(self, content_length: int,
- delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, None]:
+ def _read_fixed_body(
+ self, content_length: int, delegate: httputil.HTTPMessageDelegate
+ ) -> Generator[Any, Any, None]:
while content_length > 0:
body = yield self.stream.read_bytes(
- min(self.params.chunk_size, content_length), partial=True)
+ min(self.params.chunk_size, content_length), partial=True
+ )
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
@gen.coroutine
def _read_chunked_body(
- self, delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, None]:
+ self, delegate: httputil.HTTPMessageDelegate
+ ) -> Generator[Any, Any, None]:
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
total_size = 0
while True:
chunk_len = int(chunk_len.strip(), 16)
if chunk_len == 0:
crlf = yield self.stream.read_bytes(2)
- if crlf != b'\r\n':
- raise httputil.HTTPInputError("improperly terminated chunked request")
+ if crlf != b"\r\n":
+ raise httputil.HTTPInputError(
+ "improperly terminated chunked request"
+ )
return
total_size += chunk_len
if total_size > self._max_body_size:
bytes_to_read = chunk_len
while bytes_to_read:
chunk = yield self.stream.read_bytes(
- min(bytes_to_read, self.params.chunk_size), partial=True)
+ min(bytes_to_read, self.params.chunk_size), partial=True
+ )
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
@gen.coroutine
def _read_body_until_close(
- self, delegate: httputil.HTTPMessageDelegate) -> Generator[Any, Any, None]:
+ self, delegate: httputil.HTTPMessageDelegate
+ ) -> Generator[Any, Any, None]:
body = yield self.stream.read_until_close()
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
"""Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
"""
+
def __init__(self, delegate: httputil.HTTPMessageDelegate, chunk_size: int) -> None:
self._delegate = delegate
self._chunk_size = chunk_size
self._decompressor = None # type: Optional[GzipDecompressor]
- def headers_received(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]:
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
if headers.get("Content-Encoding") == "gzip":
self._decompressor = GzipDecompressor()
# Downstream delegates will only see uncompressed data,
# so rename the content-encoding header.
# (but note that curl_httpclient doesn't do this).
- headers.add("X-Consumed-Content-Encoding",
- headers["Content-Encoding"])
+ headers.add("X-Consumed-Content-Encoding", headers["Content-Encoding"])
del headers["Content-Encoding"]
return self._delegate.headers_received(start_line, headers)
compressed_data = chunk
while compressed_data:
decompressed = self._decompressor.decompress(
- compressed_data, self._chunk_size)
+ compressed_data, self._chunk_size
+ )
if decompressed:
ret = self._delegate.data_received(decompressed)
if ret is not None:
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""
- def __init__(self, stream: iostream.IOStream,
- params: HTTP1ConnectionParameters=None, context: object=None) -> None:
+
+ def __init__(
+ self,
+ stream: iostream.IOStream,
+ params: HTTP1ConnectionParameters = None,
+ context: object = None,
+ ) -> None:
"""
:arg stream: an `.IOStream`
:arg params: a `.HTTP1ConnectionParameters` or None
@gen.coroutine
def _server_request_loop(
- self, delegate: httputil.HTTPServerConnectionDelegate) -> Generator[Any, Any, None]:
+ self, delegate: httputil.HTTPServerConnectionDelegate
+ ) -> Generator[Any, Any, None]:
try:
while True:
- conn = HTTP1Connection(self.stream, False,
- self.params, self.context)
+ conn = HTTP1Connection(self.stream, False, self.params, self.context)
request_delegate = delegate.start_request(self, conn)
try:
ret = yield conn.read_response(request_delegate)
- except (iostream.StreamClosedError,
- iostream.UnsatisfiableReadError):
+ except (iostream.StreamClosedError, iostream.UnsatisfiableReadError):
return
except _QuietException:
# This exception was already logged.
Use `AsyncHTTPClient` instead.
"""
- def __init__(self, async_client_class: Type['AsyncHTTPClient']=None, **kwargs: Any) -> None:
+
+ def __init__(
+ self, async_client_class: Type["AsyncHTTPClient"] = None, **kwargs: Any
+ ) -> None:
# Initialize self._closed at the beginning of the constructor
# so that an exception raised here doesn't lead to confusing
# failures in __del__.
# Create the client while our IOLoop is "current", without
# clobbering the thread's real current IOLoop (if any).
- async def make_client() -> 'AsyncHTTPClient':
+ async def make_client() -> "AsyncHTTPClient":
await gen.sleep(0)
assert async_client_class is not None
return async_client_class(**kwargs)
+
self._async_client = self._io_loop.run_sync(make_client)
self._closed = False
self._io_loop.close()
self._closed = True
- def fetch(self, request: Union['HTTPRequest', str], **kwargs: Any) -> 'HTTPResponse':
+ def fetch(
+ self, request: Union["HTTPRequest", str], **kwargs: Any
+ ) -> "HTTPResponse":
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If an error occurs during the fetch, we raise an `HTTPError` unless
the ``raise_error`` keyword argument is set to False.
"""
- response = self._io_loop.run_sync(functools.partial(
- self._async_client.fetch, request, **kwargs))
+ response = self._io_loop.run_sync(
+ functools.partial(self._async_client.fetch, request, **kwargs)
+ )
return response
@classmethod
def configurable_default(cls) -> Type[Configurable]:
from tornado.simple_httpclient import SimpleAsyncHTTPClient
+
return SimpleAsyncHTTPClient
@classmethod
- def _async_clients(cls) -> Dict[IOLoop, 'AsyncHTTPClient']:
- attr_name = '_async_client_dict_' + cls.__name__
+ def _async_clients(cls) -> Dict[IOLoop, "AsyncHTTPClient"]:
+ attr_name = "_async_client_dict_" + cls.__name__
if not hasattr(cls, attr_name):
setattr(cls, attr_name, weakref.WeakKeyDictionary())
return getattr(cls, attr_name)
- def __new__(cls, force_instance: bool=False, **kwargs: Any) -> 'AsyncHTTPClient':
+ def __new__(cls, force_instance: bool = False, **kwargs: Any) -> "AsyncHTTPClient":
io_loop = IOLoop.current()
if force_instance:
instance_cache = None
instance_cache[instance.io_loop] = instance
return instance
- def initialize(self, defaults: Dict[str, Any]=None) -> None:
+ def initialize(self, defaults: Dict[str, Any] = None) -> None:
self.io_loop = IOLoop.current()
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
raise RuntimeError("inconsistent AsyncHTTPClient cache")
del self._instance_cache[self.io_loop]
- def fetch(self, request: Union[str, 'HTTPRequest'],
- raise_error: bool=True, **kwargs: Any) -> 'Future[HTTPResponse]':
+ def fetch(
+ self,
+ request: Union[str, "HTTPRequest"],
+ raise_error: bool = True,
+ **kwargs: Any
+ ) -> "Future[HTTPResponse]":
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
request = HTTPRequest(url=request, **kwargs)
else:
if kwargs:
- raise ValueError("kwargs can't be used if request is an HTTPRequest object")
+ raise ValueError(
+ "kwargs can't be used if request is an HTTPRequest object"
+ )
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request_proxy = _RequestProxy(request, self.defaults)
future = Future() # type: Future[HTTPResponse]
- def handle_response(response: 'HTTPResponse') -> None:
+ def handle_response(response: "HTTPResponse") -> None:
if response.error:
if raise_error or not response._error_is_response_code:
future.set_exception(response.error)
return
future_set_result_unless_cancelled(future, response)
+
self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response)
return future
- def fetch_impl(self, request: 'HTTPRequest',
- callback: Callable[['HTTPResponse'], None]) -> None:
+ def fetch_impl(
+ self, request: "HTTPRequest", callback: Callable[["HTTPResponse"], None]
+ ) -> None:
raise NotImplementedError()
@classmethod
- def configure(cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any) -> None:
+ def configure(
+ cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any
+ ) -> None:
"""Configures the `AsyncHTTPClient` subclass to use.
``AsyncHTTPClient()`` actually creates an instance of a subclass.
class HTTPRequest(object):
"""HTTP client request object."""
+
_headers = None # type: Union[Dict[str, str], httputil.HTTPHeaders]
# Default values for HTTPRequest parameters.
follow_redirects=True,
max_redirects=5,
decompress_response=True,
- proxy_password='',
+ proxy_password="",
allow_nonstandard_methods=False,
- validate_cert=True)
-
- def __init__(self, url: str, method: str="GET",
- headers: Union[Dict[str, str], httputil.HTTPHeaders]=None,
- body: Union[bytes, str]=None,
- auth_username: str=None, auth_password: str=None, auth_mode: str=None,
- connect_timeout: float=None, request_timeout: float=None,
- if_modified_since: Union[float, datetime.datetime]=None,
- follow_redirects: bool=None,
- max_redirects: int=None, user_agent: str=None, use_gzip: bool=None,
- network_interface: str=None,
- streaming_callback: Callable[[bytes], None]=None,
- header_callback: Callable[[str], None]=None,
- prepare_curl_callback: Callable[[Any], None]=None,
- proxy_host: str=None, proxy_port: int=None, proxy_username: str=None,
- proxy_password: str=None, proxy_auth_mode: str=None,
- allow_nonstandard_methods: bool=None, validate_cert: bool=None,
- ca_certs: str=None, allow_ipv6: bool=None, client_key: str=None,
- client_cert: str=None,
- body_producer: Callable[[Callable[[bytes], None]], 'Future[None]']=None,
- expect_100_continue: bool=False, decompress_response: bool=None,
- ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None) -> None:
+ validate_cert=True,
+ )
+
+ def __init__(
+ self,
+ url: str,
+ method: str = "GET",
+ headers: Union[Dict[str, str], httputil.HTTPHeaders] = None,
+ body: Union[bytes, str] = None,
+ auth_username: str = None,
+ auth_password: str = None,
+ auth_mode: str = None,
+ connect_timeout: float = None,
+ request_timeout: float = None,
+ if_modified_since: Union[float, datetime.datetime] = None,
+ follow_redirects: bool = None,
+ max_redirects: int = None,
+ user_agent: str = None,
+ use_gzip: bool = None,
+ network_interface: str = None,
+ streaming_callback: Callable[[bytes], None] = None,
+ header_callback: Callable[[str], None] = None,
+ prepare_curl_callback: Callable[[Any], None] = None,
+ proxy_host: str = None,
+ proxy_port: int = None,
+ proxy_username: str = None,
+ proxy_password: str = None,
+ proxy_auth_mode: str = None,
+ allow_nonstandard_methods: bool = None,
+ validate_cert: bool = None,
+ ca_certs: str = None,
+ allow_ipv6: bool = None,
+ client_key: str = None,
+ client_cert: str = None,
+ body_producer: Callable[[Callable[[bytes], None]], "Future[None]"] = None,
+ expect_100_continue: bool = False,
+ decompress_response: bool = None,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
+ ) -> None:
r"""All parameters except ``url`` are optional.
:arg str url: URL to fetch
self.headers = headers
if if_modified_since:
self.headers["If-Modified-Since"] = httputil.format_timestamp(
- if_modified_since)
+ if_modified_since
+ )
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_username = proxy_username
is excluded in both implementations. ``request_time`` is now more accurate for
``curl_httpclient`` because it uses a monotonic clock when available.
"""
+
# I'm not sure why these don't get type-inferred from the references in __init__.
error = None # type: Optional[BaseException]
_error_is_response_code = False
request = None # type: HTTPRequest
- def __init__(self, request: HTTPRequest, code: int,
- headers: httputil.HTTPHeaders=None, buffer: BytesIO=None,
- effective_url: str=None, error: BaseException=None,
- request_time: float=None, time_info: Dict[str, float]=None,
- reason: str=None, start_time: float=None) -> None:
+ def __init__(
+ self,
+ request: HTTPRequest,
+ code: int,
+ headers: httputil.HTTPHeaders = None,
+ buffer: BytesIO = None,
+ effective_url: str = None,
+ error: BaseException = None,
+ request_time: float = None,
+ time_info: Dict[str, float] = None,
+ reason: str = None,
+ start_time: float = None,
+ ) -> None:
if isinstance(request, _RequestProxy):
self.request = request.request
else:
if error is None:
if self.code < 200 or self.code >= 300:
self._error_is_response_code = True
- self.error = HTTPError(self.code, message=self.reason,
- response=self)
+ self.error = HTTPError(self.code, message=self.reason, response=self)
else:
self.error = None
else:
`tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains
as an alias.
"""
- def __init__(self, code: int, message: str=None, response: HTTPResponse=None) -> None:
+
+ def __init__(
+ self, code: int, message: str = None, response: HTTPResponse = None
+ ) -> None:
self.code = code
self.message = message or httputil.responses.get(code, "Unknown")
self.response = response
Used internally by AsyncHTTPClient implementations.
"""
- def __init__(self, request: HTTPRequest, defaults: Optional[Dict[str, Any]]) -> None:
+
+ def __init__(
+ self, request: HTTPRequest, defaults: Optional[Dict[str, Any]]
+ ) -> None:
self.request = request
self.defaults = defaults
def main() -> None:
from tornado.options import define, options, parse_command_line
+
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
client = HTTPClient()
for arg in args:
try:
- response = client.fetch(arg,
- follow_redirects=options.follow_redirects,
- validate_cert=options.validate_cert,
- proxy_host=options.proxy_host,
- proxy_port=options.proxy_port,
- )
+ response = client.fetch(
+ arg,
+ follow_redirects=options.follow_redirects,
+ validate_cert=options.validate_cert,
+ proxy_host=options.proxy_host,
+ proxy_port=options.proxy_port,
+ )
except HTTPError as e:
if e.response is not None:
response = e.response
from tornado.util import Configurable
import typing
-from typing import Union, Any, Dict, Callable, List, Type, Generator, Tuple, Optional, Awaitable
+from typing import (
+ Union,
+ Any,
+ Dict,
+ Callable,
+ List,
+ Type,
+ Generator,
+ Tuple,
+ Optional,
+ Awaitable,
+)
+
if typing.TYPE_CHECKING:
from typing import Set # noqa: F401
-class HTTPServer(TCPServer, Configurable,
- httputil.HTTPServerConnectionDelegate):
+class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
# Ignore args to __init__; real initialization belongs in
# initialize since we're Configurable. (there's something
# completely)
pass
- def initialize(self, # type: ignore
- request_callback: Union[httputil.HTTPServerConnectionDelegate,
- Callable[[httputil.HTTPServerRequest], None]],
- no_keep_alive: bool=False,
- xheaders: bool=False,
- ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
- protocol: str=None,
- decompress_request: bool=False,
- chunk_size: int=None,
- max_header_size: int=None,
- idle_connection_timeout: float=None,
- body_timeout: float=None,
- max_body_size: int=None,
- max_buffer_size: int=None,
- trusted_downstream: List[str]=None) -> None:
+ def initialize( # type: ignore
+ self,
+ request_callback: Union[
+ httputil.HTTPServerConnectionDelegate,
+ Callable[[httputil.HTTPServerRequest], None],
+ ],
+ no_keep_alive: bool = False,
+ xheaders: bool = False,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
+ protocol: str = None,
+ decompress_request: bool = False,
+ chunk_size: int = None,
+ max_header_size: int = None,
+ idle_connection_timeout: float = None,
+ body_timeout: float = None,
+ max_body_size: int = None,
+ max_buffer_size: int = None,
+ trusted_downstream: List[str] = None,
+ ) -> None:
self.request_callback = request_callback
self.xheaders = xheaders
self.protocol = protocol
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout,
- no_keep_alive=no_keep_alive)
- TCPServer.__init__(self, ssl_options=ssl_options,
- max_buffer_size=max_buffer_size,
- read_chunk_size=chunk_size)
+ no_keep_alive=no_keep_alive,
+ )
+ TCPServer.__init__(
+ self,
+ ssl_options=ssl_options,
+ max_buffer_size=max_buffer_size,
+ read_chunk_size=chunk_size,
+ )
self._connections = set() # type: Set[HTTP1ServerConnection]
self.trusted_downstream = trusted_downstream
yield conn.close()
def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None:
- context = _HTTPRequestContext(stream, address,
- self.protocol,
- self.trusted_downstream)
- conn = HTTP1ServerConnection(
- stream, self.conn_params, context)
+ context = _HTTPRequestContext(
+ stream, address, self.protocol, self.trusted_downstream
+ )
+ conn = HTTP1ServerConnection(stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
- def start_request(self, server_conn: object,
- request_conn: httputil.HTTPConnection) -> httputil.HTTPMessageDelegate:
+ def start_request(
+ self, server_conn: object, request_conn: httputil.HTTPConnection
+ ) -> httputil.HTTPMessageDelegate:
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
delegate = self.request_callback.start_request(server_conn, request_conn)
else:
class _CallableAdapter(httputil.HTTPMessageDelegate):
- def __init__(self, request_callback: Callable[[httputil.HTTPServerRequest], None],
- request_conn: httputil.HTTPConnection) -> None:
+ def __init__(
+ self,
+ request_callback: Callable[[httputil.HTTPServerRequest], None],
+ request_conn: httputil.HTTPConnection,
+ ) -> None:
self.connection = request_conn
self.request_callback = request_callback
self.request = None # type: Optional[httputil.HTTPServerRequest]
self.delegate = None
self._chunks = [] # type: List[bytes]
- def headers_received(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]:
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
self.request = httputil.HTTPServerRequest(
connection=self.connection,
start_line=typing.cast(httputil.RequestStartLine, start_line),
- headers=headers)
+ headers=headers,
+ )
return None
def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
def finish(self) -> None:
assert self.request is not None
- self.request.body = b''.join(self._chunks)
+ self.request.body = b"".join(self._chunks)
self.request._parse_body()
self.request_callback(self.request)
class _HTTPRequestContext(object):
- def __init__(self, stream: iostream.IOStream, address: Tuple,
- protocol: Optional[str],
- trusted_downstream: List[str]=None) -> None:
+ def __init__(
+ self,
+ stream: iostream.IOStream,
+ address: Tuple,
+ protocol: Optional[str],
+ trusted_downstream: List[str] = None,
+ ) -> None:
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
else:
self.address_family = None
# In HTTPServerRequest we want an IP, not a full socket address.
- if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
- address is not None):
+ if (
+ self.address_family in (socket.AF_INET, socket.AF_INET6)
+ and address is not None
+ ):
self.remote_ip = address[0]
else:
# Unix (or other) socket; fake the remote address.
- self.remote_ip = '0.0.0.0'
+ self.remote_ip = "0.0.0.0"
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
# Skip trusted downstream hosts in X-Forwarded-For list
- for ip in (cand.strip() for cand in reversed(ip.split(','))):
+ for ip in (cand.strip() for cand in reversed(ip.split(","))):
if ip not in self.trusted_downstream:
break
ip = headers.get("X-Real-Ip", ip)
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto_header = headers.get(
- "X-Scheme", headers.get("X-Forwarded-Proto",
- self.protocol))
+ "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol)
+ )
if proto_header:
# use only the last proto entry if there is more than one
# TODO: support trusting mutiple layers of proxied protocol
- proto_header = proto_header.split(',')[-1].strip()
+ proto_header = proto_header.split(",")[-1].strip()
if proto_header in ("http", "https"):
self.protocol = proto_header
class _ProxyAdapter(httputil.HTTPMessageDelegate):
- def __init__(self, delegate: httputil.HTTPMessageDelegate,
- request_conn: httputil.HTTPConnection) -> None:
+ def __init__(
+ self,
+ delegate: httputil.HTTPMessageDelegate,
+ request_conn: httputil.HTTPConnection,
+ ) -> None:
self.connection = request_conn
self.delegate = delegate
- def headers_received(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]:
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
# TODO: either make context an official part of the
# HTTPConnection interface or figure out some other way to do this.
self.connection.context._apply_xheaders(headers) # type: ignore
responses
import typing
-from typing import (Tuple, Iterable, List, Mapping, Iterator, Dict, Union, Optional,
- Awaitable, Generator, AnyStr)
+from typing import (
+ Tuple,
+ Iterable,
+ List,
+ Mapping,
+ Iterator,
+ Dict,
+ Union,
+ Optional,
+ Awaitable,
+ Generator,
+ AnyStr,
+)
if typing.TYPE_CHECKING:
from typing import Deque # noqa
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
# terminator and ignore any preceding CR.
-_CRLF_RE = re.compile(r'\r?\n')
+_CRLF_RE = re.compile(r"\r?\n")
class _NormalizedHeaderCache(dict):
>>> normalized_headers["coNtent-TYPE"]
'Content-Type'
"""
+
def __init__(self, size: int) -> None:
super(_NormalizedHeaderCache, self).__init__()
self.size = size
Set-Cookie: A=B
Set-Cookie: C=D
"""
+
@typing.overload
def __init__(self, __arg: Mapping[str, List[str]]) -> None:
pass
self._dict = {} # type: typing.Dict[str, str]
self._as_list = {} # type: typing.Dict[str, typing.List[str]]
self._last_key = None
- if (len(args) == 1 and len(kwargs) == 0 and
- isinstance(args[0], HTTPHeaders)):
+ if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], HTTPHeaders):
# Copy constructor
for k, v in args[0].get_all():
self.add(k, v)
norm_name = _normalized_headers[name]
self._last_key = norm_name
if norm_name in self:
- self._dict[norm_name] = (native_str(self[norm_name]) + ',' +
- native_str(value))
+ self._dict[norm_name] = (
+ native_str(self[norm_name]) + "," + native_str(value)
+ )
self._as_list[norm_name].append(value)
else:
self[norm_name] = value
# continuation of a multi-line header
if self._last_key is None:
raise HTTPInputError("first header line cannot start with whitespace")
- new_part = ' ' + line.lstrip()
+ new_part = " " + line.lstrip()
self._as_list[self._last_key][-1] += new_part
self._dict[self._last_key] += new_part
else:
self.add(name, value.strip())
@classmethod
- def parse(cls, headers: str) -> 'HTTPHeaders':
+ def parse(cls, headers: str) -> "HTTPHeaders":
"""Returns a dictionary from HTTP header text.
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
def __iter__(self) -> Iterator[typing.Any]:
return iter(self._dict)
- def copy(self) -> 'HTTPHeaders':
+ def copy(self) -> "HTTPHeaders":
# defined in dict but not in MutableMapping.
return HTTPHeaders(self)
.. versionchanged:: 4.0
Moved from ``tornado.httpserver.HTTPRequest``.
"""
+
path = None # type: str
query = None # type: str
# HACK: Used for stream_request_body
_body_future = None # type: Future[None]
- def __init__(self, method: str=None, uri: str=None, version: str="HTTP/1.0",
- headers: HTTPHeaders=None, body: bytes=None, host: str=None,
- files: Dict[str, List['HTTPFile']]=None, connection: 'HTTPConnection'=None,
- start_line: 'RequestStartLine'=None, server_connection: object=None) -> None:
+ def __init__(
+ self,
+ method: str = None,
+ uri: str = None,
+ version: str = "HTTP/1.0",
+ headers: HTTPHeaders = None,
+ body: bytes = None,
+ host: str = None,
+ files: Dict[str, List["HTTPFile"]] = None,
+ connection: "HTTPConnection" = None,
+ start_line: "RequestStartLine" = None,
+ server_connection: object = None,
+ ) -> None:
if start_line is not None:
method, uri, version = start_line
self.method = method
self.body = body or b""
# set remote IP and protocol
- context = getattr(connection, 'context', None)
- self.remote_ip = getattr(context, 'remote_ip', None)
- self.protocol = getattr(context, 'protocol', "http")
+ context = getattr(connection, "context", None)
+ self.remote_ip = getattr(context, "remote_ip", None)
+ self.protocol = getattr(context, "protocol", "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.host_name = split_host_and_port(self.host.lower())[0]
self._finish_time = None
if uri is not None:
- self.path, sep, self.query = uri.partition('?')
+ self.path, sep, self.query = uri.partition("?")
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.body_arguments = {} # type: Dict[str, List[bytes]]
else:
return self._finish_time - self._start_time
- def get_ssl_certificate(self, binary_form: bool=False) -> Union[None, Dict, bytes]:
+ def get_ssl_certificate(
+ self, binary_form: bool = False
+ ) -> Union[None, Dict, bytes]:
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer's
return None
# TODO: add a method to HTTPConnection for this so it can work with HTTP/2
return self.connection.stream.socket.getpeercert( # type: ignore
- binary_form=binary_form)
+ binary_form=binary_form
+ )
except SSLError:
return None
def _parse_body(self) -> None:
parse_body_arguments(
- self.headers.get("Content-Type", ""), self.body,
- self.body_arguments, self.files,
- self.headers)
+ self.headers.get("Content-Type", ""),
+ self.body,
+ self.body_arguments,
+ self.files,
+ self.headers,
+ )
for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)
.. versionadded:: 4.0
"""
+
pass
.. versionadded:: 4.0
"""
+
pass
.. versionadded:: 4.0
"""
- def start_request(self, server_conn: object,
- request_conn: 'HTTPConnection')-> 'HTTPMessageDelegate':
+
+ def start_request(
+ self, server_conn: object, request_conn: "HTTPConnection"
+ ) -> "HTTPMessageDelegate":
"""This method is called by the server when a new request has started.
:arg server_conn: is an opaque object representing the long-lived
.. versionadded:: 4.0
"""
+
# TODO: genericize this class to avoid exposing the Union.
- def headers_received(self, start_line: Union['RequestStartLine', 'ResponseStartLine'],
- headers: HTTPHeaders) -> Optional[Awaitable[None]]:
+ def headers_received(
+ self,
+ start_line: Union["RequestStartLine", "ResponseStartLine"],
+ headers: HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
"""Called when the HTTP headers have been received and parsed.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`
.. versionadded:: 4.0
"""
- def write_headers(self, start_line: Union['RequestStartLine', 'ResponseStartLine'],
- headers: HTTPHeaders, chunk: bytes=None) -> 'Future[None]':
+
+ def write_headers(
+ self,
+ start_line: Union["RequestStartLine", "ResponseStartLine"],
+ headers: HTTPHeaders,
+ chunk: bytes = None,
+ ) -> "Future[None]":
"""Write an HTTP header block.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`.
"""
raise NotImplementedError()
- def write(self, chunk: bytes) -> 'Future[None]':
+ def write(self, chunk: bytes) -> "Future[None]":
"""Writes a chunk of body data.
Returns a future for flow control.
raise NotImplementedError()
-def url_concat(url: str, args: Union[None, Dict[str, str], List[Tuple[str, str]],
- Tuple[Tuple[str, str], ...]]) -> str:
+def url_concat(
+ url: str,
+ args: Union[
+ None, Dict[str, str], List[Tuple[str, str]], Tuple[Tuple[str, str], ...]
+ ],
+) -> str:
"""Concatenate url and arguments regardless of whether
url has existing query parameters.
parsed_query.extend(args)
else:
err = "'args' parameter should be dict, list or tuple. Not {0}".format(
- type(args))
+ type(args)
+ )
raise TypeError(err)
final_query = urlencode(parsed_query)
- url = urlunparse((
- parsed_url[0],
- parsed_url[1],
- parsed_url[2],
- parsed_url[3],
- final_query,
- parsed_url[5]))
+ url = urlunparse(
+ (
+ parsed_url[0],
+ parsed_url[1],
+ parsed_url[2],
+ parsed_url[3],
+ final_query,
+ parsed_url[5],
+ )
+ )
return url
* ``body``
* ``content_type``
"""
+
pass
-def _parse_request_range(range_header: str) -> Optional[Tuple[Optional[int], Optional[int]]]:
+def _parse_request_range(
+ range_header: str
+) -> Optional[Tuple[Optional[int], Optional[int]]]:
"""Parses a Range header.
Returns either ``None`` or tuple ``(start, end)``.
return int(val)
-def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, List[bytes]],
- files: Dict[str, List[HTTPFile]], headers: HTTPHeaders=None) -> None:
+def parse_body_arguments(
+ content_type: str,
+ body: bytes,
+ arguments: Dict[str, List[bytes]],
+ files: Dict[str, List[HTTPFile]],
+ headers: HTTPHeaders = None,
+) -> None:
"""Parses a form request body.
Supports ``application/x-www-form-urlencoded`` and
and ``files`` parameters are dictionaries that will be updated
with the parsed contents.
"""
- if headers and 'Content-Encoding' in headers:
- gen_log.warning("Unsupported Content-Encoding: %s",
- headers['Content-Encoding'])
+ if headers and "Content-Encoding" in headers:
+ gen_log.warning("Unsupported Content-Encoding: %s", headers["Content-Encoding"])
return
if content_type.startswith("application/x-www-form-urlencoded"):
try:
uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True)
except Exception as e:
- gen_log.warning('Invalid x-www-form-urlencoded body: %s', e)
+ gen_log.warning("Invalid x-www-form-urlencoded body: %s", e)
uri_arguments = {}
for name, values in uri_arguments.items():
if values:
gen_log.warning("Invalid multipart/form-data: %s", e)
-def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str, List[bytes]],
- files: Dict[str, List[HTTPFile]]) -> None:
+def parse_multipart_form_data(
+ boundary: bytes,
+ data: bytes,
+ arguments: Dict[str, List[bytes]],
+ files: Dict[str, List[HTTPFile]],
+) -> None:
"""Parses a ``multipart/form-data`` body.
The ``boundary`` and ``data`` parameters are both byte strings.
if disposition != "form-data" or not part.endswith(b"\r\n"):
gen_log.warning("Invalid multipart/form-data")
continue
- value = part[eoh + 4:-2]
+ value = part[eoh + 4 : -2]
if not disp_params.get("name"):
gen_log.warning("multipart/form-data value missing name")
continue
name = disp_params["name"]
if disp_params.get("filename"):
ctype = headers.get("Content-Type", "application/unknown")
- files.setdefault(name, []).append(HTTPFile(
- filename=disp_params["filename"], body=value,
- content_type=ctype))
+ files.setdefault(name, []).append(
+ HTTPFile(
+ filename=disp_params["filename"], body=value, content_type=ctype
+ )
+ )
else:
arguments.setdefault(name, []).append(value)
-def format_timestamp(ts: Union[int, float, tuple, time.struct_time, datetime.datetime]) -> str:
+def format_timestamp(
+ ts: Union[int, float, tuple, time.struct_time, datetime.datetime]
+) -> str:
"""Formats a timestamp in the format used by HTTP.
The argument may be a numeric timestamp as returned by `time.time`,
RequestStartLine = collections.namedtuple(
- 'RequestStartLine', ['method', 'path', 'version'])
+ "RequestStartLine", ["method", "path", "version"]
+)
def parse_request_start_line(line: str) -> RequestStartLine:
raise HTTPInputError("Malformed HTTP request line")
if not re.match(r"^HTTP/1\.[0-9]$", version):
raise HTTPInputError(
- "Malformed HTTP version in HTTP Request-Line: %r" % version)
+ "Malformed HTTP version in HTTP Request-Line: %r" % version
+ )
return RequestStartLine(method, path, version)
ResponseStartLine = collections.namedtuple(
- 'ResponseStartLine', ['version', 'code', 'reason'])
+ "ResponseStartLine", ["version", "code", "reason"]
+)
def parse_response_start_line(line: str) -> ResponseStartLine:
match = re.match("(HTTP/1.[0-9]) ([0-9]+) ([^\r]*)", line)
if not match:
raise HTTPInputError("Error parsing response start line")
- return ResponseStartLine(match.group(1), int(match.group(2)),
- match.group(3))
+ return ResponseStartLine(match.group(1), int(match.group(2)), match.group(3))
+
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
def _parseparam(s: str) -> Generator[str, None, None]:
- while s[:1] == ';':
+ while s[:1] == ";":
s = s[1:]
- end = s.find(';')
+ end = s.find(";")
while end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2:
- end = s.find(';', end + 1)
+ end = s.find(";", end + 1)
if end < 0:
end = len(s)
f = s[:end]
>>> d['foo']
'b\\a"r'
"""
- parts = _parseparam(';' + line)
+ parts = _parseparam(";" + line)
key = next(parts)
# decode_params treats first argument special, but we already stripped key
- params = [('Dummy', 'value')]
+ params = [("Dummy", "value")]
for p in parts:
- i = p.find('=')
+ i = p.find("=")
if i >= 0:
name = p[:i].strip().lower()
- value = p[i + 1:].strip()
+ value = p[i + 1 :].strip()
params.append((name, native_str(value)))
decoded_params = email.utils.decode_params(params)
decoded_params.pop(0) # get rid of the dummy again
out.append(k)
else:
# TODO: quote if necessary.
- out.append('%s=%s' % (k, v))
- return '; '.join(out)
+ out.append("%s=%s" % (k, v))
+ return "; ".join(out)
-def encode_username_password(username: Union[str, bytes], password: Union[str, bytes]) -> bytes:
+def encode_username_password(
+ username: Union[str, bytes], password: Union[str, bytes]
+) -> bytes:
"""Encodes a username/password pair in the format used by HTTP auth.
The return value is a byte string in the form ``username:password``.
.. versionadded:: 5.1
"""
if isinstance(username, unicode_type):
- username = unicodedata.normalize('NFC', username)
+ username = unicodedata.normalize("NFC", username)
if isinstance(password, unicode_type):
- password = unicodedata.normalize('NFC', password)
+ password = unicodedata.normalize("NFC", password)
return utf8(username) + b":" + utf8(password)
def doctests():
# type: () -> unittest.TestSuite
import doctest
+
return doctest.DocTestSuite()
.. versionadded:: 4.1
"""
- match = re.match(r'^(.+):(\d+)$', netloc)
+ match = re.match(r"^(.+):(\d+)$", netloc)
if match:
host = match.group(1)
port = int(match.group(2)) # type: Optional[int]
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
_QuotePatt = re.compile(r"[\\].")
-_nulljoin = ''.join
+_nulljoin = "".join
def _unquote_cookie(s: str) -> str:
while 0 <= i < n:
o_match = _OctalPatt.search(s, i)
q_match = _QuotePatt.search(s, i)
- if not o_match and not q_match: # Neither matched
+ if not o_match and not q_match: # Neither matched
res.append(s[i:])
break
# else:
j = o_match.start(0)
if q_match:
k = q_match.start(0)
- if q_match and (not o_match or k < j): # QuotePatt matched
+ if q_match and (not o_match or k < j): # QuotePatt matched
res.append(s[i:k])
res.append(s[k + 1])
i = k + 2
- else: # OctalPatt matched
+ else: # OctalPatt matched
res.append(s[i:j])
- res.append(chr(int(s[j + 1:j + 4], 8)))
+ res.append(chr(int(s[j + 1 : j + 4], 8)))
i = j + 4
return _nulljoin(res)
.. versionadded:: 4.4.2
"""
cookiedict = {}
- for chunk in cookie.split(str(';')):
- if str('=') in chunk:
- key, val = chunk.split(str('='), 1)
+ for chunk in cookie.split(str(";")):
+ if str("=") in chunk:
+ key, val = chunk.split(str("="), 1)
else:
# Assume an empty name per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
- key, val = str(''), chunk
+ key, val = str(""), chunk
key, val = key.strip(), val.strip()
if key or val:
# unquote using Python's algorithm.
import math
import random
-from tornado.concurrent import Future, is_future, chain_future, future_set_exc_info, future_add_done_callback # noqa: E501
+from tornado.concurrent import (
+ Future,
+ is_future,
+ chain_future,
+ future_set_exc_info,
+ future_add_done_callback,
+) # noqa: E501
from tornado.log import app_log
from tornado.util import Configurable, TimeoutError, import_object
import typing
from typing import Union, Any, Type, Optional, Callable, TypeVar, Tuple, Awaitable
+
if typing.TYPE_CHECKING:
from typing import Dict, List # noqa: F401
pass
-_T = TypeVar('_T')
-_S = TypeVar('_S', bound=_Selectable)
+_T = TypeVar("_T")
+_S = TypeVar("_S", bound=_Selectable)
class IOLoop(Configurable):
to redundantly specify the `asyncio` event loop.
"""
+
# These constants were originally based on constants from the epoll module.
NONE = 0
READ = 0x001
_ioloop_for_asyncio = dict() # type: Dict[asyncio.AbstractEventLoop, IOLoop]
@classmethod
- def configure(cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any) -> None:
+ def configure(
+ cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any
+ ) -> None:
if asyncio is not None:
from tornado.platform.asyncio import BaseAsyncIOLoop
impl = import_object(impl)
if isinstance(impl, type) and not issubclass(impl, BaseAsyncIOLoop):
raise RuntimeError(
- "only AsyncIOLoop is allowed when asyncio is available")
+ "only AsyncIOLoop is allowed when asyncio is available"
+ )
super(IOLoop, cls).configure(impl, **kwargs)
@staticmethod
- def instance() -> 'IOLoop':
+ def instance() -> "IOLoop":
"""Deprecated alias for `IOLoop.current()`.
.. versionchanged:: 5.0
@typing.overload
@staticmethod
- def current() -> 'IOLoop':
+ def current() -> "IOLoop":
pass
@typing.overload # noqa: F811
@staticmethod
- def current(instance: bool=True) -> Optional['IOLoop']:
+ def current(instance: bool = True) -> Optional["IOLoop"]:
pass
@staticmethod # noqa: F811
- def current(instance: bool=True) -> Optional['IOLoop']:
+ def current(instance: bool = True) -> Optional["IOLoop"]:
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as
except KeyError:
if instance:
from tornado.platform.asyncio import AsyncIOMainLoop
+
current = AsyncIOMainLoop(make_current=True) # type: Optional[IOLoop]
else:
current = None
@classmethod
def configurable_default(cls) -> Type[Configurable]:
from tornado.platform.asyncio import AsyncIOLoop
+
return AsyncIOLoop
- def initialize(self, make_current: bool=None) -> None:
+ def initialize(self, make_current: bool = None) -> None:
if make_current is None:
if IOLoop.current(instance=False) is None:
self.make_current()
raise RuntimeError("current IOLoop already exists")
self.make_current()
- def close(self, all_fds: bool=False) -> None:
+ def close(self, all_fds: bool = False) -> None:
"""Closes the `IOLoop`, freeing any resources used.
If ``all_fds`` is true, all file descriptors registered on the
raise NotImplementedError()
@typing.overload
- def add_handler(self, fd: int, handler: Callable[[int, int], None], events: int) -> None:
+ def add_handler(
+ self, fd: int, handler: Callable[[int, int], None], events: int
+ ) -> None:
pass
@typing.overload # noqa: F811
- def add_handler(self, fd: _S, handler: Callable[[_S, int], None], events: int) -> None:
+ def add_handler(
+ self, fd: _S, handler: Callable[[_S, int], None], events: int
+ ) -> None:
pass
- def add_handler(self, fd: Union[int, _Selectable], # noqa: F811
- handler: Callable[..., None], events: int) -> None:
+ def add_handler( # noqa: F811
+ self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int
+ ) -> None:
"""Registers the given handler to receive the given events for ``fd``.
The ``fd`` argument may either be an integer file descriptor or
This method should be called from start() in subclasses.
"""
- if not any([logging.getLogger().handlers,
- logging.getLogger('tornado').handlers,
- logging.getLogger('tornado.application').handlers]):
+ if not any(
+ [
+ logging.getLogger().handlers,
+ logging.getLogger("tornado").handlers,
+ logging.getLogger("tornado.application").handlers,
+ ]
+ ):
logging.basicConfig()
def stop(self) -> None:
"""
raise NotImplementedError()
- def run_sync(self, func: Callable, timeout: float=None) -> Any:
+ def run_sync(self, func: Callable, timeout: float = None) -> Any:
"""Starts the `IOLoop`, runs the given function, and stops the loop.
The function must return either an awaitable object or
result = func()
if result is not None:
from tornado.gen import convert_yielded
+
result = convert_yielded(result)
except Exception:
fut = Future() # type: Future[Any]
fut.set_result(result)
assert future_cell[0] is not None
self.add_future(future_cell[0], lambda future: self.stop())
+
self.add_callback(run)
if timeout is not None:
+
def timeout_callback() -> None:
# If we can cancel the future, do so and wait on it. If not,
# Just stop the loop and return with the task still pending.
assert future_cell[0] is not None
if not future_cell[0].cancel():
self.stop()
+
timeout_handle = self.add_timeout(self.time() + timeout, timeout_callback)
self.start()
if timeout is not None:
self.remove_timeout(timeout_handle)
assert future_cell[0] is not None
if future_cell[0].cancelled() or not future_cell[0].done():
- raise TimeoutError('Operation timed out after %s seconds' % timeout)
+ raise TimeoutError("Operation timed out after %s seconds" % timeout)
return future_cell[0].result()
def time(self) -> float:
"""
return time.time()
- def add_timeout(self, deadline: Union[float, datetime.timedelta],
- callback: Callable[..., None],
- *args: Any, **kwargs: Any) -> object:
+ def add_timeout(
+ self,
+ deadline: Union[float, datetime.timedelta],
+ callback: Callable[..., None],
+ *args: Any,
+ **kwargs: Any
+ ) -> object:
"""Runs the ``callback`` at the time ``deadline`` from the I/O loop.
Returns an opaque handle that may be passed to
if isinstance(deadline, numbers.Real):
return self.call_at(deadline, callback, *args, **kwargs)
elif isinstance(deadline, datetime.timedelta):
- return self.call_at(self.time() + deadline.total_seconds(),
- callback, *args, **kwargs)
+ return self.call_at(
+ self.time() + deadline.total_seconds(), callback, *args, **kwargs
+ )
else:
raise TypeError("Unsupported deadline %r" % deadline)
- def call_later(self, delay: float, callback: Callable[..., None],
- *args: Any, **kwargs: Any) -> object:
+ def call_later(
+ self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+ ) -> object:
"""Runs the ``callback`` after ``delay`` seconds have passed.
Returns an opaque handle that may be passed to `remove_timeout`
"""
return self.call_at(self.time() + delay, callback, *args, **kwargs)
- def call_at(self, when: float, callback: Callable[..., None],
- *args: Any, **kwargs: Any) -> object:
+ def call_at(
+ self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+ ) -> object:
"""Runs the ``callback`` at the absolute time designated by ``when``.
``when`` must be a number using the same reference point as
"""
raise NotImplementedError()
- def add_callback(self, callback: Callable,
- *args: Any, **kwargs: Any) -> None:
+ def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
"""Calls the given callback on the next I/O loop iteration.
It is safe to call this method from any thread at any time,
"""
raise NotImplementedError()
- def add_callback_from_signal(self, callback: Callable,
- *args: Any, **kwargs: Any) -> None:
+ def add_callback_from_signal(
+ self, callback: Callable, *args: Any, **kwargs: Any
+ ) -> None:
"""Calls the given callback on the next I/O loop iteration.
Safe for use from a Python signal handler; should not be used
"""
raise NotImplementedError()
- def spawn_callback(self, callback: Callable,
- *args: Any, **kwargs: Any) -> None:
+ def spawn_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
"""Calls the given callback on the next IOLoop iteration.
As of Tornado 6.0, this method is equivalent to `add_callback`.
"""
self.add_callback(callback, *args, **kwargs)
- def add_future(self, future: Union['Future[_T]', 'concurrent.futures.Future[_T]'],
- callback: Callable[['Future[_T]'], None]) -> None:
+ def add_future(
+ self,
+ future: Union["Future[_T]", "concurrent.futures.Future[_T]"],
+ callback: Callable[["Future[_T]"], None],
+ ) -> None:
"""Schedules a callback on the ``IOLoop`` when the given
`.Future` is finished.
"""
assert is_future(future)
future_add_done_callback(
- future, lambda future: self.add_callback(callback, future))
-
- def run_in_executor(self, executor: Optional[concurrent.futures.Executor],
- func: Callable[..., _T], *args: Any) -> Awaitable[_T]:
+ future, lambda future: self.add_callback(callback, future)
+ )
+
+ def run_in_executor(
+ self,
+ executor: Optional[concurrent.futures.Executor],
+ func: Callable[..., _T],
+ *args: Any
+ ) -> Awaitable[_T]:
"""Runs a function in a ``concurrent.futures.Executor``. If
``executor`` is ``None``, the IO loop's default executor will be used.
.. versionadded:: 5.0
"""
if executor is None:
- if not hasattr(self, '_executor'):
+ if not hasattr(self, "_executor"):
from tornado.process import cpu_count
+
self._executor = concurrent.futures.ThreadPoolExecutor(
- max_workers=(cpu_count() * 5)) # type: concurrent.futures.Executor
+ max_workers=(cpu_count() * 5)
+ ) # type: concurrent.futures.Executor
executor = self._executor
c_future = executor.submit(func, *args)
# Concurrent Futures are not usable with await. Wrap this in a
ret = callback()
if ret is not None:
from tornado import gen
+
# Functions that return Futures typically swallow all
# exceptions and store them in the Future. If a Future
# makes it out to the IOLoop, ensure its exception (if any)
"""Avoid unhandled-exception warnings from spawned coroutines."""
future.result()
- def split_fd(self, fd: Union[int, _Selectable]) -> Tuple[int, Union[int, _Selectable]]:
+ def split_fd(
+ self, fd: Union[int, _Selectable]
+ ) -> Tuple[int, Union[int, _Selectable]]:
"""Returns an (fd, obj) pair from an ``fd`` parameter.
We accept both raw file descriptors and file-like objects as
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
- __slots__ = ['deadline', 'callback', 'tdeadline']
+ __slots__ = ["deadline", "callback", "tdeadline"]
- def __init__(self, deadline: float, callback: Callable[[], None],
- io_loop: IOLoop) -> None:
+ def __init__(
+ self, deadline: float, callback: Callable[[], None], io_loop: IOLoop
+ ) -> None:
if not isinstance(deadline, numbers.Real):
raise TypeError("Unsupported deadline %r" % deadline)
self.deadline = deadline
self.callback = callback
- self.tdeadline = (deadline, next(io_loop._timeout_counter)) # type: Tuple[float, int]
+ self.tdeadline = (
+ deadline,
+ next(io_loop._timeout_counter),
+ ) # type: Tuple[float, int]
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
- def __lt__(self, other: '_Timeout') -> bool:
+ def __lt__(self, other: "_Timeout") -> bool:
return self.tdeadline < other.tdeadline
- def __le__(self, other: '_Timeout') -> bool:
+ def __le__(self, other: "_Timeout") -> bool:
return self.tdeadline <= other.tdeadline
.. versionchanged:: 5.1
The ``jitter`` argument is added.
"""
- def __init__(self, callback: Callable[[], None],
- callback_time: float, jitter: float=0) -> None:
+
+ def __init__(
+ self, callback: Callable[[], None], callback_time: float, jitter: float = 0
+ ) -> None:
self.callback = callback
if callback_time <= 0:
raise ValueError("Periodic callback must have a positive callback_time")
# to the start of the next. If one call takes too long,
# skip cycles to get back to a multiple of the original
# schedule.
- self._next_timeout += (math.floor((current_time - self._next_timeout) /
- callback_time_sec) + 1) * callback_time_sec
+ self._next_timeout += (
+ math.floor((current_time - self._next_timeout) / callback_time_sec) + 1
+ ) * callback_time_sec
else:
# If the clock moved backwards, ensure we advance the next
# timeout instead of recomputing the same value again.
from tornado.util import errno_from_exception
import typing
-from typing import Union, Optional, Awaitable, Callable, Type, Pattern, Any, Dict, TypeVar, Tuple
+from typing import (
+ Union,
+ Optional,
+ Awaitable,
+ Callable,
+ Type,
+ Pattern,
+ Any,
+ Dict,
+ TypeVar,
+ Tuple,
+)
from types import TracebackType
+
if typing.TYPE_CHECKING:
from typing import Deque, List # noqa: F401
-_IOStreamType = TypeVar('_IOStreamType', bound='IOStream')
+_IOStreamType = TypeVar("_IOStreamType", bound="IOStream")
try:
from tornado.platform.posix import _set_nonblocking
# These errnos indicate that a connection has been abruptly terminated.
# They should be caught and handled less noisily than other errors.
-_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
- errno.ETIMEDOUT)
+_ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE, errno.ETIMEDOUT)
if hasattr(errno, "WSAECONNRESET"):
- _ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore # noqa: E501
+ _ERRNO_CONNRESET += ( # type: ignore
+ errno.WSAECONNRESET, # type: ignore
+ errno.WSAECONNABORTED, # type: ignore
+ errno.WSAETIMEDOUT, # type: ignore
+ )
-if sys.platform == 'darwin':
+if sys.platform == "darwin":
# OSX appears to have a race condition that causes send(2) to return
# EPROTOTYPE if called while a socket is being torn down:
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
if hasattr(errno, "WSAEINPROGRESS"):
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore
-_WINDOWS = sys.platform.startswith('win')
+_WINDOWS = sys.platform.startswith("win")
class StreamClosedError(IOError):
.. versionchanged:: 4.3
Added the ``real_error`` attribute.
"""
- def __init__(self, real_error: BaseException=None) -> None:
- super(StreamClosedError, self).__init__('Stream is closed')
+
+ def __init__(self, real_error: BaseException = None) -> None:
+ super(StreamClosedError, self).__init__("Stream is closed")
self.real_error = real_error
Raised by ``read_until`` and ``read_until_regex`` with a ``max_bytes``
argument.
"""
+
pass
def __init__(self) -> None:
# A sequence of (False, bytearray) and (True, memoryview) objects
- self._buffers = collections.deque() \
- # type: Deque[Tuple[bool, Union[bytearray, memoryview]]]
+ self._buffers = (
+ collections.deque()
+ ) # type: Deque[Tuple[bool, Union[bytearray, memoryview]]]
# Position in the first buffer
self._first_pos = 0
self._size = 0
try:
is_memview, b = self._buffers[0]
except IndexError:
- return memoryview(b'')
+ return memoryview(b"")
pos = self._first_pos
if is_memview:
- return typing.cast(memoryview, b[pos:pos + size])
+ return typing.cast(memoryview, b[pos : pos + size])
else:
- return memoryview(b)[pos:pos + size]
+ return memoryview(b)[pos : pos + size]
def advance(self, size: int) -> None:
"""
`read_from_fd`, and optionally `get_fd_error`.
"""
- def __init__(self, max_buffer_size: int=None,
- read_chunk_size: int=None, max_write_buffer_size: int=None) -> None:
+
+ def __init__(
+ self,
+ max_buffer_size: int = None,
+ read_chunk_size: int = None,
+ max_write_buffer_size: int = None,
+ ) -> None:
"""`BaseIOStream` constructor.
:arg max_buffer_size: Maximum amount of incoming data to buffer;
self.max_buffer_size = max_buffer_size or 104857600
# A chunk size that is too close to max_buffer_size can cause
# spurious failures.
- self.read_chunk_size = min(read_chunk_size or 65536,
- self.max_buffer_size // 2)
+ self.read_chunk_size = min(read_chunk_size or 65536, self.max_buffer_size // 2)
self.max_write_buffer_size = max_write_buffer_size
self.error = None # type: Optional[BaseException]
self._read_buffer = bytearray()
self._read_partial = False
self._read_until_close = False
self._read_future = None # type: Optional[Future]
- self._write_futures = collections.deque() # type: Deque[Tuple[int, Future[None]]]
+ self._write_futures = (
+ collections.deque()
+ ) # type: Deque[Tuple[int, Future[None]]]
self._close_callback = None # type: Optional[Callable[[], None]]
self._connect_future = None # type: Optional[Future[IOStream]]
# _ssl_connect_future should be defined in SSLIOStream
"""
return None
- def read_until_regex(self, regex: bytes, max_bytes: int=None) -> Awaitable[bytes]:
+ def read_until_regex(self, regex: bytes, max_bytes: int = None) -> Awaitable[bytes]:
"""Asynchronously read until we have matched the given regex.
The result includes the data that matches the regex and anything
raise
return future
- def read_until(self, delimiter: bytes, max_bytes: int=None) -> Awaitable[bytes]:
+ def read_until(self, delimiter: bytes, max_bytes: int = None) -> Awaitable[bytes]:
"""Asynchronously read until we have found the given delimiter.
The result includes all the data read including the delimiter.
raise
return future
- def read_bytes(self, num_bytes: int, partial: bool=False) -> Awaitable[bytes]:
+ def read_bytes(self, num_bytes: int, partial: bool = False) -> Awaitable[bytes]:
"""Asynchronously read a number of bytes.
If ``partial`` is true, data is returned as soon as we have
raise
return future
- def read_into(self, buf: bytearray, partial: bool=False) -> Awaitable[int]:
+ def read_into(self, buf: bytearray, partial: bool = False) -> Awaitable[int]:
"""Asynchronously read a number of bytes.
``buf`` must be a writable buffer into which data will be read.
n = len(buf)
if available_bytes >= n:
end = self._read_buffer_pos + n
- buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos:end]
+ buf[:] = memoryview(self._read_buffer)[self._read_buffer_pos : end]
del self._read_buffer[:end]
self._after_user_read_buffer = self._read_buffer
elif available_bytes > 0:
- buf[:available_bytes] = memoryview(self._read_buffer)[self._read_buffer_pos:]
+ buf[:available_bytes] = memoryview(self._read_buffer)[
+ self._read_buffer_pos :
+ ]
# Set up the supplied buffer as our temporary read buffer.
# The original (if it had any data remaining) has been
raise
return future
- def write(self, data: Union[bytes, memoryview]) -> 'Future[None]':
+ def write(self, data: Union[bytes, memoryview]) -> "Future[None]":
"""Asynchronously write the given data to this stream.
This method returns a `.Future` that resolves (with a result
"""
self._check_closed()
if data:
- if (self.max_write_buffer_size is not None and
- len(self._write_buffer) + len(data) > self.max_write_buffer_size):
+ if (
+ self.max_write_buffer_size is not None
+ and len(self._write_buffer) + len(data) > self.max_write_buffer_size
+ ):
raise StreamBufferFullError("Reached maximum write buffer size")
self._write_buffer.append(data)
self._total_write_index += len(data)
self._close_callback = callback
self._maybe_add_error_listener()
- def close(self, exc_info: Union[None, bool, BaseException,
- Tuple[Optional[Type[BaseException]],
- Optional[BaseException],
- Optional[TracebackType]]]=False) -> None:
+ def close(
+ self,
+ exc_info: Union[
+ None,
+ bool,
+ BaseException,
+ Tuple[
+ Optional[Type[BaseException]],
+ Optional[BaseException],
+ Optional[TracebackType],
+ ],
+ ] = False,
+ ) -> None:
"""Close this stream.
If ``exc_info`` is true, set the ``error`` attribute to the current
# yet anyway, so we don't need to listen in this case.
state |= self.io_loop.READ
if state != self._state:
- assert self._state is not None, \
- "shouldn't happen: _handle_events without self._state"
+ assert (
+ self._state is not None
+ ), "shouldn't happen: _handle_events without self._state"
self._state = state
self.io_loop.update_handler(self.fileno(), self._state)
except UnsatisfiableReadError as e:
gen_log.info("Unsatisfiable read, closing connection: %s" % e)
self.close(exc_info=e)
except Exception as e:
- gen_log.error("Uncaught exception, closing connection.",
- exc_info=True)
+ gen_log.error("Uncaught exception, closing connection.", exc_info=True)
self.close(exc_info=e)
raise
# this loop.
# If we've reached target_bytes, we know we're done.
- if (target_bytes is not None and
- self._read_buffer_size >= target_bytes):
+ if target_bytes is not None and self._read_buffer_size >= target_bytes:
break
# Otherwise, we need to call the more expensive find_read_pos.
while True:
try:
if self._user_read_buffer:
- buf = memoryview(self._read_buffer)[self._read_buffer_size:] \
- # type: Union[memoryview, bytearray]
+ buf = memoryview(self._read_buffer)[
+ self._read_buffer_size :
+ ] # type: Union[memoryview, bytearray]
else:
buf = bytearray(self.read_chunk_size)
bytes_read = self.read_from_fd(buf)
Returns a position in the buffer if the current read can be satisfied,
or None if it cannot.
"""
- if (self._read_bytes is not None and
- (self._read_buffer_size >= self._read_bytes or
- (self._read_partial and self._read_buffer_size > 0))):
+ if self._read_bytes is not None and (
+ self._read_buffer_size >= self._read_bytes
+ or (self._read_partial and self._read_buffer_size > 0)
+ ):
num_bytes = min(self._read_bytes, self._read_buffer_size)
return num_bytes
elif self._read_delimiter is not None:
# since large merges are relatively expensive and get undone in
# _consume().
if self._read_buffer:
- loc = self._read_buffer.find(self._read_delimiter,
- self._read_buffer_pos)
+ loc = self._read_buffer.find(
+ self._read_delimiter, self._read_buffer_pos
+ )
if loc != -1:
loc -= self._read_buffer_pos
delimiter_len = len(self._read_delimiter)
- self._check_max_bytes(self._read_delimiter,
- loc + delimiter_len)
+ self._check_max_bytes(self._read_delimiter, loc + delimiter_len)
return loc + delimiter_len
- self._check_max_bytes(self._read_delimiter,
- self._read_buffer_size)
+ self._check_max_bytes(self._read_delimiter, self._read_buffer_size)
elif self._read_regex is not None:
if self._read_buffer:
- m = self._read_regex.search(self._read_buffer,
- self._read_buffer_pos)
+ m = self._read_regex.search(self._read_buffer, self._read_buffer_pos)
if m is not None:
loc = m.end() - self._read_buffer_pos
self._check_max_bytes(self._read_regex, loc)
return None
def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None:
- if (self._read_max_bytes is not None and
- size > self._read_max_bytes):
+ if self._read_max_bytes is not None and size > self._read_max_bytes:
raise UnsatisfiableReadError(
- "delimiter %r not found within %d bytes" % (
- delimiter, self._read_max_bytes))
+ "delimiter %r not found within %d bytes"
+ % (delimiter, self._read_max_bytes)
+ )
def _handle_write(self) -> None:
while True:
# Broken pipe errors are usually caused by connection
# reset, and its better to not log EPIPE errors to
# minimize log spam
- gen_log.warning("Write error on %s: %s",
- self.fileno(), e)
+ gen_log.warning("Write error on %s: %s", self.fileno(), e)
self.close(exc_info=e)
return
return b""
assert loc <= self._read_buffer_size
# Slice the bytearray buffer into bytes, without intermediate copying
- b = (memoryview(self._read_buffer)
- [self._read_buffer_pos:self._read_buffer_pos + loc]
- ).tobytes()
+ b = (
+ memoryview(self._read_buffer)[
+ self._read_buffer_pos : self._read_buffer_pos + loc
+ ]
+ ).tobytes()
self._read_buffer_pos += loc
self._read_buffer_size -= loc
# Amortized O(1) shrink
# (this heuristic is implemented natively in Python 3.4+
# but is replicated here for Python 2)
if self._read_buffer_pos > self._read_buffer_size:
- del self._read_buffer[:self._read_buffer_pos]
+ del self._read_buffer[: self._read_buffer_pos]
self._read_buffer_pos = 0
return b
# immediately anyway. Instead, we insert checks at various times to
# see if the connection is idle and add the read listener then.
if self._state is None or self._state == ioloop.IOLoop.ERROR:
- if (not self.closed() and
- self._read_buffer_size == 0 and
- self._close_callback is not None):
+ if (
+ not self.closed()
+ and self._read_buffer_size == 0
+ and self._close_callback is not None
+ ):
self._add_io_state(ioloop.IOLoop.READ)
def _add_io_state(self, state: int) -> None:
return
if self._state is None:
self._state = ioloop.IOLoop.ERROR | state
- self.io_loop.add_handler(
- self.fileno(), self._handle_events, self._state)
+ self.io_loop.add_handler(self.fileno(), self._handle_events, self._state)
elif not self._state & state:
self._state = self._state | state
self.io_loop.update_handler(self.fileno(), self._state)
May be overridden in subclasses.
"""
- return (isinstance(exc, (socket.error, IOError)) and
- errno_from_exception(exc) in _ERRNO_CONNRESET)
+ return (
+ isinstance(exc, (socket.error, IOError))
+ and errno_from_exception(exc) in _ERRNO_CONNRESET
+ )
class IOStream(BaseIOStream):
:hide:
"""
+
def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None:
self.socket = socket
self.socket.setblocking(False)
self.socket = None # type: ignore
def get_fd_error(self) -> Optional[Exception]:
- errno = self.socket.getsockopt(socket.SOL_SOCKET,
- socket.SO_ERROR)
+ errno = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return socket.error(errno, os.strerror(errno))
def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
# See https://github.com/tornadoweb/tornado/pull/2008
del data
- def connect(self: _IOStreamType, address: tuple,
- server_hostname: str=None) -> 'Future[_IOStreamType]':
+ def connect(
+ self: _IOStreamType, address: tuple, server_hostname: str = None
+ ) -> "Future[_IOStreamType]":
"""Connects the socket to a remote address without blocking.
May only be called if the socket passed to the constructor was
"""
self._connecting = True
future = Future() # type: Future[_IOStreamType]
- self._connect_future = typing.cast('Future[IOStream]', future)
+ self._connect_future = typing.cast("Future[IOStream]", future)
try:
self.socket.connect(address)
except socket.error as e:
# returned immediately when attempting to connect to
# localhost, so handle them the same way as an error
# reported later in _handle_connect.
- if (errno_from_exception(e) not in _ERRNO_INPROGRESS and
- errno_from_exception(e) not in _ERRNO_WOULDBLOCK):
+ if (
+ errno_from_exception(e) not in _ERRNO_INPROGRESS
+ and errno_from_exception(e) not in _ERRNO_WOULDBLOCK
+ ):
if future is None:
- gen_log.warning("Connect error on fd %s: %s",
- self.socket.fileno(), e)
+ gen_log.warning(
+ "Connect error on fd %s: %s", self.socket.fileno(), e
+ )
self.close(exc_info=e)
return future
self._add_io_state(self.io_loop.WRITE)
return future
- def start_tls(self, server_side: bool,
- ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
- server_hostname: str=None) -> Awaitable['SSLIOStream']:
+ def start_tls(
+ self,
+ server_side: bool,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
+ server_hostname: str = None,
+ ) -> Awaitable["SSLIOStream"]:
"""Convert this `IOStream` to an `SSLIOStream`.
This enables protocols that begin in clear-text mode and
``ssl_options=dict(cert_reqs=ssl.CERT_NONE)`` or a
suitably-configured `ssl.SSLContext` to disable.
"""
- if (self._read_future or
- self._write_futures or
- self._connect_future or
- self._closed or
- self._read_buffer or self._write_buffer):
+ if (
+ self._read_future
+ or self._write_futures
+ or self._connect_future
+ or self._closed
+ or self._read_buffer
+ or self._write_buffer
+ ):
raise ValueError("IOStream is not idle; cannot convert to SSL")
if ssl_options is None:
if server_side:
socket = self.socket
self.io_loop.remove_handler(socket)
self.socket = None # type: ignore
- socket = ssl_wrap_socket(socket, ssl_options,
- server_hostname=server_hostname,
- server_side=server_side,
- do_handshake_on_connect=False)
+ socket = ssl_wrap_socket(
+ socket,
+ ssl_options,
+ server_hostname=server_hostname,
+ server_side=server_side,
+ do_handshake_on_connect=False,
+ )
orig_close_callback = self._close_callback
self._close_callback = None
# in that case a connection failure would be handled by the
# error path in _handle_events instead of here.
if self._connect_future is None:
- gen_log.warning("Connect error on fd %s: %s",
- self.socket.fileno(), errno.errorcode[err])
+ gen_log.warning(
+ "Connect error on fd %s: %s",
+ self.socket.fileno(),
+ errno.errorcode[err],
+ )
self.close()
return
if self._connect_future is not None:
self._connecting = False
def set_nodelay(self, value: bool) -> None:
- if (self.socket is not None and
- self.socket.family in (socket.AF_INET, socket.AF_INET6)):
+ if self.socket is not None and self.socket.family in (
+ socket.AF_INET,
+ socket.AF_INET6,
+ ):
try:
- self.socket.setsockopt(socket.IPPROTO_TCP,
- socket.TCP_NODELAY, 1 if value else 0)
+ self.socket.setsockopt(
+ socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 if value else 0
+ )
except socket.error as e:
# Sometimes setsockopt will fail if the socket is closed
# at the wrong time. This can happen with HTTPServer
before constructing the `SSLIOStream`. Unconnected sockets will be
wrapped when `IOStream.connect` is finished.
"""
+
socket = None # type: ssl.SSLSocket
def __init__(self, *args: Any, **kwargs: Any) -> None:
`ssl.SSLContext` object or a dictionary of keywords arguments
for `ssl.wrap_socket`
"""
- self._ssl_options = kwargs.pop('ssl_options', _client_ssl_defaults)
+ self._ssl_options = kwargs.pop("ssl_options", _client_ssl_defaults)
super(SSLIOStream, self).__init__(*args, **kwargs)
self._ssl_accepting = True
self._handshake_reading = False
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
self._handshake_writing = True
return
- elif err.args[0] in (ssl.SSL_ERROR_EOF,
- ssl.SSL_ERROR_ZERO_RETURN):
+ elif err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
return self.close(exc_info=err)
elif err.args[0] == ssl.SSL_ERROR_SSL:
try:
peer = self.socket.getpeername()
except Exception:
- peer = '(not connected)'
- gen_log.warning("SSL Error on %s %s: %s",
- self.socket.fileno(), peer, err)
+ peer = "(not connected)"
+ gen_log.warning(
+ "SSL Error on %s %s: %s", self.socket.fileno(), peer, err
+ )
return self.close(exc_info=err)
raise
except socket.error as err:
# to cause do_handshake to raise EBADF and ENOTCONN, so make
# those errors quiet as well.
# https://groups.google.com/forum/?fromgroups#!topic/python-tornado/ApucKJat1_0
- if (self._is_connreset(err) or
- err.args[0] in (errno.EBADF, errno.ENOTCONN)):
+ if self._is_connreset(err) or err.args[0] in (errno.EBADF, errno.ENOTCONN):
return self.close(exc_info=err)
raise
except AttributeError as err:
the hostname.
"""
if isinstance(self._ssl_options, dict):
- verify_mode = self._ssl_options.get('cert_reqs', ssl.CERT_NONE)
+ verify_mode = self._ssl_options.get("cert_reqs", ssl.CERT_NONE)
elif isinstance(self._ssl_options, ssl.SSLContext):
verify_mode = self._ssl_options.verify_mode
assert verify_mode in (ssl.CERT_NONE, ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL)
return
super(SSLIOStream, self)._handle_write()
- def connect(self, address: Tuple, server_hostname: str=None) -> 'Future[SSLIOStream]':
+ def connect(
+ self, address: Tuple, server_hostname: str = None
+ ) -> "Future[SSLIOStream]":
self._server_hostname = server_hostname
# Ignore the result of connect(). If it fails,
# wait_for_handshake will raise an error too. This is
old_state = self._state
assert old_state is not None
self._state = None
- self.socket = ssl_wrap_socket(self.socket, self._ssl_options,
- server_hostname=self._server_hostname,
- do_handshake_on_connect=False)
+ self.socket = ssl_wrap_socket(
+ self.socket,
+ self._ssl_options,
+ server_hostname=self._server_hostname,
+ do_handshake_on_connect=False,
+ )
self._add_io_state(old_state)
- def wait_for_handshake(self) -> 'Future[SSLIOStream]':
+ def wait_for_handshake(self) -> "Future[SSLIOStream]":
"""Wait for the initial SSL handshake to complete.
If a ``callback`` is given, it will be called with no
one-way, so a `PipeIOStream` can be used for reading or writing but not
both.
"""
+
def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None:
self.fd = fd
self._fio = io.FileIO(self.fd, "r+")
def doctests() -> Any:
import doctest
+
return doctest.DocTestSuite()
CONTEXT_SEPARATOR = "\x04"
-def get(*locale_codes: str) -> 'Locale':
+def get(*locale_codes: str) -> "Locale":
"""Returns the closest match for the given locale codes.
We iterate over all given locale codes in order. If we have a tight
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
-def load_translations(directory: str, encoding: str=None) -> None:
+def load_translations(directory: str, encoding: str = None) -> None:
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
- gen_log.error("Unrecognized locale %r (path: %s)", locale,
- os.path.join(directory, path))
+ gen_log.error(
+ "Unrecognized locale %r (path: %s)",
+ locale,
+ os.path.join(directory, path),
+ )
continue
full_path = os.path.join(directory, path)
if encoding is None:
# Try to autodetect encoding based on the BOM.
- with open(full_path, 'rb') as bf:
+ with open(full_path, "rb") as bf:
data = bf.read(len(codecs.BOM_UTF16_LE))
if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
- encoding = 'utf-16'
+ encoding = "utf-16"
else:
# utf-8-sig is "utf-8 with optional BOM". It's discouraged
# in most cases but is common with CSV files because Excel
# cannot read utf-8 files without a BOM.
- encoding = 'utf-8-sig'
+ encoding = "utf-8-sig"
# python 3: csv.reader requires a file open in text mode.
# Specify an encoding to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding=encoding)
else:
plural = "unknown"
if plural not in ("plural", "singular", "unknown"):
- gen_log.error("Unrecognized plural indicator %r in %s line %d",
- plural, path, i + 1)
+ gen_log.error(
+ "Unrecognized plural indicator %r in %s line %d",
+ plural,
+ path,
+ i + 1,
+ )
continue
_translations[locale].setdefault(plural, {})[english] = translation
f.close()
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
"""
import gettext
+
global _translations
global _supported_locales
global _use_gettext
_translations = {}
for lang in os.listdir(directory):
- if lang.startswith('.'):
+ if lang.startswith("."):
continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)):
continue
try:
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
- _translations[lang] = gettext.translation(domain, directory,
- languages=[lang])
+ _translations[lang] = gettext.translation(
+ domain, directory, languages=[lang]
+ )
except Exception as e:
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
continue
After calling one of `load_translations` or `load_gettext_translations`,
call `get` or `get_closest` to get a Locale object.
"""
+
_cache = {} # type: Dict[str, Locale]
@classmethod
- def get_closest(cls, *locale_codes: str) -> 'Locale':
+ def get_closest(cls, *locale_codes: str) -> "Locale":
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code:
return cls.get(_default_locale)
@classmethod
- def get(cls, code: str) -> 'Locale':
+ def get(cls, code: str) -> "Locale":
"""Returns the Locale for the given locale code.
If it is not supported, we raise an exception.
# Initialize strings for date formatting
_ = self.translate
self._months = [
- _("January"), _("February"), _("March"), _("April"),
- _("May"), _("June"), _("July"), _("August"),
- _("September"), _("October"), _("November"), _("December")]
+ _("January"),
+ _("February"),
+ _("March"),
+ _("April"),
+ _("May"),
+ _("June"),
+ _("July"),
+ _("August"),
+ _("September"),
+ _("October"),
+ _("November"),
+ _("December"),
+ ]
self._weekdays = [
- _("Monday"), _("Tuesday"), _("Wednesday"), _("Thursday"),
- _("Friday"), _("Saturday"), _("Sunday")]
-
- def translate(self, message: str, plural_message: str=None, count: int=None) -> str:
+ _("Monday"),
+ _("Tuesday"),
+ _("Wednesday"),
+ _("Thursday"),
+ _("Friday"),
+ _("Saturday"),
+ _("Sunday"),
+ ]
+
+ def translate(
+ self, message: str, plural_message: str = None, count: int = None
+ ) -> str:
"""Returns the translation for the given message for this locale.
If ``plural_message`` is given, you must also provide
"""
raise NotImplementedError()
- def pgettext(self, context: str, message: str, plural_message: str=None,
- count: int=None) -> str:
+ def pgettext(
+ self, context: str, message: str, plural_message: str = None, count: int = None
+ ) -> str:
raise NotImplementedError()
- def format_date(self, date: Union[int, float, datetime.datetime], gmt_offset: int=0,
- relative: bool=True, shorter: bool=False, full_format: bool=False) -> str:
+ def format_date(
+ self,
+ date: Union[int, float, datetime.datetime],
+ gmt_offset: int = 0,
+ relative: bool = True,
+ shorter: bool = False,
+ full_format: bool = False,
+ ) -> str:
"""Formats the given date (which should be GMT).
By default, we return a relative time (e.g., "2 minutes ago"). You
if not full_format:
if relative and days == 0:
if seconds < 50:
- return _("1 second ago", "%(seconds)d seconds ago",
- seconds) % {"seconds": seconds}
+ return _("1 second ago", "%(seconds)d seconds ago", seconds) % {
+ "seconds": seconds
+ }
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
- return _("1 minute ago", "%(minutes)d minutes ago",
- minutes) % {"minutes": minutes}
+ return _("1 minute ago", "%(minutes)d minutes ago", minutes) % {
+ "minutes": minutes
+ }
hours = round(seconds / (60.0 * 60))
- return _("1 hour ago", "%(hours)d hours ago",
- hours) % {"hours": hours}
+ return _("1 hour ago", "%(hours)d hours ago", hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
- elif days == 1 and local_date.day == local_yesterday.day and \
- relative:
- format = _("yesterday") if shorter else \
- _("yesterday at %(time)s")
+ elif days == 1 and local_date.day == local_yesterday.day and relative:
+ format = _("yesterday") if shorter else _("yesterday at %(time)s")
elif days < 5:
- format = _("%(weekday)s") if shorter else \
- _("%(weekday)s at %(time)s")
+ format = _("%(weekday)s") if shorter else _("%(weekday)s at %(time)s")
elif days < 334: # 11mo, since confusing for same month last year
- format = _("%(month_name)s %(day)s") if shorter else \
- _("%(month_name)s %(day)s at %(time)s")
+ format = (
+ _("%(month_name)s %(day)s")
+ if shorter
+ else _("%(month_name)s %(day)s at %(time)s")
+ )
if format is None:
- format = _("%(month_name)s %(day)s, %(year)s") if shorter else \
- _("%(month_name)s %(day)s, %(year)s at %(time)s")
+ format = (
+ _("%(month_name)s %(day)s, %(year)s")
+ if shorter
+ else _("%(month_name)s %(day)s, %(year)s at %(time)s")
+ )
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
if tfhour_clock:
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
- (u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
- local_date.hour % 12 or 12, local_date.minute)
+ (u"\u4e0a\u5348", u"\u4e0b\u5348")[local_date.hour >= 12],
+ local_date.hour % 12 or 12,
+ local_date.minute,
+ )
else:
str_time = "%d:%02d %s" % (
- local_date.hour % 12 or 12, local_date.minute,
- ("am", "pm")[local_date.hour >= 12])
+ local_date.hour % 12 or 12,
+ local_date.minute,
+ ("am", "pm")[local_date.hour >= 12],
+ )
return format % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
"year": str(local_date.year),
- "time": str_time
+ "time": str_time,
}
- def format_day(self, date: datetime.datetime, gmt_offset: int=0, dow: bool=True) -> bool:
+ def format_day(
+ self, date: datetime.datetime, gmt_offset: int = 0, dow: bool = True
+ ) -> bool:
"""Formats the given date as a day of week.
Example: "Monday, January 22". You can remove the day of week with
return ""
if len(parts) == 1:
return parts[0]
- comma = u' \u0648 ' if self.code.startswith("fa") else u", "
+ comma = u" \u0648 " if self.code.startswith("fa") else u", "
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
+
def __init__(self, code: str, translations: Dict[str, Dict[str, str]]) -> None:
self.translations = translations
super(CSVLocale, self).__init__(code)
- def translate(self, message: str, plural_message: str=None, count: int=None) -> str:
+ def translate(
+ self, message: str, plural_message: str = None, count: int = None
+ ) -> str:
if plural_message is not None:
assert count is not None
if count != 1:
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
- def pgettext(self, context: str, message: str, plural_message: str=None,
- count: int=None) -> str:
+ def pgettext(
+ self, context: str, message: str, plural_message: str = None, count: int = None
+ ) -> str:
if self.translations:
- gen_log.warning('pgettext is not supported by CSVLocale')
+ gen_log.warning("pgettext is not supported by CSVLocale")
return self.translate(message, plural_message, count)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
+
def __init__(self, code: str, translations: gettext.NullTranslations) -> None:
self.ngettext = translations.ngettext
self.gettext = translations.gettext
# calls into self.translate
super(GettextLocale, self).__init__(code)
- def translate(self, message: str, plural_message: str=None, count: int=None) -> str:
+ def translate(
+ self, message: str, plural_message: str = None, count: int = None
+ ) -> str:
if plural_message is not None:
assert count is not None
return self.ngettext(message, plural_message, count)
else:
return self.gettext(message)
- def pgettext(self, context: str, message: str, plural_message: str=None,
- count: int=None) -> str:
+ def pgettext(
+ self, context: str, message: str, plural_message: str = None, count: int = None
+ ) -> str:
"""Allows to set context for translation, accepts plural forms.
Usage example::
"""
if plural_message is not None:
assert count is not None
- msgs_with_ctxt = ("%s%s%s" % (context, CONTEXT_SEPARATOR, message),
- "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
- count)
+ msgs_with_ctxt = (
+ "%s%s%s" % (context, CONTEXT_SEPARATOR, message),
+ "%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
+ count,
+ )
result = self.ngettext(*msgs_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
from typing import Union, Optional, Type, Any, Generator
import typing
+
if typing.TYPE_CHECKING:
from typing import Deque, Set # noqa: F401
-__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
+__all__ = ["Condition", "Event", "Semaphore", "BoundedSemaphore", "Lock"]
class _TimeoutGarbageCollector(object):
yield condition.wait(short_timeout)
print('looping....')
"""
+
def __init__(self) -> None:
self._waiters = collections.deque() # type: Deque[Future]
self._timeouts = 0
self._timeouts += 1
if self._timeouts > 100:
self._timeouts = 0
- self._waiters = collections.deque(
- w for w in self._waiters if not w.done())
+ self._waiters = collections.deque(w for w in self._waiters if not w.done())
class Condition(_TimeoutGarbageCollector):
self.io_loop = ioloop.IOLoop.current()
def __repr__(self) -> str:
- result = '<%s' % (self.__class__.__name__, )
+ result = "<%s" % (self.__class__.__name__,)
if self._waiters:
- result += ' waiters[%s]' % len(self._waiters)
- return result + '>'
+ result += " waiters[%s]" % len(self._waiters)
+ return result + ">"
- def wait(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[bool]':
+ def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[bool]":
"""Wait for `.notify`.
Returns a `.Future` that resolves ``True`` if the condition is notified,
waiter = Future() # type: Future[bool]
self._waiters.append(waiter)
if timeout:
+
def on_timeout() -> None:
if not waiter.done():
future_set_result_unless_cancelled(waiter, False)
self._garbage_collect()
+
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
- waiter.add_done_callback(
- lambda _: io_loop.remove_timeout(timeout_handle))
+ waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
- def notify(self, n: int=1) -> None:
+ def notify(self, n: int = 1) -> None:
"""Wake ``n`` waiters."""
waiters = [] # Waiters we plan to run right now.
while n and self._waiters:
Not waiting this time
Done
"""
+
def __init__(self) -> None:
self._value = False
self._waiters = set() # type: Set[Future[None]]
def __repr__(self) -> str:
- return '<%s %s>' % (
- self.__class__.__name__, 'set' if self.is_set() else 'clear')
+ return "<%s %s>" % (
+ self.__class__.__name__,
+ "set" if self.is_set() else "clear",
+ )
def is_set(self) -> bool:
"""Return ``True`` if the internal flag is true."""
"""
self._value = False
- def wait(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]':
+ def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[None]":
"""Block until the internal flag is true.
Returns a Future, which raises `tornado.util.TimeoutError` after a
if timeout is None:
return fut
else:
- timeout_fut = gen.with_timeout(timeout, fut, quiet_exceptions=(CancelledError,))
+ timeout_fut = gen.with_timeout(
+ timeout, fut, quiet_exceptions=(CancelledError,)
+ )
# This is a slightly clumsy workaround for the fact that
# gen.with_timeout doesn't cancel its futures. Cancelling
# fut will remove it from the waiters list.
- timeout_fut.add_done_callback(lambda tf: fut.cancel() if not fut.done() else None)
+ timeout_fut.add_done_callback(
+ lambda tf: fut.cancel() if not fut.done() else None
+ )
return timeout_fut
# Now semaphore.release() has been called.
"""
+
def __init__(self, obj: Any) -> None:
self._obj = obj
def __enter__(self) -> None:
pass
- def __exit__(self, exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[types.TracebackType]) -> None:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[types.TracebackType],
+ ) -> None:
self._obj.release()
Added ``async with`` support in Python 3.5.
"""
- def __init__(self, value: int=1) -> None:
+
+ def __init__(self, value: int = 1) -> None:
super(Semaphore, self).__init__()
if value < 0:
- raise ValueError('semaphore initial value must be >= 0')
+ raise ValueError("semaphore initial value must be >= 0")
self._value = value
def __repr__(self) -> str:
res = super(Semaphore, self).__repr__()
- extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format(
- self._value)
+ extra = (
+ "locked" if self._value == 0 else "unlocked,value:{0}".format(self._value)
+ )
if self._waiters:
- extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
- return '<{0} [{1}]>'.format(res[1:-1], extra)
+ extra = "{0},waiters:{1}".format(extra, len(self._waiters))
+ return "<{0} [{1}]>".format(res[1:-1], extra)
def release(self) -> None:
"""Increment the counter and wake one waiter."""
break
def acquire(
- self, timeout: Union[float, datetime.timedelta]=None,
- ) -> 'Future[_ReleasingContextManager]':
+ self, timeout: Union[float, datetime.timedelta] = None
+ ) -> "Future[_ReleasingContextManager]":
"""Decrement the counter. Returns a Future.
Block if the counter is zero and wait for a `.release`. The Future
else:
self._waiters.append(waiter)
if timeout:
+
def on_timeout() -> None:
if not waiter.done():
waiter.set_exception(gen.TimeoutError())
self._garbage_collect()
+
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
- lambda _: io_loop.remove_timeout(timeout_handle))
+ lambda _: io_loop.remove_timeout(timeout_handle)
+ )
return waiter
def __enter__(self) -> None:
raise RuntimeError(
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
- " 'with semaphore'")
-
- def __exit__(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- traceback: Optional[types.TracebackType]) -> None:
+ " 'with semaphore'"
+ )
+
+ def __exit__(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> None:
self.__enter__()
@gen.coroutine
def __aenter__(self) -> Generator[Any, Any, None]:
yield self.acquire()
- async def __aexit__(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[types.TracebackType]) -> None:
+ async def __aexit__(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[types.TracebackType],
+ ) -> None:
self.release()
resources with limited capacity, so a semaphore released too many times
is a sign of a bug.
"""
- def __init__(self, value: int=1) -> None:
+
+ def __init__(self, value: int = 1) -> None:
super(BoundedSemaphore, self).__init__(value=value)
self._initial_value = value
Added ``async with`` support in Python 3.5.
"""
+
def __init__(self) -> None:
self._block = BoundedSemaphore(value=1)
def __repr__(self) -> str:
- return "<%s _block=%s>" % (
- self.__class__.__name__,
- self._block)
+ return "<%s _block=%s>" % (self.__class__.__name__, self._block)
def acquire(
- self, timeout: Union[float, datetime.timedelta]=None,
- ) -> 'Future[_ReleasingContextManager]':
+ self, timeout: Union[float, datetime.timedelta] = None
+ ) -> "Future[_ReleasingContextManager]":
"""Attempt to lock. Returns a Future.
Returns a Future, which raises `tornado.util.TimeoutError` after a
try:
self._block.release()
except ValueError:
- raise RuntimeError('release unlocked lock')
+ raise RuntimeError("release unlocked lock")
def __enter__(self) -> None:
- raise RuntimeError(
- "Use Lock like 'with (yield lock)', not like 'with lock'")
-
- def __exit__(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[types.TracebackType]) -> None:
+ raise RuntimeError("Use Lock like 'with (yield lock)', not like 'with lock'")
+
+ def __exit__(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[types.TracebackType],
+ ) -> None:
self.__enter__()
@gen.coroutine
def __aenter__(self) -> Generator[Any, Any, None]:
yield self.acquire()
- async def __aexit__(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[types.TracebackType]) -> None:
+ async def __aexit__(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[types.TracebackType],
+ ) -> None:
self.release()
def _stderr_supports_color() -> bool:
try:
- if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
+ if hasattr(sys.stderr, "isatty") and sys.stderr.isatty():
if curses:
curses.setupterm()
if curses.tigetnum("colors") > 0:
return True
elif colorama:
- if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
- object()):
+ if sys.stderr is getattr(
+ colorama.initialise, "wrapped_stderr", object()
+ ):
return True
except Exception:
# Very broad exception handling because it's always better to
Added support for ``colorama``. Changed the constructor
signature to be compatible with `logging.config.dictConfig`.
"""
- DEFAULT_FORMAT = \
- '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
- DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
+
+ DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # noqa: E501
+ DEFAULT_DATE_FORMAT = "%y%m%d %H:%M:%S"
DEFAULT_COLORS = {
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.ERROR: 1, # Red
}
- def __init__(self, fmt: str=DEFAULT_FORMAT, datefmt: str=DEFAULT_DATE_FORMAT,
- style: str='%', color: bool=True, colors: Dict[int, int]=DEFAULT_COLORS) -> None:
+ def __init__(
+ self,
+ fmt: str = DEFAULT_FORMAT,
+ datefmt: str = DEFAULT_DATE_FORMAT,
+ style: str = "%",
+ color: bool = True,
+ colors: Dict[int, int] = DEFAULT_COLORS,
+ ) -> None:
r"""
:arg bool color: Enables color support.
:arg str fmt: Log message format.
self._colors = {} # type: Dict[int, str]
if color and _stderr_supports_color():
if curses is not None:
- fg_color = (curses.tigetstr("setaf") or
- curses.tigetstr("setf") or b"")
+ fg_color = curses.tigetstr("setaf") or curses.tigetstr("setf") or b""
for levelno, code in colors.items():
# Convert the terminal control characters from
# bytes to unicode strings for easier use with the
# logging module.
- self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
+ self._colors[levelno] = unicode_type(
+ curses.tparm(fg_color, code), "ascii"
+ )
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
else:
# If curses is not present (currently we'll only get here for
# colorama on windows), assume hard-coded ANSI color codes.
for levelno, code in colors.items():
- self._colors[levelno] = '\033[2;3%dm' % code
- self._normal = '\033[0m'
+ self._colors[levelno] = "\033[2;3%dm" % code
+ self._normal = "\033[0m"
else:
- self._normal = ''
+ self._normal = ""
def format(self, record: Any) -> str:
try:
record.color = self._colors[record.levelno]
record.end_color = self._normal
else:
- record.color = record.end_color = ''
+ record.color = record.end_color = ""
formatted = self._fmt % record.__dict__
# each line separately so that non-utf8 bytes don't cause
# all the newlines to turn into '\n'.
lines = [formatted.rstrip()]
- lines.extend(_safe_unicode(ln) for ln in record.exc_text.split('\n'))
- formatted = '\n'.join(lines)
+ lines.extend(_safe_unicode(ln) for ln in record.exc_text.split("\n"))
+ formatted = "\n".join(lines)
return formatted.replace("\n", "\n ")
-def enable_pretty_logging(options: Any=None,
- logger: logging.Logger=None) -> None:
+def enable_pretty_logging(options: Any = None, logger: logging.Logger = None) -> None:
"""Turns on formatted logging output as configured.
This is called automatically by `tornado.options.parse_command_line`
"""
if options is None:
import tornado.options
+
options = tornado.options.options
- if options.logging is None or options.logging.lower() == 'none':
+ if options.logging is None or options.logging.lower() == "none":
return
if logger is None:
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
rotate_mode = options.log_rotate_mode
- if rotate_mode == 'size':
+ if rotate_mode == "size":
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
- backupCount=options.log_file_num_backups) # type: logging.Handler
- elif rotate_mode == 'time':
+ backupCount=options.log_file_num_backups,
+ ) # type: logging.Handler
+ elif rotate_mode == "time":
channel = logging.handlers.TimedRotatingFileHandler(
filename=options.log_file_prefix,
when=options.log_rotate_when,
interval=options.log_rotate_interval,
- backupCount=options.log_file_num_backups)
+ backupCount=options.log_file_num_backups,
+ )
else:
- error_message = 'The value of log_rotate_mode option should be ' +\
- '"size" or "time", not "%s".' % rotate_mode
+ error_message = (
+ "The value of log_rotate_mode option should be "
+ + '"size" or "time", not "%s".' % rotate_mode
+ )
raise ValueError(error_message)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
- if (options.log_to_stderr or
- (options.log_to_stderr is None and not logger.handlers)):
+ if options.log_to_stderr or (options.log_to_stderr is None and not logger.handlers):
# Set up color if we are in a tty and curses is installed
channel = logging.StreamHandler()
channel.setFormatter(LogFormatter())
logger.addHandler(channel)
-def define_logging_options(options: Any=None) -> None:
+def define_logging_options(options: Any = None) -> None:
"""Add logging-related flags to ``options``.
These options are present automatically on the default options instance;
if options is None:
# late import to prevent cycle
import tornado.options
+
options = tornado.options.options
- options.define("logging", default="info",
- help=("Set the Python log level. If 'none', tornado won't touch the "
- "logging configuration."),
- metavar="debug|info|warning|error|none")
- options.define("log_to_stderr", type=bool, default=None,
- help=("Send log output to stderr (colorized if possible). "
- "By default use stderr if --log_file_prefix is not set and "
- "no other logging is configured."))
- options.define("log_file_prefix", type=str, default=None, metavar="PATH",
- help=("Path prefix for log files. "
- "Note that if you are running multiple tornado processes, "
- "log_file_prefix must be different for each of them (e.g. "
- "include the port number)"))
- options.define("log_file_max_size", type=int, default=100 * 1000 * 1000,
- help="max size of log files before rollover")
- options.define("log_file_num_backups", type=int, default=10,
- help="number of log files to keep")
-
- options.define("log_rotate_when", type=str, default='midnight',
- help=("specify the type of TimedRotatingFileHandler interval "
- "other options:('S', 'M', 'H', 'D', 'W0'-'W6')"))
- options.define("log_rotate_interval", type=int, default=1,
- help="The interval value of timed rotating")
-
- options.define("log_rotate_mode", type=str, default='size',
- help="The mode of rotating files(time or size)")
+ options.define(
+ "logging",
+ default="info",
+ help=(
+ "Set the Python log level. If 'none', tornado won't touch the "
+ "logging configuration."
+ ),
+ metavar="debug|info|warning|error|none",
+ )
+ options.define(
+ "log_to_stderr",
+ type=bool,
+ default=None,
+ help=(
+ "Send log output to stderr (colorized if possible). "
+ "By default use stderr if --log_file_prefix is not set and "
+ "no other logging is configured."
+ ),
+ )
+ options.define(
+ "log_file_prefix",
+ type=str,
+ default=None,
+ metavar="PATH",
+ help=(
+ "Path prefix for log files. "
+ "Note that if you are running multiple tornado processes, "
+ "log_file_prefix must be different for each of them (e.g. "
+ "include the port number)"
+ ),
+ )
+ options.define(
+ "log_file_max_size",
+ type=int,
+ default=100 * 1000 * 1000,
+ help="max size of log files before rollover",
+ )
+ options.define(
+ "log_file_num_backups", type=int, default=10, help="number of log files to keep"
+ )
+
+ options.define(
+ "log_rotate_when",
+ type=str,
+ default="midnight",
+ help=(
+ "specify the type of TimedRotatingFileHandler interval "
+ "other options:('S', 'M', 'H', 'D', 'W0'-'W6')"
+ ),
+ )
+ options.define(
+ "log_rotate_interval",
+ type=int,
+ default=1,
+ help="The interval value of timed rotating",
+ )
+
+ options.define(
+ "log_rotate_mode",
+ type=str,
+ default="size",
+ help="The mode of rotating files(time or size)",
+ )
options.add_parse_callback(lambda: enable_pretty_logging(options))
# Note that the naming of ssl.Purpose is confusing; the purpose
# of a context is to authentiate the opposite side of the connection.
-_client_ssl_defaults = ssl.create_default_context(
- ssl.Purpose.SERVER_AUTH)
-_server_ssl_defaults = ssl.create_default_context(
- ssl.Purpose.CLIENT_AUTH)
-if hasattr(ssl, 'OP_NO_COMPRESSION'):
+_client_ssl_defaults = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+_server_ssl_defaults = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+if hasattr(ssl, "OP_NO_COMPRESSION"):
# See netutil.ssl_options_to_context
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
# module-import time, the import lock is already held by the main thread,
# leading to deadlock. Avoid it by caching the idna encoder on the main
# thread now.
-u'foo'.encode('idna')
+u"foo".encode("idna")
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
-u'foo'.encode('latin1')
+u"foo".encode("latin1")
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
_DEFAULT_BACKLOG = 128
-def bind_sockets(port: int, address: str=None,
- family: socket.AddressFamily=socket.AF_UNSPEC,
- backlog: int=_DEFAULT_BACKLOG, flags: int=None,
- reuse_port: bool=False) -> List[socket.socket]:
+def bind_sockets(
+ port: int,
+ address: str = None,
+ family: socket.AddressFamily = socket.AF_UNSPEC,
+ backlog: int = _DEFAULT_BACKLOG,
+ flags: int = None,
+ reuse_port: bool = False,
+) -> List[socket.socket]:
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
flags = socket.AI_PASSIVE
bound_port = None
unique_addresses = set() # type: set
- for res in sorted(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
- 0, flags), key=lambda x: x[0]):
+ for res in sorted(
+ socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags),
+ key=lambda x: x[0],
+ ):
if res in unique_addresses:
continue
unique_addresses.add(res)
af, socktype, proto, canonname, sockaddr = res
- if (sys.platform == 'darwin' and address == 'localhost' and
- af == socket.AF_INET6 and sockaddr[3] != 0):
+ if (
+ sys.platform == "darwin"
+ and address == "localhost"
+ and af == socket.AF_INET6
+ and sockaddr[3] != 0
+ ):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
# doesn't understand that this is a local address and will
continue
raise
set_close_exec(sock.fileno())
- if os.name != 'nt':
+ if os.name != "nt":
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error as e:
return sockets
-if hasattr(socket, 'AF_UNIX'):
- def bind_unix_socket(file: str, mode: int=0o600,
- backlog: int=_DEFAULT_BACKLOG) -> socket.socket:
+if hasattr(socket, "AF_UNIX"):
+
+ def bind_unix_socket(
+ file: str, mode: int = 0o600, backlog: int = _DEFAULT_BACKLOG
+ ) -> socket.socket:
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
return sock
-def add_accept_handler(sock: socket.socket,
- callback: Callable[[socket.socket, Any], None]) -> Callable[[], None]:
+def add_accept_handler(
+ sock: socket.socket, callback: Callable[[socket.socket, Any], None]
+) -> Callable[[], None]:
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
When a connection is accepted, ``callback(connection, address)`` will
Supports IPv4 and IPv6.
"""
- if not ip or '\x00' in ip:
+ if not ip or "\x00" in ip:
# getaddrinfo resolves empty strings to localhost, and truncates
# on zero bytes.
return False
try:
- res = socket.getaddrinfo(ip, 0, socket.AF_UNSPEC,
- socket.SOCK_STREAM,
- 0, socket.AI_NUMERICHOST)
+ res = socket.getaddrinfo(
+ ip, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST
+ )
return bool(res)
except socket.gaierror as e:
if e.args[0] == socket.EAI_NONAME:
The default implementation has changed from `BlockingResolver` to
`DefaultExecutorResolver`.
"""
+
@classmethod
- def configurable_base(cls) -> Type['Resolver']:
+ def configurable_base(cls) -> Type["Resolver"]:
return Resolver
@classmethod
- def configurable_default(cls) -> Type['Resolver']:
+ def configurable_default(cls) -> Type["Resolver"]:
return DefaultExecutorResolver
def resolve(
- self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
- ) -> 'Future[List[Tuple[int, Any]]]':
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+ ) -> "Future[List[Tuple[int, Any]]]":
"""Resolves an address.
The ``host`` argument is a string which may be a hostname or a
def _resolve_addr(
- host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+ host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> List[Tuple[int, Any]]:
# On Solaris, getaddrinfo fails if the given port is not found
# in /etc/services and no socket type is given, so we must pass
.. versionadded:: 5.0
"""
+
@gen.coroutine
def resolve(
- self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> Generator[Any, Any, List[Tuple[int, Any]]]:
result = yield IOLoop.current().run_in_executor(
- None, _resolve_addr, host, port, family)
+ None, _resolve_addr, host, port, family
+ )
return result
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
- def initialize(self, executor: concurrent.futures.Executor=None,
- close_executor: bool=True) -> None:
+
+ def initialize(
+ self, executor: concurrent.futures.Executor = None, close_executor: bool = True
+ ) -> None:
self.io_loop = IOLoop.current()
if executor is not None:
self.executor = executor
@run_on_executor
def resolve(
- self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> List[Tuple[int, Any]]:
return _resolve_addr(host, port, family)
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
+
def initialize(self) -> None: # type: ignore
super(BlockingResolver, self).initialize()
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
+
_threadpool = None # type: ignore
_threadpool_pid = None # type: int
- def initialize(self, num_threads: int=10) -> None: # type: ignore
+ def initialize(self, num_threads: int = 10) -> None: # type: ignore
threadpool = ThreadedResolver._create_threadpool(num_threads)
super(ThreadedResolver, self).initialize(
- executor=threadpool, close_executor=False)
+ executor=threadpool, close_executor=False
+ )
@classmethod
- def _create_threadpool(cls, num_threads: int) -> concurrent.futures.ThreadPoolExecutor:
+ def _create_threadpool(
+ cls, num_threads: int
+ ) -> concurrent.futures.ThreadPoolExecutor:
pid = os.getpid()
if cls._threadpool_pid != pid:
# Threads cannot survive after a fork, so if our pid isn't what it
.. versionchanged:: 5.0
Added support for host-port-family triplets.
"""
+
def initialize(self, resolver: Resolver, mapping: dict) -> None: # type: ignore
self.resolver = resolver
self.mapping = mapping
self.resolver.close()
def resolve(
- self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
- ) -> 'Future[List[Tuple[int, Any]]]':
+ self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
+ ) -> "Future[List[Tuple[int, Any]]]":
if (host, port, family) in self.mapping:
host, port = self.mapping[(host, port, family)]
elif (host, port) in self.mapping:
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
# to SSLContext.wrap_socket).
-_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
- 'cert_reqs', 'ca_certs', 'ciphers'])
+_SSL_CONTEXT_KEYWORDS = frozenset(
+ ["ssl_version", "certfile", "keyfile", "cert_reqs", "ca_certs", "ciphers"]
+)
-def ssl_options_to_context(ssl_options: Union[Dict[str, Any], ssl.SSLContext]) -> ssl.SSLContext:
+def ssl_options_to_context(
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext]
+) -> ssl.SSLContext:
"""Try to convert an ``ssl_options`` dictionary to an
`~ssl.SSLContext` object.
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
# Can't use create_default_context since this interface doesn't
# tell us client vs server.
- context = ssl.SSLContext(
- ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
- if 'certfile' in ssl_options:
- context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
- if 'cert_reqs' in ssl_options:
- context.verify_mode = ssl_options['cert_reqs']
- if 'ca_certs' in ssl_options:
- context.load_verify_locations(ssl_options['ca_certs'])
- if 'ciphers' in ssl_options:
- context.set_ciphers(ssl_options['ciphers'])
- if hasattr(ssl, 'OP_NO_COMPRESSION'):
+ context = ssl.SSLContext(ssl_options.get("ssl_version", ssl.PROTOCOL_SSLv23))
+ if "certfile" in ssl_options:
+ context.load_cert_chain(
+ ssl_options["certfile"], ssl_options.get("keyfile", None)
+ )
+ if "cert_reqs" in ssl_options:
+ context.verify_mode = ssl_options["cert_reqs"]
+ if "ca_certs" in ssl_options:
+ context.load_verify_locations(ssl_options["ca_certs"])
+ if "ciphers" in ssl_options:
+ context.set_ciphers(ssl_options["ciphers"])
+ if hasattr(ssl, "OP_NO_COMPRESSION"):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant depends on openssl version 1.0.
# TODO: Do we need to do this ourselves or can we trust
return context
-def ssl_wrap_socket(socket: socket.socket, ssl_options: Union[Dict[str, Any], ssl.SSLContext],
- server_hostname: str=None, **kwargs: Any) -> ssl.SSLSocket:
+def ssl_wrap_socket(
+ socket: socket.socket,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext],
+ server_hostname: str = None,
+ **kwargs: Any
+) -> ssl.SSLSocket:
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either an `ssl.SSLContext` object or a
# TODO: add a unittest (python added server-side SNI support in 3.4)
# In the meantime it can be manually tested with
# python3 -m tornado.httpclient https://sni.velox.ch
- return context.wrap_socket(socket, server_hostname=server_hostname,
- **kwargs)
+ return context.wrap_socket(socket, server_hostname=server_hostname, **kwargs)
else:
return context.wrap_socket(socket, **kwargs)
class Error(Exception):
"""Exception raised by errors in the options module."""
+
pass
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
+
def __init__(self) -> None:
# we have to use self.__dict__ because we override setattr.
- self.__dict__['_options'] = {}
- self.__dict__['_parse_callbacks'] = []
- self.define("help", type=bool, help="show this help information",
- callback=self._help_callback)
+ self.__dict__["_options"] = {}
+ self.__dict__["_parse_callbacks"] = []
+ self.define(
+ "help",
+ type=bool,
+ help="show this help information",
+ callback=self._help_callback,
+ )
def _normalize_name(self, name: str) -> str:
- return name.replace('_', '-')
+ return name.replace("_", "-")
def __getattr__(self, name: str) -> Any:
name = self._normalize_name(name)
.. versionadded:: 3.1
"""
return dict(
- (opt.name, opt.value()) for name, opt in self._options.items()
- if not group or group == opt.group_name)
+ (opt.name, opt.value())
+ for name, opt in self._options.items()
+ if not group or group == opt.group_name
+ )
def as_dict(self) -> Dict[str, Any]:
"""The names and values of all options.
.. versionadded:: 3.1
"""
- return dict(
- (opt.name, opt.value()) for name, opt in self._options.items())
-
- def define(self, name: str, default: Any=None, type: type=None,
- help: str=None, metavar: str=None,
- multiple: bool=False, group: str=None, callback: Callable[[Any], None]=None) -> None:
+ return dict((opt.name, opt.value()) for name, opt in self._options.items())
+
+ def define(
+ self,
+ name: str,
+ default: Any = None,
+ type: type = None,
+ help: str = None,
+ metavar: str = None,
+ multiple: bool = False,
+ group: str = None,
+ callback: Callable[[Any], None] = None,
+ ) -> None:
"""Defines a new command line option.
``type`` can be any of `str`, `int`, `float`, `bool`,
"""
normalized = self._normalize_name(name)
if normalized in self._options:
- raise Error("Option %r already defined in %s" %
- (normalized, self._options[normalized].file_name))
+ raise Error(
+ "Option %r already defined in %s"
+ % (normalized, self._options[normalized].file_name)
+ )
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
# Can be called directly, or through top level define() fn, in which
# case, step up above that frame to look for real caller.
- if (frame.f_back.f_code.co_filename == options_file and
- frame.f_back.f_code.co_name == 'define'):
+ if (
+ frame.f_back.f_code.co_filename == options_file
+ and frame.f_back.f_code.co_name == "define"
+ ):
frame = frame.f_back
file_name = frame.f_back.f_code.co_filename
group_name = group # type: Optional[str]
else:
group_name = file_name
- option = _Option(name, file_name=file_name,
- default=default, type=type, help=help,
- metavar=metavar, multiple=multiple,
- group_name=group_name,
- callback=callback)
+ option = _Option(
+ name,
+ file_name=file_name,
+ default=default,
+ type=type,
+ help=help,
+ metavar=metavar,
+ multiple=multiple,
+ group_name=group_name,
+ callback=callback,
+ )
self._options[normalized] = option
- def parse_command_line(self, args: List[str]=None, final: bool=True) -> List[str]:
+ def parse_command_line(
+ self, args: List[str] = None, final: bool = True
+ ) -> List[str]:
"""Parses all options given on the command line (defaults to
`sys.argv`).
remaining = args[i:]
break
if args[i] == "--":
- remaining = args[i + 1:]
+ remaining = args[i + 1 :]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = self._normalize_name(name)
if name not in self._options:
self.print_help()
- raise Error('Unrecognized command line option: %r' % name)
+ raise Error("Unrecognized command line option: %r" % name)
option = self._options[name]
if not equals:
if option.type == bool:
value = "true"
else:
- raise Error('Option %r requires a value' % name)
+ raise Error("Option %r requires a value" % name)
option.parse(value)
if final:
return remaining
- def parse_config_file(self, path: str, final: bool=True) -> None:
+ def parse_config_file(self, path: str, final: bool = True) -> None:
"""Parses and loads the config file at the given path.
The config file contains Python code that will be executed (so
Added the ability to set options via strings in config files.
"""
- config = {'__file__': os.path.abspath(path)}
- with open(path, 'rb') as f:
+ config = {"__file__": os.path.abspath(path)}
+ with open(path, "rb") as f:
exec_in(native_str(f.read()), config, config)
for name in config:
normalized = self._normalize_name(name)
option = self._options[normalized]
if option.multiple:
if not isinstance(config[name], (list, str)):
- raise Error("Option %r is required to be a list of %s "
- "or a comma-separated string" %
- (option.name, option.type.__name__))
+ raise Error(
+ "Option %r is required to be a list of %s "
+ "or a comma-separated string"
+ % (option.name, option.type.__name__)
+ )
if type(config[name]) == str and option.type != str:
option.parse(config[name])
if final:
self.run_parse_callbacks()
- def print_help(self, file: TextIO=None) -> None:
+ def print_help(self, file: TextIO = None) -> None:
"""Prints all the command line options to stderr (or another file)."""
if file is None:
file = sys.stderr
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
- if option.default is not None and option.default != '':
+ if option.default is not None and option.default != "":
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
- lines.insert(0, '')
+ lines.insert(0, "")
print(" --%-30s %s" % (prefix, lines[0]), file=file)
for line in lines[1:]:
- print("%-34s %s" % (' ', line), file=file)
+ print("%-34s %s" % (" ", line), file=file)
print(file=file)
def _help_callback(self, value: bool) -> None:
for callback in self._parse_callbacks:
callback()
- def mockable(self) -> '_Mockable':
+ def mockable(self) -> "_Mockable":
"""Returns a wrapper around self that is compatible with
`mock.patch <unittest.mock.patch>`.
_Mockable's getattr and setattr pass through to the underlying
OptionParser, and delattr undoes the effect of a previous setattr.
"""
+
def __init__(self, options: OptionParser) -> None:
# Modify __dict__ directly to bypass __setattr__
- self.__dict__['_options'] = options
- self.__dict__['_originals'] = {}
+ self.__dict__["_options"] = options
+ self.__dict__["_originals"] = {}
def __getattr__(self, name: str) -> Any:
return getattr(self._options, name)
# and the callback use List[T], but type is still Type[T]).
UNSET = object()
- def __init__(self, name: str, default: Any=None, type: type=None,
- help: str=None, metavar: str=None, multiple: bool=False, file_name: str=None,
- group_name: str=None, callback: Callable[[Any], None]=None) -> None:
+ def __init__(
+ self,
+ name: str,
+ default: Any = None,
+ type: type = None,
+ help: str = None,
+ metavar: str = None,
+ multiple: bool = False,
+ file_name: str = None,
+ group_name: str = None,
+ callback: Callable[[Any], None] = None,
+ ) -> None:
if default is None and multiple:
default = []
self.name = name
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
basestring_type: self._parse_string,
- }.get(self.type, self.type) # type: Callable[[str], Any]
+ }.get(
+ self.type, self.type
+ ) # type: Callable[[str], Any]
if self.multiple:
self._value = []
for part in value.split(","):
def set(self, value: Any) -> None:
if self.multiple:
if not isinstance(value, list):
- raise Error("Option %r is required to be a list of %s" %
- (self.name, self.type.__name__))
+ raise Error(
+ "Option %r is required to be a list of %s"
+ % (self.name, self.type.__name__)
+ )
for item in value:
if item is not None and not isinstance(item, self.type):
- raise Error("Option %r is required to be a list of %s" %
- (self.name, self.type.__name__))
+ raise Error(
+ "Option %r is required to be a list of %s"
+ % (self.name, self.type.__name__)
+ )
else:
if value is not None and not isinstance(value, self.type):
- raise Error("Option %r is required to be a %s (%s given)" %
- (self.name, self.type.__name__, type(value)))
+ raise Error(
+ "Option %r is required to be a %s (%s given)"
+ % (self.name, self.type.__name__, type(value))
+ )
self._value = value
if self.callback is not None:
self.callback(self._value)
return datetime.datetime.strptime(value, format)
except ValueError:
pass
- raise Error('Unrecognized date/time format: %r' % value)
+ raise Error("Unrecognized date/time format: %r" % value)
_TIMEDELTA_ABBREV_DICT = {
- 'h': 'hours',
- 'm': 'minutes',
- 'min': 'minutes',
- 's': 'seconds',
- 'sec': 'seconds',
- 'ms': 'milliseconds',
- 'us': 'microseconds',
- 'd': 'days',
- 'w': 'weeks',
+ "h": "hours",
+ "m": "minutes",
+ "min": "minutes",
+ "s": "seconds",
+ "sec": "seconds",
+ "ms": "milliseconds",
+ "us": "microseconds",
+ "d": "days",
+ "w": "weeks",
}
- _FLOAT_PATTERN = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
+ _FLOAT_PATTERN = r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?"
_TIMEDELTA_PATTERN = re.compile(
- r'\s*(%s)\s*(\w*)\s*' % _FLOAT_PATTERN, re.IGNORECASE)
+ r"\s*(%s)\s*(\w*)\s*" % _FLOAT_PATTERN, re.IGNORECASE
+ )
def _parse_timedelta(self, value: str) -> datetime.timedelta:
try:
if not m:
raise Exception()
num = float(m.group(1))
- units = m.group(2) or 'seconds'
+ units = m.group(2) or "seconds"
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
sum += datetime.timedelta(**{units: num})
start = m.end()
"""
-def define(name: str, default: Any=None, type: type=None, help: str=None,
- metavar: str=None, multiple: bool=False, group: str=None,
- callback: Callable[[Any], None]=None) -> None:
+def define(
+ name: str,
+ default: Any = None,
+ type: type = None,
+ help: str = None,
+ metavar: str = None,
+ multiple: bool = False,
+ group: str = None,
+ callback: Callable[[Any], None] = None,
+) -> None:
"""Defines an option in the global namespace.
See `OptionParser.define`.
"""
- return options.define(name, default=default, type=type, help=help,
- metavar=metavar, multiple=multiple, group=group,
- callback=callback)
-
-
-def parse_command_line(args: List[str]=None, final: bool=True) -> List[str]:
+ return options.define(
+ name,
+ default=default,
+ type=type,
+ help=help,
+ metavar=metavar,
+ multiple=multiple,
+ group=group,
+ callback=callback,
+ )
+
+
+def parse_command_line(args: List[str] = None, final: bool = True) -> List[str]:
"""Parses global options from the command line.
See `OptionParser.parse_command_line`.
return options.parse_command_line(args, final=final)
-def parse_config_file(path: str, final: bool=True) -> None:
+def parse_config_file(path: str, final: bool = True) -> None:
"""Parses global options from a config file.
See `OptionParser.parse_config_file`.
return options.parse_config_file(path, final=final)
-def print_help(file: TextIO=None) -> None:
+def print_help(file: TextIO = None) -> None:
"""Prints all the command line options to stderr (or another file).
See `OptionParser.print_help`.
import typing
from typing import Any, TypeVar, Awaitable, Callable, Union, Optional
+
if typing.TYPE_CHECKING:
from typing import Set, Dict, Tuple # noqa: F401
-_T = TypeVar('_T')
+_T = TypeVar("_T")
class BaseAsyncIOLoop(IOLoop):
- def initialize(self, asyncio_loop: asyncio.AbstractEventLoop, # type: ignore
- **kwargs: Any) -> None:
+ def initialize( # type: ignore
+ self, asyncio_loop: asyncio.AbstractEventLoop, **kwargs: Any
+ ) -> None:
self.asyncio_loop = asyncio_loop
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
self.handlers = {} # type: Dict[int, Tuple[Union[int, _Selectable], Callable]]
IOLoop._ioloop_for_asyncio[asyncio_loop] = self
super(BaseAsyncIOLoop, self).initialize(**kwargs)
- def close(self, all_fds: bool=False) -> None:
+ def close(self, all_fds: bool = False) -> None:
self.closing = True
for fd in list(self.handlers):
fileobj, handler_func = self.handlers[fd]
del IOLoop._ioloop_for_asyncio[self.asyncio_loop]
self.asyncio_loop.close()
- def add_handler(self, fd: Union[int, _Selectable],
- handler: Callable[..., None], events: int) -> None:
+ def add_handler(
+ self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int
+ ) -> None:
fd, fileobj = self.split_fd(fd)
if fd in self.handlers:
raise ValueError("fd %s added twice" % fd)
self.handlers[fd] = (fileobj, handler)
if events & IOLoop.READ:
- self.asyncio_loop.add_reader(
- fd, self._handle_events, fd, IOLoop.READ)
+ self.asyncio_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
if events & IOLoop.WRITE:
- self.asyncio_loop.add_writer(
- fd, self._handle_events, fd, IOLoop.WRITE)
+ self.asyncio_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
def update_handler(self, fd: Union[int, _Selectable], events: int) -> None:
fd, fileobj = self.split_fd(fd)
if events & IOLoop.READ:
if fd not in self.readers:
- self.asyncio_loop.add_reader(
- fd, self._handle_events, fd, IOLoop.READ)
+ self.asyncio_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
else:
if fd in self.readers:
self.readers.remove(fd)
if events & IOLoop.WRITE:
if fd not in self.writers:
- self.asyncio_loop.add_writer(
- fd, self._handle_events, fd, IOLoop.WRITE)
+ self.asyncio_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
else:
if fd in self.writers:
def stop(self) -> None:
self.asyncio_loop.stop()
- def call_at(self, when: float, callback: Callable[..., None],
- *args: Any, **kwargs: Any) -> object:
+ def call_at(
+ self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
+ ) -> object:
# asyncio.call_at supports *args but not **kwargs, so bind them here.
# We do not synchronize self.time and asyncio_loop.time, so
# convert from absolute to relative.
return self.asyncio_loop.call_later(
- max(0, when - self.time()), self._run_callback,
- functools.partial(callback, *args, **kwargs))
+ max(0, when - self.time()),
+ self._run_callback,
+ functools.partial(callback, *args, **kwargs),
+ )
def remove_timeout(self, timeout: object) -> None:
timeout.cancel() # type: ignore
def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
try:
self.asyncio_loop.call_soon_threadsafe(
- self._run_callback,
- functools.partial(callback, *args, **kwargs))
+ self._run_callback, functools.partial(callback, *args, **kwargs)
+ )
except RuntimeError:
# "Event loop is closed". Swallow the exception for
# consistency with PollIOLoop (and logical consistency
add_callback_from_signal = add_callback
- def run_in_executor(self, executor: Optional[concurrent.futures.Executor],
- func: Callable[..., _T], *args: Any) -> Awaitable[_T]:
+ def run_in_executor(
+ self,
+ executor: Optional[concurrent.futures.Executor],
+ func: Callable[..., _T],
+ *args: Any
+ ) -> Awaitable[_T]:
return self.asyncio_loop.run_in_executor(executor, func, *args)
def set_default_executor(self, executor: concurrent.futures.Executor) -> None:
Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop.
"""
+
def initialize(self, **kwargs: Any) -> None: # type: ignore
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs)
Now used automatically when appropriate; it is no longer necessary
to refer to this class directly.
"""
+
def initialize(self, **kwargs: Any) -> None: # type: ignore
self.is_current = False
loop = asyncio.new_event_loop()
loop.close()
raise
- def close(self, all_fds: bool=False) -> None:
+ def close(self, all_fds: bool = False) -> None:
if self.is_current:
self.clear_current()
super(AsyncIOLoop, self).close(all_fds=all_fds)
.. versionadded:: 5.0
"""
+
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
import os
-if os.name == 'nt':
+if os.name == "nt":
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec
-__all__ = ['set_close_exec']
+__all__ = ["set_close_exec"]
from tornado.netutil import Resolver, is_valid_ip
import typing
+
if typing.TYPE_CHECKING:
from typing import Generator, Any, List, Tuple, Dict # noqa: F401
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
+
def initialize(self) -> None:
self.io_loop = IOLoop.current()
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
self.fds = {} # type: Dict[int, int]
def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
- state = ((IOLoop.READ if readable else 0) |
- (IOLoop.WRITE if writable else 0))
+ state = (IOLoop.READ if readable else 0) | (IOLoop.WRITE if writable else 0)
if not state:
self.io_loop.remove_handler(fd)
del self.fds[fd]
@gen.coroutine
def resolve(
- self, host: str, port: int, family: int=0,
- ) -> 'Generator[Any, Any, List[Tuple[int, Any]]]':
+ self, host: str, port: int, family: int = 0
+ ) -> "Generator[Any, Any, List[Tuple[int, Any]]]":
if is_valid_ip(host):
addresses = [host]
else:
# gethostbyname doesn't take callback as a kwarg
fut = Future() # type: Future[Tuple[Any, Any]]
- self.channel.gethostbyname(host, family,
- lambda result, error: fut.set_result((result, error)))
+ self.channel.gethostbyname(
+ host, family, lambda result, error: fut.set_result((result, error))
+ )
result, error = yield fut
if error:
- raise IOError('C-Ares returned error %s: %s while resolving %s' %
- (error, pycares.errno.strerror(error), host))
+ raise IOError(
+ "C-Ares returned error %s: %s while resolving %s"
+ % (error, pycares.errno.strerror(error), host)
+ )
addresses = result.addresses
addrinfo = []
for address in addresses:
- if '.' in address:
+ if "." in address:
address_family = socket.AF_INET
- elif ':' in address:
+ elif ":" in address:
address_family = socket.AF_INET6
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
- raise IOError('Requested socket family %d but got %d' %
- (family, address_family))
+ raise IOError(
+ "Requested socket family %d but got %d" % (family, address_family)
+ )
addrinfo.append((typing.cast(int, address_family), (address, port)))
return addrinfo
from tornado.netutil import Resolver
import typing
+
if typing.TYPE_CHECKING:
from typing import Generator, Any, List, Tuple # noqa: F401
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
+
def initialize(self) -> None:
# partial copy of twisted.names.client.createResolver, which doesn't
# allow for a reactor to be passed in.
self.reactor = twisted.internet.asyncioreactor.AsyncioSelectorReactor()
- host_resolver = twisted.names.hosts.Resolver('/etc/hosts')
+ host_resolver = twisted.names.hosts.Resolver("/etc/hosts")
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
- real_resolver = twisted.names.client.Resolver('/etc/resolv.conf',
- reactor=self.reactor)
+ real_resolver = twisted.names.client.Resolver(
+ "/etc/resolv.conf", reactor=self.reactor
+ )
self.resolver = twisted.names.resolve.ResolverChain(
- [host_resolver, cache_resolver, real_resolver])
+ [host_resolver, cache_resolver, real_resolver]
+ )
@gen.coroutine
def resolve(
- self, host: str, port: int, family: int=0,
- ) -> 'Generator[Any, Any, List[Tuple[int, Any]]]':
+ self, host: str, port: int, family: int = 0
+ ) -> "Generator[Any, Any, List[Tuple[int, Any]]]":
# getHostByName doesn't accept IP addresses, so if the input
# looks like an IP address just return it immediately.
if twisted.internet.abstract.isIPAddress(host):
else:
resolved_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != resolved_family:
- raise Exception('Requested socket family %d but got %d' %
- (family, resolved_family))
- result = [
- (typing.cast(int, resolved_family), (resolved, port)),
- ]
+ raise Exception(
+ "Requested socket family %d but got %d" % (family, resolved_family)
+ )
+ result = [(typing.cast(int, resolved_family), (resolved, port))]
return result
-if hasattr(gen.convert_yielded, 'register'):
+if hasattr(gen.convert_yielded, "register"):
+
@gen.convert_yielded.register(Deferred) # type: ignore
def _(d: Deferred) -> Future:
f = Future() # type: Future[Any]
raise Exception("errback called without error")
except:
future_set_exc_info(f, sys.exc_info())
+
d.addCallbacks(f.set_result, errback)
return f
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation # type: ignore
-SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) # noqa: E501
+SetHandleInformation.argtypes = (
+ ctypes.wintypes.HANDLE,
+ ctypes.wintypes.DWORD,
+ ctypes.wintypes.DWORD,
+) # noqa: E501
SetHandleInformation.restype = ctypes.wintypes.BOOL
HANDLE_FLAG_INHERIT = 0x00000001
import typing
from typing import Tuple, Optional, Any, Callable
+
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
def _reseed_random() -> None:
- if 'random' not in sys.modules:
+ if "random" not in sys.modules:
return
import random
+
# If os.urandom is available, this method does the same thing as
# random.seed (at least as of python 2.6). If os.urandom is not
# available, we mix in the pid in addition to a timestamp.
_task_id = None
-def fork_processes(num_processes: Optional[int], max_restarts: int=100) -> int:
+def fork_processes(num_processes: Optional[int], max_restarts: int = 100) -> int:
"""Starts multiple worker processes.
If ``num_processes`` is None or <= 0, we detect the number of cores
continue
id = children.pop(pid)
if os.WIFSIGNALED(status):
- gen_log.warning("child %d (pid %d) killed by signal %d, restarting",
- id, pid, os.WTERMSIG(status))
+ gen_log.warning(
+ "child %d (pid %d) killed by signal %d, restarting",
+ id,
+ pid,
+ os.WTERMSIG(status),
+ )
elif os.WEXITSTATUS(status) != 0:
- gen_log.warning("child %d (pid %d) exited with status %d, restarting",
- id, pid, os.WEXITSTATUS(status))
+ gen_log.warning(
+ "child %d (pid %d) exited with status %d, restarting",
+ id,
+ pid,
+ os.WEXITSTATUS(status),
+ )
else:
gen_log.info("child %d (pid %d) exited normally", id, pid)
continue
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
+
STREAM = object()
_initialized = False
# should be closed in the parent process on success.
pipe_fds = [] # type: List[int]
to_close = [] # type: List[int]
- if kwargs.get('stdin') is Subprocess.STREAM:
+ if kwargs.get("stdin") is Subprocess.STREAM:
in_r, in_w = _pipe_cloexec()
- kwargs['stdin'] = in_r
+ kwargs["stdin"] = in_r
pipe_fds.extend((in_r, in_w))
to_close.append(in_r)
self.stdin = PipeIOStream(in_w)
- if kwargs.get('stdout') is Subprocess.STREAM:
+ if kwargs.get("stdout") is Subprocess.STREAM:
out_r, out_w = _pipe_cloexec()
- kwargs['stdout'] = out_w
+ kwargs["stdout"] = out_w
pipe_fds.extend((out_r, out_w))
to_close.append(out_w)
self.stdout = PipeIOStream(out_r)
- if kwargs.get('stderr') is Subprocess.STREAM:
+ if kwargs.get("stderr") is Subprocess.STREAM:
err_r, err_w = _pipe_cloexec()
- kwargs['stderr'] = err_w
+ kwargs["stderr"] = err_w
pipe_fds.extend((err_r, err_w))
to_close.append(err_w)
self.stderr = PipeIOStream(err_r)
for fd in to_close:
os.close(fd)
self.pid = self.proc.pid
- for attr in ['stdin', 'stdout', 'stderr']:
+ for attr in ["stdin", "stdout", "stderr"]:
if not hasattr(self, attr): # don't clobber streams set above
setattr(self, attr, getattr(self.proc, attr))
self._exit_callback = None # type: Optional[Callable[[int], None]]
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
- def wait_for_exit(self, raise_error: bool=True) -> 'Future[int]':
+ def wait_for_exit(self, raise_error: bool = True) -> "Future[int]":
"""Returns a `.Future` which resolves when the process exits.
Usage::
def callback(ret: int) -> None:
if ret != 0 and raise_error:
# Unfortunately we don't have the original args any more.
- future.set_exception(CalledProcessError(ret, 'unknown'))
+ future.set_exception(CalledProcessError(ret, "unknown"))
else:
future_set_result_unless_cancelled(future, ret)
+
self.set_exit_callback(callback)
return future
io_loop = ioloop.IOLoop.current()
cls._old_sigchld = signal.signal(
signal.SIGCHLD,
- lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup))
+ lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup),
+ )
cls._initialized = True
@classmethod
return
assert ret_pid == pid
subproc = cls._waiting.pop(pid)
- subproc.io_loop.add_callback_from_signal(
- subproc._set_returncode, status)
+ subproc.io_loop.add_callback_from_signal(subproc._set_returncode, status)
def _set_returncode(self, status: int) -> None:
if os.WIFSIGNALED(status):
from typing import Union, TypeVar, Generic, Awaitable
import typing
+
if typing.TYPE_CHECKING:
from typing import Deque, Tuple, List, Any # noqa: F401
-_T = TypeVar('_T')
+_T = TypeVar("_T")
-__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
+__all__ = ["Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty"]
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""
+
pass
class QueueFull(Exception):
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
+
pass
-def _set_timeout(future: Future, timeout: Union[None, float, datetime.timedelta]) -> None:
+def _set_timeout(
+ future: Future, timeout: Union[None, float, datetime.timedelta]
+) -> None:
if timeout:
+
def on_timeout() -> None:
if not future.done():
future.set_exception(gen.TimeoutError())
+
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
- future.add_done_callback(
- lambda _: io_loop.remove_timeout(timeout_handle))
+ future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
class _QueueIterator(Generic[_T]):
- def __init__(self, q: 'Queue[_T]') -> None:
+ def __init__(self, q: "Queue[_T]") -> None:
self.q = q
def __anext__(self) -> Awaitable[_T]:
Added ``async for`` support in Python 3.5.
"""
+
# Exact type depends on subclass. Could be another generic
# parameter and use protocols to be more precise here.
_queue = None # type: Any
- def __init__(self, maxsize: int=0) -> None:
+ def __init__(self, maxsize: int = 0) -> None:
if maxsize is None:
raise TypeError("maxsize can't be None")
else:
return self.qsize() >= self.maxsize
- def put(self, item: _T, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]':
+ def put(
+ self, item: _T, timeout: Union[float, datetime.timedelta] = None
+ ) -> "Future[None]":
"""Put an item into the queue, perhaps waiting until there is room.
Returns a Future, which raises `tornado.util.TimeoutError` after a
else:
self.__put_internal(item)
- def get(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[_T]':
+ def get(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[_T]":
"""Remove and return an item from the queue.
Returns a Future which resolves once an item is available, or raises
Raises `ValueError` if called more times than `.put`.
"""
if self._unfinished_tasks <= 0:
- raise ValueError('task_done() called too many times')
+ raise ValueError("task_done() called too many times")
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
- def join(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]':
+ def join(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[None]":
"""Block until all items in the queue are processed.
Returns a Future, which raises `tornado.util.TimeoutError` after a
def _put(self, item: _T) -> None:
self._queue.append(item)
+
# End of the overridable methods.
def __put_internal(self, item: _T) -> None:
self._getters.popleft()
def __repr__(self) -> str:
- return '<%s at %s %s>' % (
- type(self).__name__, hex(id(self)), self._format())
+ return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format())
def __str__(self) -> str:
- return '<%s %s>' % (type(self).__name__, self._format())
+ return "<%s %s>" % (type(self).__name__, self._format())
def _format(self) -> str:
- result = 'maxsize=%r' % (self.maxsize, )
- if getattr(self, '_queue', None):
- result += ' queue=%r' % self._queue
+ result = "maxsize=%r" % (self.maxsize,)
+ if getattr(self, "_queue", None):
+ result += " queue=%r" % self._queue
if self._getters:
- result += ' getters[%s]' % len(self._getters)
+ result += " getters[%s]" % len(self._getters)
if self._putters:
- result += ' putters[%s]' % len(self._putters)
+ result += " putters[%s]" % len(self._putters)
if self._unfinished_tasks:
- result += ' tasks=%s' % self._unfinished_tasks
+ result += " tasks=%s" % self._unfinished_tasks
return result
(1, 'medium-priority item')
(10, 'low-priority item')
"""
+
def _init(self) -> None:
self._queue = []
2
3
"""
+
def _init(self) -> None:
self._queue = []
class Router(httputil.HTTPServerConnectionDelegate):
"""Abstract router interface."""
- def find_handler(self, request: httputil.HTTPServerRequest,
- **kwargs: Any) -> Optional[httputil.HTTPMessageDelegate]:
+ def find_handler(
+ self, request: httputil.HTTPServerRequest, **kwargs: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
that can serve the request.
Routing implementations may pass additional kwargs to extend the routing logic.
"""
raise NotImplementedError()
- def start_request(self, server_conn: object,
- request_conn: httputil.HTTPConnection) -> httputil.HTTPMessageDelegate:
+ def start_request(
+ self, server_conn: object, request_conn: httputil.HTTPConnection
+ ) -> httputil.HTTPMessageDelegate:
return _RoutingDelegate(self, server_conn, request_conn)
class _RoutingDelegate(httputil.HTTPMessageDelegate):
- def __init__(self, router: Router, server_conn: object,
- request_conn: httputil.HTTPConnection) -> None:
+ def __init__(
+ self, router: Router, server_conn: object, request_conn: httputil.HTTPConnection
+ ) -> None:
self.server_conn = server_conn
self.request_conn = request_conn
self.delegate = None # type: Optional[httputil.HTTPMessageDelegate]
self.router = router # type: Router
- def headers_received(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]:
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
assert isinstance(start_line, httputil.RequestStartLine)
request = httputil.HTTPServerRequest(
connection=self.request_conn,
server_connection=self.server_conn,
- start_line=start_line, headers=headers)
+ start_line=start_line,
+ headers=headers,
+ )
self.delegate = self.router.find_handler(request)
if self.delegate is None:
- app_log.debug("Delegate for %s %s request not found",
- start_line.method, start_line.path)
+ app_log.debug(
+ "Delegate for %s %s request not found",
+ start_line.method,
+ start_line.path,
+ )
self.delegate = _DefaultMessageDelegate(self.request_conn)
return self.delegate.headers_received(start_line, headers)
def finish(self) -> None:
self.connection.write_headers(
- httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders())
+ httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"),
+ httputil.HTTPHeaders(),
+ )
self.connection.finish()
# _RuleList can either contain pre-constructed Rules or a sequence of
# arguments to be passed to the Rule constructor.
-_RuleList = List[Union['Rule',
- List[Any], # Can't do detailed typechecking of lists.
- Tuple[Union[str, 'Matcher'], Any],
- Tuple[Union[str, 'Matcher'], Any, Dict[str, Any]],
- Tuple[Union[str, 'Matcher'], Any, Dict[str, Any], str]]]
+_RuleList = List[
+ Union[
+ "Rule",
+ List[Any], # Can't do detailed typechecking of lists.
+ Tuple[Union[str, "Matcher"], Any],
+ Tuple[Union[str, "Matcher"], Any, Dict[str, Any]],
+ Tuple[Union[str, "Matcher"], Any, Dict[str, Any], str],
+ ]
+]
class RuleRouter(Router):
"""Rule-based router implementation."""
- def __init__(self, rules: _RuleList=None) -> None:
+ def __init__(self, rules: _RuleList = None) -> None:
"""Constructs a router from an ordered list of rules::
RuleRouter([
self.rules.append(self.process_rule(rule))
- def process_rule(self, rule: 'Rule') -> 'Rule':
+ def process_rule(self, rule: "Rule") -> "Rule":
"""Override this method for additional preprocessing of each rule.
:arg Rule rule: a rule to be processed.
"""
return rule
- def find_handler(self, request: httputil.HTTPServerRequest,
- **kwargs: Any) -> Optional[httputil.HTTPMessageDelegate]:
+ def find_handler(
+ self, request: httputil.HTTPServerRequest, **kwargs: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
for rule in self.rules:
target_params = rule.matcher.match(request)
if target_params is not None:
if rule.target_kwargs:
- target_params['target_kwargs'] = rule.target_kwargs
+ target_params["target_kwargs"] = rule.target_kwargs
delegate = self.get_target_delegate(
- rule.target, request, **target_params)
+ rule.target, request, **target_params
+ )
if delegate is not None:
return delegate
return None
- def get_target_delegate(self, target: Any, request: httputil.HTTPServerRequest,
- **target_params: Any) -> Optional[httputil.HTTPMessageDelegate]:
+ def get_target_delegate(
+ self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
Rule's target. This method is called by `~.find_handler` and can be
extended to provide additional target types.
in a rule's matcher (see `Matcher.reverse`).
"""
- def __init__(self, rules: _RuleList=None) -> None:
+ def __init__(self, rules: _RuleList = None) -> None:
self.named_rules = {} # type: Dict[str, Any]
super(ReversibleRuleRouter, self).__init__(rules)
- def process_rule(self, rule: 'Rule') -> 'Rule':
+ def process_rule(self, rule: "Rule") -> "Rule":
rule = super(ReversibleRuleRouter, self).process_rule(rule)
if rule.name:
if rule.name in self.named_rules:
app_log.warning(
- "Multiple handlers named %s; replacing previous value",
- rule.name)
+ "Multiple handlers named %s; replacing previous value", rule.name
+ )
self.named_rules[rule.name] = rule
return rule
class Rule(object):
"""A routing rule."""
- def __init__(self, matcher: 'Matcher', target: Any,
- target_kwargs: Dict[str, Any]=None, name: str=None) -> None:
+ def __init__(
+ self,
+ matcher: "Matcher",
+ target: Any,
+ target_kwargs: Dict[str, Any] = None,
+ name: str = None,
+ ) -> None:
"""Constructs a Rule instance.
:arg Matcher matcher: a `Matcher` instance used for determining
return self.matcher.reverse(*args)
def __repr__(self) -> str:
- return '%s(%r, %s, kwargs=%r, name=%r)' % \
- (self.__class__.__name__, self.matcher,
- self.target, self.target_kwargs, self.name)
+ return "%s(%r, %s, kwargs=%r, name=%r)" % (
+ self.__class__.__name__,
+ self.matcher,
+ self.target,
+ self.target_kwargs,
+ self.name,
+ )
class Matcher(object):
def __init__(self, path_pattern: Union[str, Pattern]) -> None:
if isinstance(path_pattern, basestring_type):
- if not path_pattern.endswith('$'):
- path_pattern += '$'
+ if not path_pattern.endswith("$"):
+ path_pattern += "$"
self.regex = re.compile(path_pattern)
else:
self.regex = path_pattern
- assert len(self.regex.groupindex) in (0, self.regex.groups), \
- ("groups in url regexes must either be all named or all "
- "positional: %r" % self.regex.pattern)
+ assert len(self.regex.groupindex) in (0, self.regex.groups), (
+ "groups in url regexes must either be all named or all "
+ "positional: %r" % self.regex.pattern
+ )
self._path, self._group_count = self._find_groups()
# or groupdict but not both.
if self.regex.groupindex:
path_kwargs = dict(
- (str(k), _unquote_or_none(v))
- for (k, v) in match.groupdict().items())
+ (str(k), _unquote_or_none(v)) for (k, v) in match.groupdict().items()
+ )
else:
path_args = [_unquote_or_none(s) for s in match.groups()]
def reverse(self, *args: Any) -> Optional[str]:
if self._path is None:
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
- assert len(args) == self._group_count, "required number of arguments " \
- "not found"
+ assert len(args) == self._group_count, (
+ "required number of arguments " "not found"
+ )
if not len(args):
return self._path
converted_args = []
would return ('/%s/%s/', 2).
"""
pattern = self.regex.pattern
- if pattern.startswith('^'):
+ if pattern.startswith("^"):
pattern = pattern[1:]
- if pattern.endswith('$'):
+ if pattern.endswith("$"):
pattern = pattern[:-1]
- if self.regex.groups != pattern.count('('):
+ if self.regex.groups != pattern.count("("):
# The pattern is too complicated for our simplistic matching,
# so we can't support reversing it.
return None, None
pieces = []
- for fragment in pattern.split('('):
- if ')' in fragment:
- paren_loc = fragment.index(')')
+ for fragment in pattern.split("("):
+ if ")" in fragment:
+ paren_loc = fragment.index(")")
if paren_loc >= 0:
- pieces.append('%s' + fragment[paren_loc + 1:])
+ pieces.append("%s" + fragment[paren_loc + 1 :])
else:
try:
unescaped_fragment = re_unescape(fragment)
return (None, None)
pieces.append(unescaped_fragment)
- return ''.join(pieces), self.regex.groups
+ return "".join(pieces), self.regex.groups
class URLSpec(Rule):
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
backwards compatibility.
"""
- def __init__(self, pattern: Union[str, Pattern], handler: Any,
- kwargs: Dict[str, Any]=None, name: str=None) -> None:
+
+ def __init__(
+ self,
+ pattern: Union[str, Pattern],
+ handler: Any,
+ kwargs: Dict[str, Any] = None,
+ name: str = None,
+ ) -> None:
"""Parameters:
* ``pattern``: Regular expression to be matched. Any capturing
self.kwargs = kwargs
def __repr__(self) -> str:
- return '%s(%r, %s, kwargs=%r, name=%r)' % \
- (self.__class__.__name__, self.regex.pattern,
- self.handler_class, self.kwargs, self.name)
+ return "%s(%r, %s, kwargs=%r, name=%r)" % (
+ self.__class__.__name__,
+ self.regex.pattern,
+ self.handler_class,
+ self.kwargs,
+ self.name,
+ )
@overload
from tornado.escape import _unicode
from tornado import gen
-from tornado.httpclient import (HTTPResponse, HTTPError, AsyncHTTPClient, main,
- _RequestProxy, HTTPRequest)
+from tornado.httpclient import (
+ HTTPResponse,
+ HTTPError,
+ AsyncHTTPClient,
+ main,
+ _RequestProxy,
+ HTTPRequest,
+)
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.ioloop import IOLoop
from typing import Dict, Any, Generator, Callable, Optional, Type, Union
from types import TracebackType
import typing
+
if typing.TYPE_CHECKING:
from typing import Deque, Tuple, List # noqa: F401
.. versionadded:: 5.1
"""
+
def __init__(self, message: str) -> None:
super(HTTPTimeoutError, self).__init__(599, message=message)
.. versionadded:: 5.1
"""
+
def __init__(self, message: str) -> None:
super(HTTPStreamClosedError, self).__init__(599, message=message)
are not reused, and callers cannot select the network interface to be
used.
"""
- def initialize(self, max_clients: int=10, # type: ignore
- hostname_mapping: Dict[str, str]=None, max_buffer_size: int=104857600,
- resolver: Resolver=None, defaults: Dict[str, Any]=None,
- max_header_size: int=None, max_body_size: int=None) -> None:
+
+ def initialize( # type: ignore
+ self,
+ max_clients: int = 10,
+ hostname_mapping: Dict[str, str] = None,
+ max_buffer_size: int = 104857600,
+ resolver: Resolver = None,
+ defaults: Dict[str, Any] = None,
+ max_header_size: int = None,
+ max_body_size: int = None,
+ ) -> None:
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
"""
super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults)
self.max_clients = max_clients
- self.queue = collections.deque() \
- # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]]
- self.active = {} # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]]
- self.waiting = {} \
- # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]]
+ self.queue = (
+ collections.deque()
+ ) # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]]
+ self.active = (
+ {}
+ ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]]
+ self.waiting = (
+ {}
+ ) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]]
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
self.max_body_size = max_body_size
self.resolver = Resolver()
self.own_resolver = True
if hostname_mapping is not None:
- self.resolver = OverrideResolver(resolver=self.resolver,
- mapping=hostname_mapping)
+ self.resolver = OverrideResolver(
+ resolver=self.resolver, mapping=hostname_mapping
+ )
self.tcp_client = TCPClient(resolver=self.resolver)
def close(self) -> None:
self.resolver.close()
self.tcp_client.close()
- def fetch_impl(self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]) -> None:
+ def fetch_impl(
+ self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]
+ ) -> None:
key = object()
self.queue.append((key, request, callback))
if not len(self.active) < self.max_clients:
assert request.connect_timeout is not None
assert request.request_timeout is not None
timeout_handle = self.io_loop.add_timeout(
- self.io_loop.time() + min(request.connect_timeout,
- request.request_timeout),
- functools.partial(self._on_timeout, key, "in request queue"))
+ self.io_loop.time()
+ + min(request.connect_timeout, request.request_timeout),
+ functools.partial(self._on_timeout, key, "in request queue"),
+ )
else:
timeout_handle = None
self.waiting[key] = (request, callback, timeout_handle)
self._process_queue()
if self.queue:
- gen_log.debug("max_clients limit reached, request queued. "
- "%d active, %d queued requests." % (
- len(self.active), len(self.queue)))
+ gen_log.debug(
+ "max_clients limit reached, request queued. "
+ "%d active, %d queued requests." % (len(self.active), len(self.queue))
+ )
def _process_queue(self) -> None:
while self.queue and len(self.active) < self.max_clients:
def _connection_class(self) -> type:
return _HTTPConnection
- def _handle_request(self, request: HTTPRequest, release_callback: Callable[[], None],
- final_callback: Callable[[HTTPResponse], None]) -> None:
+ def _handle_request(
+ self,
+ request: HTTPRequest,
+ release_callback: Callable[[], None],
+ final_callback: Callable[[HTTPResponse], None],
+ ) -> None:
self._connection_class()(
- self, request, release_callback,
- final_callback, self.max_buffer_size, self.tcp_client,
- self.max_header_size, self.max_body_size)
+ self,
+ request,
+ release_callback,
+ final_callback,
+ self.max_buffer_size,
+ self.tcp_client,
+ self.max_header_size,
+ self.max_body_size,
+ )
def _release_fetch(self, key: object) -> None:
del self.active[key]
self.io_loop.remove_timeout(timeout_handle)
del self.waiting[key]
- def _on_timeout(self, key: object, info: str=None) -> None:
+ def _on_timeout(self, key: object, info: str = None) -> None:
"""Timeout callback of request.
Construct a timeout HTTPResponse when a timeout occurs.
error_message = "Timeout {0}".format(info) if info else "Timeout"
timeout_response = HTTPResponse(
- request, 599, error=HTTPTimeoutError(error_message),
- request_time=self.io_loop.time() - request.start_time)
+ request,
+ 599,
+ error=HTTPTimeoutError(error_message),
+ request_time=self.io_loop.time() - request.start_time,
+ )
self.io_loop.add_callback(callback, timeout_response)
del self.waiting[key]
class _HTTPConnection(httputil.HTTPMessageDelegate):
- _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
-
- def __init__(self, client: Optional[SimpleAsyncHTTPClient], request: HTTPRequest,
- release_callback: Callable[[], None],
- final_callback: Callable[[HTTPResponse], None], max_buffer_size: int,
- tcp_client: TCPClient, max_header_size: int, max_body_size: int) -> None:
+ _SUPPORTED_METHODS = set(
+ ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
+ )
+
+ def __init__(
+ self,
+ client: Optional[SimpleAsyncHTTPClient],
+ request: HTTPRequest,
+ release_callback: Callable[[], None],
+ final_callback: Callable[[HTTPResponse], None],
+ max_buffer_size: int,
+ tcp_client: TCPClient,
+ max_header_size: int,
+ max_body_size: int,
+ ) -> None:
self.io_loop = IOLoop.current()
self.start_time = self.io_loop.time()
self.start_wall_time = time.time()
try:
self.parsed = urllib.parse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
- raise ValueError("Unsupported url scheme: %s" %
- self.request.url)
+ raise ValueError("Unsupported url scheme: %s" % self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = self.parsed.netloc
host, port = httputil.split_host_and_port(netloc)
if port is None:
port = 443 if self.parsed.scheme == "https" else 80
- if re.match(r'^\[.*\]$', host):
+ if re.match(r"^\[.*\]$", host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
- functools.partial(self._on_timeout, "while connecting"))
+ functools.partial(self._on_timeout, "while connecting"),
+ )
stream = yield self.tcp_client.connect(
- host, port, af=af,
+ host,
+ port,
+ af=af,
ssl_options=ssl_options,
- max_buffer_size=self.max_buffer_size)
+ max_buffer_size=self.max_buffer_size,
+ )
if self.final_callback is None:
# final_callback is cleared if we've hit our timeout.
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
- functools.partial(self._on_timeout, "during request"))
- if (self.request.method not in self._SUPPORTED_METHODS and
- not self.request.allow_nonstandard_methods):
+ functools.partial(self._on_timeout, "during request"),
+ )
+ if (
+ self.request.method not in self._SUPPORTED_METHODS
+ and not self.request.allow_nonstandard_methods
+ ):
raise KeyError("unknown method %s" % self.request.method)
- for key in ('network_interface',
- 'proxy_host', 'proxy_port',
- 'proxy_username', 'proxy_password',
- 'proxy_auth_mode'):
+ for key in (
+ "network_interface",
+ "proxy_host",
+ "proxy_port",
+ "proxy_username",
+ "proxy_password",
+ "proxy_auth_mode",
+ ):
if getattr(self.request, key, None):
- raise NotImplementedError('%s not supported' % key)
+ raise NotImplementedError("%s not supported" % key)
if "Connection" not in self.request.headers:
self.request.headers["Connection"] = "close"
if "Host" not in self.request.headers:
- if '@' in self.parsed.netloc:
- self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
+ if "@" in self.parsed.netloc:
+ self.request.headers["Host"] = self.parsed.netloc.rpartition(
+ "@"
+ )[-1]
else:
self.request.headers["Host"] = self.parsed.netloc
username, password = None, None
username, password = self.parsed.username, self.parsed.password
elif self.request.auth_username is not None:
username = self.request.auth_username
- password = self.request.auth_password or ''
+ password = self.request.auth_password or ""
if username is not None:
assert password is not None
if self.request.auth_mode not in (None, "basic"):
- raise ValueError("unsupported auth_mode %s",
- self.request.auth_mode)
- self.request.headers["Authorization"] = (
- "Basic " + _unicode(base64.b64encode(
- httputil.encode_username_password(username, password))))
+ raise ValueError(
+ "unsupported auth_mode %s", self.request.auth_mode
+ )
+ self.request.headers["Authorization"] = "Basic " + _unicode(
+ base64.b64encode(
+ httputil.encode_username_password(username, password)
+ )
+ )
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
body_expected = self.request.method in ("POST", "PATCH", "PUT")
- body_present = (self.request.body is not None or
- self.request.body_producer is not None)
- if ((body_expected and not body_present) or
- (body_present and not body_expected)):
+ body_present = (
+ self.request.body is not None
+ or self.request.body_producer is not None
+ )
+ if (body_expected and not body_present) or (
+ body_present and not body_expected
+ ):
raise ValueError(
- 'Body must %sbe None for method %s (unless '
- 'allow_nonstandard_methods is true)' %
- ('not ' if body_expected else '', self.request.method))
+ "Body must %sbe None for method %s (unless "
+ "allow_nonstandard_methods is true)"
+ % ("not " if body_expected else "", self.request.method)
+ )
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
# When body_producer is used the caller is responsible for
# setting Content-Length (or else chunked encoding will be used).
- self.request.headers["Content-Length"] = str(len(
- self.request.body))
- if (self.request.method == "POST" and
- "Content-Type" not in self.request.headers):
- self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+ self.request.headers["Content-Length"] = str(len(self.request.body))
+ if (
+ self.request.method == "POST"
+ and "Content-Type" not in self.request.headers
+ ):
+ self.request.headers[
+ "Content-Type"
+ ] = "application/x-www-form-urlencoded"
if self.request.decompress_response:
self.request.headers["Accept-Encoding"] = "gzip"
- req_path = ((self.parsed.path or '/') +
- (('?' + self.parsed.query) if self.parsed.query else ''))
+ req_path = (self.parsed.path or "/") + (
+ ("?" + self.parsed.query) if self.parsed.query else ""
+ )
self.connection = self._create_connection(stream)
- start_line = httputil.RequestStartLine(self.request.method,
- req_path, '')
+ start_line = httputil.RequestStartLine(
+ self.request.method, req_path, ""
+ )
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
yield self.connection.read_response(self)
if not self._handle_exception(*sys.exc_info()):
raise
- def _get_ssl_options(self, scheme: str) -> Union[None, Dict[str, Any], ssl.SSLContext]:
+ def _get_ssl_options(
+ self, scheme: str
+ ) -> Union[None, Dict[str, Any], ssl.SSLContext]:
if scheme == "https":
if self.request.ssl_options is not None:
return self.request.ssl_options
# If we are using the defaults, don't construct a
# new SSLContext.
- if (self.request.validate_cert and
- self.request.ca_certs is None and
- self.request.client_cert is None and
- self.request.client_key is None):
+ if (
+ self.request.validate_cert
+ and self.request.ca_certs is None
+ and self.request.client_cert is None
+ and self.request.client_key is None
+ ):
return _client_ssl_defaults
ssl_ctx = ssl.create_default_context(
- ssl.Purpose.SERVER_AUTH,
- cafile=self.request.ca_certs)
+ ssl.Purpose.SERVER_AUTH, cafile=self.request.ca_certs
+ )
if not self.request.validate_cert:
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
if self.request.client_cert is not None:
- ssl_ctx.load_cert_chain(self.request.client_cert,
- self.request.client_key)
- if hasattr(ssl, 'OP_NO_COMPRESSION'):
+ ssl_ctx.load_cert_chain(
+ self.request.client_cert, self.request.client_key
+ )
+ if hasattr(ssl, "OP_NO_COMPRESSION"):
# See netutil.ssl_options_to_context
ssl_ctx.options |= ssl.OP_NO_COMPRESSION
return ssl_ctx
return None
- def _on_timeout(self, info: str=None) -> None:
+ def _on_timeout(self, info: str = None) -> None:
"""Timeout callback of _HTTPConnection instance.
Raise a `HTTPTimeoutError` when a timeout occurs.
self._timeout = None
error_message = "Timeout {0}".format(info) if info else "Timeout"
if self.final_callback is not None:
- self._handle_exception(HTTPTimeoutError, HTTPTimeoutError(error_message),
- None)
+ self._handle_exception(
+ HTTPTimeoutError, HTTPTimeoutError(error_message), None
+ )
def _remove_timeout(self) -> None:
if self._timeout is not None:
def _create_connection(self, stream: IOStream) -> HTTP1Connection:
stream.set_nodelay(True)
connection = HTTP1Connection(
- stream, True,
+ stream,
+ True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
max_body_size=self.max_body_size,
- decompress=bool(self.request.decompress_response)),
- self._sockaddr)
+ decompress=bool(self.request.decompress_response),
+ ),
+ self._sockaddr,
+ )
return connection
@gen.coroutine
self.final_callback = None # type: ignore
self.io_loop.add_callback(final_callback, response)
- def _handle_exception(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[TracebackType]) -> bool:
+ def _handle_exception(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> bool:
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
value = HTTPStreamClosedError("Stream closed")
else:
value = value.real_error
- self._run_callback(HTTPResponse(self.request, 599, error=value,
- request_time=self.io_loop.time() - self.start_time,
- start_time=self.start_wall_time,
- ))
+ self._run_callback(
+ HTTPResponse(
+ self.request,
+ 599,
+ error=value,
+ request_time=self.io_loop.time() - self.start_time,
+ start_time=self.start_wall_time,
+ )
+ )
if hasattr(self, "stream"):
# TODO: this may cause a StreamClosedError to be raised
except HTTPStreamClosedError:
self._handle_exception(*sys.exc_info())
- def headers_received(self, first_line: Union[httputil.ResponseStartLine,
- httputil.RequestStartLine],
- headers: httputil.HTTPHeaders) -> None:
+ def headers_received(
+ self,
+ first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> None:
assert isinstance(first_line, httputil.ResponseStartLine)
if self.request.expect_100_continue and first_line.code == 100:
self._write_body(False)
if self.request.header_callback is not None:
# Reassemble the start line.
- self.request.header_callback('%s %s %s\r\n' % first_line)
+ self.request.header_callback("%s %s %s\r\n" % first_line)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
- self.request.header_callback('\r\n')
+ self.request.header_callback("\r\n")
def _should_follow_redirect(self) -> bool:
if self.request.follow_redirects:
assert self.request.max_redirects is not None
- return (self.code in (301, 302, 303, 307, 308) and
- self.request.max_redirects > 0)
+ return (
+ self.code in (301, 302, 303, 307, 308)
+ and self.request.max_redirects > 0
+ )
return False
def finish(self) -> None:
assert self.code is not None
- data = b''.join(self.chunks)
+ data = b"".join(self.chunks)
self._remove_timeout()
- original_request = getattr(self.request, "original_request",
- self.request)
+ original_request = getattr(self.request, "original_request", self.request)
if self._should_follow_redirect():
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
- new_request.url = urllib.parse.urljoin(self.request.url,
- self.headers["Location"])
+ new_request.url = urllib.parse.urljoin(
+ self.request.url, self.headers["Location"]
+ )
new_request.max_redirects = self.request.max_redirects - 1
del new_request.headers["Host"]
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
if self.code in (302, 303):
new_request.method = "GET"
new_request.body = None
- for h in ["Content-Length", "Content-Type",
- "Content-Encoding", "Transfer-Encoding"]:
+ for h in [
+ "Content-Length",
+ "Content-Type",
+ "Content-Encoding",
+ "Transfer-Encoding",
+ ]:
try:
del self.request.headers[h]
except KeyError:
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
- response = HTTPResponse(original_request,
- self.code, reason=getattr(self, 'reason', None),
- headers=self.headers,
- request_time=self.io_loop.time() - self.start_time,
- start_time=self.start_wall_time,
- buffer=buffer,
- effective_url=self.request.url)
+ response = HTTPResponse(
+ original_request,
+ self.code,
+ reason=getattr(self, "reason", None),
+ headers=self.headers,
+ request_time=self.io_loop.time() - self.start_time,
+ start_time=self.start_wall_time,
+ buffer=buffer,
+ effective_url=self.request.url,
+ )
self._run_callback(response)
self._on_end_request()
import typing
from typing import Generator, Any, Union, Dict, Tuple, List, Callable, Iterator
+
if typing.TYPE_CHECKING:
from typing import Optional, Set # noqa: F401
http://tools.ietf.org/html/rfc6555
"""
- def __init__(self, addrinfo: List[Tuple],
- connect: Callable[[socket.AddressFamily, Tuple],
- Tuple[IOStream, 'Future[IOStream]']]) -> None:
+
+ def __init__(
+ self,
+ addrinfo: List[Tuple],
+ connect: Callable[
+ [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"]
+ ],
+ ) -> None:
self.io_loop = IOLoop.current()
self.connect = connect
- self.future = Future() # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
+ self.future = (
+ Future()
+ ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
self.timeout = None # type: Optional[object]
self.connect_timeout = None # type: Optional[object]
self.last_error = None # type: Optional[Exception]
self.streams = set() # type: Set[IOStream]
@staticmethod
- def split(addrinfo: List[Tuple]) -> Tuple[List[Tuple[socket.AddressFamily, Tuple]],
- List[Tuple[socket.AddressFamily, Tuple]]]:
+ def split(
+ addrinfo: List[Tuple]
+ ) -> Tuple[
+ List[Tuple[socket.AddressFamily, Tuple]],
+ List[Tuple[socket.AddressFamily, Tuple]],
+ ]:
"""Partition the ``addrinfo`` list by address family.
Returns two lists. The first list contains the first entry from
return primary, secondary
def start(
- self, timeout: float=_INITIAL_CONNECT_TIMEOUT,
- connect_timeout: Union[float, datetime.timedelta]=None,
- ) -> 'Future[Tuple[socket.AddressFamily, Any, IOStream]]':
+ self,
+ timeout: float = _INITIAL_CONNECT_TIMEOUT,
+ connect_timeout: Union[float, datetime.timedelta] = None,
+ ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]":
self.try_connect(iter(self.primary_addrs))
self.set_timeout(timeout)
if connect_timeout is not None:
# might still be working. Send a final error on the future
# only when both queues are finished.
if self.remaining == 0 and not self.future.done():
- self.future.set_exception(self.last_error or
- IOError("connection failed"))
+ self.future.set_exception(
+ self.last_error or IOError("connection failed")
+ )
return
stream, future = self.connect(af, addr)
self.streams.add(stream)
future_add_done_callback(
- future, functools.partial(self.on_connect_done, addrs, af, addr))
+ future, functools.partial(self.on_connect_done, addrs, af, addr)
+ )
- def on_connect_done(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
- af: socket.AddressFamily, addr: Tuple,
- future: 'Future[IOStream]') -> None:
+ def on_connect_done(
+ self,
+ addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
+ af: socket.AddressFamily,
+ addr: Tuple,
+ future: "Future[IOStream]",
+ ) -> None:
self.remaining -= 1
try:
stream = future.result()
self.close_streams()
def set_timeout(self, timeout: float) -> None:
- self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
- self.on_timeout)
+ self.timeout = self.io_loop.add_timeout(
+ self.io_loop.time() + timeout, self.on_timeout
+ )
def on_timeout(self) -> None:
self.timeout = None
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
- def set_connect_timeout(self, connect_timeout: Union[float, datetime.timedelta]) -> None:
+ def set_connect_timeout(
+ self, connect_timeout: Union[float, datetime.timedelta]
+ ) -> None:
self.connect_timeout = self.io_loop.add_timeout(
- connect_timeout, self.on_connect_timeout)
+ connect_timeout, self.on_connect_timeout
+ )
def on_connect_timeout(self) -> None:
if not self.future.done():
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
- def __init__(self, resolver: Resolver=None) -> None:
+
+ def __init__(self, resolver: Resolver = None) -> None:
if resolver is not None:
self.resolver = resolver
self._own_resolver = False
self.resolver.close()
@gen.coroutine
- def connect(self, host: str, port: int, af: socket.AddressFamily=socket.AF_UNSPEC,
- ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
- max_buffer_size: int=None, source_ip: str=None, source_port: int=None,
- timeout: Union[float, datetime.timedelta]=None) -> Generator[Any, Any, IOStream]:
+ def connect(
+ self,
+ host: str,
+ port: int,
+ af: socket.AddressFamily = socket.AF_UNSPEC,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
+ max_buffer_size: int = None,
+ source_ip: str = None,
+ source_port: int = None,
+ timeout: Union[float, datetime.timedelta] = None,
+ ) -> Generator[Any, Any, IOStream]:
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
raise TypeError("Unsupported timeout %r" % timeout)
if timeout is not None:
addrinfo = yield gen.with_timeout(
- timeout, self.resolver.resolve(host, port, af))
+ timeout, self.resolver.resolve(host, port, af)
+ )
else:
addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo,
- functools.partial(self._create_stream, max_buffer_size,
- source_ip=source_ip, source_port=source_port)
+ functools.partial(
+ self._create_stream,
+ max_buffer_size,
+ source_ip=source_ip,
+ source_port=source_port,
+ ),
)
af, addr, stream = yield connector.start(connect_timeout=timeout)
# TODO: For better performance we could cache the (af, addr)
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
if timeout is not None:
- stream = yield gen.with_timeout(timeout, stream.start_tls(
- False, ssl_options=ssl_options, server_hostname=host))
+ stream = yield gen.with_timeout(
+ timeout,
+ stream.start_tls(
+ False, ssl_options=ssl_options, server_hostname=host
+ ),
+ )
else:
- stream = yield stream.start_tls(False, ssl_options=ssl_options,
- server_hostname=host)
+ stream = yield stream.start_tls(
+ False, ssl_options=ssl_options, server_hostname=host
+ )
return stream
- def _create_stream(self, max_buffer_size: int, af: socket.AddressFamily,
- addr: Tuple, source_ip: str=None,
- source_port: int=None) -> Tuple[IOStream, 'Future[IOStream]']:
+ def _create_stream(
+ self,
+ max_buffer_size: int,
+ af: socket.AddressFamily,
+ addr: Tuple,
+ source_ip: str = None,
+ source_port: int = None,
+ ) -> Tuple[IOStream, "Future[IOStream]"]:
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
source_port_bind = source_port if isinstance(source_port, int) else 0
if source_port_bind and not source_ip:
# User required a specific port, but did not specify
# a certain source IP, will bind to the default loopback.
- source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
+ source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1"
# Trying to use the same address family as the requested af socket:
# - 127.0.0.1 for IPv4
# - ::1 for IPv6
# Fail loudly if unable to use the IP/port.
raise
try:
- stream = IOStream(socket_obj,
- max_buffer_size=max_buffer_size)
+ stream = IOStream(socket_obj, max_buffer_size=max_buffer_size)
except socket.error as e:
fu = Future() # type: Future[IOStream]
fu.set_exception(e)
import typing
from typing import Union, Dict, Any, Iterable, Optional, Awaitable
+
if typing.TYPE_CHECKING:
from typing import Callable, List # noqa: F401
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
- def __init__(self, ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
- max_buffer_size: int=None, read_chunk_size: int=None) -> None:
+
+ def __init__(
+ self,
+ ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
+ max_buffer_size: int = None,
+ read_chunk_size: int = None,
+ ) -> None:
self.ssl_options = ssl_options
- self._sockets = {} # type: Dict[int, socket.socket]
+ self._sockets = {} # type: Dict[int, socket.socket]
self._handlers = {} # type: Dict[int, Callable[[], None]]
self._pending_sockets = [] # type: List[socket.socket]
self._started = False
# which seems like too much work
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
# Only certfile is required: it can contain both keys
- if 'certfile' not in self.ssl_options:
+ if "certfile" not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
- if not os.path.exists(self.ssl_options['certfile']):
- raise ValueError('certfile "%s" does not exist' %
- self.ssl_options['certfile'])
- if ('keyfile' in self.ssl_options and
- not os.path.exists(self.ssl_options['keyfile'])):
- raise ValueError('keyfile "%s" does not exist' %
- self.ssl_options['keyfile'])
-
- def listen(self, port: int, address: str="") -> None:
+ if not os.path.exists(self.ssl_options["certfile"]):
+ raise ValueError(
+ 'certfile "%s" does not exist' % self.ssl_options["certfile"]
+ )
+ if "keyfile" in self.ssl_options and not os.path.exists(
+ self.ssl_options["keyfile"]
+ ):
+ raise ValueError(
+ 'keyfile "%s" does not exist' % self.ssl_options["keyfile"]
+ )
+
+ def listen(self, port: int, address: str = "") -> None:
"""Starts accepting connections on the given port.
This method may be called more than once to listen on multiple ports.
for sock in sockets:
self._sockets[sock.fileno()] = sock
self._handlers[sock.fileno()] = add_accept_handler(
- sock, self._handle_connection)
+ sock, self._handle_connection
+ )
def add_socket(self, socket: socket.socket) -> None:
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
- def bind(self, port: int, address: str=None,
- family: socket.AddressFamily=socket.AF_UNSPEC,
- backlog: int=128, reuse_port: bool=False) -> None:
+ def bind(
+ self,
+ port: int,
+ address: str = None,
+ family: socket.AddressFamily = socket.AF_UNSPEC,
+ backlog: int = 128,
+ reuse_port: bool = False,
+ ) -> None:
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
.. versionchanged:: 4.4
Added the ``reuse_port`` argument.
"""
- sockets = bind_sockets(port, address=address, family=family,
- backlog=backlog, reuse_port=reuse_port)
+ sockets = bind_sockets(
+ port, address=address, family=family, backlog=backlog, reuse_port=reuse_port
+ )
if self._started:
self.add_sockets(sockets)
else:
self._pending_sockets.extend(sockets)
- def start(self, num_processes: Optional[int]=1) -> None:
+ def start(self, num_processes: Optional[int] = 1) -> None:
"""Starts this server in the `.IOLoop`.
By default, we run the server in this process and do not fork any
self._handlers.pop(fd)()
sock.close()
- def handle_stream(self, stream: IOStream, address: tuple) -> Optional[Awaitable[None]]:
+ def handle_stream(
+ self, stream: IOStream, address: tuple
+ ) -> Optional[Awaitable[None]]:
"""Override to handle a new `.IOStream` from an incoming connection.
This method may be a coroutine; if so any exceptions it raises
if self.ssl_options is not None:
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
try:
- connection = ssl_wrap_socket(connection,
- self.ssl_options,
- server_side=True,
- do_handshake_on_connect=False)
+ connection = ssl_wrap_socket(
+ connection,
+ self.ssl_options,
+ server_side=True,
+ do_handshake_on_connect=False,
+ )
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_EOF:
return connection.close()
stream = SSLIOStream(
connection,
max_buffer_size=self.max_buffer_size,
- read_chunk_size=self.read_chunk_size) # type: IOStream
+ read_chunk_size=self.read_chunk_size,
+ ) # type: IOStream
else:
- stream = IOStream(connection,
- max_buffer_size=self.max_buffer_size,
- read_chunk_size=self.read_chunk_size)
+ stream = IOStream(
+ connection,
+ max_buffer_size=self.max_buffer_size,
+ read_chunk_size=self.read_chunk_size,
+ )
future = self.handle_stream(stream, address)
if future is not None:
- IOLoop.current().add_future(gen.convert_yielded(future),
- lambda f: f.result())
+ IOLoop.current().add_future(
+ gen.convert_yielded(future), lambda f: f.result()
+ )
except Exception:
app_log.error("Error in connection callback", exc_info=True)
from tornado.log import app_log
from tornado.util import ObjectDict, exec_in, unicode_type
-from typing import Any, Union, Callable, List, Dict, Iterable, Optional, TextIO, ContextManager
+from typing import (
+ Any,
+ Union,
+ Callable,
+ List,
+ Dict,
+ Iterable,
+ Optional,
+ TextIO,
+ ContextManager,
+)
import typing
+
if typing.TYPE_CHECKING:
from typing import Tuple # noqa: F401
.. versionadded:: 4.3
"""
- if mode == 'all':
+ if mode == "all":
return text
- elif mode == 'single':
+ elif mode == "single":
text = re.sub(r"([\t ]+)", " ", text)
text = re.sub(r"(\s*\n\s*)", "\n", text)
return text
- elif mode == 'oneline':
+ elif mode == "oneline":
return re.sub(r"(\s+)", " ", text)
else:
raise Exception("invalid whitespace mode %s" % mode)
We compile into Python from the given template_string. You can generate
the template from variables with generate().
"""
+
# note that the constructor's signature is not extracted with
# autodoc because _UNSET looks like garbage. When changing
# this signature update website/sphinx/template.rst too.
- def __init__(self, template_string: Union[str, bytes], name: str="<string>",
- loader: 'BaseLoader'=None, compress_whitespace: Union[bool, _UnsetMarker]=_UNSET,
- autoescape: Union[str, _UnsetMarker]=_UNSET,
- whitespace: str=None) -> None:
+ def __init__(
+ self,
+ template_string: Union[str, bytes],
+ name: str = "<string>",
+ loader: "BaseLoader" = None,
+ compress_whitespace: Union[bool, _UnsetMarker] = _UNSET,
+ autoescape: Union[str, _UnsetMarker] = _UNSET,
+ whitespace: str = None,
+ ) -> None:
"""Construct a Template.
:arg str template_string: the contents of the template file.
whitespace = "all"
# Validate the whitespace setting.
assert whitespace is not None
- filter_whitespace(whitespace, '')
+ filter_whitespace(whitespace, "")
if not isinstance(autoescape, _UnsetMarker):
self.autoescape = autoescape # type: Optional[str]
self.autoescape = _DEFAULT_AUTOESCAPE
self.namespace = loader.namespace if loader else {}
- reader = _TemplateReader(name, escape.native_str(template_string),
- whitespace)
+ reader = _TemplateReader(name, escape.native_str(template_string), whitespace)
self.file = _File(self, _parse(reader, self))
self.code = self._generate_python(loader)
self.loader = loader
# from being applied to the generated code.
self.compiled = compile(
escape.to_unicode(self.code),
- "%s.generated.py" % self.name.replace('.', '_'),
- "exec", dont_inherit=True)
+ "%s.generated.py" % self.name.replace(".", "_"),
+ "exec",
+ dont_inherit=True,
+ )
except Exception:
formatted_code = _format_code(self.code).rstrip()
app_log.error("%s code:\n%s", self.name, formatted_code)
"_tt_string_types": (unicode_type, bytes),
# __name__ and __loader__ allow the traceback mechanism to find
# the generated source code.
- "__name__": self.name.replace('.', '_'),
+ "__name__": self.name.replace(".", "_"),
"__loader__": ObjectDict(get_source=lambda name: self.code),
}
namespace.update(self.namespace)
linecache.clearcache()
return execute()
- def _generate_python(self, loader: Optional['BaseLoader']) -> str:
+ def _generate_python(self, loader: Optional["BaseLoader"]) -> str:
buffer = StringIO()
try:
# named_blocks maps from names to _NamedBlock objects
ancestors.reverse()
for ancestor in ancestors:
ancestor.find_named_blocks(loader, named_blocks)
- writer = _CodeWriter(buffer, named_blocks, loader,
- ancestors[0].template)
+ writer = _CodeWriter(buffer, named_blocks, loader, ancestors[0].template)
ancestors[0].generate(writer)
return buffer.getvalue()
finally:
buffer.close()
- def _get_ancestors(self, loader: Optional['BaseLoader']) -> List['_File']:
+ def _get_ancestors(self, loader: Optional["BaseLoader"]) -> List["_File"]:
ancestors = [self.file]
for chunk in self.file.body.chunks:
if isinstance(chunk, _ExtendsBlock):
if not loader:
- raise ParseError("{% extends %} block found, but no "
- "template loader")
+ raise ParseError(
+ "{% extends %} block found, but no " "template loader"
+ )
template = loader.load(chunk.name, self.name)
ancestors.extend(template._get_ancestors(loader))
return ancestors
``{% extends %}`` and ``{% include %}``. The loader caches all
templates after they are loaded the first time.
"""
- def __init__(self, autoescape: str=_DEFAULT_AUTOESCAPE,
- namespace: Dict[str, Any]=None,
- whitespace: str=None) -> None:
+
+ def __init__(
+ self,
+ autoescape: str = _DEFAULT_AUTOESCAPE,
+ namespace: Dict[str, Any] = None,
+ whitespace: str = None,
+ ) -> None:
"""Construct a template loader.
:arg str autoescape: The name of a function in the template
with self.lock:
self.templates = {}
- def resolve_path(self, name: str, parent_path: str=None) -> str:
+ def resolve_path(self, name: str, parent_path: str = None) -> str:
"""Converts a possibly-relative path to absolute (used internally)."""
raise NotImplementedError()
- def load(self, name: str, parent_path: str=None) -> Template:
+ def load(self, name: str, parent_path: str = None) -> Template:
"""Loads a template."""
name = self.resolve_path(name, parent_path=parent_path)
with self.lock:
class Loader(BaseLoader):
"""A template loader that loads from a single root directory.
"""
+
def __init__(self, root_directory: str, **kwargs: Any) -> None:
super(Loader, self).__init__(**kwargs)
self.root = os.path.abspath(root_directory)
- def resolve_path(self, name: str, parent_path: str=None) -> str:
- if parent_path and not parent_path.startswith("<") and \
- not parent_path.startswith("/") and \
- not name.startswith("/"):
+ def resolve_path(self, name: str, parent_path: str = None) -> str:
+ if (
+ parent_path
+ and not parent_path.startswith("<")
+ and not parent_path.startswith("/")
+ and not name.startswith("/")
+ ):
current_path = os.path.join(self.root, parent_path)
file_dir = os.path.dirname(os.path.abspath(current_path))
relative_path = os.path.abspath(os.path.join(file_dir, name))
if relative_path.startswith(self.root):
- name = relative_path[len(self.root) + 1:]
+ name = relative_path[len(self.root) + 1 :]
return name
def _create_template(self, name: str) -> Template:
class DictLoader(BaseLoader):
"""A template loader that loads from a dictionary."""
+
def __init__(self, dict: Dict[str, str], **kwargs: Any) -> None:
super(DictLoader, self).__init__(**kwargs)
self.dict = dict
- def resolve_path(self, name: str, parent_path: str=None) -> str:
- if parent_path and not parent_path.startswith("<") and \
- not parent_path.startswith("/") and \
- not name.startswith("/"):
+ def resolve_path(self, name: str, parent_path: str = None) -> str:
+ if (
+ parent_path
+ and not parent_path.startswith("<")
+ and not parent_path.startswith("/")
+ and not name.startswith("/")
+ ):
file_dir = posixpath.dirname(parent_path)
name = posixpath.normpath(posixpath.join(file_dir, name))
return name
class _Node(object):
- def each_child(self) -> Iterable['_Node']:
+ def each_child(self) -> Iterable["_Node"]:
return ()
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
raise NotImplementedError()
- def find_named_blocks(self, loader: Optional[BaseLoader],
- named_blocks: Dict[str, '_NamedBlock']) -> None:
+ def find_named_blocks(
+ self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"]
+ ) -> None:
for child in self.each_child():
child.find_named_blocks(loader, named_blocks)
class _File(_Node):
- def __init__(self, template: Template, body: '_ChunkList') -> None:
+ def __init__(self, template: Template, body: "_ChunkList") -> None:
self.template = template
self.body = body
self.line = 0
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line("def _tt_execute():", self.line)
with writer.indent():
writer.write_line("_tt_buffer = []", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
- def each_child(self) -> Iterable['_Node']:
+ def each_child(self) -> Iterable["_Node"]:
return (self.body,)
def __init__(self, chunks: List[_Node]) -> None:
self.chunks = chunks
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
for chunk in self.chunks:
chunk.generate(writer)
- def each_child(self) -> Iterable['_Node']:
+ def each_child(self) -> Iterable["_Node"]:
return self.chunks
self.template = template
self.line = line
- def each_child(self) -> Iterable['_Node']:
+ def each_child(self) -> Iterable["_Node"]:
return (self.body,)
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
block = writer.named_blocks[self.name]
with writer.include(block.template, self.line):
block.body.generate(writer)
- def find_named_blocks(self, loader: Optional[BaseLoader],
- named_blocks: Dict[str, '_NamedBlock']) -> None:
+ def find_named_blocks(
+ self, loader: Optional[BaseLoader], named_blocks: Dict[str, "_NamedBlock"]
+ ) -> None:
named_blocks[self.name] = self
_Node.find_named_blocks(self, loader, named_blocks)
class _IncludeBlock(_Node):
- def __init__(self, name: str, reader: '_TemplateReader', line: int) -> None:
+ def __init__(self, name: str, reader: "_TemplateReader", line: int) -> None:
self.name = name
self.template_name = reader.name
self.line = line
- def find_named_blocks(self, loader: Optional[BaseLoader],
- named_blocks: Dict[str, _NamedBlock]) -> None:
+ def find_named_blocks(
+ self, loader: Optional[BaseLoader], named_blocks: Dict[str, _NamedBlock]
+ ) -> None:
assert loader is not None
included = loader.load(self.name, self.template_name)
included.file.find_named_blocks(loader, named_blocks)
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
assert writer.loader is not None
included = writer.loader.load(self.name, self.template_name)
with writer.include(included, self.line):
self.line = line
self.body = body
- def each_child(self) -> Iterable['_Node']:
+ def each_child(self) -> Iterable["_Node"]:
return (self.body,)
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
method_name = "_tt_apply%d" % writer.apply_counter
writer.apply_counter += 1
writer.write_line("def %s():" % method_name, self.line)
writer.write_line("_tt_append = _tt_buffer.append", self.line)
self.body.generate(writer)
writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
- writer.write_line("_tt_append(_tt_utf8(%s(%s())))" % (
- self.method, method_name), self.line)
+ writer.write_line(
+ "_tt_append(_tt_utf8(%s(%s())))" % (self.method, method_name), self.line
+ )
class _ControlBlock(_Node):
def each_child(self) -> Iterable[_Node]:
return (self.body,)
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line("%s:" % self.statement, self.line)
with writer.indent():
self.body.generate(writer)
self.statement = statement
self.line = line
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
# In case the previous block was empty
writer.write_line("pass", self.line)
writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
self.statement = statement
self.line = line
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line(self.statement, self.line)
class _Expression(_Node):
- def __init__(self, expression: str, line: int, raw: bool=False) -> None:
+ def __init__(self, expression: str, line: int, raw: bool = False) -> None:
self.expression = expression
self.line = line
self.raw = raw
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
writer.write_line("_tt_tmp = %s" % self.expression, self.line)
- writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
- " _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
+ writer.write_line(
+ "if isinstance(_tt_tmp, _tt_string_types):" " _tt_tmp = _tt_utf8(_tt_tmp)",
+ self.line,
+ )
writer.write_line("else: _tt_tmp = _tt_utf8(str(_tt_tmp))", self.line)
if not self.raw and writer.current_template.autoescape is not None:
# In python3 functions like xhtml_escape return unicode,
# so we have to convert to utf8 again.
- writer.write_line("_tt_tmp = _tt_utf8(%s(_tt_tmp))" %
- writer.current_template.autoescape, self.line)
+ writer.write_line(
+ "_tt_tmp = _tt_utf8(%s(_tt_tmp))" % writer.current_template.autoescape,
+ self.line,
+ )
writer.write_line("_tt_append(_tt_tmp)", self.line)
class _Module(_Expression):
def __init__(self, expression: str, line: int) -> None:
- super(_Module, self).__init__("_tt_modules." + expression, line,
- raw=True)
+ super(_Module, self).__init__("_tt_modules." + expression, line, raw=True)
class _Text(_Node):
self.line = line
self.whitespace = whitespace
- def generate(self, writer: '_CodeWriter') -> None:
+ def generate(self, writer: "_CodeWriter") -> None:
value = self.value
# Compress whitespace if requested, with a crude heuristic to avoid
value = filter_whitespace(self.whitespace, value)
if value:
- writer.write_line('_tt_append(%r)' % escape.utf8(value), self.line)
+ writer.write_line("_tt_append(%r)" % escape.utf8(value), self.line)
class ParseError(Exception):
.. versionchanged:: 4.3
Added ``filename`` and ``lineno`` attributes.
"""
- def __init__(self, message: str, filename: str=None, lineno: int=0) -> None:
+
+ def __init__(self, message: str, filename: str = None, lineno: int = 0) -> None:
self.message = message
# The names "filename" and "lineno" are chosen for consistency
# with python SyntaxError.
self.lineno = lineno
def __str__(self) -> str:
- return '%s at %s:%d' % (self.message, self.filename, self.lineno)
+ return "%s at %s:%d" % (self.message, self.filename, self.lineno)
class _CodeWriter(object):
- def __init__(self, file: TextIO, named_blocks: Dict[str, _NamedBlock],
- loader: Optional[BaseLoader], current_template: Template) -> None:
+ def __init__(
+ self,
+ file: TextIO,
+ named_blocks: Dict[str, _NamedBlock],
+ loader: Optional[BaseLoader],
+ current_template: Template,
+ ) -> None:
self.file = file
self.named_blocks = named_blocks
self.loader = loader
def indent(self) -> ContextManager:
class Indenter(object):
- def __enter__(_) -> '_CodeWriter':
+ def __enter__(_) -> "_CodeWriter":
self._indent += 1
return self
self.current_template = template
class IncludeTemplate(object):
- def __enter__(_) -> '_CodeWriter':
+ def __enter__(_) -> "_CodeWriter":
return self
def __exit__(_, *args: Any) -> None:
return IncludeTemplate()
- def write_line(self, line: str, line_number: int, indent: int=None) -> None:
+ def write_line(self, line: str, line_number: int, indent: int = None) -> None:
if indent is None:
indent = self._indent
- line_comment = ' # %s:%d' % (self.current_template.name, line_number)
+ line_comment = " # %s:%d" % (self.current_template.name, line_number)
if self.include_stack:
- ancestors = ["%s:%d" % (tmpl.name, lineno)
- for (tmpl, lineno) in self.include_stack]
- line_comment += ' (via %s)' % ', '.join(reversed(ancestors))
+ ancestors = [
+ "%s:%d" % (tmpl.name, lineno) for (tmpl, lineno) in self.include_stack
+ ]
+ line_comment += " (via %s)" % ", ".join(reversed(ancestors))
print(" " * indent + line + line_comment, file=self.file)
self.line = 1
self.pos = 0
- def find(self, needle: str, start: int=0, end: int=None) -> int:
+ def find(self, needle: str, start: int = 0, end: int = None) -> int:
assert start >= 0, start
pos = self.pos
start += pos
index -= pos
return index
- def consume(self, count: int=None) -> str:
+ def consume(self, count: int = None) -> str:
if count is None:
count = len(self.text) - self.pos
newpos = self.pos + count
self.line += self.text.count("\n", self.pos, newpos)
- s = self.text[self.pos:newpos]
+ s = self.text[self.pos : newpos]
self.pos = newpos
return s
return self.text[self.pos + key]
def __str__(self) -> str:
- return self.text[self.pos:]
+ return self.text[self.pos :]
def raise_parse_error(self, msg: str) -> None:
raise ParseError(msg, self.name, self.line)
return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
-def _parse(reader: _TemplateReader, template: Template,
- in_block: str=None, in_loop: str=None) -> _ChunkList:
+def _parse(
+ reader: _TemplateReader,
+ template: Template,
+ in_block: str = None,
+ in_loop: str = None,
+) -> _ChunkList:
body = _ChunkList([])
while True:
# Find next template directive
# EOF
if in_block:
reader.raise_parse_error(
- "Missing {%% end %%} block for %s" % in_block)
- body.chunks.append(_Text(reader.consume(), reader.line,
- reader.whitespace))
+ "Missing {%% end %%} block for %s" % in_block
+ )
+ body.chunks.append(
+ _Text(reader.consume(), reader.line, reader.whitespace)
+ )
return body
# If the first curly brace is not the start of a special token,
# start searching from the character after it
# When there are more than 2 curlies in a row, use the
# innermost ones. This is useful when generating languages
# like latex where curlies are also meaningful
- if (curly + 2 < reader.remaining() and
- reader[curly + 1] == '{' and reader[curly + 2] == '{'):
+ if (
+ curly + 2 < reader.remaining()
+ and reader[curly + 1] == "{"
+ and reader[curly + 2] == "{"
+ ):
curly += 1
continue
break
# Append any text before the special token
if curly > 0:
cons = reader.consume(curly)
- body.chunks.append(_Text(cons, reader.line,
- reader.whitespace))
+ body.chunks.append(_Text(cons, reader.line, reader.whitespace))
start_brace = reader.consume(2)
line = reader.line
# which also use double braces.
if reader.remaining() and reader[0] == "!":
reader.consume(1)
- body.chunks.append(_Text(start_brace, line,
- reader.whitespace))
+ body.chunks.append(_Text(start_brace, line, reader.whitespace))
continue
# Comment
allowed_parents = intermediate_blocks.get(operator)
if allowed_parents is not None:
if not in_block:
- reader.raise_parse_error("%s outside %s block" %
- (operator, allowed_parents))
+ reader.raise_parse_error(
+ "%s outside %s block" % (operator, allowed_parents)
+ )
if in_block not in allowed_parents:
reader.raise_parse_error(
- "%s block cannot be attached to %s block" %
- (operator, in_block))
+ "%s block cannot be attached to %s block" % (operator, in_block)
+ )
body.chunks.append(_IntermediateControlBlock(contents, line))
continue
reader.raise_parse_error("Extra {% end %} block")
return body
- elif operator in ("extends", "include", "set", "import", "from",
- "comment", "autoescape", "whitespace", "raw",
- "module"):
+ elif operator in (
+ "extends",
+ "include",
+ "set",
+ "import",
+ "from",
+ "comment",
+ "autoescape",
+ "whitespace",
+ "raw",
+ "module",
+ ):
if operator == "comment":
continue
if operator == "extends":
elif operator == "whitespace":
mode = suffix.strip()
# Validate the selected mode
- filter_whitespace(mode, '')
+ filter_whitespace(mode, "")
reader.whitespace = mode
continue
elif operator == "raw":
elif operator in ("break", "continue"):
if not in_loop:
- reader.raise_parse_error("%s outside %s block" %
- (operator, set(["for", "while"])))
+ reader.raise_parse_error(
+ "%s outside %s block" % (operator, set(["for", "while"]))
+ )
body.chunks.append(_Statement(contents, line))
continue
from concurrent.futures import ThreadPoolExecutor
from tornado import gen
from tornado.ioloop import IOLoop
-from tornado.platform.asyncio import AsyncIOLoop, to_asyncio_future, AnyThreadEventLoopPolicy
+from tornado.platform.asyncio import (
+ AsyncIOLoop,
+ to_asyncio_future,
+ AnyThreadEventLoopPolicy,
+)
from tornado.testing import AsyncTestCase, gen_test
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in ensure_future,
# which was introduced during Python 3.4, deprecating the prior "async".
- if hasattr(asyncio, 'ensure_future'):
+ if hasattr(asyncio, "ensure_future"):
ensure_future = asyncio.ensure_future
else:
# async is a reserved word in Python 3.7
- ensure_future = getattr(asyncio, 'async')
+ ensure_future = getattr(asyncio, "async")
x = yield ensure_future(
- asyncio.get_event_loop().run_in_executor(None, lambda: 42))
+ asyncio.get_event_loop().run_in_executor(None, lambda: 42)
+ )
self.assertEqual(x, 42)
@gen_test
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
+
result = yield f()
self.assertEqual(result, 42)
return await to_asyncio_future(native_coroutine_without_adapter())
# Tornado supports native coroutines both with and without adapters
- self.assertEqual(
- self.io_loop.run_sync(native_coroutine_without_adapter),
- 42)
- self.assertEqual(
- self.io_loop.run_sync(native_coroutine_with_adapter),
- 42)
- self.assertEqual(
- self.io_loop.run_sync(native_coroutine_with_adapter2),
- 42)
+ self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42)
+ self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42)
+ self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures (which our Future is since 5.0).
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
- native_coroutine_without_adapter()),
- 42)
+ native_coroutine_without_adapter()
+ ),
+ 42,
+ )
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
- native_coroutine_with_adapter()),
- 42)
+ native_coroutine_with_adapter()
+ ),
+ 42,
+ )
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
- native_coroutine_with_adapter2()),
- 42)
+ native_coroutine_with_adapter2()
+ ),
+ 42,
+ )
class LeakTest(unittest.TestCase):
loop = asyncio.get_event_loop()
loop.close()
return loop
+
future = self.executor.submit(get_and_close_event_loop)
return future.result()
def run_policy_test(self, accessor, expected_type):
# With the default policy, non-main threads don't get an event
# loop.
- self.assertRaises((RuntimeError, AssertionError),
- self.executor.submit(accessor).result)
+ self.assertRaises(
+ (RuntimeError, AssertionError), self.executor.submit(accessor).result
+ )
# Set the policy and we can get a loop.
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
- self.assertIsInstance(
- self.executor.submit(accessor).result(),
- expected_type)
+ self.assertIsInstance(self.executor.submit(accessor).result(), expected_type)
# Clean up to silence leak warnings. Always use asyncio since
# IOLoop doesn't (currently) close the underlying loop.
self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
import unittest
from tornado.auth import (
- OpenIdMixin, OAuthMixin, OAuth2Mixin,
- GoogleOAuth2Mixin, FacebookGraphMixin, TwitterMixin,
+ OpenIdMixin,
+ OAuthMixin,
+ OAuth2Mixin,
+ GoogleOAuth2Mixin,
+ FacebookGraphMixin,
+ TwitterMixin,
)
from tornado.escape import json_decode
from tornado import gen
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
- self._OPENID_ENDPOINT = test.get_url('/openid/server/authenticate')
+ self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate")
@gen.coroutine
def get(self):
- if self.get_argument('openid.mode', None):
- user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
+ if self.get_argument("openid.mode", None):
+ user = yield self.get_authenticated_user(
+ http_client=self.settings["http_client"]
+ )
if user is None:
raise Exception("user is None")
self.finish(user)
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
- if self.get_argument('openid.mode') != 'check_authentication':
+ if self.get_argument("openid.mode") != "check_authentication":
raise Exception("incorrect openid.mode %r")
- self.write('is_valid:true')
+ self.write("is_valid:true")
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
- self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/oauth1/server/access_token')
+ self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token")
def _oauth_consumer_token(self):
- return dict(key='asdf', secret='qwer')
+ return dict(key="asdf", secret="qwer")
@gen.coroutine
def get(self):
- if self.get_argument('oauth_token', None):
- user = yield self.get_authenticated_user(http_client=self.settings['http_client'])
+ if self.get_argument("oauth_token", None):
+ user = yield self.get_authenticated_user(
+ http_client=self.settings["http_client"]
+ )
if user is None:
raise Exception("user is None")
self.finish(user)
return
- yield self.authorize_redirect(http_client=self.settings['http_client'])
+ yield self.authorize_redirect(http_client=self.settings["http_client"])
@gen.coroutine
def _oauth_get_user_future(self, access_token):
- if self.get_argument('fail_in_get_user', None):
+ if self.get_argument("fail_in_get_user", None):
raise Exception("failing in get_user")
- if access_token != dict(key='uiop', secret='5678'):
+ if access_token != dict(key="uiop", secret="5678"):
raise Exception("incorrect access token %r" % access_token)
- return dict(email='foo@example.com')
+ return dict(email="foo@example.com")
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
+
@gen.coroutine
def get(self):
- if self.get_argument('oauth_token', None):
+ if self.get_argument("oauth_token", None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
- return dict(key='asdf', secret='qwer')
+ return dict(key="asdf", secret="qwer")
def get(self):
params = self._oauth_request_parameters(
- 'http://www.example.com/api/asdf',
- dict(key='uiop', secret='5678'),
- parameters=dict(foo='bar'))
+ "http://www.example.com/api/asdf",
+ dict(key="uiop", secret="5678"),
+ parameters=dict(foo="bar"),
+ )
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
- self.write('oauth_token=zxcv&oauth_token_secret=1234')
+ self.write("oauth_token=zxcv&oauth_token_secret=1234")
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
- self.write('oauth_token=uiop&oauth_token_secret=5678')
+ self.write("oauth_token=uiop&oauth_token_secret=5678")
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth2/server/authorize')
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize")
def get(self):
res = self.authorize_redirect()
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
def initialize(self, test):
- self._OAUTH_AUTHORIZE_URL = test.get_url('/facebook/server/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/facebook/server/access_token')
- self._FACEBOOK_BASE_URL = test.get_url('/facebook/server')
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token")
+ self._FACEBOOK_BASE_URL = test.get_url("/facebook/server")
@gen.coroutine
def get(self):
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
- code=self.get_argument("code"))
+ code=self.get_argument("code"),
+ )
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
- extra_params={"scope": "read_stream,offline_access"})
+ extra_params={"scope": "read_stream,offline_access"},
+ )
class FacebookServerAccessTokenHandler(RequestHandler):
class FacebookServerMeHandler(RequestHandler):
def get(self):
- self.write('{}')
+ self.write("{}")
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
- self._OAUTH_REQUEST_TOKEN_URL = test.get_url('/oauth1/server/request_token')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/twitter/server/access_token')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/oauth1/server/authorize')
- self._OAUTH_AUTHENTICATE_URL = test.get_url('/twitter/server/authenticate')
- self._TWITTER_BASE_URL = test.get_url('/twitter/api')
+ self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token")
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
+ self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate")
+ self._TWITTER_BASE_URL = test.get_url("/twitter/api")
def get_auth_http_client(self):
- return self.settings['http_client']
+ return self.settings["http_client"]
class TwitterClientLoginHandler(TwitterClientHandler):
# cheating with a hard-coded access token.
try:
response = yield self.twitter_request(
- '/users/show/%s' % self.get_argument('name'),
- access_token=dict(key='hjkl', secret='vbnm'))
+ "/users/show/%s" % self.get_argument("name"),
+ access_token=dict(key="hjkl", secret="vbnm"),
+ )
except HTTPClientError:
# TODO(bdarnell): Should we catch HTTP errors and
# transform some of them (like 403s) into AuthError?
self.set_status(500)
- self.finish('error from twitter request')
+ self.finish("error from twitter request")
else:
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
- self.write('oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo')
+ self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo")
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
- if screen_name == 'error':
+ if screen_name == "error":
raise HTTPError(500)
- assert 'oauth_nonce' in self.request.arguments
- assert 'oauth_timestamp' in self.request.arguments
- assert 'oauth_signature' in self.request.arguments
- assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
- assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
- assert self.get_argument('oauth_version') == '1.0'
- assert self.get_argument('oauth_token') == 'hjkl'
+ assert "oauth_nonce" in self.request.arguments
+ assert "oauth_timestamp" in self.request.arguments
+ assert "oauth_signature" in self.request.arguments
+ assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
+ assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
+ assert self.get_argument("oauth_version") == "1.0"
+ assert self.get_argument("oauth_token") == "hjkl"
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
- assert 'oauth_nonce' in self.request.arguments
- assert 'oauth_timestamp' in self.request.arguments
- assert 'oauth_signature' in self.request.arguments
- assert self.get_argument('oauth_consumer_key') == 'test_twitter_consumer_key'
- assert self.get_argument('oauth_signature_method') == 'HMAC-SHA1'
- assert self.get_argument('oauth_version') == '1.0'
- assert self.get_argument('oauth_token') == 'hjkl'
- self.write(dict(screen_name='foo', name='Foo'))
+ assert "oauth_nonce" in self.request.arguments
+ assert "oauth_timestamp" in self.request.arguments
+ assert "oauth_signature" in self.request.arguments
+ assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
+ assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
+ assert self.get_argument("oauth_version") == "1.0"
+ assert self.get_argument("oauth_token") == "hjkl"
+ self.write(dict(screen_name="foo", name="Foo"))
class AuthTest(AsyncHTTPTestCase):
return Application(
[
# test endpoints
- ('/openid/client/login', OpenIdClientLoginHandler, dict(test=self)),
- ('/oauth10/client/login', OAuth1ClientLoginHandler,
- dict(test=self, version='1.0')),
- ('/oauth10/client/request_params',
- OAuth1ClientRequestParametersHandler,
- dict(version='1.0')),
- ('/oauth10a/client/login', OAuth1ClientLoginHandler,
- dict(test=self, version='1.0a')),
- ('/oauth10a/client/login_coroutine',
- OAuth1ClientLoginCoroutineHandler,
- dict(test=self, version='1.0a')),
- ('/oauth10a/client/request_params',
- OAuth1ClientRequestParametersHandler,
- dict(version='1.0a')),
- ('/oauth2/client/login', OAuth2ClientLoginHandler, dict(test=self)),
-
- ('/facebook/client/login', FacebookClientLoginHandler, dict(test=self)),
-
- ('/twitter/client/login', TwitterClientLoginHandler, dict(test=self)),
- ('/twitter/client/authenticate', TwitterClientAuthenticateHandler, dict(test=self)),
- ('/twitter/client/login_gen_coroutine',
- TwitterClientLoginGenCoroutineHandler, dict(test=self)),
- ('/twitter/client/show_user',
- TwitterClientShowUserHandler, dict(test=self)),
-
+ ("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)),
+ (
+ "/oauth10/client/login",
+ OAuth1ClientLoginHandler,
+ dict(test=self, version="1.0"),
+ ),
+ (
+ "/oauth10/client/request_params",
+ OAuth1ClientRequestParametersHandler,
+ dict(version="1.0"),
+ ),
+ (
+ "/oauth10a/client/login",
+ OAuth1ClientLoginHandler,
+ dict(test=self, version="1.0a"),
+ ),
+ (
+ "/oauth10a/client/login_coroutine",
+ OAuth1ClientLoginCoroutineHandler,
+ dict(test=self, version="1.0a"),
+ ),
+ (
+ "/oauth10a/client/request_params",
+ OAuth1ClientRequestParametersHandler,
+ dict(version="1.0a"),
+ ),
+ ("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)),
+ ("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)),
+ ("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)),
+ (
+ "/twitter/client/authenticate",
+ TwitterClientAuthenticateHandler,
+ dict(test=self),
+ ),
+ (
+ "/twitter/client/login_gen_coroutine",
+ TwitterClientLoginGenCoroutineHandler,
+ dict(test=self),
+ ),
+ (
+ "/twitter/client/show_user",
+ TwitterClientShowUserHandler,
+ dict(test=self),
+ ),
# simulated servers
- ('/openid/server/authenticate', OpenIdServerAuthenticateHandler),
- ('/oauth1/server/request_token', OAuth1ServerRequestTokenHandler),
- ('/oauth1/server/access_token', OAuth1ServerAccessTokenHandler),
-
- ('/facebook/server/access_token', FacebookServerAccessTokenHandler),
- ('/facebook/server/me', FacebookServerMeHandler),
- ('/twitter/server/access_token', TwitterServerAccessTokenHandler),
- (r'/twitter/api/users/show/(.*)\.json', TwitterServerShowUserHandler),
- (r'/twitter/api/account/verify_credentials\.json',
- TwitterServerVerifyCredentialsHandler),
+ ("/openid/server/authenticate", OpenIdServerAuthenticateHandler),
+ ("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler),
+ ("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler),
+ ("/facebook/server/access_token", FacebookServerAccessTokenHandler),
+ ("/facebook/server/me", FacebookServerMeHandler),
+ ("/twitter/server/access_token", TwitterServerAccessTokenHandler),
+ (r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler),
+ (
+ r"/twitter/api/account/verify_credentials\.json",
+ TwitterServerVerifyCredentialsHandler,
+ ),
],
http_client=self.http_client,
- twitter_consumer_key='test_twitter_consumer_key',
- twitter_consumer_secret='test_twitter_consumer_secret',
- facebook_api_key='test_facebook_api_key',
- facebook_secret='test_facebook_secret')
+ twitter_consumer_key="test_twitter_consumer_key",
+ twitter_consumer_secret="test_twitter_consumer_secret",
+ facebook_api_key="test_facebook_api_key",
+ facebook_secret="test_facebook_secret",
+ )
def test_openid_redirect(self):
- response = self.fetch('/openid/client/login', follow_redirects=False)
+ response = self.fetch("/openid/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(
- '/openid/server/authenticate?' in response.headers['Location'])
+ self.assertTrue("/openid/server/authenticate?" in response.headers["Location"])
def test_openid_get_user(self):
- response = self.fetch('/openid/client/login?openid.mode=blah'
- '&openid.ns.ax=http://openid.net/srv/ax/1.0'
- '&openid.ax.type.email=http://axschema.org/contact/email'
- '&openid.ax.value.email=foo@example.com')
+ response = self.fetch(
+ "/openid/client/login?openid.mode=blah"
+ "&openid.ns.ax=http://openid.net/srv/ax/1.0"
+ "&openid.ax.type.email=http://axschema.org/contact/email"
+ "&openid.ax.value.email=foo@example.com"
+ )
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect(self):
- response = self.fetch('/oauth10/client/login', follow_redirects=False)
+ response = self.fetch("/oauth10/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/oauth1/server/authorize?oauth_token=zxcv"
+ )
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
def test_oauth10_get_user(self):
response = self.fetch(
- '/oauth10/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/oauth10/client/login?oauth_token=zxcv",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['email'], 'foo@example.com')
- self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+ self.assertEqual(parsed["email"], "foo@example.com")
+ self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10_request_parameters(self):
- response = self.fetch('/oauth10/client/request_params')
+ response = self.fetch("/oauth10/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
- self.assertEqual(parsed['oauth_token'], 'uiop')
- self.assertTrue('oauth_nonce' in parsed)
- self.assertTrue('oauth_signature' in parsed)
+ self.assertEqual(parsed["oauth_consumer_key"], "asdf")
+ self.assertEqual(parsed["oauth_token"], "uiop")
+ self.assertTrue("oauth_nonce" in parsed)
+ self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_redirect(self):
- response = self.fetch('/oauth10a/client/login', follow_redirects=False)
+ response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/oauth1/server/authorize?oauth_token=zxcv"
+ )
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
- @unittest.skipIf(mock is None, 'mock package not present')
+ @unittest.skipIf(mock is None, "mock package not present")
def test_oauth10a_redirect_error(self):
- with mock.patch.object(OAuth1ServerRequestTokenHandler, 'get') as get:
+ with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get:
get.side_effect = Exception("boom")
with ExpectLog(app_log, "Uncaught exception"):
- response = self.fetch('/oauth10a/client/login', follow_redirects=False)
+ response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 500)
def test_oauth10a_get_user(self):
response = self.fetch(
- '/oauth10a/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/oauth10a/client/login?oauth_token=zxcv",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['email'], 'foo@example.com')
- self.assertEqual(parsed['access_token'], dict(key='uiop', secret='5678'))
+ self.assertEqual(parsed["email"], "foo@example.com")
+ self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10a_request_parameters(self):
- response = self.fetch('/oauth10a/client/request_params')
+ response = self.fetch("/oauth10a/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed['oauth_consumer_key'], 'asdf')
- self.assertEqual(parsed['oauth_token'], 'uiop')
- self.assertTrue('oauth_nonce' in parsed)
- self.assertTrue('oauth_signature' in parsed)
+ self.assertEqual(parsed["oauth_consumer_key"], "asdf")
+ self.assertEqual(parsed["oauth_token"], "uiop")
+ self.assertTrue("oauth_nonce" in parsed)
+ self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
- '/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
- response = self.fetch('/oauth2/client/login', follow_redirects=False)
+ response = self.fetch("/oauth2/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue('/oauth2/server/authorize?' in response.headers['Location'])
+ self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"])
def test_facebook_login(self):
- response = self.fetch('/facebook/client/login', follow_redirects=False)
+ response = self.fetch("/facebook/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue('/facebook/server/authorize?' in response.headers['Location'])
- response = self.fetch('/facebook/client/login?code=1234', follow_redirects=False)
+ self.assertTrue("/facebook/server/authorize?" in response.headers["Location"])
+ response = self.fetch(
+ "/facebook/client/login?code=1234", follow_redirects=False
+ )
self.assertEqual(response.code, 200)
user = json_decode(response.body)
- self.assertEqual(user['access_token'], 'asdf')
- self.assertEqual(user['session_expires'], '3600')
+ self.assertEqual(user["access_token"], "asdf")
+ self.assertEqual(user["session_expires"], "3600")
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/oauth1/server/authorize?oauth_token=zxcv'))
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/oauth1/server/authorize?oauth_token=zxcv"
+ )
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
def test_twitter_redirect(self):
- self.base_twitter_redirect('/twitter/client/login')
+ self.base_twitter_redirect("/twitter/client/login")
def test_twitter_redirect_gen_coroutine(self):
- self.base_twitter_redirect('/twitter/client/login_gen_coroutine')
+ self.base_twitter_redirect("/twitter/client/login_gen_coroutine")
def test_twitter_authenticate_redirect(self):
- response = self.fetch('/twitter/client/authenticate', follow_redirects=False)
+ response = self.fetch("/twitter/client/authenticate", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(response.headers['Location'].endswith(
- '/twitter/server/authenticate?oauth_token=zxcv'), response.headers['Location'])
+ self.assertTrue(
+ response.headers["Location"].endswith(
+ "/twitter/server/authenticate?oauth_token=zxcv"
+ ),
+ response.headers["Location"],
+ )
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
- '_oauth_request_token="enhjdg==|MTIzNA=="' in response.headers['Set-Cookie'],
- response.headers['Set-Cookie'])
+ '_oauth_request_token="enhjdg==|MTIzNA=="'
+ in response.headers["Set-Cookie"],
+ response.headers["Set-Cookie"],
+ )
def test_twitter_get_user(self):
response = self.fetch(
- '/twitter/client/login?oauth_token=zxcv',
- headers={'Cookie': '_oauth_request_token=enhjdg==|MTIzNA=='})
+ "/twitter/client/login?oauth_token=zxcv",
+ headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
+ )
response.rethrow()
parsed = json_decode(response.body)
- self.assertEqual(parsed,
- {u'access_token': {u'key': u'hjkl',
- u'screen_name': u'foo',
- u'secret': u'vbnm'},
- u'name': u'Foo',
- u'screen_name': u'foo',
- u'username': u'foo'})
+ self.assertEqual(
+ parsed,
+ {
+ u"access_token": {
+ u"key": u"hjkl",
+ u"screen_name": u"foo",
+ u"secret": u"vbnm",
+ },
+ u"name": u"Foo",
+ u"screen_name": u"foo",
+ u"username": u"foo",
+ },
+ )
def test_twitter_show_user(self):
- response = self.fetch('/twitter/client/show_user?name=somebody')
+ response = self.fetch("/twitter/client/show_user?name=somebody")
response.rethrow()
- self.assertEqual(json_decode(response.body),
- {'name': 'Somebody', 'screen_name': 'somebody'})
+ self.assertEqual(
+ json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"}
+ )
def test_twitter_show_user_error(self):
- response = self.fetch('/twitter/client/show_user?name=error')
+ response = self.fetch("/twitter/client/show_user?name=error")
self.assertEqual(response.code, 500)
- self.assertEqual(response.body, b'error from twitter request')
+ self.assertEqual(response.body, b"error from twitter request")
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
self.test = test
- self._OAUTH_REDIRECT_URI = test.get_url('/client/login')
- self._OAUTH_AUTHORIZE_URL = test.get_url('/google/oauth2/authorize')
- self._OAUTH_ACCESS_TOKEN_URL = test.get_url('/google/oauth2/token')
+ self._OAUTH_REDIRECT_URI = test.get_url("/client/login")
+ self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize")
+ self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token")
@gen.coroutine
def get(self):
- code = self.get_argument('code', None)
+ code = self.get_argument("code", None)
if code is not None:
# retrieve authenticate google user
- access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI,
- code)
+ access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code)
user = yield self.oauth2_request(
self.test.get_url("/google/oauth2/userinfo"),
- access_token=access["access_token"])
+ access_token=access["access_token"],
+ )
# return the user and access token as json
user["access_token"] = access["access_token"]
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self._OAUTH_REDIRECT_URI,
- client_id=self.settings['google_oauth']['key'],
- client_secret=self.settings['google_oauth']['secret'],
- scope=['profile', 'email'],
- response_type='code',
- extra_params={'prompt': 'select_account'})
+ client_id=self.settings["google_oauth"]["key"],
+ client_secret=self.settings["google_oauth"]["secret"],
+ scope=["profile", "email"],
+ response_type="code",
+ extra_params={"prompt": "select_account"},
+ )
class GoogleOAuth2AuthorizeHandler(RequestHandler):
def get(self):
# issue a fake auth code and redirect to redirect_uri
- code = 'fake-authorization-code'
- self.redirect(url_concat(self.get_argument('redirect_uri'),
- dict(code=code)))
+ code = "fake-authorization-code"
+ self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code)))
class GoogleOAuth2TokenHandler(RequestHandler):
def post(self):
- assert self.get_argument('code') == 'fake-authorization-code'
+ assert self.get_argument("code") == "fake-authorization-code"
# issue a fake token
- self.finish({
- 'access_token': 'fake-access-token',
- 'expires_in': 'never-expires'
- })
+ self.finish(
+ {"access_token": "fake-access-token", "expires_in": "never-expires"}
+ )
class GoogleOAuth2UserinfoHandler(RequestHandler):
def get(self):
- assert self.get_argument('access_token') == 'fake-access-token'
+ assert self.get_argument("access_token") == "fake-access-token"
# return a fake user
- self.finish({
- 'name': 'Foo',
- 'email': 'foo@example.com'
- })
+ self.finish({"name": "Foo", "email": "foo@example.com"})
class GoogleOAuth2Test(AsyncHTTPTestCase):
return Application(
[
# test endpoints
- ('/client/login', GoogleLoginHandler, dict(test=self)),
-
+ ("/client/login", GoogleLoginHandler, dict(test=self)),
# simulated google authorization server endpoints
- ('/google/oauth2/authorize', GoogleOAuth2AuthorizeHandler),
- ('/google/oauth2/token', GoogleOAuth2TokenHandler),
- ('/google/oauth2/userinfo', GoogleOAuth2UserinfoHandler),
+ ("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler),
+ ("/google/oauth2/token", GoogleOAuth2TokenHandler),
+ ("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler),
],
google_oauth={
- "key": 'fake_google_client_id',
- "secret": 'fake_google_client_secret'
- })
+ "key": "fake_google_client_id",
+ "secret": "fake_google_client_secret",
+ },
+ )
def test_google_login(self):
- response = self.fetch('/client/login')
- self.assertDictEqual({
- u'name': u'Foo',
- u'email': u'foo@example.com',
- u'access_token': u'fake-access-token',
- }, json_decode(response.body))
+ response = self.fetch("/client/login")
+ self.assertDictEqual(
+ {
+ u"name": u"Foo",
+ u"email": u"foo@example.com",
+ u"access_token": u"fake-access-token",
+ },
+ json_decode(response.body),
+ )
"""
# Create temporary test application
- os.mkdir(os.path.join(self.path, 'testapp'))
- open(os.path.join(self.path, 'testapp/__init__.py'), 'w').close()
- with open(os.path.join(self.path, 'testapp/__main__.py'), 'w') as f:
+ os.mkdir(os.path.join(self.path, "testapp"))
+ open(os.path.join(self.path, "testapp/__init__.py"), "w").close()
+ with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
- if 'PYTHONPATH' in os.environ:
- pythonpath += os.pathsep + os.environ['PYTHONPATH']
+ if "PYTHONPATH" in os.environ:
+ pythonpath += os.pathsep + os.environ["PYTHONPATH"]
p = Popen(
- [sys.executable, '-m', 'testapp'], stdout=subprocess.PIPE,
- cwd=self.path, env=dict(os.environ, PYTHONPATH=pythonpath),
- universal_newlines=True)
+ [sys.executable, "-m", "testapp"],
+ stdout=subprocess.PIPE,
+ cwd=self.path,
+ env=dict(os.environ, PYTHONPATH=pythonpath),
+ universal_newlines=True,
+ )
out = p.communicate()[0]
- self.assertEqual(out, 'Starting\nStarting\n')
+ self.assertEqual(out, "Starting\nStarting\n")
def test_reload_wrapper_preservation(self):
# This test verifies that when `python -m tornado.autoreload`
"""
# Create temporary test application
- os.mkdir(os.path.join(self.path, 'testapp'))
- init_file = os.path.join(self.path, 'testapp', '__init__.py')
- open(init_file, 'w').close()
- main_file = os.path.join(self.path, 'testapp', '__main__.py')
- with open(main_file, 'w') as f:
+ os.mkdir(os.path.join(self.path, "testapp"))
+ init_file = os.path.join(self.path, "testapp", "__init__.py")
+ open(init_file, "w").close()
+ main_file = os.path.join(self.path, "testapp", "__main__.py")
+ with open(main_file, "w") as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
- if 'PYTHONPATH' in os.environ:
- pythonpath += os.pathsep + os.environ['PYTHONPATH']
+ if "PYTHONPATH" in os.environ:
+ pythonpath += os.pathsep + os.environ["PYTHONPATH"]
autoreload_proc = Popen(
- [sys.executable, '-m', 'tornado.autoreload', '-m', 'testapp'],
- stdout=subprocess.PIPE, cwd=self.path,
+ [sys.executable, "-m", "tornado.autoreload", "-m", "testapp"],
+ stdout=subprocess.PIPE,
+ cwd=self.path,
env=dict(os.environ, PYTHONPATH=pythonpath),
- universal_newlines=True)
+ universal_newlines=True,
+ )
# This timeout needs to be fairly generous for pypy due to jit
# warmup costs.
raise Exception("subprocess failed to terminate")
out = autoreload_proc.communicate()[0]
- self.assertEqual(out, 'Starting\n' * 2)
+ self.assertEqual(out, "Starting\n" * 2)
import socket
import unittest
-from tornado.concurrent import Future, run_on_executor, future_set_result_unless_cancelled
+from tornado.concurrent import (
+ Future,
+ run_on_executor,
+ future_set_result_unless_cancelled,
+)
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
class MiscFutureTest(AsyncTestCase):
-
def test_future_set_result_unless_cancelled(self):
fut = Future() # type: Future[int]
future_set_result_unless_cancelled(fut, 42)
self.port = port
def process_response(self, data):
- m = re.match('(.*)\t(.*)\n', to_unicode(data))
+ m = re.match("(.*)\t(.*)\n", to_unicode(data))
if m is None:
raise Exception("did not match")
status, message = m.groups()
- if status == 'ok':
+ if status == "ok":
return message
else:
raise CapError(message)
class GeneratorCapClient(BaseCapClient):
@gen.coroutine
def capitalize(self, request_data):
- logging.debug('capitalize')
+ logging.debug("capitalize")
stream = IOStream(socket.socket())
- logging.debug('connecting')
- yield stream.connect(('127.0.0.1', self.port))
- stream.write(utf8(request_data + '\n'))
- logging.debug('reading')
- data = yield stream.read_until(b'\n')
- logging.debug('returning')
+ logging.debug("connecting")
+ yield stream.connect(("127.0.0.1", self.port))
+ stream.write(utf8(request_data + "\n"))
+ logging.debug("reading")
+ data = yield stream.read_until(b"\n")
+ logging.debug("returning")
stream.close()
raise gen.Return(self.process_response(data))
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
+
self.io_loop.run_sync(f)
def test_generator_error(self):
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
+
self.io_loop.run_sync(f)
def __init__(self):
self.__executor = futures.thread.ThreadPoolExecutor(1)
- @run_on_executor(executor='_Object__executor')
+ @run_on_executor(executor="_Object__executor")
def f(self):
return 42
async def f():
answer = await o.f()
return answer
+
result = yield f()
self.assertEqual(result, 42)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
self.password = password
def get(self):
- realm = 'test'
- opaque = 'asdf'
+ realm = "test"
+ opaque = "asdf"
# Real implementations would use a random nonce.
nonce = "1234"
- auth_header = self.request.headers.get('Authorization', None)
+ auth_header = self.request.headers.get("Authorization", None)
if auth_header is not None:
- auth_mode, params = auth_header.split(' ', 1)
- assert auth_mode == 'Digest'
+ auth_mode, params = auth_header.split(" ", 1)
+ assert auth_mode == "Digest"
param_dict = {}
- for pair in params.split(','):
- k, v = pair.strip().split('=', 1)
+ for pair in params.split(","):
+ k, v = pair.strip().split("=", 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
- assert param_dict['realm'] == realm
- assert param_dict['opaque'] == opaque
- assert param_dict['nonce'] == nonce
- assert param_dict['username'] == self.username
- assert param_dict['uri'] == self.request.path
- h1 = md5(utf8('%s:%s:%s' % (self.username, realm, self.password))).hexdigest()
- h2 = md5(utf8('%s:%s' % (self.request.method,
- self.request.path))).hexdigest()
- digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
- if digest == param_dict['response']:
- self.write('ok')
+ assert param_dict["realm"] == realm
+ assert param_dict["opaque"] == opaque
+ assert param_dict["nonce"] == nonce
+ assert param_dict["username"] == self.username
+ assert param_dict["uri"] == self.request.path
+ h1 = md5(
+ utf8("%s:%s:%s" % (self.username, realm, self.password))
+ ).hexdigest()
+ h2 = md5(
+ utf8("%s:%s" % (self.request.method, self.request.path))
+ ).hexdigest()
+ digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest()
+ if digest == param_dict["response"]:
+ self.write("ok")
else:
- self.write('fail')
+ self.write("fail")
else:
self.set_status(401)
- self.set_header('WWW-Authenticate',
- 'Digest realm="%s", nonce="%s", opaque="%s"' %
- (realm, nonce, opaque))
+ self.set_header(
+ "WWW-Authenticate",
+ 'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque),
+ )
class CustomReasonHandler(RequestHandler):
self.http_client = self.create_client()
def get_app(self):
- return Application([
- ('/digest', DigestAuthHandler, {'username': 'foo', 'password': 'bar'}),
- ('/digest_non_ascii', DigestAuthHandler, {'username': 'foo', 'password': 'barユ£'}),
- ('/custom_reason', CustomReasonHandler),
- ('/custom_fail_reason', CustomFailReasonHandler),
- ])
+ return Application(
+ [
+ ("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}),
+ (
+ "/digest_non_ascii",
+ DigestAuthHandler,
+ {"username": "foo", "password": "barユ£"},
+ ),
+ ("/custom_reason", CustomReasonHandler),
+ ("/custom_fail_reason", CustomFailReasonHandler),
+ ]
+ )
def create_client(self, **kwargs):
- return CurlAsyncHTTPClient(force_instance=True,
- defaults=dict(allow_ipv6=False),
- **kwargs)
+ return CurlAsyncHTTPClient(
+ force_instance=True, defaults=dict(allow_ipv6=False), **kwargs
+ )
def test_digest_auth(self):
- response = self.fetch('/digest', auth_mode='digest',
- auth_username='foo', auth_password='bar')
- self.assertEqual(response.body, b'ok')
+ response = self.fetch(
+ "/digest", auth_mode="digest", auth_username="foo", auth_password="bar"
+ )
+ self.assertEqual(response.body, b"ok")
def test_custom_reason(self):
- response = self.fetch('/custom_reason')
+ response = self.fetch("/custom_reason")
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
- response = self.fetch('/custom_fail_reason')
+ response = self.fetch("/custom_fail_reason")
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
def test_digest_auth_non_ascii(self):
- response = self.fetch('/digest_non_ascii', auth_mode='digest',
- auth_username='foo', auth_password='barユ£')
- self.assertEqual(response.body, b'ok')
+ response = self.fetch(
+ "/digest_non_ascii",
+ auth_mode="digest",
+ auth_username="foo",
+ auth_password="barユ£",
+ )
+ self.assertEqual(response.body, b"ok")
import tornado.escape
from tornado.escape import (
- utf8, xhtml_escape, xhtml_unescape, url_escape, url_unescape,
- to_unicode, json_decode, json_encode, squeeze, recursive_unicode,
+ utf8,
+ xhtml_escape,
+ xhtml_unescape,
+ url_escape,
+ url_unescape,
+ to_unicode,
+ json_decode,
+ json_encode,
+ squeeze,
+ recursive_unicode,
)
from tornado.util import unicode_type
linkify_tests = [
# (input, linkify_kwargs, expected_output)
-
- ("hello http://world.com/!", {},
- u'hello <a href="http://world.com/">http://world.com/</a>!'),
-
- ("hello http://world.com/with?param=true&stuff=yes", {},
- u'hello <a href="http://world.com/with?param=true&stuff=yes">http://world.com/with?param=true&stuff=yes</a>'), # noqa: E501
-
+ (
+ "hello http://world.com/!",
+ {},
+ u'hello <a href="http://world.com/">http://world.com/</a>!',
+ ),
+ (
+ "hello http://world.com/with?param=true&stuff=yes",
+ {},
+ u'hello <a href="http://world.com/with?param=true&stuff=yes">http://world.com/with?param=true&stuff=yes</a>', # noqa: E501
+ ),
# an opened paren followed by many chars killed Gruber's regex
- ("http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", {},
- u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'), # noqa: E501
-
+ (
+ "http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+ {},
+ u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501
+ ),
# as did too many dots at the end
- ("http://url.com/withmany.......................................", {},
- u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................'), # noqa: E501
-
- ("http://url.com/withmany((((((((((((((((((((((((((((((((((a)", {},
- u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)'), # noqa: E501
-
+ (
+ "http://url.com/withmany.......................................",
+ {},
+ u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................', # noqa: E501
+ ),
+ (
+ "http://url.com/withmany((((((((((((((((((((((((((((((((((a)",
+ {},
+ u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)', # noqa: E501
+ ),
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
- ("http://foo.com/blah_blah", {},
- u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>'),
-
- ("http://foo.com/blah_blah/", {},
- u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>'),
-
- ("(Something like http://foo.com/blah_blah)", {},
- u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)'),
-
- ("http://foo.com/blah_blah_(wikipedia)", {},
- u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>'),
-
- ("http://foo.com/blah_(blah)_(wikipedia)_blah", {},
- u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>'), # noqa: E501
-
- ("(Something like http://foo.com/blah_blah_(wikipedia))", {},
- u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)'), # noqa: E501
-
- ("http://foo.com/blah_blah.", {},
- u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.'),
-
- ("http://foo.com/blah_blah/.", {},
- u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.'),
-
- ("<http://foo.com/blah_blah>", {},
- u'<<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>>'),
-
- ("<http://foo.com/blah_blah/>", {},
- u'<<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>>'),
-
- ("http://foo.com/blah_blah,", {},
- u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,'),
-
- ("http://www.example.com/wpstyle/?p=364.", {},
- u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.'),
-
- ("rdar://1234",
- {"permitted_protocols": ["http", "rdar"]},
- u'<a href="rdar://1234">rdar://1234</a>'),
-
- ("rdar:/1234",
- {"permitted_protocols": ["rdar"]},
- u'<a href="rdar:/1234">rdar:/1234</a>'),
-
- ("http://userid:password@example.com:8080", {},
- u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>'), # noqa: E501
-
- ("http://userid@example.com", {},
- u'<a href="http://userid@example.com">http://userid@example.com</a>'),
-
- ("http://userid@example.com:8080", {},
- u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>'),
-
- ("http://userid:password@example.com", {},
- u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>'),
-
- ("message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
- {"permitted_protocols": ["http", "message"]},
- u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
- u'message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>'),
-
- (u"http://\u27a1.ws/\u4a39", {},
- u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>'),
-
- ("<tag>http://example.com</tag>", {},
- u'<tag><a href="http://example.com">http://example.com</a></tag>'),
-
- ("Just a www.example.com link.", {},
- u'Just a <a href="http://www.example.com">www.example.com</a> link.'),
-
- ("Just a www.example.com link.",
- {"require_protocol": True},
- u'Just a www.example.com link.'),
-
- ("A http://reallylong.com/link/that/exceedsthelenglimit.html",
- {"require_protocol": True, "shorten": True},
- u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
- u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>'), # noqa: E501
-
- ("A http://reallylongdomainnamethatwillbetoolong.com/hi!",
- {"shorten": True},
- u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
- u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!'), # noqa: E501
-
- ("A file:///passwords.txt and http://web.com link", {},
- u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link'),
-
- ("A file:///passwords.txt and http://web.com link",
- {"permitted_protocols": ["file"]},
- u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link'),
-
- ("www.external-link.com",
- {"extra_params": 'rel="nofollow" class="external"'},
- u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
-
- ("www.external-link.com and www.internal-link.com/blogs extra",
- {"extra_params": lambda href: 'class="internal"' if href.startswith("http://www.internal-link.com") else 'rel="nofollow" class="external"'}, # noqa: E501
- u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
- u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra'), # noqa: E501
-
- ("www.external-link.com",
- {"extra_params": lambda href: ' rel="nofollow" class="external" '},
- u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>'), # noqa: E501
+ (
+ "http://foo.com/blah_blah",
+ {},
+ u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>',
+ ),
+ (
+ "http://foo.com/blah_blah/",
+ {},
+ u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>',
+ ),
+ (
+ "(Something like http://foo.com/blah_blah)",
+ {},
+ u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)',
+ ),
+ (
+ "http://foo.com/blah_blah_(wikipedia)",
+ {},
+ u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>',
+ ),
+ (
+ "http://foo.com/blah_(blah)_(wikipedia)_blah",
+ {},
+ u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>', # noqa: E501
+ ),
+ (
+ "(Something like http://foo.com/blah_blah_(wikipedia))",
+ {},
+ u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)', # noqa: E501
+ ),
+ (
+ "http://foo.com/blah_blah.",
+ {},
+ u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.',
+ ),
+ (
+ "http://foo.com/blah_blah/.",
+ {},
+ u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.',
+ ),
+ (
+ "<http://foo.com/blah_blah>",
+ {},
+ u'<<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>>',
+ ),
+ (
+ "<http://foo.com/blah_blah/>",
+ {},
+ u'<<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>>',
+ ),
+ (
+ "http://foo.com/blah_blah,",
+ {},
+ u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,',
+ ),
+ (
+ "http://www.example.com/wpstyle/?p=364.",
+ {},
+ u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.', # noqa: E501
+ ),
+ (
+ "rdar://1234",
+ {"permitted_protocols": ["http", "rdar"]},
+ u'<a href="rdar://1234">rdar://1234</a>',
+ ),
+ (
+ "rdar:/1234",
+ {"permitted_protocols": ["rdar"]},
+ u'<a href="rdar:/1234">rdar:/1234</a>',
+ ),
+ (
+ "http://userid:password@example.com:8080",
+ {},
+ u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>', # noqa: E501
+ ),
+ (
+ "http://userid@example.com",
+ {},
+ u'<a href="http://userid@example.com">http://userid@example.com</a>',
+ ),
+ (
+ "http://userid@example.com:8080",
+ {},
+ u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>',
+ ),
+ (
+ "http://userid:password@example.com",
+ {},
+ u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>',
+ ),
+ (
+ "message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
+ {"permitted_protocols": ["http", "message"]},
+ u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
+ u"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>",
+ ),
+ (
+ u"http://\u27a1.ws/\u4a39",
+ {},
+ u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>',
+ ),
+ (
+ "<tag>http://example.com</tag>",
+ {},
+ u'<tag><a href="http://example.com">http://example.com</a></tag>',
+ ),
+ (
+ "Just a www.example.com link.",
+ {},
+ u'Just a <a href="http://www.example.com">www.example.com</a> link.',
+ ),
+ (
+ "Just a www.example.com link.",
+ {"require_protocol": True},
+ u"Just a www.example.com link.",
+ ),
+ (
+ "A http://reallylong.com/link/that/exceedsthelenglimit.html",
+ {"require_protocol": True, "shorten": True},
+ u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
+ u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>', # noqa: E501
+ ),
+ (
+ "A http://reallylongdomainnamethatwillbetoolong.com/hi!",
+ {"shorten": True},
+ u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
+ u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!', # noqa: E501
+ ),
+ (
+ "A file:///passwords.txt and http://web.com link",
+ {},
+ u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link',
+ ),
+ (
+ "A file:///passwords.txt and http://web.com link",
+ {"permitted_protocols": ["file"]},
+ u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link',
+ ),
+ (
+ "www.external-link.com",
+ {"extra_params": 'rel="nofollow" class="external"'},
+ u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501
+ ),
+ (
+ "www.external-link.com and www.internal-link.com/blogs extra",
+ {
+ "extra_params": lambda href: 'class="internal"'
+ if href.startswith("http://www.internal-link.com")
+ else 'rel="nofollow" class="external"'
+ },
+ u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
+ u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra', # noqa: E501
+ ),
+ (
+ "www.external-link.com",
+ {"extra_params": lambda href: ' rel="nofollow" class="external" '},
+ u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501
+ ),
] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]]
("<foo>", "<foo>"),
(u"<foo>", u"<foo>"),
(b"<foo>", b"<foo>"),
-
("<>&\"'", "<>&"'"),
("&", "&amp;"),
-
(u"<\u00e9>", u"<\u00e9>"),
(b"<\xc3\xa9>", b"<\xc3\xa9>"),
] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]]
def test_xhtml_unescape_numeric(self):
tests = [
- ('foo bar', 'foo bar'),
- ('foo bar', 'foo bar'),
- ('foo bar', 'foo bar'),
- ('foo઼bar', u'foo\u0abcbar'),
- ('foo&#xyz;bar', 'foo&#xyz;bar'), # invalid encoding
- ('foo&#;bar', 'foo&#;bar'), # invalid encoding
- ('foo&#x;bar', 'foo&#x;bar'), # invalid encoding
+ ("foo bar", "foo bar"),
+ ("foo bar", "foo bar"),
+ ("foo bar", "foo bar"),
+ ("foo઼bar", u"foo\u0abcbar"),
+ ("foo&#xyz;bar", "foo&#xyz;bar"), # invalid encoding
+ ("foo&#;bar", "foo&#;bar"), # invalid encoding
+ ("foo&#x;bar", "foo&#x;bar"), # invalid encoding
]
for escaped, unescaped in tests:
self.assertEqual(unescaped, xhtml_unescape(escaped))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
- (u'\u00e9'.encode('utf8'), '%C3%A9'),
- (u'\u00e9'.encode('latin1'), '%E9'),
-
+ (u"\u00e9".encode("utf8"), "%C3%A9"),
+ (u"\u00e9".encode("latin1"), "%E9"),
# unicode strings become utf8
- (u'\u00e9', '%C3%A9'),
+ (u"\u00e9", "%C3%A9"),
] # type: List[Tuple[Union[str, bytes], str]]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
- ('%C3%A9', u'\u00e9', 'utf8'),
- ('%C3%A9', u'\u00c3\u00a9', 'latin1'),
- ('%C3%A9', utf8(u'\u00e9'), None),
+ ("%C3%A9", u"\u00e9", "utf8"),
+ ("%C3%A9", u"\u00c3\u00a9", "latin1"),
+ ("%C3%A9", utf8(u"\u00e9"), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
- unescaped = '+ #%'
- plus_escaped = '%2B+%23%25'
- escaped = '%2B%20%23%25'
+ unescaped = "+ #%"
+ plus_escaped = "%2B+%23%25"
+ escaped = "%2B%20%23%25"
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
- self.assertEqual(url_unescape(plus_escaped, encoding=None),
- utf8(unescaped))
- self.assertEqual(url_unescape(escaped, encoding=None, plus=False),
- utf8(unescaped))
+ self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped))
+ self.assertEqual(
+ url_unescape(escaped, encoding=None, plus=False), utf8(unescaped)
+ )
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
- self.assertEqual(squeeze(u'sequences of whitespace chars'),
- u'sequences of whitespace chars')
+ self.assertEqual(
+ squeeze(u"sequences of whitespace chars"),
+ u"sequences of whitespace chars",
+ )
def test_recursive_unicode(self):
tests = {
- 'dict': {b"foo": b"bar"},
- 'list': [b"foo", b"bar"],
- 'tuple': (b"foo", b"bar"),
- 'bytes': b"foo"
+ "dict": {b"foo": b"bar"},
+ "list": [b"foo", b"bar"],
+ "tuple": (b"foo", b"bar"),
+ "bytes": b"foo",
}
- self.assertEqual(recursive_unicode(tests['dict']), {u"foo": u"bar"})
- self.assertEqual(recursive_unicode(tests['list']), [u"foo", u"bar"])
- self.assertEqual(recursive_unicode(tests['tuple']), (u"foo", u"bar"))
- self.assertEqual(recursive_unicode(tests['bytes']), u"foo")
+ self.assertEqual(recursive_unicode(tests["dict"]), {u"foo": u"bar"})
+ self.assertEqual(recursive_unicode(tests["list"]), [u"foo", u"bar"])
+ self.assertEqual(recursive_unicode(tests["tuple"]), (u"foo", u"bar"))
+ self.assertEqual(recursive_unicode(tests["bytes"]), u"foo")
@gen.coroutine
def f():
pass
+
self.io_loop.run_sync(f)
def test_exception_phase1(self):
@gen.coroutine
def f():
1 / 0
+
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
def test_exception_phase2(self):
def f():
yield gen.moment
1 / 0
+
self.assertRaises(ZeroDivisionError, self.io_loop.run_sync, f)
def test_bogus_yield(self):
@gen.coroutine
def f():
yield 42
+
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
def test_bogus_yield_tuple(self):
@gen.coroutine
def f():
yield (1, 2)
+
self.assertRaises(gen.BadYieldError, self.io_loop.run_sync, f)
def test_reuse(self):
@gen.coroutine
def f():
yield gen.moment
+
self.io_loop.run_sync(f)
self.io_loop.run_sync(f)
@gen.coroutine
def f():
yield None
+
self.io_loop.run_sync(f)
def test_multi(self):
def f():
results = yield [self.add_one_async(1), self.add_one_async(2)]
self.assertEqual(results, [2, 3])
+
self.io_loop.run_sync(f)
def test_multi_dict(self):
def f():
results = yield dict(foo=self.add_one_async(1), bar=self.add_one_async(2))
self.assertEqual(results, dict(foo=2, bar=3))
+
self.io_loop.run_sync(f)
def test_multi_delayed(self):
@gen.coroutine
def f():
# callbacks run at different times
- responses = yield gen.multi_future([
- self.delay(3, "v1"),
- self.delay(1, "v2"),
- ])
+ responses = yield gen.multi_future(
+ [self.delay(3, "v1"), self.delay(1, "v2")]
+ )
self.assertEqual(responses, ["v1", "v2"])
+
self.io_loop.run_sync(f)
def test_multi_dict_delayed(self):
@gen.coroutine
def f():
# callbacks run at different times
- responses = yield gen.multi_future(dict(
- foo=self.delay(3, "v1"),
- bar=self.delay(1, "v2"),
- ))
+ responses = yield gen.multi_future(
+ dict(foo=self.delay(3, "v1"), bar=self.delay(1, "v2"))
+ )
self.assertEqual(responses, dict(foo="v1", bar="v2"))
+
self.io_loop.run_sync(f)
@skipOnTravis
def test_multi_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))])
+ yield gen.Multi(
+ [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ]
+ )
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_future(2)])
+ yield gen.Multi(
+ [self.async_exception(RuntimeError("error 1")), self.async_future(2)]
+ )
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
- yield gen.Multi([self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))],
- quiet_exceptions=RuntimeError)
+ yield gen.Multi(
+ [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ],
+ quiet_exceptions=RuntimeError,
+ )
@gen_test
def test_multi_future_exceptions(self):
with ExpectLog(app_log, "Multiple exceptions in yield list"):
with self.assertRaises(RuntimeError) as cm:
- yield [self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))]
+ yield [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ]
self.assertEqual(str(cm.exception), "error 1")
# With only one exception, no error is logged.
with self.assertRaises(RuntimeError):
- yield [self.async_exception(RuntimeError("error 1")),
- self.async_future(2)]
+ yield [self.async_exception(RuntimeError("error 1")), self.async_future(2)]
# Exception logging may be explicitly quieted.
with self.assertRaises(RuntimeError):
yield gen.multi_future(
- [self.async_exception(RuntimeError("error 1")),
- self.async_exception(RuntimeError("error 2"))],
- quiet_exceptions=RuntimeError)
+ [
+ self.async_exception(RuntimeError("error 1")),
+ self.async_exception(RuntimeError("error 2")),
+ ],
+ quiet_exceptions=RuntimeError,
+ )
def test_sync_raise_return(self):
@gen.coroutine
@gen.coroutine
def f():
raise gen.Return(42)
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
def f():
yield gen.moment
raise gen.Return(42)
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@gen.coroutine
def f():
return 42
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
def f():
yield gen.moment
return 42
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
if True:
return 42
yield gen.Task(self.io_loop.add_callback)
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
async def f2():
result = await f1()
return result
+
result = yield f2()
self.assertEqual(result, 42)
self.finished = True
# `yield None`)
async def f():
import asyncio
+
await asyncio.sleep(0)
return 42
+
result = yield f()
self.assertEqual(result, 42)
self.finished = True
@gen.coroutine
def f():
return
+
result = yield f()
self.assertEqual(result, None)
self.finished = True
def f():
yield gen.moment
return
+
result = yield f()
self.assertEqual(result, None)
self.finished = True
@gen.coroutine
def f():
1 / 0
+
# The exception is raised when the future is yielded
# (or equivalently when its result method is called),
# not when the function itself is called).
def f():
yield gen.moment
1 / 0
+
future = f()
with self.assertRaises(ZeroDivisionError):
yield future
for i in range(5):
calls.append(name)
yield yieldable
+
# First, confirm the behavior without moment: each coroutine
# monopolizes the event loop until it finishes.
immediate = Future() # type: Future[None]
immediate.set_result(None)
- yield [f('a', immediate), f('b', immediate)]
- self.assertEqual(''.join(calls), 'aaaaabbbbb')
+ yield [f("a", immediate), f("b", immediate)]
+ self.assertEqual("".join(calls), "aaaaabbbbb")
# With moment, they take turns.
calls = []
- yield [f('a', gen.moment), f('b', gen.moment)]
- self.assertEqual(''.join(calls), 'ababababab')
+ yield [f("a", gen.moment), f("b", gen.moment)]
+ self.assertEqual("".join(calls), "ababababab")
self.finished = True
calls = []
- yield [f('a', gen.moment), f('b', immediate)]
- self.assertEqual(''.join(calls), 'abbbbbaaaa')
+ yield [f("a", gen.moment), f("b", immediate)]
+ self.assertEqual("".join(calls), "abbbbbaaaa")
@gen_test
def test_sleep(self):
self.finished = True
@skipNotCPython
- @unittest.skipIf((3,) < sys.version_info < (3, 6),
- "asyncio.Future has reference cycles")
+ @unittest.skipIf(
+ (3,) < sys.version_info < (3, 6), "asyncio.Future has reference cycles"
+ )
def test_coroutine_refcounting(self):
# On CPython, tasks and their arguments should be released immediately
# without waiting for garbage collection.
def inner():
class Foo(object):
pass
+
local_var = Foo()
self.local_ref = weakref.ref(local_var)
def dummy():
pass
+
yield gen.coroutine(dummy)()
- raise ValueError('Some error')
+ raise ValueError("Some error")
@gen.coroutine
def inner2():
self.assertIsInstance(coro, asyncio.Future)
# We expect the coroutine repr() to show the place where
# it was instantiated
- expected = ("created at %s:%d"
- % (__file__, f.__code__.co_firstlineno + 3))
+ expected = "created at %s:%d" % (__file__, f.__code__.co_firstlineno + 3)
actual = repr(coro)
self.assertIn(expected, actual)
def prepare(self):
self.chunks = [] # type: List[str]
yield gen.moment
- self.chunks.append('1')
+ self.chunks.append("1")
@gen.coroutine
def get(self):
- self.chunks.append('2')
+ self.chunks.append("2")
yield gen.moment
- self.chunks.append('3')
+ self.chunks.append("3")
yield gen.moment
- self.write(''.join(self.chunks))
+ self.write("".join(self.chunks))
class AsyncPrepareErrorHandler(RequestHandler):
raise HTTPError(403)
def get(self):
- self.finish('ok')
+ self.finish("ok")
class NativeCoroutineHandler(RequestHandler):
class GenWebTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([
- ('/coroutine_sequence', GenCoroutineSequenceHandler),
- ('/coroutine_unfinished_sequence',
- GenCoroutineUnfinishedSequenceHandler),
- ('/undecorated_coroutine', UndecoratedCoroutinesHandler),
- ('/async_prepare_error', AsyncPrepareErrorHandler),
- ('/native_coroutine', NativeCoroutineHandler),
- ])
+ return Application(
+ [
+ ("/coroutine_sequence", GenCoroutineSequenceHandler),
+ (
+ "/coroutine_unfinished_sequence",
+ GenCoroutineUnfinishedSequenceHandler,
+ ),
+ ("/undecorated_coroutine", UndecoratedCoroutinesHandler),
+ ("/async_prepare_error", AsyncPrepareErrorHandler),
+ ("/native_coroutine", NativeCoroutineHandler),
+ ]
+ )
def test_coroutine_sequence_handler(self):
- response = self.fetch('/coroutine_sequence')
+ response = self.fetch("/coroutine_sequence")
self.assertEqual(response.body, b"123")
def test_coroutine_unfinished_sequence_handler(self):
- response = self.fetch('/coroutine_unfinished_sequence')
+ response = self.fetch("/coroutine_unfinished_sequence")
self.assertEqual(response.body, b"123")
def test_undecorated_coroutines(self):
- response = self.fetch('/undecorated_coroutine')
- self.assertEqual(response.body, b'123')
+ response = self.fetch("/undecorated_coroutine")
+ self.assertEqual(response.body, b"123")
def test_async_prepare_error_handler(self):
- response = self.fetch('/async_prepare_error')
+ response = self.fetch("/async_prepare_error")
self.assertEqual(response.code, 403)
def test_native_coroutine_handler(self):
- response = self.fetch('/native_coroutine')
+ response = self.fetch("/native_coroutine")
self.assertEqual(response.code, 200)
- self.assertEqual(response.body, b'ok')
+ self.assertEqual(response.body, b"ok")
class WithTimeoutTest(AsyncTestCase):
@gen_test
def test_timeout(self):
with self.assertRaises(gen.TimeoutError):
- yield gen.with_timeout(datetime.timedelta(seconds=0.1),
- Future())
+ yield gen.with_timeout(datetime.timedelta(seconds=0.1), Future())
@gen_test
def test_completes_before_timeout(self):
future = Future() # type: Future[str]
- self.io_loop.add_timeout(datetime.timedelta(seconds=0.1),
- lambda: future.set_result('asdf'))
- result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
- self.assertEqual(result, 'asdf')
+ self.io_loop.add_timeout(
+ datetime.timedelta(seconds=0.1), lambda: future.set_result("asdf")
+ )
+ result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
+ self.assertEqual(result, "asdf")
@gen_test
def test_fails_before_timeout(self):
future = Future() # type: Future[str]
self.io_loop.add_timeout(
datetime.timedelta(seconds=0.1),
- lambda: future.set_exception(ZeroDivisionError()))
+ lambda: future.set_exception(ZeroDivisionError()),
+ )
with self.assertRaises(ZeroDivisionError):
- yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
+ yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
@gen_test
def test_already_resolved(self):
future = Future() # type: Future[str]
- future.set_result('asdf')
- result = yield gen.with_timeout(datetime.timedelta(seconds=3600),
- future)
- self.assertEqual(result, 'asdf')
+ future.set_result("asdf")
+ result = yield gen.with_timeout(datetime.timedelta(seconds=3600), future)
+ self.assertEqual(result, "asdf")
@gen_test
def test_timeout_concurrent_future(self):
# A concurrent future that does not resolve before the timeout.
with futures.ThreadPoolExecutor(1) as executor:
with self.assertRaises(gen.TimeoutError):
- yield gen.with_timeout(self.io_loop.time(),
- executor.submit(time.sleep, 0.1))
+ yield gen.with_timeout(
+ self.io_loop.time(), executor.submit(time.sleep, 0.1)
+ )
@gen_test
def test_completed_concurrent_future(self):
# A concurrent future that is resolved before we even submit it
# to with_timeout.
with futures.ThreadPoolExecutor(1) as executor:
+
def dummy():
pass
+
f = executor.submit(dummy)
f.result() # wait for completion
yield gen.with_timeout(datetime.timedelta(seconds=3600), f)
def test_normal_concurrent_future(self):
# A conccurrent future that resolves while waiting for the timeout.
with futures.ThreadPoolExecutor(1) as executor:
- yield gen.with_timeout(datetime.timedelta(seconds=3600),
- executor.submit(lambda: time.sleep(0.01)))
+ yield gen.with_timeout(
+ datetime.timedelta(seconds=3600),
+ executor.submit(lambda: time.sleep(0.01)),
+ )
class WaitIteratorTest(AsyncTestCase):
@gen_test
def test_empty_iterator(self):
g = gen.WaitIterator()
- self.assertTrue(g.done(), 'empty generator iterated')
+ self.assertTrue(g.done(), "empty generator iterated")
with self.assertRaises(ValueError):
g = gen.WaitIterator(Future(), bar=Future())
while not dg.done():
dr = yield dg.next()
if dg.current_index == "f1":
- self.assertTrue(dg.current_future == f1 and dr == 24,
- "WaitIterator dict status incorrect")
+ self.assertTrue(
+ dg.current_future == f1 and dr == 24,
+ "WaitIterator dict status incorrect",
+ )
elif dg.current_index == "f2":
- self.assertTrue(dg.current_future == f2 and dr == 42,
- "WaitIterator dict status incorrect")
+ self.assertTrue(
+ dg.current_future == f2 and dr == 42,
+ "WaitIterator dict status incorrect",
+ )
else:
- self.fail("got bad WaitIterator index {}".format(
- dg.current_index))
+ self.fail("got bad WaitIterator index {}".format(dg.current_index))
i += 1
try:
r = yield g.next()
except ZeroDivisionError:
- self.assertIs(g.current_future, futures[0],
- 'exception future invalid')
+ self.assertIs(g.current_future, futures[0], "exception future invalid")
else:
if i == 0:
- self.assertEqual(r, 24, 'iterator value incorrect')
- self.assertEqual(g.current_index, 2, 'wrong index')
+ self.assertEqual(r, 24, "iterator value incorrect")
+ self.assertEqual(g.current_index, 2, "wrong index")
elif i == 2:
- self.assertEqual(r, 42, 'iterator value incorrect')
- self.assertEqual(g.current_index, 1, 'wrong index')
+ self.assertEqual(r, 42, "iterator value incorrect")
+ self.assertEqual(g.current_index, 1, "wrong index")
elif i == 3:
- self.assertEqual(r, 84, 'iterator value incorrect')
- self.assertEqual(g.current_index, 3, 'wrong index')
+ self.assertEqual(r, 84, "iterator value incorrect")
+ self.assertEqual(g.current_index, 3, "wrong index")
i += 1
@gen_test
try:
async for r in g:
if i == 0:
- self.assertEqual(r, 24, 'iterator value incorrect')
- self.assertEqual(g.current_index, 2, 'wrong index')
+ self.assertEqual(r, 24, "iterator value incorrect")
+ self.assertEqual(g.current_index, 2, "wrong index")
else:
raise Exception("expected exception on iteration 1")
i += 1
i += 1
async for r in g:
if i == 2:
- self.assertEqual(r, 42, 'iterator value incorrect')
- self.assertEqual(g.current_index, 1, 'wrong index')
+ self.assertEqual(r, 42, "iterator value incorrect")
+ self.assertEqual(g.current_index, 1, "wrong index")
elif i == 3:
- self.assertEqual(r, 84, 'iterator value incorrect')
- self.assertEqual(g.current_index, 3, 'wrong index')
+ self.assertEqual(r, 84, "iterator value incorrect")
+ self.assertEqual(g.current_index, 3, "wrong index")
else:
raise Exception("didn't expect iteration %d" % i)
i += 1
self.finished = True
+
yield f()
self.assertTrue(self.finished)
# WaitIterator itself, only the Future it returns. Since
# WaitIterator uses weak references internally to improve GC
# performance, this used to cause problems.
- yield gen.with_timeout(datetime.timedelta(seconds=0.1),
- gen.WaitIterator(gen.sleep(0)).next())
+ yield gen.with_timeout(
+ datetime.timedelta(seconds=0.1), gen.WaitIterator(gen.sleep(0)).next()
+ )
class RunnerGCTest(AsyncTestCase):
def is_pypy3(self):
- return (platform.python_implementation() == 'PyPy' and
- sys.version_info > (3,))
+ return platform.python_implementation() == "PyPy" and sys.version_info > (3,)
@gen_test
def test_gc(self):
self.io_loop.add_callback(callback)
yield fut
- yield gen.with_timeout(
- datetime.timedelta(seconds=0.2),
- tester()
- )
+ yield gen.with_timeout(datetime.timedelta(seconds=0.2), tester())
def test_gc_infinite_coro(self):
# Github issue 2229: suspended coroutines should be GCed when
yield gen.sleep(0.2)
loop.run_sync(do_something)
- with ExpectLog('asyncio', "Task was destroyed but it is pending"):
+ with ExpectLog("asyncio", "Task was destroyed but it is pending"):
loop.close()
gc.collect()
# Future was collected
self.assertEqual(result, [None, None])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
add_accept_handler(listener, accept_callback)
self.client_stream = IOStream(socket.socket())
self.addCleanup(self.client_stream.close)
- yield [self.client_stream.connect(('127.0.0.1', port)),
- event.wait()]
+ yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()]
self.io_loop.remove_handler(listener)
listener.close()
yield conn.read_response(Delegate())
yield event.wait()
self.assertEqual(self.code, 200)
- self.assertEqual(b''.join(body), b'hello')
+ self.assertEqual(b"".join(body), b"hello")
from tornado.escape import utf8, native_str
from tornado import gen
-from tornado.httpclient import HTTPRequest, HTTPResponse, _RequestProxy, HTTPError, HTTPClient
+from tornado.httpclient import (
+ HTTPRequest,
+ HTTPResponse,
+ _RequestProxy,
+ HTTPError,
+ HTTPClient,
+)
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
class PostHandler(RequestHandler):
def post(self):
- self.finish("Post arg1: %s, arg2: %s" % (
- self.get_argument("arg1"), self.get_argument("arg2")))
+ self.finish(
+ "Post arg1: %s, arg2: %s"
+ % (self.get_argument("arg1"), self.get_argument("arg2"))
+ )
class PutHandler(RequestHandler):
class RedirectHandler(RequestHandler):
def prepare(self):
- self.write('redirects can have bodies too')
- self.redirect(self.get_argument("url"),
- status=int(self.get_argument("status", "302")))
+ self.write("redirects can have bodies too")
+ self.redirect(
+ self.get_argument("url"), status=int(self.get_argument("status", "302"))
+ )
class ChunkHandler(RequestHandler):
class UserAgentHandler(RequestHandler):
def get(self):
- self.write(self.request.headers.get('User-Agent', 'User agent not set'))
+ self.write(self.request.headers.get("User-Agent", "User agent not set"))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
- self.set_header('Content-Length', 42)
+ self.set_header("Content-Length", 42)
def _clear_headers_for_304(self):
# Tornado strips content-length from 304 responses, but here we
class PatchHandler(RequestHandler):
-
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
- SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',) # type: ignore
+ SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore
def method(self):
self.write(self.request.method)
def get(self):
# Use get_arguments for keys to get strings, but
# request.arguments for values to get bytes.
- for k, v in zip(self.get_arguments('k'),
- self.request.arguments['v']):
+ for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]):
self.set_header(k, v)
+
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
- return Application([
- url("/hello", HelloWorldHandler),
- url("/post", PostHandler),
- url("/put", PutHandler),
- url("/redirect", RedirectHandler),
- url("/chunk", ChunkHandler),
- url("/auth", AuthHandler),
- url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
- url("/echopost", EchoPostHandler),
- url("/user_agent", UserAgentHandler),
- url("/304_with_content_length", ContentLength304Handler),
- url("/all_methods", AllMethodsHandler),
- url('/patch', PatchHandler),
- url('/set_header', SetHeaderHandler),
- ], gzip=True)
+ return Application(
+ [
+ url("/hello", HelloWorldHandler),
+ url("/post", PostHandler),
+ url("/put", PutHandler),
+ url("/redirect", RedirectHandler),
+ url("/chunk", ChunkHandler),
+ url("/auth", AuthHandler),
+ url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
+ url("/echopost", EchoPostHandler),
+ url("/user_agent", UserAgentHandler),
+ url("/304_with_content_length", ContentLength304Handler),
+ url("/all_methods", AllMethodsHandler),
+ url("/patch", PatchHandler),
+ url("/set_header", SetHeaderHandler),
+ ],
+ gzip=True,
+ )
def test_patch_receives_payload(self):
body = b"some patch data"
- response = self.fetch("/patch", method='PATCH', body=body)
+ response = self.fetch("/patch", method="PATCH", body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = [] # type: typing.List[bytes]
- response = self.fetch("/hello",
- streaming_callback=chunks.append)
+ response = self.fetch("/hello", streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
- response = self.fetch("/post", method="POST",
- body="arg1=foo&arg2=bar")
+ response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
self.assertEqual(response.body, b"asdfqwer")
chunks = [] # type: typing.List[bytes]
- response = self.fetch("/chunk",
- streaming_callback=chunks.append)
+ response = self.fetch("/chunk", streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
+
@gen.coroutine
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
- yield stream.write(b"""\
+ yield stream.write(
+ b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
2
0
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
stream.close()
+
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
def test_basic_auth(self):
# This test data appears in section 2 of RFC 7617.
- self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
- auth_password="open sesame").body,
- b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ self.assertEqual(
+ self.fetch(
+ "/auth", auth_username="Aladdin", auth_password="open sesame"
+ ).body,
+ b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
+ )
def test_basic_auth_explicit_mode(self):
- self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
- auth_password="open sesame",
- auth_mode="basic").body,
- b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ self.assertEqual(
+ self.fetch(
+ "/auth",
+ auth_username="Aladdin",
+ auth_password="open sesame",
+ auth_mode="basic",
+ ).body,
+ b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
+ )
def test_basic_auth_unicode(self):
# This test data appears in section 2.1 of RFC 7617.
- self.assertEqual(self.fetch("/auth", auth_username="test",
- auth_password="123£").body,
- b"Basic dGVzdDoxMjPCow==")
+ self.assertEqual(
+ self.fetch("/auth", auth_username="test", auth_password="123£").body,
+ b"Basic dGVzdDoxMjPCow==",
+ )
# The standard mandates NFC. Give it a decomposed username
# and ensure it is normalized to composed form.
username = unicodedata.normalize("NFD", u"josé")
- self.assertEqual(self.fetch("/auth",
- auth_username=username,
- auth_password="səcrət").body,
- b"Basic am9zw6k6c8mZY3LJmXQ=")
+ self.assertEqual(
+ self.fetch("/auth", auth_username=username, auth_password="səcrət").body,
+ b"Basic am9zw6k6c8mZY3LJmXQ=",
+ )
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)):
- self.fetch("/auth", auth_username="Aladdin",
- auth_password="open sesame",
- auth_mode="asdf",
- raise_error=True)
+ self.fetch(
+ "/auth",
+ auth_username="Aladdin",
+ auth_password="open sesame",
+ auth_mode="asdf",
+ raise_error=True,
+ )
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
response = self.fetch(url)
- self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"),
- response.body)
+ self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body)
def test_body_encoding(self):
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
- response = self.fetch("/echopost", method="POST", body=unicode_body,
- headers={"Content-Type": "application/blah"})
+ response = self.fetch(
+ "/echopost",
+ method="POST",
+ body=unicode_body,
+ headers={"Content-Type": "application/blah"},
+ )
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
- response = self.fetch("/echopost", method="POST",
- body=byte_body,
- headers={"Content-Type": "application/blah"})
+ response = self.fetch(
+ "/echopost",
+ method="POST",
+ body=byte_body,
+ headers={"Content-Type": "application/blah"},
+ )
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
- response = self.fetch("/echopost", method="POST", body=byte_body,
- headers={"Content-Type": "application/blah"},
- user_agent=u"foo")
+ response = self.fetch(
+ "/echopost",
+ method="POST",
+ body=byte_body,
+ headers={"Content-Type": "application/blah"},
+ user_agent=u"foo",
+ )
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
chunks = []
def header_callback(header_line):
- if header_line.startswith('HTTP/1.1 101'):
+ if header_line.startswith("HTTP/1.1 101"):
# Upgrading to HTTP/2
pass
- elif header_line.startswith('HTTP/'):
+ elif header_line.startswith("HTTP/"):
first_line.append(header_line)
- elif header_line != '\r\n':
- k, v = header_line.split(':', 1)
+ elif header_line != "\r\n":
+ k, v = header_line.split(":", 1)
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
- self.assertEqual(headers['content-type'], 'text/html; charset=UTF-8')
+ self.assertEqual(headers["content-type"], "text/html; charset=UTF-8")
chunks.append(chunk)
- self.fetch('/chunk', header_callback=header_callback,
- streaming_callback=streaming_callback)
+ self.fetch(
+ "/chunk",
+ header_callback=header_callback,
+ streaming_callback=streaming_callback,
+ )
self.assertEqual(len(first_line), 1, first_line)
- self.assertRegexpMatches(first_line[0], 'HTTP/[0-9]\\.[0-9] 200.*\r\n')
- self.assertEqual(chunks, [b'asdf', b'qwer'])
+ self.assertRegexpMatches(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n")
+ self.assertEqual(chunks, [b"asdf", b"qwer"])
@gen_test
def test_configure_defaults(self):
- defaults = dict(user_agent='TestDefaultUserAgent', allow_ipv6=False)
+ defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False)
# Construct a new instance of the configured client class
- client = self.http_client.__class__(force_instance=True,
- defaults=defaults)
+ client = self.http_client.__class__(force_instance=True, defaults=defaults)
try:
- response = yield client.fetch(self.get_url('/user_agent'))
- self.assertEqual(response.body, b'TestDefaultUserAgent')
+ response = yield client.fetch(self.get_url("/user_agent"))
+ self.assertEqual(response.body, b"TestDefaultUserAgent")
finally:
client.close()
for value in [u"MyUserAgent", b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
- headers['User-Agent'] = value
- resp = self.fetch('/user_agent', headers=headers)
+ headers["User-Agent"] = value
+ resp = self.fetch("/user_agent", headers=headers)
self.assertEqual(
- resp.body, b"MyUserAgent",
- "response=%r, value=%r, container=%r" %
- (resp.body, value, container))
+ resp.body,
+ b"MyUserAgent",
+ "response=%r, value=%r, container=%r"
+ % (resp.body, value, container),
+ )
def test_multi_line_headers(self):
# Multi-line http headers are rare but rfc-allowed
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
+
@gen.coroutine
def accept_callback(conn, address):
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
- yield stream.write(b"""\
+ yield stream.write(
+ b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
stream.close()
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
- self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")
+ self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block")
self.io_loop.remove_handler(sock.fileno())
def test_304_with_content_length(self):
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
- response = self.fetch('/304_with_content_length')
+ response = self.fetch("/304_with_content_length")
self.assertEqual(response.code, 304)
- self.assertEqual(response.headers['Content-Length'], '42')
+ self.assertEqual(response.headers["Content-Length"], "42")
@gen_test
def test_future_interface(self):
- response = yield self.http_client.fetch(self.get_url('/hello'))
- self.assertEqual(response.body, b'Hello world!')
+ response = yield self.http_client.fetch(self.get_url("/hello"))
+ self.assertEqual(response.body, b"Hello world!")
@gen_test
def test_future_http_error(self):
with self.assertRaises(HTTPError) as context:
- yield self.http_client.fetch(self.get_url('/notfound'))
+ yield self.http_client.fetch(self.get_url("/notfound"))
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
- response = yield self.http_client.fetch(self.get_url('/notfound'), raise_error=False)
+ response = yield self.http_client.fetch(
+ self.get_url("/notfound"), raise_error=False
+ )
self.assertEqual(response.code, 404)
@gen_test
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
- url = self.get_url('/hello')
+ url = self.get_url("/hello")
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
- self.assertEqual(response2.body, b'Hello world!')
+ self.assertEqual(response2.body, b"Hello world!")
def test_all_methods(self):
- for method in ['GET', 'DELETE', 'OPTIONS']:
- response = self.fetch('/all_methods', method=method)
+ for method in ["GET", "DELETE", "OPTIONS"]:
+ response = self.fetch("/all_methods", method=method)
self.assertEqual(response.body, utf8(method))
- for method in ['POST', 'PUT', 'PATCH']:
- response = self.fetch('/all_methods', method=method, body=b'')
+ for method in ["POST", "PUT", "PATCH"]:
+ response = self.fetch("/all_methods", method=method, body=b"")
self.assertEqual(response.body, utf8(method))
- response = self.fetch('/all_methods', method='HEAD')
- self.assertEqual(response.body, b'')
- response = self.fetch('/all_methods', method='OTHER',
- allow_nonstandard_methods=True)
- self.assertEqual(response.body, b'OTHER')
+ response = self.fetch("/all_methods", method="HEAD")
+ self.assertEqual(response.body, b"")
+ response = self.fetch(
+ "/all_methods", method="OTHER", allow_nonstandard_methods=True
+ )
+ self.assertEqual(response.body, b"OTHER")
def test_body_sanity_checks(self):
# These methods require a body.
- for method in ('POST', 'PUT', 'PATCH'):
+ for method in ("POST", "PUT", "PATCH"):
with self.assertRaises(ValueError) as context:
- self.fetch('/all_methods', method=method, raise_error=True)
- self.assertIn('must not be None', str(context.exception))
+ self.fetch("/all_methods", method=method, raise_error=True)
+ self.assertIn("must not be None", str(context.exception))
- resp = self.fetch('/all_methods', method=method,
- allow_nonstandard_methods=True)
+ resp = self.fetch(
+ "/all_methods", method=method, allow_nonstandard_methods=True
+ )
self.assertEqual(resp.code, 200)
# These methods don't allow a body.
- for method in ('GET', 'DELETE', 'OPTIONS'):
+ for method in ("GET", "DELETE", "OPTIONS"):
with self.assertRaises(ValueError) as context:
- self.fetch('/all_methods', method=method, body=b'asdf', raise_error=True)
- self.assertIn('must be None', str(context.exception))
+ self.fetch(
+ "/all_methods", method=method, body=b"asdf", raise_error=True
+ )
+ self.assertIn("must be None", str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
- if method != 'GET':
- self.fetch('/all_methods', method=method, body=b'asdf',
- allow_nonstandard_methods=True, raise_error=True)
+ if method != "GET":
+ self.fetch(
+ "/all_methods",
+ method=method,
+ body=b"asdf",
+ allow_nonstandard_methods=True,
+ raise_error=True,
+ )
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
- response = self.fetch("/redirect?status=307&url=/put",
- method="PUT", body=b"hello")
+ response = self.fetch(
+ "/redirect?status=307&url=/put", method="PUT", body=b"hello"
+ )
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/',
- user_agent='foo'),
- dict())
- self.assertEqual(proxy.user_agent, 'foo')
+ proxy = _RequestProxy(
+ HTTPRequest("http://example.com/", user_agent="foo"), dict()
+ )
+ self.assertEqual(proxy.user_agent, "foo")
def test_default_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'),
- dict(network_interface='foo'))
- self.assertEqual(proxy.network_interface, 'foo')
+ proxy = _RequestProxy(
+ HTTPRequest("http://example.com/"), dict(network_interface="foo")
+ )
+ self.assertEqual(proxy.network_interface, "foo")
def test_both_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/',
- proxy_host='foo'),
- dict(proxy_host='bar'))
- self.assertEqual(proxy.proxy_host, 'foo')
+ proxy = _RequestProxy(
+ HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar")
+ )
+ self.assertEqual(proxy.proxy_host, "foo")
def test_neither_set(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'),
- dict())
+ proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'),
- dict())
+ proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
- proxy = _RequestProxy(HTTPRequest('http://example.com/'), None)
+ proxy = _RequestProxy(HTTPRequest("http://example.com/"), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
- response = HTTPResponse(HTTPRequest('http://example.com'), # type: ignore
- 200, headers={}, buffer=BytesIO())
+ response = HTTPResponse( # type: ignore
+ HTTPRequest("http://example.com"), 200, headers={}, buffer=BytesIO()
+ )
s = str(response)
- self.assertTrue(s.startswith('HTTPResponse('))
- self.assertIn('code=200', s)
+ self.assertTrue(s.startswith("HTTPResponse("))
+ self.assertIn("code=200", s)
class SyncHTTPClientTest(unittest.TestCase):
@gen.coroutine
def init_server():
sock, self.port = bind_unused_port()
- app = Application([('/', HelloWorldHandler)])
+ app = Application([("/", HelloWorldHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
+
self.server_ioloop.run_sync(init_server)
self.server_thread = threading.Thread(target=self.server_ioloop.start)
for i in range(5):
yield
self.server_ioloop.stop()
+
self.server_ioloop.add_callback(slow_stop)
+
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
- return 'http://127.0.0.1:%d%s' % (self.port, path)
+ return "http://127.0.0.1:%d%s" % (self.port, path)
def test_sync_client(self):
- response = self.http_client.fetch(self.get_url('/'))
- self.assertEqual(b'Hello world!', response.body)
+ response = self.http_client.fetch(self.get_url("/"))
+ self.assertEqual(b"Hello world!", response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
- self.http_client.fetch(self.get_url('/notfound'))
+ self.http_client.fetch(self.get_url("/notfound"))
self.assertEqual(assertion.exception.code, 404)
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
- request = HTTPRequest('http://example.com', headers={'foo': 'bar'})
- self.assertEqual(request.headers, {'foo': 'bar'})
+ request = HTTPRequest("http://example.com", headers={"foo": "bar"})
+ self.assertEqual(request.headers, {"foo": "bar"})
def test_headers_setter(self):
- request = HTTPRequest('http://example.com')
- request.headers = {'bar': 'baz'} # type: ignore
- self.assertEqual(request.headers, {'bar': 'baz'})
+ request = HTTPRequest("http://example.com")
+ request.headers = {"bar": "baz"} # type: ignore
+ self.assertEqual(request.headers, {"bar": "baz"})
def test_null_headers_setter(self):
- request = HTTPRequest('http://example.com')
+ request = HTTPRequest("http://example.com")
request.headers = None # type: ignore
self.assertEqual(request.headers, {})
def test_body(self):
- request = HTTPRequest('http://example.com', body='foo')
- self.assertEqual(request.body, utf8('foo'))
+ request = HTTPRequest("http://example.com", body="foo")
+ self.assertEqual(request.body, utf8("foo"))
def test_body_setter(self):
- request = HTTPRequest('http://example.com')
- request.body = 'foo' # type: ignore
- self.assertEqual(request.body, utf8('foo'))
+ request = HTTPRequest("http://example.com")
+ request.body = "foo" # type: ignore
+ self.assertEqual(request.body, utf8("foo"))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
- request = HTTPRequest('http://example.com', if_modified_since=http_date)
- self.assertEqual(request.headers,
- {'If-Modified-Since': format_timestamp(http_date)})
+ request = HTTPRequest("http://example.com", if_modified_since=http_date)
+ self.assertEqual(
+ request.headers, {"If-Modified-Since": format_timestamp(http_date)}
+ )
class HTTPErrorTestCase(unittest.TestCase):
self.assertEqual(repr(e), "HTTP 403: Forbidden")
def test_error_with_response(self):
- resp = HTTPResponse(HTTPRequest('http://example.com/'), 403)
+ resp = HTTPResponse(HTTPRequest("http://example.com/"), 403)
with self.assertRaises(HTTPError) as cm:
resp.rethrow()
e = cm.exception
from tornado import gen, netutil
from tornado.concurrent import Future
-from tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str
+from tornado.escape import (
+ json_decode,
+ json_encode,
+ utf8,
+ _unicode,
+ recursive_unicode,
+ native_str,
+)
from tornado.http1connection import HTTP1Connection
from tornado.httpclient import HTTPError
from tornado.httpserver import HTTPServer
-from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
+from tornado.httputil import (
+ HTTPHeaders,
+ HTTPMessageDelegate,
+ HTTPServerConnectionDelegate,
+ ResponseStartLine,
+) # noqa: E501
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test # noqa: E501
+from tornado.testing import (
+ AsyncHTTPTestCase,
+ AsyncHTTPSTestCase,
+ AsyncTestCase,
+ ExpectLog,
+ gen_test,
+) # noqa: E501
from tornado.test.util import skipOnTravis
from tornado.web import Application, RequestHandler, stream_request_body
from io import BytesIO
import typing
+
if typing.TYPE_CHECKING:
from typing import Dict, List # noqa: F401
def finish(self):
conn.detach() # type: ignore
- callback((self.start_line, self.headers, b''.join(chunks)))
+ callback((self.start_line, self.headers, b"".join(chunks)))
+
conn = HTTP1Connection(stream, True)
conn.read_response(Delegate())
class HandlerBaseTestCase(AsyncHTTPTestCase):
def get_app(self):
- return Application([('/', self.__class__.Handler)])
+ return Application([("/", self.__class__.Handler)])
def fetch_json(self, *args, **kwargs):
response = self.fetch(*args, **kwargs)
# introduced in python3.2, it was present but undocumented in
# python 2.7
skipIfOldSSL = unittest.skipIf(
- getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0),
- "old version of ssl module and/or openssl")
+ getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0),
+ "old version of ssl module and/or openssl",
+)
class BaseSSLTest(AsyncHTTPSTestCase):
def get_app(self):
- return Application([('/', HelloWorldRequestHandler,
- dict(protocol="https"))])
+ return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))])
class SSLTestMixin(object):
def get_ssl_options(self):
- return dict(ssl_version=self.get_ssl_version(),
- **AsyncHTTPSTestCase.default_ssl_options())
+ return dict(
+ ssl_version=self.get_ssl_version(),
+ **AsyncHTTPSTestCase.default_ssl_options()
+ )
def get_ssl_version(self):
raise NotImplementedError()
def test_ssl(self):
- response = self.fetch('/')
+ response = self.fetch("/")
self.assertEqual(response.body, b"Hello world")
def test_large_post(self):
- response = self.fetch('/',
- method='POST',
- body='A' * 5000)
+ response = self.fetch("/", method="POST", body="A" * 5000)
self.assertEqual(response.body, b"Got 5000 bytes in POST")
def test_non_ssl_request(self):
# Make sure the server closes the connection when it gets a non-ssl
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
- with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
- with ExpectLog(gen_log, 'Uncaught exception', required=False):
+ with ExpectLog(gen_log, "(SSL Error|uncaught exception)"):
+ with ExpectLog(gen_log, "Uncaught exception", required=False):
with self.assertRaises((IOError, HTTPError)):
self.fetch(
- self.get_url("/").replace('https:', 'http:'),
+ self.get_url("/").replace("https:", "http:"),
request_timeout=3600,
connect_timeout=3600,
- raise_error=True)
+ raise_error=True,
+ )
def test_error_logging(self):
# No stack traces are logged for SSL errors.
- with ExpectLog(gen_log, 'SSL Error') as expect_log:
+ with ExpectLog(gen_log, "SSL Error") as expect_log:
with self.assertRaises((IOError, HTTPError)):
- self.fetch(self.get_url("/").replace("https:", "http:"),
- raise_error=True)
+ self.fetch(
+ self.get_url("/").replace("https:", "http:"), raise_error=True
+ )
self.assertFalse(expect_log.logged_stack)
+
# Python's SSL implementation differs significantly between versions.
# For example, SSLv3 and TLSv1 throw an exception if you try to read
# from the socket before the handshake is complete, but the default
class SSLContextTest(BaseSSLTest, SSLTestMixin):
def get_ssl_options(self):
- context = ssl_options_to_context(
- AsyncHTTPSTestCase.get_ssl_options(self))
+ context = ssl_options_to_context(AsyncHTTPSTestCase.get_ssl_options(self))
assert isinstance(context, ssl.SSLContext)
return context
class BadSSLOptionsTest(unittest.TestCase):
def test_missing_arguments(self):
application = Application()
- self.assertRaises(KeyError, HTTPServer, application, ssl_options={
- "keyfile": "/__missing__.crt",
- })
+ self.assertRaises(
+ KeyError,
+ HTTPServer,
+ application,
+ ssl_options={"keyfile": "/__missing__.crt"},
+ )
def test_missing_key(self):
"""A missing SSL key should cause an immediate exception."""
application = Application()
module_dir = os.path.dirname(__file__)
- existing_certificate = os.path.join(module_dir, 'test.crt')
- existing_key = os.path.join(module_dir, 'test.key')
-
- self.assertRaises((ValueError, IOError),
- HTTPServer, application, ssl_options={
- "certfile": "/__mising__.crt",
- })
- self.assertRaises((ValueError, IOError),
- HTTPServer, application, ssl_options={
- "certfile": existing_certificate,
- "keyfile": "/__missing__.key"
- })
+ existing_certificate = os.path.join(module_dir, "test.crt")
+ existing_key = os.path.join(module_dir, "test.key")
+
+ self.assertRaises(
+ (ValueError, IOError),
+ HTTPServer,
+ application,
+ ssl_options={"certfile": "/__mising__.crt"},
+ )
+ self.assertRaises(
+ (ValueError, IOError),
+ HTTPServer,
+ application,
+ ssl_options={
+ "certfile": existing_certificate,
+ "keyfile": "/__missing__.key",
+ },
+ )
# This actually works because both files exist
- HTTPServer(application, ssl_options={
- "certfile": existing_certificate,
- "keyfile": existing_key,
- })
+ HTTPServer(
+ application,
+ ssl_options={"certfile": existing_certificate, "keyfile": existing_key},
+ )
class MultipartTestHandler(RequestHandler):
def post(self):
- self.finish({"header": self.request.headers["X-Header-Encoding-Test"],
- "argument": self.get_argument("argument"),
- "filename": self.request.files["files"][0].filename,
- "filebody": _unicode(self.request.files["files"][0]["body"]),
- })
+ self.finish(
+ {
+ "header": self.request.headers["X-Header-Encoding-Test"],
+ "argument": self.get_argument("argument"),
+ "filename": self.request.files["files"][0].filename,
+ "filebody": _unicode(self.request.files["files"][0]["body"]),
+ }
+ )
# This test is also called from wsgi_test
class HTTPConnectionTest(AsyncHTTPTestCase):
def get_handlers(self):
- return [("/multipart", MultipartTestHandler),
- ("/hello", HelloWorldRequestHandler)]
+ return [
+ ("/multipart", MultipartTestHandler),
+ ("/hello", HelloWorldRequestHandler),
+ ]
def get_app(self):
return Application(self.get_handlers())
def raw_fetch(self, headers, body, newline=b"\r\n"):
with closing(IOStream(socket.socket())) as stream:
- self.io_loop.run_sync(lambda: stream.connect(('127.0.0.1', self.get_http_port())))
+ self.io_loop.run_sync(
+ lambda: stream.connect(("127.0.0.1", self.get_http_port()))
+ )
stream.write(
- newline.join(headers +
- [utf8("Content-Length: %d" % len(body))]) +
- newline + newline + body)
+ newline.join(headers + [utf8("Content-Length: %d" % len(body))])
+ + newline
+ + newline
+ + body
+ )
read_stream_body(stream, self.stop)
start_line, headers, body = self.wait()
return body
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
# anything (we use utf8 by default).
- response = self.raw_fetch([
- b"POST /multipart HTTP/1.0",
- b"Content-Type: multipart/form-data; boundary=1234567890",
- b"X-Header-encoding-test: \xe9",
- ],
- b"\r\n".join([
- b"Content-Disposition: form-data; name=argument",
- b"",
- u"\u00e1".encode("utf-8"),
- b"--1234567890",
- u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"),
- b"",
- u"\u00fa".encode("utf-8"),
- b"--1234567890--",
- b"",
- ]))
+ response = self.raw_fetch(
+ [
+ b"POST /multipart HTTP/1.0",
+ b"Content-Type: multipart/form-data; boundary=1234567890",
+ b"X-Header-encoding-test: \xe9",
+ ],
+ b"\r\n".join(
+ [
+ b"Content-Disposition: form-data; name=argument",
+ b"",
+ u"\u00e1".encode("utf-8"),
+ b"--1234567890",
+ u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode(
+ "utf8"
+ ),
+ b"",
+ u"\u00fa".encode("utf-8"),
+ b"--1234567890--",
+ b"",
+ ]
+ ),
+ )
data = json_decode(response)
self.assertEqual(u"\u00e9", data["header"])
self.assertEqual(u"\u00e1", data["argument"])
def test_newlines(self):
# We support both CRLF and bare LF as line separators.
for newline in (b"\r\n", b"\n"):
- response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"",
- newline=newline)
- self.assertEqual(response, b'Hello world')
+ response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline)
+ self.assertEqual(response, b"Hello world")
@gen_test
def test_100_continue(self):
# headers, and then the real response after the body.
stream = IOStream(socket.socket())
yield stream.connect(("127.0.0.1", self.get_http_port()))
- yield stream.write(b"\r\n".join([
- b"POST /hello HTTP/1.1",
- b"Content-Length: 1024",
- b"Expect: 100-continue",
- b"Connection: close",
- b"\r\n"]))
+ yield stream.write(
+ b"\r\n".join(
+ [
+ b"POST /hello HTTP/1.1",
+ b"Content-Length: 1024",
+ b"Expect: 100-continue",
+ b"Connection: close",
+ b"\r\n",
+ ]
+ )
+ )
data = yield stream.read_until(b"\r\n\r\n")
self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
stream.write(b"a" * 1024)
first_line = yield stream.read_until(b"\r\n")
self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
header_data = yield stream.read_until(b"\r\n\r\n")
- headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
+ headers = HTTPHeaders.parse(native_str(header_data.decode("latin1")))
body = yield stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Got 1024 bytes in POST")
stream.close()
def prepare(self):
self.errors = {} # type: Dict[str, str]
fields = [
- ('method', str),
- ('uri', str),
- ('version', str),
- ('remote_ip', str),
- ('protocol', str),
- ('host', str),
- ('path', str),
- ('query', str),
+ ("method", str),
+ ("uri", str),
+ ("version", str),
+ ("remote_ip", str),
+ ("protocol", str),
+ ("host", str),
+ ("path", str),
+ ("query", str),
]
for field, expected_type in fields:
self.check_type(field, getattr(self.request, field), expected_type)
- self.check_type('header_key', list(self.request.headers.keys())[0], str)
- self.check_type('header_value', list(self.request.headers.values())[0], str)
+ self.check_type("header_key", list(self.request.headers.keys())[0], str)
+ self.check_type("header_value", list(self.request.headers.values())[0], str)
- self.check_type('cookie_key', list(self.request.cookies.keys())[0], str)
- self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str)
+ self.check_type("cookie_key", list(self.request.cookies.keys())[0], str)
+ self.check_type(
+ "cookie_value", list(self.request.cookies.values())[0].value, str
+ )
# secure cookies
- self.check_type('arg_key', list(self.request.arguments.keys())[0], str)
- self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes)
+ self.check_type("arg_key", list(self.request.arguments.keys())[0], str)
+ self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes)
def post(self):
- self.check_type('body', self.request.body, bytes)
+ self.check_type("body", self.request.body, bytes)
self.write(self.errors)
def get(self):
def check_type(self, name, obj, expected_type):
actual_type = type(obj)
if expected_type != actual_type:
- self.errors[name] = "expected %s, got %s" % (expected_type,
- actual_type)
+ self.errors[name] = "expected %s, got %s" % (expected_type, actual_type)
class HTTPServerTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([("/echo", EchoHandler),
- ("/typecheck", TypeCheckHandler),
- ("//doubleslash", EchoHandler),
- ])
+ return Application(
+ [
+ ("/echo", EchoHandler),
+ ("/typecheck", TypeCheckHandler),
+ ("//doubleslash", EchoHandler),
+ ]
+ )
def test_query_string_encoding(self):
response = self.fetch("/echo?foo=%C3%A9")
data = json_decode(response.body)
self.assertEqual(data, {})
- response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers)
+ response = self.fetch(
+ "/typecheck", method="POST", body="foo=bar", headers=headers
+ )
data = json_decode(response.body)
self.assertEqual(data, {})
def test_malformed_body(self):
# parse_qs is pretty forgiving, but it will fail on python 3
# if the data is not utf8.
- with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'):
+ with ExpectLog(gen_log, "Invalid x-www-form-urlencoded body"):
response = self.fetch(
- '/echo', method="POST",
- headers={'Content-Type': 'application/x-www-form-urlencoded'},
- body=b'\xe9')
+ "/echo",
+ method="POST",
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ body=b"\xe9",
+ )
self.assertEqual(200, response.code)
- self.assertEqual(b'{}', response.body)
+ self.assertEqual(b"{}", response.body)
class HTTPServerRawTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([
- ('/echo', EchoHandler),
- ])
+ return Application([("/echo", EchoHandler)])
def setUp(self):
super(HTTPServerRawTest, self).setUp()
self.stream = IOStream(socket.socket())
- self.io_loop.run_sync(lambda: self.stream.connect(('127.0.0.1', self.get_http_port())))
+ self.io_loop.run_sync(
+ lambda: self.stream.connect(("127.0.0.1", self.get_http_port()))
+ )
def tearDown(self):
self.stream.close()
self.wait()
def test_malformed_first_line_response(self):
- with ExpectLog(gen_log, '.*Malformed HTTP request line'):
- self.stream.write(b'asdf\r\n\r\n')
+ with ExpectLog(gen_log, ".*Malformed HTTP request line"):
+ self.stream.write(b"asdf\r\n\r\n")
read_stream_body(self.stream, self.stop)
start_line, headers, response = self.wait()
- self.assertEqual('HTTP/1.1', start_line.version)
+ self.assertEqual("HTTP/1.1", start_line.version)
self.assertEqual(400, start_line.code)
- self.assertEqual('Bad Request', start_line.reason)
+ self.assertEqual("Bad Request", start_line.reason)
def test_malformed_first_line_log(self):
- with ExpectLog(gen_log, '.*Malformed HTTP request line'):
- self.stream.write(b'asdf\r\n\r\n')
+ with ExpectLog(gen_log, ".*Malformed HTTP request line"):
+ self.stream.write(b"asdf\r\n\r\n")
# TODO: need an async version of ExpectLog so we don't need
# hard-coded timeouts here.
- self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
- self.stop)
+ self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop)
self.wait()
def test_malformed_headers(self):
- with ExpectLog(gen_log, '.*Malformed HTTP message.*no colon in header line'):
- self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
- self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
- self.stop)
+ with ExpectLog(gen_log, ".*Malformed HTTP message.*no colon in header line"):
+ self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
+ self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop)
self.wait()
def test_chunked_request_body(self):
# Chunked requests are not widely supported and we don't have a way
# to generate them in AsyncHTTPClient, but HTTPServer will read them.
- self.stream.write(b"""\
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded
bar
0
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
read_stream_body(self.stream, self.stop)
start_line, headers, response = self.wait()
- self.assertEqual(json_decode(response), {u'foo': [u'bar']})
+ self.assertEqual(json_decode(response), {u"foo": [u"bar"]})
def test_chunked_request_uppercase(self):
# As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
# case-insensitive.
- self.stream.write(b"""\
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded
bar
0
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
read_stream_body(self.stream, self.stop)
start_line, headers, response = self.wait()
- self.assertEqual(json_decode(response), {u'foo': [u'bar']})
+ self.assertEqual(json_decode(response), {u"foo": [u"bar"]})
@gen_test
def test_invalid_content_length(self):
- with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
- self.stream.write(b"""\
+ with ExpectLog(gen_log, ".*Only integer Content-Length is allowed"):
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
Content-Length: foo
bar
-""".replace(b"\n", b"\r\n"))
+""".replace(
+ b"\n", b"\r\n"
+ )
+ )
yield self.stream.read_until_close()
class XHeaderTest(HandlerBaseTestCase):
class Handler(RequestHandler):
def get(self):
- self.set_header('request-version', self.request.version)
- self.write(dict(remote_ip=self.request.remote_ip,
- remote_protocol=self.request.protocol))
+ self.set_header("request-version", self.request.version)
+ self.write(
+ dict(
+ remote_ip=self.request.remote_ip,
+ remote_protocol=self.request.protocol,
+ )
+ )
def get_httpserver_options(self):
- return dict(xheaders=True, trusted_downstream=['5.5.5.5'])
+ return dict(xheaders=True, trusted_downstream=["5.5.5.5"])
def test_ip_headers(self):
self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1")
valid_ipv4 = {"X-Real-IP": "4.4.4.4"}
self.assertEqual(
- self.fetch_json("/", headers=valid_ipv4)["remote_ip"],
- "4.4.4.4")
+ self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4"
+ )
valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"}
self.assertEqual(
- self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"],
- "4.4.4.4")
+ self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4"
+ )
valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6)["remote_ip"],
- "2620:0:1cfe:face:b00c::3")
+ "2620:0:1cfe:face:b00c::3",
+ )
valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"}
self.assertEqual(
self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"],
- "2620:0:1cfe:face:b00c::3")
+ "2620:0:1cfe:face:b00c::3",
+ )
invalid_chars = {"X-Real-IP": "4.4.4.4<script>"}
self.assertEqual(
- self.fetch_json("/", headers=invalid_chars)["remote_ip"],
- "127.0.0.1")
+ self.fetch_json("/", headers=invalid_chars)["remote_ip"], "127.0.0.1"
+ )
invalid_chars_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5<script>"}
self.assertEqual(
- self.fetch_json("/", headers=invalid_chars_list)["remote_ip"],
- "127.0.0.1")
+ self.fetch_json("/", headers=invalid_chars_list)["remote_ip"], "127.0.0.1"
+ )
invalid_host = {"X-Real-IP": "www.google.com"}
self.assertEqual(
- self.fetch_json("/", headers=invalid_host)["remote_ip"],
- "127.0.0.1")
+ self.fetch_json("/", headers=invalid_host)["remote_ip"], "127.0.0.1"
+ )
def test_trusted_downstream(self):
valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"}
resp = self.fetch("/", headers=valid_ipv4_list)
- if resp.headers['request-version'].startswith('HTTP/2'):
+ if resp.headers["request-version"].startswith("HTTP/2"):
# This is a hack - there's nothing that fundamentally requires http/1
# here but tornado_http2 doesn't support it yet.
- self.skipTest('requires HTTP/1.x')
+ self.skipTest("requires HTTP/1.x")
result = json_decode(resp.body)
- self.assertEqual(result['remote_ip'], "4.4.4.4")
+ self.assertEqual(result["remote_ip"], "4.4.4.4")
def test_scheme_headers(self):
self.assertEqual(self.fetch_json("/")["remote_protocol"], "http")
https_scheme = {"X-Scheme": "https"}
self.assertEqual(
- self.fetch_json("/", headers=https_scheme)["remote_protocol"],
- "https")
+ self.fetch_json("/", headers=https_scheme)["remote_protocol"], "https"
+ )
https_forwarded = {"X-Forwarded-Proto": "https"}
self.assertEqual(
- self.fetch_json("/", headers=https_forwarded)["remote_protocol"],
- "https")
+ self.fetch_json("/", headers=https_forwarded)["remote_protocol"], "https"
+ )
https_multi_forwarded = {"X-Forwarded-Proto": "https , http"}
self.assertEqual(
self.fetch_json("/", headers=https_multi_forwarded)["remote_protocol"],
- "http")
+ "http",
+ )
http_multi_forwarded = {"X-Forwarded-Proto": "http,https"}
self.assertEqual(
self.fetch_json("/", headers=http_multi_forwarded)["remote_protocol"],
- "https")
+ "https",
+ )
bad_forwarded = {"X-Forwarded-Proto": "unknown"}
self.assertEqual(
- self.fetch_json("/", headers=bad_forwarded)["remote_protocol"],
- "http")
+ self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], "http"
+ )
class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase):
def get_app(self):
- return Application([('/', XHeaderTest.Handler)])
+ return Application([("/", XHeaderTest.Handler)])
def get_httpserver_options(self):
output = super(SSLXHeaderTest, self).get_httpserver_options()
- output['xheaders'] = True
+ output["xheaders"] = True
return output
def test_request_without_xprotocol(self):
http_scheme = {"X-Scheme": "http"}
self.assertEqual(
- self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http")
+ self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http"
+ )
bad_scheme = {"X-Scheme": "unknown"}
self.assertEqual(
- self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https")
+ self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https"
+ )
class ManualProtocolTest(HandlerBaseTestCase):
self.write(dict(protocol=self.request.protocol))
def get_httpserver_options(self):
- return dict(protocol='https')
+ return dict(protocol="https")
def test_manual_protocol(self):
- self.assertEqual(self.fetch_json('/')['protocol'], 'https')
+ self.assertEqual(self.fetch_json("/")["protocol"], "https")
-@unittest.skipIf(not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin',
- "unix sockets not supported on this platform")
+@unittest.skipIf(
+ not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
+ "unix sockets not supported on this platform",
+)
class UnixSocketTest(AsyncTestCase):
"""HTTPServers can listen on Unix sockets too.
Unfortunately, there's no way to specify a unix socket in a url for
an HTTP client, so we have to test this by hand.
"""
+
def setUp(self):
super(UnixSocketTest, self).setUp()
self.tmpdir = tempfile.mkdtemp()
response = yield self.stream.read_until(b"\r\n")
self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
header_data = yield self.stream.read_until(b"\r\n\r\n")
- headers = HTTPHeaders.parse(header_data.decode('latin1'))
+ headers = HTTPHeaders.parse(header_data.decode("latin1"))
body = yield self.stream.read_bytes(int(headers["Content-Length"]))
self.assertEqual(body, b"Hello world")
These tests don't use AsyncHTTPClient because we want to control
connection reuse and closing.
"""
+
def get_app(self):
class HelloHandler(RequestHandler):
def get(self):
- self.finish('Hello world')
+ self.finish("Hello world")
def post(self):
- self.finish('Hello world')
+ self.finish("Hello world")
class LargeHandler(RequestHandler):
def get(self):
# 512KB should be bigger than the socket buffers so it will
# be written out in chunks.
- self.write(''.join(chr(i % 256) * 1024 for i in range(512)))
+ self.write("".join(chr(i % 256) * 1024 for i in range(512)))
class FinishOnCloseHandler(RequestHandler):
@gen.coroutine
# This is not very realistic, but finishing the request
# from the close callback has the right timing to mimic
# some errors seen in the wild.
- self.finish('closed')
+ self.finish("closed")
- return Application([('/', HelloHandler),
- ('/large', LargeHandler),
- ('/finish_on_close', FinishOnCloseHandler)])
+ return Application(
+ [
+ ("/", HelloHandler),
+ ("/large", LargeHandler),
+ ("/finish_on_close", FinishOnCloseHandler),
+ ]
+ )
def setUp(self):
super(KeepAliveTest, self).setUp()
- self.http_version = b'HTTP/1.1'
+ self.http_version = b"HTTP/1.1"
def tearDown(self):
# We just closed the client side of the socket; let the IOLoop run
self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
self.wait()
- if hasattr(self, 'stream'):
+ if hasattr(self, "stream"):
self.stream.close()
super(KeepAliveTest, self).tearDown()
@gen.coroutine
def connect(self):
self.stream = IOStream(socket.socket())
- yield self.stream.connect(('127.0.0.1', self.get_http_port()))
+ yield self.stream.connect(("127.0.0.1", self.get_http_port()))
@gen.coroutine
def read_headers(self):
- first_line = yield self.stream.read_until(b'\r\n')
- self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
- header_bytes = yield self.stream.read_until(b'\r\n\r\n')
- headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
+ first_line = yield self.stream.read_until(b"\r\n")
+ self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
+ header_bytes = yield self.stream.read_until(b"\r\n\r\n")
+ headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
raise gen.Return(headers)
@gen.coroutine
def read_response(self):
self.headers = yield self.read_headers()
- body = yield self.stream.read_bytes(int(self.headers['Content-Length']))
- self.assertEqual(b'Hello world', body)
+ body = yield self.stream.read_bytes(int(self.headers["Content-Length"]))
+ self.assertEqual(b"Hello world", body)
def close(self):
self.stream.close()
@gen_test
def test_two_requests(self):
yield self.connect()
- self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
yield self.read_response()
- self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
yield self.read_response()
self.close()
@gen_test
def test_request_close(self):
yield self.connect()
- self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
yield self.read_response()
data = yield self.stream.read_until_close()
self.assertTrue(not data)
- self.assertEqual(self.headers['Connection'], 'close')
+ self.assertEqual(self.headers["Connection"], "close")
self.close()
# keepalive is supported for http 1.0 too, but it's opt-in
@gen_test
def test_http10(self):
- self.http_version = b'HTTP/1.0'
+ self.http_version = b"HTTP/1.0"
yield self.connect()
- self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
yield self.read_response()
data = yield self.stream.read_until_close()
self.assertTrue(not data)
- self.assertTrue('Connection' not in self.headers)
+ self.assertTrue("Connection" not in self.headers)
self.close()
@gen_test
def test_http10_keepalive(self):
- self.http_version = b'HTTP/1.0'
+ self.http_version = b"HTTP/1.0"
yield self.connect()
- self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
yield self.read_response()
- self.assertEqual(self.headers['Connection'], 'Keep-Alive')
- self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
+ self.assertEqual(self.headers["Connection"], "Keep-Alive")
+ self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
yield self.read_response()
- self.assertEqual(self.headers['Connection'], 'Keep-Alive')
+ self.assertEqual(self.headers["Connection"], "Keep-Alive")
self.close()
@gen_test
def test_http10_keepalive_extra_crlf(self):
- self.http_version = b'HTTP/1.0'
+ self.http_version = b"HTTP/1.0"
yield self.connect()
- self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
yield self.read_response()
- self.assertEqual(self.headers['Connection'], 'Keep-Alive')
- self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
+ self.assertEqual(self.headers["Connection"], "Keep-Alive")
+ self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
yield self.read_response()
- self.assertEqual(self.headers['Connection'], 'Keep-Alive')
+ self.assertEqual(self.headers["Connection"], "Keep-Alive")
self.close()
@gen_test
def test_pipelined_requests(self):
yield self.connect()
- self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
yield self.read_response()
yield self.read_response()
self.close()
@gen_test
def test_pipelined_cancel(self):
yield self.connect()
- self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
+ self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
# only read once
yield self.read_response()
self.close()
@gen_test
def test_cancel_during_download(self):
yield self.connect()
- self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
+ self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
yield self.read_headers()
yield self.stream.read_bytes(1024)
self.close()
@gen_test
def test_finish_while_closed(self):
yield self.connect()
- self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
+ self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
yield self.read_headers()
self.close()
@gen_test
def test_keepalive_chunked(self):
- self.http_version = b'HTTP/1.0'
+ self.http_version = b"HTTP/1.0"
yield self.connect()
- self.stream.write(b'POST / HTTP/1.0\r\n'
- b'Connection: keep-alive\r\n'
- b'Transfer-Encoding: chunked\r\n'
- b'\r\n'
- b'0\r\n'
- b'\r\n')
+ self.stream.write(
+ b"POST / HTTP/1.0\r\n"
+ b"Connection: keep-alive\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
yield self.read_response()
- self.assertEqual(self.headers['Connection'], 'Keep-Alive')
- self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
+ self.assertEqual(self.headers["Connection"], "Keep-Alive")
+ self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
yield self.read_response()
- self.assertEqual(self.headers['Connection'], 'Keep-Alive')
+ self.assertEqual(self.headers["Connection"], "Keep-Alive")
self.close()
class GzipBaseTest(object):
def get_app(self):
- return Application([('/', EchoHandler)])
+ return Application([("/", EchoHandler)])
def post_gzip(self, body):
bytesio = BytesIO()
- gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
+ gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio)
gzip_file.write(utf8(body))
gzip_file.close()
compressed_body = bytesio.getvalue()
- return self.fetch('/', method='POST', body=compressed_body,
- headers={'Content-Encoding': 'gzip'})
+ return self.fetch(
+ "/",
+ method="POST",
+ body=compressed_body,
+ headers={"Content-Encoding": "gzip"},
+ )
def test_uncompressed(self):
- response = self.fetch('/', method='POST', body='foo=bar')
- self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
+ response = self.fetch("/", method="POST", body="foo=bar")
+ self.assertEquals(json_decode(response.body), {u"foo": [u"bar"]})
class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
return dict(decompress_request=True)
def test_gzip(self):
- response = self.post_gzip('foo=bar')
- self.assertEquals(json_decode(response.body), {u'foo': [u'bar']})
+ response = self.post_gzip("foo=bar")
+ self.assertEquals(json_decode(response.body), {u"foo": [u"bar"]})
class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
# the body (but parsing form bodies is currently just a log message,
# not a fatal error).
with ExpectLog(gen_log, "Unsupported Content-Encoding"):
- response = self.post_gzip('foo=bar')
+ response = self.post_gzip("foo=bar")
self.assertEquals(json_decode(response.body), {})
class StreamingChunkSizeTest(AsyncHTTPTestCase):
# 50 characters long, and repetitive so it can be compressed.
- BODY = b'01234567890123456789012345678901234567890123456789'
+ BODY = b"01234567890123456789012345678901234567890123456789"
CHUNK_SIZE = 16
def get_http_client(self):
def finish(self):
response_body = utf8(json_encode(self.chunk_lengths))
self.connection.write_headers(
- ResponseStartLine('HTTP/1.1', 200, 'OK'),
- HTTPHeaders({'Content-Length': str(len(response_body))}))
+ ResponseStartLine("HTTP/1.1", 200, "OK"),
+ HTTPHeaders({"Content-Length": str(len(response_body))}),
+ )
self.connection.write(response_body)
self.connection.finish()
class App(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
return StreamingChunkSizeTest.MessageDelegate(request_conn)
+
return App()
def fetch_chunk_sizes(self, **kwargs):
- response = self.fetch('/', method='POST', **kwargs)
+ response = self.fetch("/", method="POST", **kwargs)
response.rethrow()
chunks = json_decode(response.body)
self.assertEqual(len(self.BODY), sum(chunks))
for chunk_size in chunks:
- self.assertLessEqual(chunk_size, self.CHUNK_SIZE,
- 'oversized chunk: ' + str(chunks))
- self.assertGreater(chunk_size, 0,
- 'empty chunk: ' + str(chunks))
+ self.assertLessEqual(
+ chunk_size, self.CHUNK_SIZE, "oversized chunk: " + str(chunks)
+ )
+ self.assertGreater(chunk_size, 0, "empty chunk: " + str(chunks))
return chunks
def compress(self, body):
bytesio = BytesIO()
- gzfile = gzip.GzipFile(mode='w', fileobj=bytesio)
+ gzfile = gzip.GzipFile(mode="w", fileobj=bytesio)
gzfile.write(body)
gzfile.close()
compressed = bytesio.getvalue()
self.assertEqual([16, 16, 16, 2], chunks)
def test_compressed_body(self):
- self.fetch_chunk_sizes(body=self.compress(self.BODY),
- headers={'Content-Encoding': 'gzip'})
+ self.fetch_chunk_sizes(
+ body=self.compress(self.BODY), headers={"Content-Encoding": "gzip"}
+ )
# Compression creates irregular boundaries so the assertions
# in fetch_chunk_sizes are as specific as we can get.
def body_producer(write):
write(self.BODY[:20])
write(self.BODY[20:])
+
chunks = self.fetch_chunk_sizes(body_producer=body_producer)
# HTTP chunk boundaries translate to application-visible breaks
self.assertEqual([16, 4, 16, 14], chunks)
def body_producer(write):
write(compressed[:20])
write(compressed[20:])
- self.fetch_chunk_sizes(body_producer=body_producer,
- headers={'Content-Encoding': 'gzip'})
+
+ self.fetch_chunk_sizes(
+ body_producer=body_producer, headers={"Content-Encoding": "gzip"}
+ )
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([('/', HelloWorldRequestHandler)])
+ return Application([("/", HelloWorldRequestHandler)])
def get_httpserver_options(self):
return dict(max_header_size=1024)
def test_small_headers(self):
- response = self.fetch("/", headers={'X-Filler': 'a' * 100})
+ response = self.fetch("/", headers={"X-Filler": "a" * 100})
response.rethrow()
self.assertEqual(response.body, b"Hello world")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read", required=False):
try:
- self.fetch("/", headers={'X-Filler': 'a' * 1000}, raise_error=True)
+ self.fetch("/", headers={"X-Filler": "a" * 1000}, raise_error=True)
self.fail("did not raise expected exception")
except HTTPError as e:
# 431 is "Request Header Fields Too Large", defined in RFC
@skipOnTravis
class IdleTimeoutTest(AsyncHTTPTestCase):
def get_app(self):
- return Application([('/', HelloWorldRequestHandler)])
+ return Application([("/", HelloWorldRequestHandler)])
def get_httpserver_options(self):
return dict(idle_connection_timeout=0.1)
@gen.coroutine
def connect(self):
stream = IOStream(socket.socket())
- yield stream.connect(('127.0.0.1', self.get_http_port()))
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
self.streams.append(stream)
raise gen.Return(stream)
self.bytes_read = 0
def prepare(self):
- if 'expected_size' in self.request.arguments:
+ if "expected_size" in self.request.arguments:
self.request.connection.set_max_body_size(
- int(self.get_argument('expected_size')))
- if 'body_timeout' in self.request.arguments:
+ int(self.get_argument("expected_size"))
+ )
+ if "body_timeout" in self.request.arguments:
self.request.connection.set_body_timeout(
- float(self.get_argument('body_timeout')))
+ float(self.get_argument("body_timeout"))
+ )
def data_received(self, data):
self.bytes_read += len(data)
def put(self):
self.write(str(self.bytes_read))
- return Application([('/buffered', BufferedHandler),
- ('/streaming', StreamingHandler)])
+ return Application(
+ [("/buffered", BufferedHandler), ("/streaming", StreamingHandler)]
+ )
def get_httpserver_options(self):
return dict(body_timeout=3600, max_body_size=4096)
return SimpleAsyncHTTPClient()
def test_small_body(self):
- response = self.fetch('/buffered', method='PUT', body=b'a' * 4096)
- self.assertEqual(response.body, b'4096')
- response = self.fetch('/streaming', method='PUT', body=b'a' * 4096)
- self.assertEqual(response.body, b'4096')
+ response = self.fetch("/buffered", method="PUT", body=b"a" * 4096)
+ self.assertEqual(response.body, b"4096")
+ response = self.fetch("/streaming", method="PUT", body=b"a" * 4096)
+ self.assertEqual(response.body, b"4096")
def test_large_body_buffered(self):
- with ExpectLog(gen_log, '.*Content-Length too long'):
- response = self.fetch('/buffered', method='PUT', body=b'a' * 10240)
+ with ExpectLog(gen_log, ".*Content-Length too long"):
+ response = self.fetch("/buffered", method="PUT", body=b"a" * 10240)
self.assertEqual(response.code, 400)
- @unittest.skipIf(os.name == 'nt', 'flaky on windows')
+ @unittest.skipIf(os.name == "nt", "flaky on windows")
def test_large_body_buffered_chunked(self):
# This test is flaky on windows for unknown reasons.
- with ExpectLog(gen_log, '.*chunked body too large'):
- response = self.fetch('/buffered', method='PUT',
- body_producer=lambda write: write(b'a' * 10240))
+ with ExpectLog(gen_log, ".*chunked body too large"):
+ response = self.fetch(
+ "/buffered",
+ method="PUT",
+ body_producer=lambda write: write(b"a" * 10240),
+ )
self.assertEqual(response.code, 400)
def test_large_body_streaming(self):
- with ExpectLog(gen_log, '.*Content-Length too long'):
- response = self.fetch('/streaming', method='PUT', body=b'a' * 10240)
+ with ExpectLog(gen_log, ".*Content-Length too long"):
+ response = self.fetch("/streaming", method="PUT", body=b"a" * 10240)
self.assertEqual(response.code, 400)
- @unittest.skipIf(os.name == 'nt', 'flaky on windows')
+ @unittest.skipIf(os.name == "nt", "flaky on windows")
def test_large_body_streaming_chunked(self):
- with ExpectLog(gen_log, '.*chunked body too large'):
- response = self.fetch('/streaming', method='PUT',
- body_producer=lambda write: write(b'a' * 10240))
+ with ExpectLog(gen_log, ".*chunked body too large"):
+ response = self.fetch(
+ "/streaming",
+ method="PUT",
+ body_producer=lambda write: write(b"a" * 10240),
+ )
self.assertEqual(response.code, 400)
def test_large_body_streaming_override(self):
- response = self.fetch('/streaming?expected_size=10240', method='PUT',
- body=b'a' * 10240)
- self.assertEqual(response.body, b'10240')
+ response = self.fetch(
+ "/streaming?expected_size=10240", method="PUT", body=b"a" * 10240
+ )
+ self.assertEqual(response.body, b"10240")
def test_large_body_streaming_chunked_override(self):
- response = self.fetch('/streaming?expected_size=10240', method='PUT',
- body_producer=lambda write: write(b'a' * 10240))
- self.assertEqual(response.body, b'10240')
+ response = self.fetch(
+ "/streaming?expected_size=10240",
+ method="PUT",
+ body_producer=lambda write: write(b"a" * 10240),
+ )
+ self.assertEqual(response.body, b"10240")
@gen_test
def test_timeout(self):
stream = IOStream(socket.socket())
try:
- yield stream.connect(('127.0.0.1', self.get_http_port()))
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
# Use a raw stream because AsyncHTTPClient won't let us read a
# response without finishing a body.
- stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
- b'Content-Length: 42\r\n\r\n')
- with ExpectLog(gen_log, 'Timeout reading body'):
+ stream.write(
+ b"PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n"
+ b"Content-Length: 42\r\n\r\n"
+ )
+ with ExpectLog(gen_log, "Timeout reading body"):
response = yield stream.read_until_close()
- self.assertEqual(response, b'')
+ self.assertEqual(response, b"")
finally:
stream.close()
# The max_body_size override is reset between requests.
stream = IOStream(socket.socket())
try:
- yield stream.connect(('127.0.0.1', self.get_http_port()))
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
# Use a raw stream so we can make sure it's all on one connection.
- stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
- b'Content-Length: 10240\r\n\r\n')
- stream.write(b'a' * 10240)
+ stream.write(
+ b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n"
+ b"Content-Length: 10240\r\n\r\n"
+ )
+ stream.write(b"a" * 10240)
fut = Future() # type: Future[bytes]
read_stream_body(stream, callback=fut.set_result)
start_line, headers, response = yield fut
- self.assertEqual(response, b'10240')
+ self.assertEqual(response, b"10240")
# Without the ?expected_size parameter, we get the old default value
- stream.write(b'PUT /streaming HTTP/1.1\r\n'
- b'Content-Length: 10240\r\n\r\n')
- with ExpectLog(gen_log, '.*Content-Length too long'):
+ stream.write(
+ b"PUT /streaming HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n"
+ )
+ with ExpectLog(gen_log, ".*Content-Length too long"):
data = yield stream.read_until_close()
- self.assertEqual(data, b'HTTP/1.1 400 Bad Request\r\n\r\n')
+ self.assertEqual(data, b"HTTP/1.1 400 Bad Request\r\n\r\n")
finally:
stream.close()
# This test will be skipped if we're using HTTP/2,
# so just close it out cleanly using the modern interface.
request.connection.write_headers(
- ResponseStartLine('', 200, 'OK'),
- HTTPHeaders())
+ ResponseStartLine("", 200, "OK"), HTTPHeaders()
+ )
request.connection.finish()
return
message = b"Hello world"
- request.connection.write(utf8("HTTP/1.1 200 OK\r\n"
- "Content-Length: %d\r\n\r\n" % len(message)))
+ request.connection.write(
+ utf8("HTTP/1.1 200 OK\r\n" "Content-Length: %d\r\n\r\n" % len(message))
+ )
request.connection.write(message)
request.connection.finish()
+
return handle_request
def test_legacy_interface(self):
- response = self.fetch('/')
+ response = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(response.body, b"Hello world")
# -*- coding: utf-8 -*-
from tornado.httputil import (
- url_concat, parse_multipart_form_data, HTTPHeaders, format_timestamp,
- HTTPServerRequest, parse_request_start_line, parse_cookie, qs_to_qsl,
- HTTPInputError, HTTPFile
+ url_concat,
+ parse_multipart_form_data,
+ HTTPHeaders,
+ format_timestamp,
+ HTTPServerRequest,
+ parse_request_start_line,
+ parse_cookie,
+ qs_to_qsl,
+ HTTPInputError,
+ HTTPFile,
)
from tornado.escape import utf8, native_str
from tornado.log import gen_log
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
- url = url_concat(
- "https://localhost/path",
- [('y', 'y'), ('z', 'z')],
- )
+ url = url_concat("https://localhost/path", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
- url = url_concat(
- "https://localhost/path",
- [('y', '/y'), ('z', 'z')],
- )
+ url = url_concat("https://localhost/path", [("y", "/y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
- url = url_concat(
- "https://localhost/path?",
- [('y', 'y'), ('z', 'z')],
- )
+ url = url_concat("https://localhost/path?", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
- url = url_concat(
- "https://localhost/path?x",
- [('y', 'y'), ('z', 'z')],
- )
+ url = url_concat("https://localhost/path?x", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_trailing_amp(self):
- url = url_concat(
- "https://localhost/path?x&",
- [('y', 'y'), ('z', 'z')],
- )
+ url = url_concat("https://localhost/path?x&", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_mult_params(self):
- url = url_concat(
- "https://localhost/path?a=1&b=2",
- [('y', 'y'), ('z', 'z')],
- )
+ url = url_concat("https://localhost/path?a=1&b=2", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
- url = url_concat(
- "https://localhost/path?r=1&t=2",
- [],
- )
+ url = url_concat("https://localhost/path?r=1&t=2", [])
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_none_params(self):
- url = url_concat(
- "https://localhost/path?r=1&t=2",
- None,
- )
+ url = url_concat("https://localhost/path?r=1&t=2", None)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_with_frag(self):
- url = url_concat(
- "https://localhost/path#tab",
- [('y', 'y')],
- )
+ url = url_concat("https://localhost/path#tab", [("y", "y")])
self.assertEqual(url, "https://localhost/path?y=y#tab")
def test_url_concat_multi_same_params(self):
- url = url_concat(
- "https://localhost/path",
- [('y', 'y1'), ('y', 'y2')],
- )
+ url = url_concat("https://localhost/path", [("y", "y1"), ("y", "y2")])
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
def test_url_concat_multi_same_query_params(self):
- url = url_concat(
- "https://localhost/path?r=1&r=2",
- [('y', 'y')],
- )
+ url = url_concat("https://localhost/path?r=1&r=2", [("y", "y")])
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
def test_url_concat_dict_params(self):
- url = url_concat(
- "https://localhost/path",
- dict(y='y'),
- )
+ url = url_concat("https://localhost/path", dict(y="y"))
self.assertEqual(url, "https://localhost/path?y=y")
class QsParseTest(unittest.TestCase):
-
def test_parsing(self):
qsstring = "a=1&b=2&a=3"
qs = urllib.parse.parse_qs(qsstring)
qsl = list(qs_to_qsl(qs))
- self.assertIn(('a', '1'), qsl)
- self.assertIn(('a', '3'), qsl)
- self.assertIn(('b', '2'), qsl)
+ self.assertIn(("a", "1"), qsl)
+ self.assertIn(("a", "3"), qsl)
+ self.assertIn(("b", "2"), qsl)
class MultipartFormDataTest(unittest.TestCase):
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
---1234--""".replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
Content-Disposition: form-data; name=files; filename=ab.txt
Foo
---1234--""".replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["body"], b"Foo")
def test_special_filenames(self):
- filenames = ['a;b.txt',
- 'a"b.txt',
- 'a";b.txt',
- 'a;"b.txt',
- 'a";";.txt',
- 'a\\"b.txt',
- 'a\\b.txt',
- ]
+ filenames = [
+ "a;b.txt",
+ 'a"b.txt',
+ 'a";b.txt',
+ 'a;"b.txt',
+ 'a";";.txt',
+ 'a\\"b.txt',
+ "a\\b.txt",
+ ]
for filename in filenames:
logging.debug("trying filename %r", filename)
str_data = """\
Content-Disposition: form-data; name="files"; filename="%s"
Foo
---1234--""" % filename.replace('\\', '\\\\').replace('"', '\\"')
+--1234--""" % filename.replace(
+ "\\", "\\\\"
+ ).replace(
+ '"', '\\"'
+ )
data = utf8(str_data.replace("\n", "\r\n"))
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
Foo
---1234--""".replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
- data = b'''\
+ data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
---1234--'''.replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
parse_multipart_form_data(b'"1234"', data, args, files)
file = files["files"][0]
self.assertEqual(file["body"], b"Foo")
def test_missing_headers(self):
- data = b'''\
+ data = b"""\
--1234
Foo
---1234--'''.replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
with ExpectLog(gen_log, "multipart/form-data missing headers"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_invalid_content_disposition(self):
- data = b'''\
+ data = b"""\
--1234
Content-Disposition: invalid; name="files"; filename="ab.txt"
Foo
---1234--'''.replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_line_does_not_end_with_correct_line_break(self):
- data = b'''\
+ data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
-Foo--1234--'''.replace(b"\n", b"\r\n")
+Foo--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
Content-Disposition: form-data; filename="ab.txt"
Foo
---1234--""".replace(b"\n", b"\r\n")
+--1234--""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
with ExpectLog(gen_log, "multipart/form-data value missing name"):
parse_multipart_form_data(b"1234", data, args, files)
Foo
--1234--
-""".replace(b"\n", b"\r\n")
+""".replace(
+ b"\n", b"\r\n"
+ )
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
Foo: even
more
lines
-""".replace("\n", "\r\n")
+""".replace(
+ "\n", "\r\n"
+ )
headers = HTTPHeaders.parse(data)
self.assertEqual(headers["asdf"], "qwer zxcv")
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
self.assertEqual(headers["Foo"], "bar baz,even more lines")
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
- self.assertEqual(sorted(list(headers.get_all())),
- [("Asdf", "qwer zxcv"),
- ("Foo", "bar baz"),
- ("Foo", "even more lines")])
+ self.assertEqual(
+ sorted(list(headers.get_all())),
+ [("Asdf", "qwer zxcv"), ("Foo", "bar baz"), ("Foo", "even more lines")],
+ )
def test_malformed_continuation(self):
# If the first line starts with whitespace, it's a
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
- u'\u001b', # VERTICAL TAB
- u'\u001c', # FILE SEPARATOR
- u'\u001d', # GROUP SEPARATOR
- u'\u001e', # RECORD SEPARATOR
- u'\u0085', # NEXT LINE
- u'\u2028', # LINE SEPARATOR
- u'\u2029', # PARAGRAPH SEPARATOR
+ u"\u001b", # VERTICAL TAB
+ u"\u001c", # FILE SEPARATOR
+ u"\u001d", # GROUP SEPARATOR
+ u"\u001e", # RECORD SEPARATOR
+ u"\u0085", # NEXT LINE
+ u"\u2028", # LINE SEPARATOR
+ u"\u2029", # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
- for encoding in ['utf8', 'latin1']:
+ for encoding in ["utf8", "latin1"]:
try:
try:
encoded = newline.encode(encoding)
except UnicodeEncodeError:
# Some chars cannot be represented in latin1
continue
- data = b'Cookie: foo=' + encoded + b'bar'
+ data = b"Cookie: foo=" + encoded + b"bar"
# parse() wants a native_str, so decode through latin1
# in the same way the real parser does.
- headers = HTTPHeaders.parse(
- native_str(data.decode('latin1')))
- expected = [('Cookie', 'foo=' +
- native_str(encoded.decode('latin1')) + 'bar')]
- self.assertEqual(
- expected, list(headers.get_all()))
+ headers = HTTPHeaders.parse(native_str(data.decode("latin1")))
+ expected = [
+ (
+ "Cookie",
+ "foo=" + native_str(encoded.decode("latin1")) + "bar",
+ )
+ ]
+ self.assertEqual(expected, list(headers.get_all()))
except Exception:
- gen_log.warning("failed while trying %r in %s",
- newline, encoding)
+ gen_log.warning("failed while trying %r in %s", newline, encoding)
raise
def test_optional_cr(self):
# Both CRLF and LF should be accepted as separators. CR should not be
# part of the data when followed by LF, but it is a normal char
# otherwise (or should bare CR be an error?)
- headers = HTTPHeaders.parse(
- 'CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n')
- self.assertEqual(sorted(headers.get_all()),
- [('Cr', 'cr\rMore: more'),
- ('Crlf', 'crlf'),
- ('Lf', 'lf'),
- ])
+ headers = HTTPHeaders.parse("CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n")
+ self.assertEqual(
+ sorted(headers.get_all()),
+ [("Cr", "cr\rMore: more"), ("Crlf", "crlf"), ("Lf", "lf")],
+ )
def test_copy(self):
- all_pairs = [('A', '1'), ('A', '2'), ('B', 'c')]
+ all_pairs = [("A", "1"), ("A", "2"), ("B", "c")]
h1 = HTTPHeaders()
for k, v in all_pairs:
h1.add(k, v)
for headers in [h2, h3, h4]:
# Neither the dict or its member lists are reused.
self.assertIsNot(headers, h1)
- self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
+ self.assertIsNot(headers.get_list("A"), h1.get_list("A"))
def test_pickle_roundtrip(self):
headers = HTTPHeaders()
- headers.add('Set-Cookie', 'a=b')
- headers.add('Set-Cookie', 'c=d')
- headers.add('Content-Type', 'text/html')
+ headers.add("Set-Cookie", "a=b")
+ headers.add("Set-Cookie", "c=d")
+ headers.add("Content-Type", "text/html")
pickled = pickle.dumps(headers)
unpickled = pickle.loads(pickled)
self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
def test_setdefault(self):
headers = HTTPHeaders()
- headers['foo'] = 'bar'
+ headers["foo"] = "bar"
# If a value is present, setdefault returns it without changes.
- self.assertEqual(headers.setdefault('foo', 'baz'), 'bar')
- self.assertEqual(headers['foo'], 'bar')
+ self.assertEqual(headers.setdefault("foo", "baz"), "bar")
+ self.assertEqual(headers["foo"], "bar")
# If a value is not present, setdefault sets it for future use.
- self.assertEqual(headers.setdefault('quux', 'xyzzy'), 'xyzzy')
- self.assertEqual(headers['quux'], 'xyzzy')
- self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
+ self.assertEqual(headers.setdefault("quux", "xyzzy"), "xyzzy")
+ self.assertEqual(headers["quux"], "xyzzy")
+ self.assertEqual(sorted(headers.get_all()), [("Foo", "bar"), ("Quux", "xyzzy")])
def test_string(self):
headers = HTTPHeaders()
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
TIMESTAMP = 1359312200.503611
- EXPECTED = 'Sun, 27 Jan 2013 18:43:20 GMT'
+ EXPECTED = "Sun, 27 Jan 2013 18:43:20 GMT"
def check(self, value):
self.assertEqual(format_timestamp(value), self.EXPECTED)
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
- HTTPServerRequest(uri='/')
+ HTTPServerRequest(uri="/")
def test_body_is_a_byte_string(self):
- requets = HTTPServerRequest(uri='/')
+ requets = HTTPServerRequest(uri="/")
self.assertIsInstance(requets.body, bytes)
def test_repr_does_not_contain_headers(self):
- request = HTTPServerRequest(uri='/', headers=HTTPHeaders({'Canary': ['Coal Mine']}))
- self.assertTrue('Canary' not in repr(request))
+ request = HTTPServerRequest(
+ uri="/", headers=HTTPHeaders({"Canary": ["Coal Mine"]})
+ )
+ self.assertTrue("Canary" not in repr(request))
class ParseRequestStartLineTest(unittest.TestCase):
"""
Test cases copied from Python's Lib/test/test_http_cookies.py
"""
- self.assertEqual(parse_cookie('chips=ahoy; vienna=finger'),
- {'chips': 'ahoy', 'vienna': 'finger'})
+ self.assertEqual(
+ parse_cookie("chips=ahoy; vienna=finger"),
+ {"chips": "ahoy", "vienna": "finger"},
+ )
# Here parse_cookie() differs from Python's cookie parsing in that it
# treats all semicolons as delimiters, even within quotes.
self.assertEqual(
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
- {'keebler': '"E=mc2', 'L': '\\"Loves\\"', 'fudge': '\\012', '': '"'}
+ {"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'},
)
# Illegal cookies that have an '=' char in an unquoted value.
- self.assertEqual(parse_cookie('keebler=E=mc2'), {'keebler': 'E=mc2'})
+ self.assertEqual(parse_cookie("keebler=E=mc2"), {"keebler": "E=mc2"})
# Cookies with ':' character in their name.
- self.assertEqual(parse_cookie('key:term=value:term'), {'key:term': 'value:term'})
+ self.assertEqual(
+ parse_cookie("key:term=value:term"), {"key:term": "value:term"}
+ )
# Cookies with '[' and ']'.
- self.assertEqual(parse_cookie('a=b; c=[; d=r; f=h'),
- {'a': 'b', 'c': '[', 'd': 'r', 'f': 'h'})
+ self.assertEqual(
+ parse_cookie("a=b; c=[; d=r; f=h"), {"a": "b", "c": "[", "d": "r", "f": "h"}
+ )
def test_cookie_edgecases(self):
# Cookies that RFC6265 allows.
- self.assertEqual(parse_cookie('a=b; Domain=example.com'),
- {'a': 'b', 'Domain': 'example.com'})
+ self.assertEqual(
+ parse_cookie("a=b; Domain=example.com"), {"a": "b", "Domain": "example.com"}
+ )
# parse_cookie() has historically kept only the last cookie with the
# same name.
- self.assertEqual(parse_cookie('a=b; h=i; a=c'), {'a': 'c', 'h': 'i'})
+ self.assertEqual(parse_cookie("a=b; h=i; a=c"), {"a": "c", "h": "i"})
def test_invalid_cookies(self):
"""
"""
# Chunks without an equals sign appear as unnamed values per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
- self.assertIn('django_language',
- parse_cookie('abc=def; unnamed; django_language=en').keys())
+ self.assertIn(
+ "django_language",
+ parse_cookie("abc=def; unnamed; django_language=en").keys(),
+ )
# Even a double quote may be an unamed value.
- self.assertEqual(parse_cookie('a=b; "; c=d'), {'a': 'b', '': '"', 'c': 'd'})
+ self.assertEqual(parse_cookie('a=b; "; c=d'), {"a": "b", "": '"', "c": "d"})
# Spaces in names and values, and an equals sign in values.
- self.assertEqual(parse_cookie('a b c=d e = f; gh=i'), {'a b c': 'd e = f', 'gh': 'i'})
+ self.assertEqual(
+ parse_cookie("a b c=d e = f; gh=i"), {"a b c": "d e = f", "gh": "i"}
+ )
# More characters the spec forbids.
- self.assertEqual(parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
- {'a b,c<>@:/[]?{}': 'd " =e,f g'})
+ self.assertEqual(
+ parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
+ {"a b,c<>@:/[]?{}": 'd " =e,f g'},
+ )
# Unicode characters. The spec only allows ASCII.
- self.assertEqual(parse_cookie('saint=André Bessette'),
- {'saint': native_str('André Bessette')})
+ self.assertEqual(
+ parse_cookie("saint=André Bessette"),
+ {"saint": native_str("André Bessette")},
+ )
# Browsers don't send extra whitespace or semicolons in Cookie headers,
# but parse_cookie() should parse whitespace the same way
# document.cookie parses whitespace.
- self.assertEqual(parse_cookie(' = b ; ; = ; c = ; '), {'': 'b', 'c': ''})
+ self.assertEqual(
+ parse_cookie(" = b ; ; = ; c = ; "), {"": "b", "c": ""}
+ )
import tornado.ioloop
import tornado.gen
import tornado.util
+
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)
from tornado.test.util import skipIfNonUnix, skipOnTravis
import typing
+
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
self.io_loop.add_callback(callback)
# Store away the time so we can check if we woke up immediately
self.start_time = time.time()
+
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
self.wait()
self.assertAlmostEqual(time.time(), self.start_time, places=2)
time.sleep(0.01)
self.stop_time = time.time()
self.io_loop.add_callback(self.stop)
+
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
def test_multiple_add(self):
sock, port = bind_unused_port()
try:
- self.io_loop.add_handler(sock.fileno(), lambda fd, events: None,
- IOLoop.READ)
+ self.io_loop.add_handler(
+ sock.fileno(), lambda fd, events: None, IOLoop.READ
+ )
# Attempting to add the same handler twice fails
# (with a platform-dependent exception)
- self.assertRaises(Exception, self.io_loop.add_handler,
- sock.fileno(), lambda fd, events: None,
- IOLoop.READ)
+ self.assertRaises(
+ Exception,
+ self.io_loop.add_handler,
+ sock.fileno(),
+ lambda fd, events: None,
+ IOLoop.READ,
+ )
finally:
self.io_loop.remove_handler(sock.fileno())
sock.close()
other_ioloop.start()
closing.set()
other_ioloop.close(all_fds=True)
+
other_ioloop = IOLoop()
thread = threading.Thread(target=target)
thread.start()
# difficult to test for)
client, server = socket.socketpair()
try:
+
def handler(fd, events):
self.assertEqual(events, IOLoop.READ)
self.stop()
+
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
- self.io_loop.add_timeout(self.io_loop.time() + 0.01,
- functools.partial(server.send, b'asdf'))
+ self.io_loop.add_timeout(
+ self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf")
+ )
self.wait()
self.io_loop.remove_handler(client.fileno())
finally:
# on PollIOLoop subclasses, but it should run silently on any
# implementation.
for i in range(2000):
- timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600,
- lambda: None)
+ timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600, lambda: None)
self.io_loop.remove_timeout(timeout)
# HACK: wait two IOLoop iterations for the GC to happen.
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
+
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
+
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
# This tests that all the timeout methods pass through *args correctly.
results = [] # type: List[int]
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
- self.io_loop.add_timeout(datetime.timedelta(seconds=0),
- results.append, 2)
+ self.io_loop.add_timeout(datetime.timedelta(seconds=0), results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
def close(self):
self.closed = True
self.sockobj.close()
+
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
- io_loop.add_handler(socket_wrapper, lambda fd, events: None,
- IOLoop.READ)
+ io_loop.add_handler(socket_wrapper, lambda fd, events: None, IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
conn, addr = server_sock.accept()
conn.close()
self.stop()
+
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
- client_sock.connect(('127.0.0.1', port))
+ client_sock.connect(("127.0.0.1", port))
self.wait()
self.io_loop.remove_handler(server_sock)
- self.io_loop.add_handler(server_sock.fileno(), handle_connection,
- IOLoop.READ)
+ self.io_loop.add_handler(server_sock.fileno(), handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
- client_sock.connect(('127.0.0.1', port))
+ client_sock.connect(("127.0.0.1", port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
def f(fd, events):
pass
+
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
except Exception:
got_exception[0] = True
self.stop()
+
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
+
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1 / 0
+
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_native_coro(self):
"""The IOLoop examines exceptions from awaitables and logs them."""
+
async def callback():
# Stop the IOLoop two iterations after raising an exception
# to give the exception time to be logged.
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
1 / 0
+
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
- client.send(b'abc')
- server.send(b'abc')
+ client.send(b"abc")
+ server.send(b"abc")
# After reading from one fd, remove the other from the IOLoop.
chunks = []
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
+
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.1, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
- self.assertEqual(chunks, [b'abc'])
+ self.assertEqual(chunks, [b"abc"])
finally:
client.close()
server.close()
# Starting the IOLoop makes it current, and stopping the loop
# makes it non-current. This process is repeatable.
for i in range(3):
+
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
+
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
+
def dummy():
pass
- self.io_loop.add_future(pool.submit(dummy),
- lambda future: self.stop(future))
+
+ self.io_loop.add_future(
+ pool.submit(dummy), lambda future: self.stop(future)
+ )
future = self.wait()
self.assertTrue(future.done())
self.assertTrue(future.result() is None)
# run in parallel.
res = yield [
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
- IOLoop.current().run_in_executor(None, sync_func, event2, event1)
+ IOLoop.current().run_in_executor(None, sync_func, event2, event1),
]
self.assertEqual([event1, event2], res)
# (simply passing the underlying concurrrent future would do that).
async def async_wrapper(self_event, other_event):
return await IOLoop.current().run_in_executor(
- None, sync_func, self_event, other_event)
+ None, sync_func, self_event, other_event
+ )
- res = yield [
- async_wrapper(event1, event2),
- async_wrapper(event2, event1)
- ]
+ res = yield [async_wrapper(event1, event2), async_wrapper(event2, event1)]
self.assertEqual([event1, event2], res)
def f():
yield gen.moment
raise gen.Return(42)
+
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
def f():
yield gen.moment
1 / 0
+
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
def test_current(self):
def f():
self.assertIs(IOLoop.current(), self.io_loop)
+
self.io_loop.run_sync(f)
def test_timeout(self):
@gen.coroutine
def f():
yield gen.sleep(1)
+
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
def test_native_coroutine(self):
async def f2():
await f1()
+
self.io_loop.run_sync(f2)
def test_basic(self):
pc = PeriodicCallback(self.dummy, 10000)
- self.assertEqual(self.simulate_calls(pc, [0] * 5),
- [1010, 1020, 1030, 1040, 1050])
+ self.assertEqual(
+ self.simulate_calls(pc, [0] * 5), [1010, 1020, 1030, 1040, 1050]
+ )
def test_overrun(self):
# If a call runs for too long, we skip entire cycles to get
# back on schedule.
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
expected = [
- 1010, 1020, 1030, # first 3 calls on schedule
- 1050, 1070, # next 2 delayed one cycle
- 1100, 1130, # next 2 delayed 2 cycles
- 1170, 1210, # next 2 delayed 3 cycles
- 1220, 1230, # then back on schedule.
+ 1010,
+ 1020,
+ 1030, # first 3 calls on schedule
+ 1050,
+ 1070, # next 2 delayed one cycle
+ 1100,
+ 1130, # next 2 delayed 2 cycles
+ 1170,
+ 1210, # next 2 delayed 3 cycles
+ 1220,
+ 1230, # then back on schedule.
]
pc = PeriodicCallback(self.dummy, 10000)
- self.assertEqual(self.simulate_calls(pc, call_durations),
- expected)
+ self.assertEqual(self.simulate_calls(pc, call_durations), expected)
def test_clock_backwards(self):
pc = PeriodicCallback(self.dummy, 10000)
# slightly slow schedule (although we assume that when
# time.time() and time.monotonic() are different, time.time()
# is getting adjusted by NTP and is therefore more accurate)
- self.assertEqual(self.simulate_calls(pc, [-2, -1, -3, -2, 0]),
- [1010, 1020, 1030, 1040, 1050])
+ self.assertEqual(
+ self.simulate_calls(pc, [-2, -1, -3, -2, 0]), [1010, 1020, 1030, 1040, 1050]
+ )
# For big jumps, we should perhaps alter the schedule, but we
# don't currently. This trace shows that we run callbacks
# every 10s of time.time(), but the first and second calls are
# 110s of real time apart because the backwards jump is
# ignored.
- self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]),
- [1010, 1020, 1030])
+ self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]), [1010, 1020, 1030])
def test_jitter(self):
random_times = [0.5, 1, 0, 0.75]
def mock_random():
return random_times.pop(0)
- with mock.patch('random.random', mock_random):
- self.assertEqual(self.simulate_calls(pc, call_durations),
- expected)
+
+ with mock.patch("random.random", mock_random):
+ self.assertEqual(self.simulate_calls(pc, call_durations), expected)
class TestIOLoopConfiguration(unittest.TestCase):
def run_python(self, *statements):
stmt_list = [
- 'from tornado.ioloop import IOLoop',
- 'classname = lambda x: x.__class__.__name__',
+ "from tornado.ioloop import IOLoop",
+ "classname = lambda x: x.__class__.__name__",
] + list(statements)
- args = [sys.executable, '-c', '; '.join(stmt_list)]
+ args = [sys.executable, "-c", "; ".join(stmt_list)]
return native_str(subprocess.check_output(args)).strip()
def test_default(self):
# When asyncio is available, it is used by default.
- cls = self.run_python('print(classname(IOLoop.current()))')
- self.assertEqual(cls, 'AsyncIOMainLoop')
- cls = self.run_python('print(classname(IOLoop()))')
- self.assertEqual(cls, 'AsyncIOLoop')
+ cls = self.run_python("print(classname(IOLoop.current()))")
+ self.assertEqual(cls, "AsyncIOMainLoop")
+ cls = self.run_python("print(classname(IOLoop()))")
+ self.assertEqual(cls, "AsyncIOLoop")
def test_asyncio(self):
cls = self.run_python(
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
- 'print(classname(IOLoop.current()))')
- self.assertEqual(cls, 'AsyncIOMainLoop')
+ "print(classname(IOLoop.current()))",
+ )
+ self.assertEqual(cls, "AsyncIOMainLoop")
def test_asyncio_main(self):
cls = self.run_python(
- 'from tornado.platform.asyncio import AsyncIOMainLoop',
- 'AsyncIOMainLoop().install()',
- 'print(classname(IOLoop.current()))')
- self.assertEqual(cls, 'AsyncIOMainLoop')
+ "from tornado.platform.asyncio import AsyncIOMainLoop",
+ "AsyncIOMainLoop().install()",
+ "print(classname(IOLoop.current()))",
+ )
+ self.assertEqual(cls, "AsyncIOMainLoop")
if __name__ == "__main__":
from tornado.concurrent import Future
from tornado import gen
from tornado import netutil
-from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError, _StreamBuffer
+from tornado.iostream import (
+ IOStream,
+ SSLIOStream,
+ PipeIOStream,
+ StreamClosedError,
+ _StreamBuffer,
+)
from tornado.httputil import HTTPHeaders
from tornado.locks import Condition, Event
from tornado.log import gen_log
from tornado.netutil import ssl_wrap_socket
from tornado.tcpserver import TCPServer
-from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test # noqa: E501
+from tornado.testing import (
+ AsyncHTTPTestCase,
+ AsyncHTTPSTestCase,
+ AsyncTestCase,
+ bind_unused_port,
+ ExpectLog,
+ gen_test,
+) # noqa: E501
from tornado.test.util import skipIfNonUnix, refusing_port, skipPypy3V58
from tornado.web import RequestHandler, Application
import errno
def _server_ssl_options():
return dict(
- certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
- keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
+ certfile=os.path.join(os.path.dirname(__file__), "test.crt"),
+ keyfile=os.path.join(os.path.dirname(__file__), "test.key"),
)
raise NotImplementedError()
def get_app(self):
- return Application([('/', HelloHandler)])
+ return Application([("/", HelloHandler)])
def test_connection_closed(self):
# When a server sends a response and then closes the connection,
@gen_test
def test_read_until_close(self):
stream = self._make_client_iostream()
- yield stream.connect(('127.0.0.1', self.get_http_port()))
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
stream.write(b"GET / HTTP/1.0\r\n\r\n")
data = yield stream.read_until_close()
def test_future_interface(self):
"""Basic test of IOStream's ability to return Futures."""
stream = self._make_client_iostream()
- connect_result = yield stream.connect(
- ("127.0.0.1", self.get_http_port()))
+ connect_result = yield stream.connect(("127.0.0.1", self.get_http_port()))
self.assertIs(connect_result, stream)
yield stream.write(b"GET / HTTP/1.0\r\n\r\n")
first_line = yield stream.read_until(b"\r\n")
self.assertEqual(first_line, b"HTTP/1.1 200 OK\r\n")
# callback=None is equivalent to no callback.
header_data = yield stream.read_until(b"\r\n\r\n")
- headers = HTTPHeaders.parse(header_data.decode('latin1'))
- content_length = int(headers['Content-Length'])
+ headers = HTTPHeaders.parse(header_data.decode("latin1"))
+ content_length = int(headers["Content-Length"])
body = yield stream.read_bytes(content_length)
- self.assertEqual(body, b'Hello')
+ self.assertEqual(body, b"Hello")
stream.close()
@gen_test
# Attempting to write zero bytes should run the callback without
# going into an infinite loop.
rs, ws = yield self.make_iostream_pair()
- yield ws.write(b'')
+ yield ws.write(b"")
ws.close()
rs.close()
# This test fails on pypy with ssl. I think it's because
# pypy's gc defeats moves objects, breaking the
# "frozen write buffer" assumption.
- if (isinstance(rs, SSLIOStream) and
- platform.python_implementation() == 'PyPy'):
- raise unittest.SkipTest(
- "pypy gc causes problems with openssl")
+ if (
+ isinstance(rs, SSLIOStream)
+ and platform.python_implementation() == "PyPy"
+ ):
+ raise unittest.SkipTest("pypy gc causes problems with openssl")
NUM_KB = 4096
for i in range(NUM_KB):
ws.write(b"A" * 1024)
def close_callback():
closed[0] = True
cond.notify()
+
rs.set_close_callback(close_callback)
try:
- ws.write(b'a')
+ ws.write(b"a")
res = yield rs.read_bytes(1)
- self.assertEqual(res, b'a')
+ self.assertEqual(res, b"a")
self.assertFalse(closed[0])
ws.close()
yield cond.wait()
# Partial reads won't return an empty string, but read_bytes(0)
# will.
data = yield rs.read_bytes(0, partial=True)
- self.assertEqual(data, b'')
+ self.assertEqual(data, b"")
finally:
ws.close()
rs.close()
def sleep_some():
self.io_loop.run_sync(lambda: gen.sleep(0.05))
+
try:
buf = bytearray(10)
fut = rs.read_into(buf)
server_stream_fut = Future() # type: Future[IOStream]
def accept_callback(connection, address):
- server_stream_fut.set_result(self._make_server_iostream(connection, **kwargs))
+ server_stream_fut.set_result(
+ self._make_server_iostream(connection, **kwargs)
+ )
netutil.add_accept_handler(listener, accept_callback)
client_stream = self._make_client_iostream(socket.socket(), **kwargs)
- connect_fut = client_stream.connect(('127.0.0.1', port))
+ connect_fut = client_stream.connect(("127.0.0.1", port))
server_stream, client_stream = yield [server_stream_fut, connect_fut]
self.io_loop.remove_handler(listener.fileno())
listener.close()
yield stream.connect(("127.0.0.1", port))
self.assertTrue(isinstance(stream.error, socket.error), stream.error)
- if sys.platform != 'cygwin':
+ if sys.platform != "cygwin":
_ERRNO_CONNREFUSED = [errno.ECONNREFUSED]
if hasattr(errno, "WSAECONNREFUSED"):
_ERRNO_CONNREFUSED.append(errno.WSAECONNREFUSED) # type: ignore
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
stream = IOStream(s)
stream.set_close_callback(self.stop)
- with mock.patch('socket.socket.connect',
- side_effect=socket.gaierror(errno.EIO, 'boom')):
+ with mock.patch(
+ "socket.socket.connect", side_effect=socket.gaierror(errno.EIO, "boom")
+ ):
with self.assertRaises(StreamClosedError):
- yield stream.connect(('localhost', 80))
+ yield stream.connect(("localhost", 80))
self.assertTrue(isinstance(stream.error, socket.gaierror))
@gen_test
def test_read_until_close_with_error(self):
server, client = yield self.make_iostream_pair()
try:
- with mock.patch('tornado.iostream.BaseIOStream._try_inline_read',
- side_effect=IOError('boom')):
- with self.assertRaisesRegexp(IOError, 'boom'):
+ with mock.patch(
+ "tornado.iostream.BaseIOStream._try_inline_read",
+ side_effect=IOError("boom"),
+ ):
+ with self.assertRaisesRegexp(IOError, "boom"):
client.read_until_close()
finally:
server.close()
try:
# Start a read that will be fulfilled asynchronously.
server.read_bytes(1)
- client.write(b'a')
+ client.write(b"a")
# Stub out read_from_fd to make it fail.
def fake_read_from_fd():
os.close(server.socket.fileno())
server.__class__.read_from_fd(server)
+
server.read_from_fd = fake_read_from_fd
# This log message is from _handle_read (not read_from_fd).
with ExpectLog(gen_log, "error on read"):
@gen.coroutine
def produce():
- data = b'x' * m
+ data = b"x" * m
for i in range(n):
yield server.write(data)
class TestIOStreamWebHTTPS(TestIOStreamWebMixin, AsyncHTTPSTestCase):
def _make_client_iostream(self):
- return SSLIOStream(socket.socket(),
- ssl_options=dict(cert_reqs=ssl.CERT_NONE))
+ return SSLIOStream(socket.socket(), ssl_options=dict(cert_reqs=ssl.CERT_NONE))
class TestIOStream(TestIOStreamMixin, AsyncTestCase):
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
- connection = ssl.wrap_socket(connection,
- server_side=True,
- do_handshake_on_connect=False,
- **_server_ssl_options())
+ connection = ssl.wrap_socket(
+ connection,
+ server_side=True,
+ do_handshake_on_connect=False,
+ **_server_ssl_options()
+ )
return SSLIOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
- return SSLIOStream(connection,
- ssl_options=dict(cert_reqs=ssl.CERT_NONE),
- **kwargs)
+ return SSLIOStream(
+ connection, ssl_options=dict(cert_reqs=ssl.CERT_NONE), **kwargs
+ )
# This will run some tests that are basically redundant but it's the
def _make_server_iostream(self, connection, **kwargs):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.load_cert_chain(
- os.path.join(os.path.dirname(__file__), 'test.crt'),
- os.path.join(os.path.dirname(__file__), 'test.key'))
- connection = ssl_wrap_socket(connection, context,
- server_side=True,
- do_handshake_on_connect=False)
+ os.path.join(os.path.dirname(__file__), "test.crt"),
+ os.path.join(os.path.dirname(__file__), "test.key"),
+ )
+ connection = ssl_wrap_socket(
+ connection, context, server_side=True, do_handshake_on_connect=False
+ )
return SSLIOStream(connection, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
self.server_accepted = Future() # type: Future[None]
netutil.add_accept_handler(self.listener, self.accept)
self.client_stream = IOStream(socket.socket())
- self.io_loop.add_future(self.client_stream.connect(
- ('127.0.0.1', self.port)), self.stop)
+ self.io_loop.add_future(
+ self.client_stream.connect(("127.0.0.1", self.port)), self.stop
+ )
self.wait()
self.io_loop.add_future(self.server_accepted, self.stop)
self.wait()
# up and in python 3.4 and up.
server_future = self.server_start_tls(_server_ssl_options())
client_future = self.client_start_tls(
- ssl.create_default_context(),
- server_hostname='127.0.0.1')
+ ssl.create_default_context(), server_hostname="127.0.0.1"
+ )
with ExpectLog(gen_log, "SSL Error"):
with self.assertRaises(ssl.SSLError):
# The client fails to connect with an SSL error.
server = server_cls(ssl_options=_server_ssl_options())
server.add_socket(sock)
- client = SSLIOStream(socket.socket(),
- ssl_options=dict(cert_reqs=ssl.CERT_NONE))
- yield client.connect(('127.0.0.1', port))
+ client = SSLIOStream(
+ socket.socket(), ssl_options=dict(cert_reqs=ssl.CERT_NONE)
+ )
+ yield client.connect(("127.0.0.1", port))
self.assertIsNotNone(client.socket.cipher())
finally:
if server is not None:
ws.write(b"hel")
ws.write(b"lo world")
- data = yield rs.read_until(b' ')
+ data = yield rs.read_until(b" ")
self.assertEqual(data, b"hello ")
data = yield rs.read_bytes(3)
if isinstance(b, (bytes, bytearray)):
return bytes(b)
elif isinstance(b, memoryview):
- return b.tobytes() # For py2
+ return b.tobytes() # For py2
else:
raise TypeError(b)
size = 1
while size < 2 * len(expected):
got = self.to_bytes(buf.peek(size))
- self.assertTrue(got) # Not empty
+ self.assertTrue(got) # Not empty
self.assertLessEqual(len(got), size)
self.assertTrue(expected.startswith(got), (expected, got))
size = (size * 3 + 1) // 2
def check_append_all_then_skip_all(self, buf, objs, input_type):
self.assertEqual(len(buf), 0)
- expected = b''
+ expected = b""
for o in objs:
expected += o
self.assertEqual(len(buf), 0)
def test_small(self):
- objs = [b'12', b'345', b'67', b'89a', b'bcde', b'fgh', b'ijklmn']
+ objs = [b"12", b"345", b"67", b"89a", b"bcde", b"fgh", b"ijklmn"]
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, bytes)
# Test internal algorithm
buf = self.make_streambuffer(10)
for i in range(9):
- buf.append(b'x')
+ buf.append(b"x")
self.assertEqual(len(buf._buffers), 1)
for i in range(9):
- buf.append(b'x')
+ buf.append(b"x")
self.assertEqual(len(buf._buffers), 2)
buf.advance(10)
self.assertEqual(len(buf._buffers), 1)
self.assertEqual(len(buf), 0)
def test_large(self):
- objs = [b'12' * 5,
- b'345' * 2,
- b'67' * 20,
- b'89a' * 12,
- b'bcde' * 1,
- b'fgh' * 7,
- b'ijklmn' * 2]
+ objs = [
+ b"12" * 5,
+ b"345" * 2,
+ b"67" * 20,
+ b"89a" * 12,
+ b"bcde" * 1,
+ b"fgh" * 7,
+ b"ijklmn" * 2,
+ ]
buf = self.make_streambuffer()
self.check_append_all_then_skip_all(buf, objs, bytes)
# Test internal algorithm
buf = self.make_streambuffer(10)
for i in range(3):
- buf.append(b'x' * 11)
+ buf.append(b"x" * 11)
self.assertEqual(len(buf._buffers), 3)
- buf.append(b'y')
+ buf.append(b"y")
self.assertEqual(len(buf._buffers), 4)
- buf.append(b'z')
+ buf.append(b"z")
self.assertEqual(len(buf._buffers), 4)
buf.advance(33)
self.assertEqual(len(buf._buffers), 1)
class TranslationLoaderTest(unittest.TestCase):
# TODO: less hacky way to get isolated tests
- SAVE_VARS = ['_translations', '_supported_locales', '_use_gettext']
+ SAVE_VARS = ["_translations", "_supported_locales", "_use_gettext"]
def clear_locale_cache(self):
tornado.locale.Locale._cache = {}
def test_csv(self):
tornado.locale.load_translations(
- os.path.join(os.path.dirname(__file__), 'csv_translations'))
+ os.path.join(os.path.dirname(__file__), "csv_translations")
+ )
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
def test_csv_bom(self):
- with open(os.path.join(os.path.dirname(__file__), 'csv_translations',
- 'fr_FR.csv'), 'rb') as f:
+ with open(
+ os.path.join(os.path.dirname(__file__), "csv_translations", "fr_FR.csv"),
+ "rb",
+ ) as f:
char_data = to_unicode(f.read())
# Re-encode our input data (which is utf-8 without BOM) in
# encodings that use the BOM and ensure that we can still load
# it. Note that utf-16-le and utf-16-be do not write a BOM,
# so we only test whichver variant is native to our platform.
- for encoding in ['utf-8-sig', 'utf-16']:
+ for encoding in ["utf-8-sig", "utf-16"]:
tmpdir = tempfile.mkdtemp()
try:
- with open(os.path.join(tmpdir, 'fr_FR.csv'), 'wb') as f:
+ with open(os.path.join(tmpdir, "fr_FR.csv"), "wb") as f:
f.write(char_data.encode(encoding))
tornado.locale.load_translations(tmpdir)
- locale = tornado.locale.get('fr_FR')
+ locale = tornado.locale.get("fr_FR")
self.assertIsInstance(locale, tornado.locale.CSVLocale)
self.assertEqual(locale.translate("school"), u"\u00e9cole")
finally:
def test_gettext(self):
tornado.locale.load_gettext_translations(
- os.path.join(os.path.dirname(__file__), 'gettext_translations'),
- "tornado_test")
+ os.path.join(os.path.dirname(__file__), "gettext_translations"),
+ "tornado_test",
+ )
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
- self.assertEqual(locale.pgettext("organization", "club", "clubs", 1), u"le club")
- self.assertEqual(locale.pgettext("organization", "club", "clubs", 2), u"les clubs")
+ self.assertEqual(
+ locale.pgettext("organization", "club", "clubs", 1), u"le club"
+ )
+ self.assertEqual(
+ locale.pgettext("organization", "club", "clubs", 2), u"les clubs"
+ )
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
- name = tornado.locale.LOCALE_NAMES['es_LA']['name']
+ name = tornado.locale.LOCALE_NAMES["es_LA"]["name"]
self.assertTrue(isinstance(name, unicode_type))
- self.assertEqual(name, u'Espa\u00f1ol')
- self.assertEqual(utf8(name), b'Espa\xc3\xb1ol')
+ self.assertEqual(name, u"Espa\u00f1ol")
+ self.assertEqual(utf8(name), b"Espa\xc3\xb1ol")
class EnglishTest(unittest.TestCase):
def test_format_date(self):
- locale = tornado.locale.get('en_US')
+ locale = tornado.locale.get("en_US")
date = datetime.datetime(2013, 4, 28, 18, 35)
- self.assertEqual(locale.format_date(date, full_format=True),
- 'April 28, 2013 at 6:35 pm')
+ self.assertEqual(
+ locale.format_date(date, full_format=True), "April 28, 2013 at 6:35 pm"
+ )
now = datetime.datetime.utcnow()
- self.assertEqual(locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
- '2 seconds ago')
- self.assertEqual(locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
- '2 minutes ago')
- self.assertEqual(locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
- '2 hours ago')
-
- self.assertEqual(locale.format_date(now - datetime.timedelta(days=1),
- full_format=False, shorter=True), 'yesterday')
+ self.assertEqual(
+ locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
+ "2 seconds ago",
+ )
+ self.assertEqual(
+ locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
+ "2 minutes ago",
+ )
+ self.assertEqual(
+ locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
+ "2 hours ago",
+ )
+
+ self.assertEqual(
+ locale.format_date(
+ now - datetime.timedelta(days=1), full_format=False, shorter=True
+ ),
+ "yesterday",
+ )
date = now - datetime.timedelta(days=2)
- self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
- locale._weekdays[date.weekday()])
+ self.assertEqual(
+ locale.format_date(date, full_format=False, shorter=True),
+ locale._weekdays[date.weekday()],
+ )
date = now - datetime.timedelta(days=300)
- self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
- '%s %d' % (locale._months[date.month - 1], date.day))
+ self.assertEqual(
+ locale.format_date(date, full_format=False, shorter=True),
+ "%s %d" % (locale._months[date.month - 1], date.day),
+ )
date = now - datetime.timedelta(days=500)
- self.assertEqual(locale.format_date(date, full_format=False, shorter=True),
- '%s %d, %d' % (locale._months[date.month - 1], date.day, date.year))
+ self.assertEqual(
+ locale.format_date(date, full_format=False, shorter=True),
+ "%s %d, %d" % (locale._months[date.month - 1], date.day, date.year),
+ )
def test_friendly_number(self):
- locale = tornado.locale.get('en_US')
- self.assertEqual(locale.friendly_number(1000000), '1,000,000')
+ locale = tornado.locale.get("en_US")
+ self.assertEqual(locale.friendly_number(1000000), "1,000,000")
def test_list(self):
- locale = tornado.locale.get('en_US')
- self.assertEqual(locale.list([]), '')
- self.assertEqual(locale.list(['A']), 'A')
- self.assertEqual(locale.list(['A', 'B']), 'A and B')
- self.assertEqual(locale.list(['A', 'B', 'C']), 'A, B and C')
+ locale = tornado.locale.get("en_US")
+ self.assertEqual(locale.list([]), "")
+ self.assertEqual(locale.list(["A"]), "A")
+ self.assertEqual(locale.list(["A", "B"]), "A and B")
+ self.assertEqual(locale.list(["A", "B", "C"]), "A, B and C")
def test_format_day(self):
- locale = tornado.locale.get('en_US')
+ locale = tornado.locale.get("en_US")
date = datetime.datetime(2013, 4, 28, 18, 35)
- self.assertEqual(locale.format_day(date=date, dow=True), 'Sunday, April 28')
- self.assertEqual(locale.format_day(date=date, dow=False), 'April 28')
+ self.assertEqual(locale.format_day(date=date, dow=True), "Sunday, April 28")
+ self.assertEqual(locale.format_day(date=date, dow=False), "April 28")
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
+
def callback(_):
if not future.result():
# wait() resolved to False, meaning it timed out.
- self.history.append('timeout')
+ self.history.append("timeout")
else:
self.history.append(key)
+
future.add_done_callback(callback)
def loop_briefly(self):
def test_repr(self):
c = locks.Condition()
- self.assertIn('Condition', repr(c))
- self.assertNotIn('waiters', repr(c))
+ self.assertIn("Condition", repr(c))
+ self.assertNotIn("waiters", repr(c))
c.wait()
- self.assertIn('waiters', repr(c))
+ self.assertIn("waiters", repr(c))
@gen_test
def test_notify(self):
def test_notify_1(self):
c = locks.Condition()
- self.record_done(c.wait(), 'wait1')
- self.record_done(c.wait(), 'wait2')
+ self.record_done(c.wait(), "wait1")
+ self.record_done(c.wait(), "wait2")
c.notify(1)
self.loop_briefly()
- self.history.append('notify1')
+ self.history.append("notify1")
c.notify(1)
self.loop_briefly()
- self.history.append('notify2')
- self.assertEqual(['wait1', 'notify1', 'wait2', 'notify2'],
- self.history)
+ self.history.append("notify2")
+ self.assertEqual(["wait1", "notify1", "wait2", "notify2"], self.history)
def test_notify_n(self):
c = locks.Condition()
c.notify_all()
self.loop_briefly()
- self.history.append('notify_all')
+ self.history.append("notify_all")
# Callbacks execute in the order they were registered.
- self.assertEqual(
- list(range(4)) + ['notify_all'], # type: ignore
- self.history)
+ self.assertEqual(list(range(4)) + ["notify_all"], self.history) # type: ignore
@gen_test
def test_wait_timeout(self):
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
- self.assertEqual(['timeout'], self.history)
+ self.assertEqual(["timeout"], self.history)
c.notify(2)
yield gen.sleep(0.01)
- self.assertEqual(['timeout', 0, 2], self.history)
- self.assertEqual(['timeout', 0, 2], self.history)
+ self.assertEqual(["timeout", 0, 2], self.history)
+ self.assertEqual(["timeout", 0, 2], self.history)
c.notify()
yield
- self.assertEqual(['timeout', 0, 2, 3], self.history)
+ self.assertEqual(["timeout", 0, 2, 3], self.history)
@gen_test
def test_notify_all_with_timeout(self):
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
- self.assertEqual(['timeout'], self.history)
+ self.assertEqual(["timeout"], self.history)
c.notify_all()
yield
- self.assertEqual(['timeout', 0, 2], self.history)
+ self.assertEqual(["timeout", 0, 2], self.history)
@gen_test
def test_nested_notify(self):
class EventTest(AsyncTestCase):
def test_repr(self):
event = locks.Event()
- self.assertTrue('clear' in str(event))
- self.assertFalse('set' in str(event))
+ self.assertTrue("clear" in str(event))
+ self.assertFalse("set" in str(event))
event.set()
- self.assertFalse('clear' in str(event))
- self.assertTrue('set' in str(event))
+ self.assertFalse("clear" in str(event))
+ self.assertTrue("set" in str(event))
def test_event(self):
e = locks.Event()
def test_repr(self):
sem = locks.Semaphore()
- self.assertIn('Semaphore', repr(sem))
- self.assertIn('unlocked,value:1', repr(sem))
+ self.assertIn("Semaphore", repr(sem))
+ self.assertIn("unlocked,value:1", repr(sem))
sem.acquire()
- self.assertIn('locked', repr(sem))
- self.assertNotIn('waiters', repr(sem))
+ self.assertIn("locked", repr(sem))
+ self.assertNotIn("waiters", repr(sem))
sem.acquire()
- self.assertIn('waiters', repr(sem))
+ self.assertIn("waiters", repr(sem))
def test_acquire(self):
sem = locks.Semaphore()
async def f():
async with sem as yielded:
self.assertTrue(yielded is None)
+
yield f()
# Semaphore was released and can be acquired again.
@gen.coroutine
def f(index):
with (yield sem.acquire()):
- history.append('acquired %d' % index)
+ history.append("acquired %d" % index)
yield gen.sleep(0.01)
- history.append('release %d' % index)
+ history.append("release %d" % index)
yield [f(i) for i in range(2)]
expected_history = []
for i in range(2):
- expected_history.extend(['acquired %d' % i, 'release %d' % i])
+ expected_history.extend(["acquired %d" % i, "release %d" % i])
self.assertEqual(expected_history, history)
async def f(idx):
async with lock:
history.append(idx)
+
futures = [f(i) for i in range(N)]
lock.release()
yield futures
pass
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
@contextlib.contextmanager
def ignore_bytes_warning():
with warnings.catch_warnings():
- warnings.simplefilter('ignore', category=BytesWarning)
+ warnings.simplefilter("ignore", category=BytesWarning)
yield
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(
- b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)")
+ b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)"
+ )
def setUp(self):
self.formatter = LogFormatter(color=False)
# variable when the tests are run, so just patch in some values
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
- self.formatter._colors = {
- logging.ERROR: u"\u0001",
- }
+ self.formatter._colors = {logging.ERROR: u"\u0001"}
self.formatter._normal = u"\u0002"
# construct a Logger directly to bypass getLogger's caching
- self.logger = logging.Logger('LogFormatterTest')
+ self.logger = logging.Logger("LogFormatterTest")
self.logger.propagate = False
self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'log.out')
+ self.filename = os.path.join(self.tempdir, "log.out")
self.handler = self.make_handler(self.filename)
self.handler.setFormatter(self.formatter)
self.logger.addHandler(self.handler)
def test_bytes_exception_logging(self):
try:
- raise Exception(b'\xe9')
+ raise Exception(b"\xe9")
except Exception:
- self.logger.exception('caught exception')
+ self.logger.exception("caught exception")
# This will be "Exception: \xe9" on python 2 or
# "Exception: b'\xe9'" on python 3.
output = self.get_output()
- self.assertRegexpMatches(output, br'Exception.*\\xe9')
+ self.assertRegexpMatches(output, br"Exception.*\\xe9")
# The traceback contains newlines, which should not have been escaped.
- self.assertNotIn(br'\n', output)
+ self.assertNotIn(br"\n", output)
class UnicodeLogFormatterTest(LogFormatterTest):
super(EnablePrettyLoggingTest, self).setUp()
self.options = OptionParser()
define_logging_options(self.options)
- self.logger = logging.Logger('tornado.test.log_test.EnablePrettyLoggingTest')
+ self.logger = logging.Logger("tornado.test.log_test.EnablePrettyLoggingTest")
self.logger.propagate = False
def test_log_file(self):
tmpdir = tempfile.mkdtemp()
try:
- self.options.log_file_prefix = tmpdir + '/test_log'
+ self.options.log_file_prefix = tmpdir + "/test_log"
enable_pretty_logging(options=self.options, logger=self.logger)
self.assertEqual(1, len(self.logger.handlers))
- self.logger.error('hello')
+ self.logger.error("hello")
self.logger.handlers[0].flush()
- filenames = glob.glob(tmpdir + '/test_log*')
+ filenames = glob.glob(tmpdir + "/test_log*")
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
- self.assertRegexpMatches(f.read(), r'^\[E [^]]*\] hello$')
+ self.assertRegexpMatches(f.read(), r"^\[E [^]]*\] hello$")
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
- for filename in glob.glob(tmpdir + '/test_log*'):
+ for filename in glob.glob(tmpdir + "/test_log*"):
os.unlink(filename)
os.rmdir(tmpdir)
def test_log_file_with_timed_rotating(self):
tmpdir = tempfile.mkdtemp()
try:
- self.options.log_file_prefix = tmpdir + '/test_log'
- self.options.log_rotate_mode = 'time'
+ self.options.log_file_prefix = tmpdir + "/test_log"
+ self.options.log_rotate_mode = "time"
enable_pretty_logging(options=self.options, logger=self.logger)
- self.logger.error('hello')
+ self.logger.error("hello")
self.logger.handlers[0].flush()
- filenames = glob.glob(tmpdir + '/test_log*')
+ filenames = glob.glob(tmpdir + "/test_log*")
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
- self.assertRegexpMatches(
- f.read(),
- r'^\[E [^]]*\] hello$')
+ self.assertRegexpMatches(f.read(), r"^\[E [^]]*\] hello$")
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
- for filename in glob.glob(tmpdir + '/test_log*'):
+ for filename in glob.glob(tmpdir + "/test_log*"):
os.unlink(filename)
os.rmdir(tmpdir)
def test_wrong_rotate_mode_value(self):
try:
- self.options.log_file_prefix = 'some_path'
- self.options.log_rotate_mode = 'wrong_mode'
- self.assertRaises(ValueError, enable_pretty_logging,
- options=self.options, logger=self.logger)
+ self.options.log_file_prefix = "some_path"
+ self.options.log_rotate_mode = "wrong_mode"
+ self.assertRaises(
+ ValueError,
+ enable_pretty_logging,
+ options=self.options,
+ logger=self.logger,
+ )
finally:
for handler in self.logger.handlers:
handler.flush()
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
+
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
- IMPORT = 'from tornado.options import options, parse_command_line'
+ IMPORT = "from tornado.options import options, parse_command_line"
LOG_INFO = 'import logging; logging.info("hello")'
- program = ';'.join([IMPORT, statement, LOG_INFO])
+ program = ";".join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
- [sys.executable, '-c', program] + (args or []),
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ [sys.executable, "-c", program] + (args or []),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ )
stdout, stderr = proc.communicate()
- self.assertEqual(proc.returncode, 0, 'process failed: %r' % stdout)
- return b'hello' in stdout
+ self.assertEqual(proc.returncode, 0, "process failed: %r" % stdout)
+ return b"hello" in stdout
def test_default(self):
- self.assertFalse(self.logs_present('pass'))
+ self.assertFalse(self.logs_present("pass"))
def test_tornado_default(self):
- self.assertTrue(self.logs_present('parse_command_line()'))
+ self.assertTrue(self.logs_present("parse_command_line()"))
def test_disable_command_line(self):
- self.assertFalse(self.logs_present('parse_command_line()',
- ['--logging=none']))
+ self.assertFalse(self.logs_present("parse_command_line()", ["--logging=none"]))
def test_disable_command_line_case_insensitive(self):
- self.assertFalse(self.logs_present('parse_command_line()',
- ['--logging=None']))
+ self.assertFalse(self.logs_present("parse_command_line()", ["--logging=None"]))
def test_disable_code_string(self):
- self.assertFalse(self.logs_present(
- 'options.logging = "none"; parse_command_line()'))
+ self.assertFalse(
+ self.logs_present('options.logging = "none"; parse_command_line()')
+ )
def test_disable_code_none(self):
- self.assertFalse(self.logs_present(
- 'options.logging = None; parse_command_line()'))
+ self.assertFalse(
+ self.logs_present("options.logging = None; parse_command_line()")
+ )
def test_disable_override(self):
# command line trumps code defaults
- self.assertTrue(self.logs_present(
- 'options.logging = None; parse_command_line()',
- ['--logging=info']))
+ self.assertTrue(
+ self.logs_present(
+ "options.logging = None; parse_command_line()", ["--logging=info"]
+ )
+ )
import unittest
from tornado.netutil import (
- BlockingResolver, OverrideResolver, ThreadedResolver, is_valid_ip, bind_sockets
+ BlockingResolver,
+ OverrideResolver,
+ ThreadedResolver,
+ is_valid_ip,
+ bind_sockets,
)
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
from tornado.test.util import skipIfNoNetwork
import typing
+
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
class _ResolverTestMixin(object):
@gen_test
def test_localhost(self):
- addrinfo = yield self.resolver.resolve('localhost', 80,
- socket.AF_UNSPEC)
- self.assertIn((socket.AF_INET, ('127.0.0.1', 80)),
- addrinfo)
+ addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC)
+ self.assertIn((socket.AF_INET, ("127.0.0.1", 80)), addrinfo)
# It is impossible to quickly and consistently generate an error in name
@gen_test
def test_bad_host(self):
with self.assertRaises(IOError):
- yield self.resolver.resolve('an invalid domain', 80,
- socket.AF_UNSPEC)
+ yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
def setUp(self):
super(OverrideResolverTest, self).setUp()
mapping = {
- ('google.com', 80): ('1.2.3.4', 80),
- ('google.com', 80, socket.AF_INET): ('1.2.3.4', 80),
- ('google.com', 80, socket.AF_INET6): ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80)
+ ("google.com", 80): ("1.2.3.4", 80),
+ ("google.com", 80, socket.AF_INET): ("1.2.3.4", 80),
+ ("google.com", 80, socket.AF_INET6): (
+ "2a02:6b8:7c:40c:c51e:495f:e23a:3",
+ 80,
+ ),
}
self.resolver = OverrideResolver(BlockingResolver(), mapping)
@gen_test
def test_resolve_multiaddr(self):
- result = yield self.resolver.resolve('google.com', 80, socket.AF_INET)
- self.assertIn((socket.AF_INET, ('1.2.3.4', 80)), result)
+ result = yield self.resolver.resolve("google.com", 80, socket.AF_INET)
+ self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result)
- result = yield self.resolver.resolve('google.com', 80, socket.AF_INET6)
- self.assertIn((socket.AF_INET6, ('2a02:6b8:7c:40c:c51e:495f:e23a:3', 80, 0, 0)), result)
+ result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6)
+ self.assertIn(
+ (socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result
+ )
@skipIfNoNetwork
@skipIfNoNetwork
-@unittest.skipIf(sys.platform == 'win32', "preexec_fn not available on win32")
+@unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
# Test for a deadlock when importing a module that runs the
# ThreadedResolver at import-time. See resolve_test.py for
# full explanation.
- command = [
- sys.executable,
- '-c',
- 'import tornado.test.resolve_test_helper']
+ command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"]
start = time.time()
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
# test error cases here.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
-@unittest.skipIf(getattr(twisted, '__version__', '0.0') < "12.1", "old version of twisted")
+@unittest.skipIf(
+ getattr(twisted, "__version__", "0.0") < "12.1", "old version of twisted"
+)
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
class IsValidIPTest(unittest.TestCase):
def test_is_valid_ip(self):
- self.assertTrue(is_valid_ip('127.0.0.1'))
- self.assertTrue(is_valid_ip('4.4.4.4'))
- self.assertTrue(is_valid_ip('::1'))
- self.assertTrue(is_valid_ip('2620:0:1cfe:face:b00c::3'))
- self.assertTrue(not is_valid_ip('www.google.com'))
- self.assertTrue(not is_valid_ip('localhost'))
- self.assertTrue(not is_valid_ip('4.4.4.4<'))
- self.assertTrue(not is_valid_ip(' 127.0.0.1'))
- self.assertTrue(not is_valid_ip(''))
- self.assertTrue(not is_valid_ip(' '))
- self.assertTrue(not is_valid_ip('\n'))
- self.assertTrue(not is_valid_ip('\x00'))
+ self.assertTrue(is_valid_ip("127.0.0.1"))
+ self.assertTrue(is_valid_ip("4.4.4.4"))
+ self.assertTrue(is_valid_ip("::1"))
+ self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3"))
+ self.assertTrue(not is_valid_ip("www.google.com"))
+ self.assertTrue(not is_valid_ip("localhost"))
+ self.assertTrue(not is_valid_ip("4.4.4.4<"))
+ self.assertTrue(not is_valid_ip(" 127.0.0.1"))
+ self.assertTrue(not is_valid_ip(""))
+ self.assertTrue(not is_valid_ip(" "))
+ self.assertTrue(not is_valid_ip("\n"))
+ self.assertTrue(not is_valid_ip("\x00"))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
- if 'TRAVIS' in os.environ:
+ if "TRAVIS" in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
- sockets = bind_sockets(0, 'localhost')
+ sockets = bind_sockets(0, "localhost")
try:
port = sockets[0].getsockname()[1]
- self.assertTrue(all(s.getsockname()[1] == port
- for s in sockets[1:]))
+ self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()
- @unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
+ @unittest.skipIf(
+ not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported"
+ )
def test_reuse_port(self):
sockets = [] # type: List[socket.socket]
socket, port = bind_unused_port(reuse_port=True)
try:
- sockets = bind_sockets(port, '127.0.0.1', reuse_port=True)
+ sockets = bind_sockets(port, "127.0.0.1", reuse_port=True)
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
finally:
socket.close()
from tornado.test.util import subTest
import typing
+
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
class Email(object):
def __init__(self, value):
- if isinstance(value, str) and '@' in value:
+ if isinstance(value, str) and "@" in value:
self._value = value
else:
raise ValueError()
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
- options.define("username", default='foo')
+ options.define("username", default="foo")
options.define("my_path")
- config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
- "options_test.cfg")
+ config_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "options_test.cfg"
+ )
options.parse_config_file(config_path)
self.assertEqual(options.port, 443)
self.assertEqual(options.username, "李康")
def callback():
self.called = True
+
options.add_parse_callback(callback)
# non-final parse doesn't run callbacks
sub_options = OptionParser()
sub_options.define("foo", type=str)
rest = base_options.parse_command_line(
- ["main.py", "--verbose", "subcommand", "--foo=bar"])
+ ["main.py", "--verbose", "subcommand", "--foo=bar"]
+ )
self.assertEqual(rest, ["subcommand", "--foo=bar"])
self.assertTrue(base_options.verbose)
rest2 = sub_options.parse_command_line(rest)
def test_setattr(self):
options = OptionParser()
- options.define('foo', default=1, type=int)
+ options.define("foo", default=1, type=int)
options.foo = 2
self.assertEqual(options.foo, 2)
# setattr requires that options be the right type and doesn't
# parse from string formats.
options = OptionParser()
- options.define('foo', default=1, type=int)
+ options.define("foo", default=1, type=int)
with self.assertRaises(Error):
- options.foo = '2'
+ options.foo = "2"
def test_setattr_with_callback(self):
values = [] # type: List[int]
options = OptionParser()
- options.define('foo', default=1, type=int, callback=values.append)
+ options.define("foo", default=1, type=int, callback=values.append)
options.foo = 2
self.assertEqual(values, [2])
def _sample_options(self):
options = OptionParser()
- options.define('a', default=1)
- options.define('b', default=2)
+ options.define("a", default=1)
+ options.define("b", default=2)
return options
def test_iter(self):
options = self._sample_options()
# OptionParsers always define 'help'.
- self.assertEqual(set(['a', 'b', 'help']), set(iter(options)))
+ self.assertEqual(set(["a", "b", "help"]), set(iter(options)))
def test_getitem(self):
options = self._sample_options()
- self.assertEqual(1, options['a'])
+ self.assertEqual(1, options["a"])
def test_setitem(self):
options = OptionParser()
- options.define('foo', default=1, type=int)
- options['foo'] = 2
- self.assertEqual(options['foo'], 2)
+ options.define("foo", default=1, type=int)
+ options["foo"] = 2
+ self.assertEqual(options["foo"], 2)
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
- expected = [('a', 1), ('b', 2), ('help', options.help)]
+ expected = [("a", 1), ("b", 2), ("help", options.help)]
actual = sorted(options.items())
self.assertEqual(expected, actual)
def test_as_dict(self):
options = self._sample_options()
- expected = {'a': 1, 'b': 2, 'help': options.help}
+ expected = {"a": 1, "b": 2, "help": options.help}
self.assertEqual(expected, options.as_dict())
def test_group_dict(self):
options = OptionParser()
- options.define('a', default=1)
- options.define('b', group='b_group', default=2)
+ options.define("a", default=1)
+ options.define("b", group="b_group", default=2)
frame = sys._getframe(0)
this_file = frame.f_code.co_filename
- self.assertEqual(set(['b_group', '', this_file]), options.groups())
+ self.assertEqual(set(["b_group", "", this_file]), options.groups())
- b_group_dict = options.group_dict('b_group')
- self.assertEqual({'b': 2}, b_group_dict)
+ b_group_dict = options.group_dict("b_group")
+ self.assertEqual({"b": 2}, b_group_dict)
- self.assertEqual({}, options.group_dict('nonexistent'))
+ self.assertEqual({}, options.group_dict("nonexistent"))
def test_mock_patch(self):
# ensure that our setattr hooks don't interfere with mock.patch
options = OptionParser()
- options.define('foo', default=1)
- options.parse_command_line(['main.py', '--foo=2'])
+ options.define("foo", default=1)
+ options.parse_command_line(["main.py", "--foo=2"])
self.assertEqual(options.foo, 2)
- with mock.patch.object(options.mockable(), 'foo', 3):
+ with mock.patch.object(options.mockable(), "foo", 3):
self.assertEqual(options.foo, 3)
self.assertEqual(options.foo, 2)
# Try nested patches mixed with explicit sets
- with mock.patch.object(options.mockable(), 'foo', 4):
+ with mock.patch.object(options.mockable(), "foo", 4):
self.assertEqual(options.foo, 4)
options.foo = 5
self.assertEqual(options.foo, 5)
- with mock.patch.object(options.mockable(), 'foo', 6):
+ with mock.patch.object(options.mockable(), "foo", 6):
self.assertEqual(options.foo, 6)
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def _define_options(self):
options = OptionParser()
- options.define('str', type=str)
- options.define('basestring', type=basestring_type)
- options.define('int', type=int)
- options.define('float', type=float)
- options.define('datetime', type=datetime.datetime)
- options.define('timedelta', type=datetime.timedelta)
- options.define('email', type=Email)
- options.define('list-of-int', type=int, multiple=True)
+ options.define("str", type=str)
+ options.define("basestring", type=basestring_type)
+ options.define("int", type=int)
+ options.define("float", type=float)
+ options.define("datetime", type=datetime.datetime)
+ options.define("timedelta", type=datetime.timedelta)
+ options.define("email", type=Email)
+ options.define("list-of-int", type=int, multiple=True)
return options
def _check_options_values(self, options):
- self.assertEqual(options.str, 'asdf')
- self.assertEqual(options.basestring, 'qwer')
+ self.assertEqual(options.str, "asdf")
+ self.assertEqual(options.basestring, "qwer")
self.assertEqual(options.int, 42)
self.assertEqual(options.float, 1.5)
- self.assertEqual(options.datetime,
- datetime.datetime(2013, 4, 28, 5, 16))
+ self.assertEqual(options.datetime, datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
- self.assertEqual(options.email.value, 'tornado@web.com')
+ self.assertEqual(options.email.value, "tornado@web.com")
self.assertTrue(isinstance(options.email, Email))
self.assertEqual(options.list_of_int, [1, 2, 3])
def test_types(self):
options = self._define_options()
- options.parse_command_line(['main.py',
- '--str=asdf',
- '--basestring=qwer',
- '--int=42',
- '--float=1.5',
- '--datetime=2013-04-28 05:16',
- '--timedelta=45s',
- '--email=tornado@web.com',
- '--list-of-int=1,2,3'])
+ options.parse_command_line(
+ [
+ "main.py",
+ "--str=asdf",
+ "--basestring=qwer",
+ "--int=42",
+ "--float=1.5",
+ "--datetime=2013-04-28 05:16",
+ "--timedelta=45s",
+ "--email=tornado@web.com",
+ "--list-of-int=1,2,3",
+ ]
+ )
self._check_options_values(options)
def test_types_with_conf_file(self):
- for config_file_name in ("options_test_types.cfg",
- "options_test_types_str.cfg"):
+ for config_file_name in (
+ "options_test_types.cfg",
+ "options_test_types_str.cfg",
+ ):
options = self._define_options()
- options.parse_config_file(os.path.join(os.path.dirname(__file__),
- config_file_name))
+ options.parse_config_file(
+ os.path.join(os.path.dirname(__file__), config_file_name)
+ )
self._check_options_values(options)
def test_multiple_string(self):
options = OptionParser()
- options.define('foo', type=str, multiple=True)
- options.parse_command_line(['main.py', '--foo=a,b,c'])
- self.assertEqual(options.foo, ['a', 'b', 'c'])
+ options.define("foo", type=str, multiple=True)
+ options.parse_command_line(["main.py", "--foo=a,b,c"])
+ self.assertEqual(options.foo, ["a", "b", "c"])
def test_multiple_int(self):
options = OptionParser()
- options.define('foo', type=int, multiple=True)
- options.parse_command_line(['main.py', '--foo=1,3,5:7'])
+ options.define("foo", type=int, multiple=True)
+ options.parse_command_line(["main.py", "--foo=1,3,5:7"])
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
def test_error_redefine(self):
options = OptionParser()
- options.define('foo')
+ options.define("foo")
with self.assertRaises(Error) as cm:
- options.define('foo')
- self.assertRegexpMatches(str(cm.exception),
- 'Option.*foo.*already defined')
+ options.define("foo")
+ self.assertRegexpMatches(str(cm.exception), "Option.*foo.*already defined")
def test_error_redefine_underscore(self):
# Ensure that the dash/underscore normalization doesn't
# interfere with the redefinition error.
tests = [
- ('foo-bar', 'foo-bar'),
- ('foo_bar', 'foo_bar'),
- ('foo-bar', 'foo_bar'),
- ('foo_bar', 'foo-bar'),
+ ("foo-bar", "foo-bar"),
+ ("foo_bar", "foo_bar"),
+ ("foo-bar", "foo_bar"),
+ ("foo_bar", "foo-bar"),
]
for a, b in tests:
with subTest(self, a=a, b=b):
options.define(a)
with self.assertRaises(Error) as cm:
options.define(b)
- self.assertRegexpMatches(str(cm.exception),
- 'Option.*foo.bar.*already defined')
+ self.assertRegexpMatches(
+ str(cm.exception), "Option.*foo.bar.*already defined"
+ )
def test_dash_underscore_cli(self):
# Dashes and underscores should be interchangeable.
- for defined_name in ['foo-bar', 'foo_bar']:
- for flag in ['--foo-bar=a', '--foo_bar=a']:
+ for defined_name in ["foo-bar", "foo_bar"]:
+ for flag in ["--foo-bar=a", "--foo_bar=a"]:
options = OptionParser()
options.define(defined_name)
- options.parse_command_line(['main.py', flag])
+ options.parse_command_line(["main.py", flag])
# Attr-style access always uses underscores.
- self.assertEqual(options.foo_bar, 'a')
+ self.assertEqual(options.foo_bar, "a")
# Dict-style access allows both.
- self.assertEqual(options['foo-bar'], 'a')
- self.assertEqual(options['foo_bar'], 'a')
+ self.assertEqual(options["foo-bar"], "a")
+ self.assertEqual(options["foo_bar"], "a")
def test_dash_underscore_file(self):
# No matter how an option was defined, it can be set with underscores
# in a config file.
- for defined_name in ['foo-bar', 'foo_bar']:
+ for defined_name in ["foo-bar", "foo_bar"]:
options = OptionParser()
options.define(defined_name)
- options.parse_config_file(os.path.join(os.path.dirname(__file__),
- "options_test.cfg"))
- self.assertEqual(options.foo_bar, 'a')
+ options.parse_config_file(
+ os.path.join(os.path.dirname(__file__), "options_test.cfg")
+ )
+ self.assertEqual(options.foo_bar, "a")
def test_dash_underscore_introspection(self):
# Original names are preserved in introspection APIs.
options = OptionParser()
- options.define('with-dash', group='g')
- options.define('with_underscore', group='g')
- all_options = ['help', 'with-dash', 'with_underscore']
+ options.define("with-dash", group="g")
+ options.define("with_underscore", group="g")
+ all_options = ["help", "with-dash", "with_underscore"]
self.assertEqual(sorted(options), all_options)
self.assertEqual(sorted(k for (k, v) in options.items()), all_options)
self.assertEqual(sorted(options.as_dict().keys()), all_options)
- self.assertEqual(sorted(options.group_dict('g')),
- ['with-dash', 'with_underscore'])
+ self.assertEqual(
+ sorted(options.group_dict("g")), ["with-dash", "with_underscore"]
+ )
# --help shows CLI-style names with dashes.
buf = StringIO()
options.print_help(buf)
- self.assertIn('--with-dash', buf.getvalue())
- self.assertIn('--with-underscore', buf.getvalue())
+ self.assertIn("--with-dash", buf.getvalue())
+ self.assertIn("--with-underscore", buf.getvalue())
# exception handler doesn't catch it
os._exit(int(self.get_argument("exit")))
if self.get_argument("signal", None):
- os.kill(os.getpid(),
- int(self.get_argument("signal")))
+ os.kill(os.getpid(), int(self.get_argument("signal")))
self.write(str(os.getpid()))
+
return Application([("/", ProcessHandler)])
def tearDown(self):
# reactor and don't restore it to a sane state after the fork
# (asyncio has the same issue, but we have a special case in
# place for it).
- with ExpectLog(gen_log, "(Starting .* processes|child .* exited|uncaught exception)"):
+ with ExpectLog(
+ gen_log, "(Starting .* processes|child .* exited|uncaught exception)"
+ ):
sock, port = bind_unused_port()
def get_url(path):
return "http://127.0.0.1:%d%s" % (port, path)
+
# ensure that none of these processes live too long
signal.alarm(5) # master process
try:
@gen_test
def test_subprocess(self):
- if IOLoop.configured_class().__name__.endswith('LayeredTwistedIOLoop'):
+ if IOLoop.configured_class().__name__.endswith("LayeredTwistedIOLoop"):
# This test fails non-deterministically with LayeredTwistedIOLoop.
# (the read_until('\n') returns '\n' instead of 'hello\n')
# This probably indicates a problem with either TornadoReactor
# or TwistedIOLoop, but I haven't been able to track it down
# and for now this is just causing spurious travis-ci failures.
- raise unittest.SkipTest("Subprocess tests not compatible with "
- "LayeredTwistedIOLoop")
- subproc = Subprocess([sys.executable, '-u', '-i'],
- stdin=Subprocess.STREAM,
- stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
+ raise unittest.SkipTest(
+ "Subprocess tests not compatible with " "LayeredTwistedIOLoop"
+ )
+ subproc = Subprocess(
+ [sys.executable, "-u", "-i"],
+ stdin=Subprocess.STREAM,
+ stdout=Subprocess.STREAM,
+ stderr=subprocess.STDOUT,
+ )
self.addCleanup(lambda: self.term_and_wait(subproc))
self.addCleanup(subproc.stdout.close)
self.addCleanup(subproc.stdin.close)
- yield subproc.stdout.read_until(b'>>> ')
+ yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"print('hello')\n")
- data = yield subproc.stdout.read_until(b'\n')
+ data = yield subproc.stdout.read_until(b"\n")
self.assertEqual(data, b"hello\n")
yield subproc.stdout.read_until(b">>> ")
@gen_test
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
- subproc = Subprocess([sys.executable, '-u', '-i'],
- stdin=Subprocess.STREAM,
- stdout=Subprocess.STREAM, stderr=subprocess.STDOUT)
+ subproc = Subprocess(
+ [sys.executable, "-u", "-i"],
+ stdin=Subprocess.STREAM,
+ stdout=Subprocess.STREAM,
+ stderr=subprocess.STDOUT,
+ )
self.addCleanup(lambda: self.term_and_wait(subproc))
- yield subproc.stdout.read_until(b'>>> ')
+ yield subproc.stdout.read_until(b">>> ")
subproc.stdin.close()
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"\n")
def test_stderr(self):
# This test is mysteriously flaky on twisted: it succeeds, but logs
# an error of EBADF on closing a file descriptor.
- subproc = Subprocess([sys.executable, '-u', '-c',
- r"import sys; sys.stderr.write('hello\n')"],
- stderr=Subprocess.STREAM)
+ subproc = Subprocess(
+ [sys.executable, "-u", "-c", r"import sys; sys.stderr.write('hello\n')"],
+ stderr=Subprocess.STREAM,
+ )
self.addCleanup(lambda: self.term_and_wait(subproc))
- data = yield subproc.stderr.read_until(b'\n')
- self.assertEqual(data, b'hello\n')
+ data = yield subproc.stderr.read_until(b"\n")
+ self.assertEqual(data, b"hello\n")
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
subproc.stderr.close()
def test_sigchild(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
- subproc = Subprocess([sys.executable, '-c', 'pass'])
+ subproc = Subprocess([sys.executable, "-c", "pass"])
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
def test_sigchild_future(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
- subproc = Subprocess([sys.executable, '-c', 'pass'])
+ subproc = Subprocess([sys.executable, "-c", "pass"])
ret = yield subproc.wait_for_exit()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
- subproc = Subprocess([sys.executable, '-c',
- 'import time; time.sleep(30)'],
- stdout=Subprocess.STREAM)
+ subproc = Subprocess(
+ [sys.executable, "-c", "import time; time.sleep(30)"],
+ stdout=Subprocess.STREAM,
+ )
self.addCleanup(subproc.stdout.close)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
except AssertionError:
raise AssertionError("subprocess failed to terminate")
else:
- raise AssertionError("subprocess closed stdout but failed to "
- "get termination signal")
+ raise AssertionError(
+ "subprocess closed stdout but failed to " "get termination signal"
+ )
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
def test_wait_for_exit_raise(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
- subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
+ subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"])
with self.assertRaises(subprocess.CalledProcessError) as cm:
yield subproc.wait_for_exit()
self.assertEqual(cm.exception.returncode, 1)
def test_wait_for_exit_raise_disabled(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
- subproc = Subprocess([sys.executable, '-c', 'import sys; sys.exit(1)'])
+ subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"])
ret = yield subproc.wait_for_exit(raise_error=False)
self.assertEqual(ret, 1)
q.get()
for q_str in repr(q), str(q):
- self.assertTrue(q_str.startswith('<Queue'))
- self.assertIn('maxsize=1', q_str)
- self.assertIn('getters[1]', q_str)
- self.assertNotIn('putters', q_str)
- self.assertNotIn('tasks', q_str)
+ self.assertTrue(q_str.startswith("<Queue"))
+ self.assertIn("maxsize=1", q_str)
+ self.assertIn("getters[1]", q_str)
+ self.assertNotIn("putters", q_str)
+ self.assertNotIn("tasks", q_str)
q.put(None)
q.put(None)
q.put(None)
for q_str in repr(q), str(q):
- self.assertNotIn('getters', q_str)
- self.assertIn('putters[1]', q_str)
- self.assertIn('tasks=2', q_str)
+ self.assertNotIn("getters", q_str)
+ self.assertIn("putters[1]", q_str)
+ self.assertIn("tasks=2", q_str)
def test_order(self):
q = queues.Queue() # type: queues.Queue[int]
results.append(i)
if i == 4:
return results
+
results = yield f()
self.assertEqual(results, list(range(5)))
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
- q.put_nowait((1, 'a'))
- q.put_nowait((0, 'b'))
+ q.put_nowait((1, "a"))
+ q.put_nowait((0, "b"))
self.assertTrue(q.full())
- q.put((3, 'c'))
- q.put((2, 'd'))
- self.assertEqual((0, 'b'), q.get_nowait())
- self.assertEqual((1, 'a'), (yield q.get()))
- self.assertEqual((2, 'd'), q.get_nowait())
- self.assertEqual((3, 'c'), (yield q.get()))
+ q.put((3, "c"))
+ q.put((2, "d"))
+ self.assertEqual((0, "b"), q.get_nowait())
+ self.assertEqual((1, "a"), (yield q.get()))
+ self.assertEqual((2, "d"), q.get_nowait())
+ self.assertEqual((3, "c"), (yield q.get()))
self.assertTrue(q.empty())
self.assertEqual(list(range(10)), history)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
# this deadlock.
resolver = ThreadedResolver()
-IOLoop.current().run_sync(lambda: resolver.resolve(u'localhost', 80))
+IOLoop.current().run_sync(lambda: resolver.resolve(u"localhost", 80))
# License for the specific language governing permissions and limitations
# under the License.
-from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine # noqa: E501
-from tornado.routing import HostMatches, PathMatches, ReversibleRouter, Router, Rule, RuleRouter
+from tornado.httputil import (
+ HTTPHeaders,
+ HTTPMessageDelegate,
+ HTTPServerConnectionDelegate,
+ ResponseStartLine,
+) # noqa: E501
+from tornado.routing import (
+ HostMatches,
+ PathMatches,
+ ReversibleRouter,
+ Router,
+ Rule,
+ RuleRouter,
+)
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, HTTPError, RequestHandler
from tornado.wsgi import WSGIContainer
class BasicRouter(Router):
def find_handler(self, request, **kwargs):
-
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
- b"OK"
+ b"OK",
)
self.connection.finish()
return app.get_handler_delegate(request, handler)
def reverse_url(self, name, *args):
- handler_path = '/' + name
+ handler_path = "/" + name
return handler_path if handler_path in self.routes else None
app1 = CustomApplication(app_name="app1")
app2 = CustomApplication(app_name="app2")
- router.add_routes({
- "/first_handler": (app1, FirstHandler),
- "/second_handler": (app2, SecondHandler),
- "/first_handler_second_app": (app2, FirstHandler),
- })
+ router.add_routes(
+ {
+ "/first_handler": (app1, FirstHandler),
+ "/second_handler": (app2, SecondHandler),
+ "/first_handler_second_app": (app2, FirstHandler),
+ }
+ )
return router
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
-
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
response_body = b"OK"
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
- HTTPHeaders({"Content-Length": str(len(response_body))}))
+ HTTPHeaders({"Content-Length": str(len(response_body))}),
+ )
self.connection.write(response_body)
self.connection.finish()
def request_callable(request):
request.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
- HTTPHeaders({"Content-Length": "2"}))
+ HTTPHeaders({"Content-Length": "2"}),
+ )
request.connection.write(b"OK")
request.connection.finish()
router = CustomRouter()
- router.add_routes({
- "/nested_handler": (app, _get_named_handler("nested_handler"))
- })
-
- app.add_handlers(".*", [
- (HostMatches("www.example.com"), [
- (PathMatches("/first_handler"),
- "tornado.test.routing_test.SecondHandler", {}, "second_handler")
- ]),
- Rule(PathMatches("/.*handler"), router),
- Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
- Rule(PathMatches("/request_callable"), request_callable),
- ("/connection_delegate", ConnectionDelegate())
- ])
+ router.add_routes(
+ {"/nested_handler": (app, _get_named_handler("nested_handler"))}
+ )
+
+ app.add_handlers(
+ ".*",
+ [
+ (
+ HostMatches("www.example.com"),
+ [
+ (
+ PathMatches("/first_handler"),
+ "tornado.test.routing_test.SecondHandler",
+ {},
+ "second_handler",
+ )
+ ],
+ ),
+ Rule(PathMatches("/.*handler"), router),
+ Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
+ Rule(PathMatches("/request_callable"), request_callable),
+ ("/connection_delegate", ConnectionDelegate()),
+ ],
+ )
return app
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"first_handler: /first_handler")
- response = self.fetch("/first_handler", headers={'Host': 'www.example.com'})
+ response = self.fetch("/first_handler", headers={"Host": "www.example.com"})
self.assertEqual(response.body, b"second_handler: /first_handler")
response = self.fetch("/nested_handler")
def get(self, *args, **kwargs):
self.finish(self.reverse_url("tornado"))
- return RuleRouter([
- (PathMatches("/tornado.*"), Application([(r"/tornado/test", Handler, {}, "tornado")])),
- (PathMatches("/wsgi"), wsgi_app),
- ])
+ return RuleRouter(
+ [
+ (
+ PathMatches("/tornado.*"),
+ Application([(r"/tornado/test", Handler, {}, "tornado")]),
+ ),
+ (PathMatches("/wsgi"), wsgi_app),
+ ]
+ )
def wsgi_app(self, environ, start_response):
start_response("200 OK", [])
TEST_MODULES = [
- 'tornado.httputil.doctests',
- 'tornado.iostream.doctests',
- 'tornado.util.doctests',
- 'tornado.test.asyncio_test',
- 'tornado.test.auth_test',
- 'tornado.test.autoreload_test',
- 'tornado.test.concurrent_test',
- 'tornado.test.curl_httpclient_test',
- 'tornado.test.escape_test',
- 'tornado.test.gen_test',
- 'tornado.test.http1connection_test',
- 'tornado.test.httpclient_test',
- 'tornado.test.httpserver_test',
- 'tornado.test.httputil_test',
- 'tornado.test.import_test',
- 'tornado.test.ioloop_test',
- 'tornado.test.iostream_test',
- 'tornado.test.locale_test',
- 'tornado.test.locks_test',
- 'tornado.test.netutil_test',
- 'tornado.test.log_test',
- 'tornado.test.options_test',
- 'tornado.test.process_test',
- 'tornado.test.queues_test',
- 'tornado.test.routing_test',
- 'tornado.test.simple_httpclient_test',
- 'tornado.test.tcpclient_test',
- 'tornado.test.tcpserver_test',
- 'tornado.test.template_test',
- 'tornado.test.testing_test',
- 'tornado.test.twisted_test',
- 'tornado.test.util_test',
- 'tornado.test.web_test',
- 'tornado.test.websocket_test',
- 'tornado.test.windows_test',
- 'tornado.test.wsgi_test',
+ "tornado.httputil.doctests",
+ "tornado.iostream.doctests",
+ "tornado.util.doctests",
+ "tornado.test.asyncio_test",
+ "tornado.test.auth_test",
+ "tornado.test.autoreload_test",
+ "tornado.test.concurrent_test",
+ "tornado.test.curl_httpclient_test",
+ "tornado.test.escape_test",
+ "tornado.test.gen_test",
+ "tornado.test.http1connection_test",
+ "tornado.test.httpclient_test",
+ "tornado.test.httpserver_test",
+ "tornado.test.httputil_test",
+ "tornado.test.import_test",
+ "tornado.test.ioloop_test",
+ "tornado.test.iostream_test",
+ "tornado.test.locale_test",
+ "tornado.test.locks_test",
+ "tornado.test.netutil_test",
+ "tornado.test.log_test",
+ "tornado.test.options_test",
+ "tornado.test.process_test",
+ "tornado.test.queues_test",
+ "tornado.test.routing_test",
+ "tornado.test.simple_httpclient_test",
+ "tornado.test.tcpclient_test",
+ "tornado.test.tcpserver_test",
+ "tornado.test.template_test",
+ "tornado.test.testing_test",
+ "tornado.test.twisted_test",
+ "tornado.test.util_test",
+ "tornado.test.web_test",
+ "tornado.test.websocket_test",
+ "tornado.test.windows_test",
+ "tornado.test.wsgi_test",
]
def test_runner_factory(stderr):
class TornadoTextTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
- kwargs['stream'] = stderr
+ kwargs["stream"] = stderr
super(TornadoTextTestRunner, self).__init__(*args, **kwargs)
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
- self.stream.write(textwrap.fill(
- "Some tests were skipped because: %s" %
- ", ".join(sorted(skip_reasons))))
+ self.stream.write(
+ textwrap.fill(
+ "Some tests were skipped because: %s"
+ % ", ".join(sorted(skip_reasons))
+ )
+ )
self.stream.write("\n")
return result
+
return TornadoTextTestRunner
class LogCounter(logging.Filter):
"""Counts the number of WARNING or higher log records."""
+
def __init__(self, *args, **kwargs):
super(LogCounter, self).__init__(*args, **kwargs)
self.info_count = self.warning_count = self.error_count = 0
# python 3 (as of virtualenv 1.7), so configure warnings
# programmatically instead.
import warnings
+
# Be strict about most warnings. This also turns on warnings that are
# ignored by default, including DeprecationWarnings and
# python 3.2's ResourceWarnings.
# Tornado generally shouldn't use anything deprecated, but some of
# our dependencies do (last match wins).
warnings.filterwarnings("ignore", category=DeprecationWarning)
- warnings.filterwarnings("error", category=DeprecationWarning,
- module=r"tornado\..*")
+ warnings.filterwarnings("error", category=DeprecationWarning, module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
- warnings.filterwarnings("error", category=PendingDeprecationWarning,
- module=r"tornado\..*")
+ warnings.filterwarnings(
+ "error", category=PendingDeprecationWarning, module=r"tornado\..*"
+ )
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
- warnings.filterwarnings("ignore", category=DeprecationWarning,
- message="Please use assert.* instead")
- warnings.filterwarnings("ignore", category=PendingDeprecationWarning,
- message="Please use assert.* instead")
+ warnings.filterwarnings(
+ "ignore", category=DeprecationWarning, message="Please use assert.* instead"
+ )
+ warnings.filterwarnings(
+ "ignore",
+ category=PendingDeprecationWarning,
+ message="Please use assert.* instead",
+ )
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
- warnings.filterwarnings("ignore", category=BytesWarning,
- module=r"twisted\..*")
+ warnings.filterwarnings("ignore", category=BytesWarning, module=r"twisted\..*")
if (3,) < sys.version_info < (3, 6):
# Prior to 3.6, async ResourceWarnings were rather noisy
# and even
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
# would generate a warning.
- warnings.filterwarnings("ignore", category=ResourceWarning, # noqa: F821
- module=r"asyncio\..*")
+ warnings.filterwarnings(
+ "ignore", category=ResourceWarning, module=r"asyncio\..*" # noqa: F821
+ )
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
- define('httpclient', type=str, default=None,
- callback=lambda s: AsyncHTTPClient.configure(
- s, defaults=dict(allow_ipv6=False)))
- define('httpserver', type=str, default=None,
- callback=HTTPServer.configure)
- define('resolver', type=str, default=None,
- callback=Resolver.configure)
- define('debug_gc', type=str, multiple=True,
- help="A comma-separated list of gc module debug constants, "
- "e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
- callback=lambda values: gc.set_debug(
- reduce(operator.or_, (getattr(gc, v) for v in values))))
+ define(
+ "httpclient",
+ type=str,
+ default=None,
+ callback=lambda s: AsyncHTTPClient.configure(
+ s, defaults=dict(allow_ipv6=False)
+ ),
+ )
+ define("httpserver", type=str, default=None, callback=HTTPServer.configure)
+ define("resolver", type=str, default=None, callback=Resolver.configure)
+ define(
+ "debug_gc",
+ type=str,
+ multiple=True,
+ help="A comma-separated list of gc module debug constants, "
+ "e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
+ callback=lambda values: gc.set_debug(
+ reduce(operator.or_, (getattr(gc, v) for v in values))
+ ),
+ )
def set_locale(x):
locale.setlocale(locale.LC_ALL, x)
- define('locale', type=str, default=None, callback=set_locale)
+
+ define("locale", type=str, default=None, callback=set_locale)
log_counter = LogCounter()
- add_parse_callback(
- lambda: logging.getLogger().handlers[0].addFilter(log_counter))
+ add_parse_callback(lambda: logging.getLogger().handlers[0].addFilter(log_counter))
# Certain errors (especially "unclosed resource" errors raised in
# destructors) go directly to stderr instead of logging. Count
sys.stderr = counting_stderr # type: ignore
import tornado.testing
+
kwargs = {}
# HACK: unittest.main will make its own changes to the warning
# or command-line flags like -bb. Passing warnings=False
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
- kwargs['warnings'] = False
+ kwargs["warnings"] = False
- kwargs['testRunner'] = test_runner_factory(orig_stderr)
+ kwargs["testRunner"] = test_runner_factory(orig_stderr)
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they
# logged anything at info level or above.
- if (log_counter.info_count > 0 or
- log_counter.warning_count > 0 or
- log_counter.error_count > 0 or
- counting_stderr.byte_count > 0):
- logging.error("logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
- log_counter.info_count, log_counter.warning_count,
- log_counter.error_count, counting_stderr.byte_count)
+ if (
+ log_counter.info_count > 0
+ or log_counter.warning_count > 0
+ or log_counter.error_count > 0
+ or counting_stderr.byte_count > 0
+ ):
+ logging.error(
+ "logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
+ log_counter.info_count,
+ log_counter.warning_count,
+ log_counter.error_count,
+ counting_stderr.byte_count,
+ )
sys.exit(1)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
from tornado.log import gen_log
from tornado.concurrent import Future
from tornado.netutil import Resolver, bind_sockets
-from tornado.simple_httpclient import SimpleAsyncHTTPClient, HTTPStreamClosedError, HTTPTimeoutError
-from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler # noqa: E501
+from tornado.simple_httpclient import (
+ SimpleAsyncHTTPClient,
+ HTTPStreamClosedError,
+ HTTPTimeoutError,
+)
+from tornado.test.httpclient_test import (
+ ChunkHandler,
+ CountdownHandler,
+ HelloWorldHandler,
+ RedirectHandler,
+) # noqa: E501
from tornado.test import httpclient_test
-from tornado.testing import (AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase,
- ExpectLog, gen_test)
+from tornado.testing import (
+ AsyncHTTPTestCase,
+ AsyncHTTPSTestCase,
+ AsyncTestCase,
+ ExpectLog,
+ gen_test,
+)
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port
from tornado.web import RequestHandler, Application, url, stream_request_body
@gen.coroutine
def write_response(self):
- yield self.stream.write(utf8("HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok" %
- self.get_argument("value")))
+ yield self.stream.write(
+ utf8(
+ "HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok"
+ % self.get_argument("value")
+ )
+ )
self.stream.close()
class NoContentLengthHandler(RequestHandler):
def get(self):
- if self.request.version.startswith('HTTP/1'):
+ if self.request.version.startswith("HTTP/1"):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.detach()
- stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
- b"hello")
+ stream.write(b"HTTP/1.0 200 OK\r\n\r\n" b"hello")
stream.close()
else:
- self.finish('HTTP/1 required')
+ self.finish("HTTP/1 required")
class EchoPostHandler(RequestHandler):
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque() # type: typing.Deque[str]
- return Application([
- url("/trigger", TriggerHandler, dict(queue=self.triggers,
- wake_callback=self.stop)),
- url("/chunk", ChunkHandler),
- url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
- url("/hang", HangHandler),
- url("/hello", HelloWorldHandler),
- url("/content_length", ContentLengthHandler),
- url("/head", HeadHandler),
- url("/options", OptionsHandler),
- url("/no_content", NoContentHandler),
- url("/see_other_post", SeeOtherPostHandler),
- url("/see_other_get", SeeOtherGetHandler),
- url("/host_echo", HostEchoHandler),
- url("/no_content_length", NoContentLengthHandler),
- url("/echo_post", EchoPostHandler),
- url("/respond_in_prepare", RespondInPrepareHandler),
- url("/redirect", RedirectHandler),
- ], gzip=True)
+ return Application(
+ [
+ url(
+ "/trigger",
+ TriggerHandler,
+ dict(queue=self.triggers, wake_callback=self.stop),
+ ),
+ url("/chunk", ChunkHandler),
+ url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
+ url("/hang", HangHandler),
+ url("/hello", HelloWorldHandler),
+ url("/content_length", ContentLengthHandler),
+ url("/head", HeadHandler),
+ url("/options", OptionsHandler),
+ url("/no_content", NoContentHandler),
+ url("/see_other_post", SeeOtherPostHandler),
+ url("/see_other_get", SeeOtherGetHandler),
+ url("/host_echo", HostEchoHandler),
+ url("/no_content_length", NoContentLengthHandler),
+ url("/echo_post", EchoPostHandler),
+ url("/respond_in_prepare", RespondInPrepareHandler),
+ url("/redirect", RedirectHandler),
+ ],
+ gzip=True,
+ )
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
- self.assertTrue(SimpleAsyncHTTPClient() is
- SimpleAsyncHTTPClient())
+ self.assertTrue(SimpleAsyncHTTPClient() is SimpleAsyncHTTPClient())
# unless force_instance is used
- self.assertTrue(SimpleAsyncHTTPClient() is not
- SimpleAsyncHTTPClient(force_instance=True))
+ self.assertTrue(
+ SimpleAsyncHTTPClient() is not SimpleAsyncHTTPClient(force_instance=True)
+ )
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
+
async def make_client():
await gen.sleep(0)
return SimpleAsyncHTTPClient()
+
client1 = self.io_loop.run_sync(make_client)
client2 = io_loop2.run_sync(make_client)
self.assertTrue(client1 is not client2)
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
+
def cb(fut, i=i):
seen.append(i)
self.stop()
+
client.fetch(self.get_url("/trigger")).add_done_callback(cb)
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
# Finish the first two requests and let the next two through
self.triggers.popleft()()
self.triggers.popleft()()
- self.wait(condition=lambda: (len(self.triggers) == 2 and
- len(seen) == 2))
+ self.wait(condition=lambda: (len(self.triggers) == 2 and len(seen) == 2))
self.assertEqual(set(seen), set([0, 1]))
self.assertEqual(len(client.queue), 0)
def test_redirect_connection_limit(self):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
- response = yield client.fetch(self.get_url('/countdown/3'),
- max_redirects=3)
+ response = yield client.fetch(self.get_url("/countdown/3"), max_redirects=3)
response.rethrow()
def test_gzip(self):
# ensures that it is in fact getting compressed.
# Setting Accept-Encoding manually bypasses the client's
# decompression so we can see the raw data.
- response = self.fetch("/chunk", use_gzip=False,
- headers={"Accept-Encoding": "gzip"})
+ response = self.fetch(
+ "/chunk", use_gzip=False, headers={"Accept-Encoding": "gzip"}
+ )
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
def test_header_reuse(self):
# Apps may reuse a headers object if they are only passing in constant
# headers like user-agent. The header object should not be modified.
- headers = HTTPHeaders({'User-Agent': 'Foo'})
+ headers = HTTPHeaders({"User-Agent": "Foo"})
self.fetch("/hello", headers=headers)
- self.assertEqual(list(headers.get_all()), [('User-Agent', 'Foo')])
+ self.assertEqual(list(headers.get_all()), [("User-Agent", "Foo")])
def test_see_other_redirect(self):
for code in (302, 303):
with closing(self.create_client(resolver=TimeoutResolver())) as client:
with self.assertRaises(HTTPTimeoutError):
- yield client.fetch(self.get_url('/hello'),
- connect_timeout=timeout,
- request_timeout=3600,
- raise_error=True)
+ yield client.fetch(
+ self.get_url("/hello"),
+ connect_timeout=timeout,
+ request_timeout=3600,
+ raise_error=True,
+ )
@skipOnTravis
def test_request_timeout(self):
timeout = 0.1
- if os.name == 'nt':
+ if os.name == "nt":
timeout = 0.5
with self.assertRaises(HTTPTimeoutError):
- self.fetch('/trigger?wake=false', request_timeout=timeout, raise_error=True)
+ self.fetch("/trigger?wake=false", request_timeout=timeout, raise_error=True)
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
@skipIfNoIPv6
def test_ipv6(self):
- [sock] = bind_sockets(0, '::1', family=socket.AF_INET6)
+ [sock] = bind_sockets(0, "::1", family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
- url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
+ url = "%s://[::1]:%d/hello" % (self.get_protocol(), port)
# ipv6 is currently enabled by default but can be disabled
with self.assertRaises(Exception):
with self.assertRaises(socket.error) as cm:
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
- if sys.platform != 'cygwin':
+ if sys.platform != "cygwin":
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
- contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception) # type: ignore
+ contains_errno = str(errno.WSAECONNREFUSED) in str( # type: ignore
+ cm.exception
+ )
self.assertTrue(contains_errno, cm.exception)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
- self.assertTrue(expected_message in str(cm.exception),
- cm.exception)
+ self.assertTrue(expected_message in str(cm.exception), cm.exception)
def test_queue_timeout(self):
with closing(self.create_client(max_clients=1)) as client:
# Wait for the trigger request to block, not complete.
- fut1 = client.fetch(self.get_url('/trigger'), request_timeout=10)
+ fut1 = client.fetch(self.get_url("/trigger"), request_timeout=10)
self.wait()
with self.assertRaises(HTTPTimeoutError) as cm:
- self.io_loop.run_sync(lambda: client.fetch(
- self.get_url('/hello'), connect_timeout=0.1, raise_error=True))
+ self.io_loop.run_sync(
+ lambda: client.fetch(
+ self.get_url("/hello"), connect_timeout=0.1, raise_error=True
+ )
+ )
self.assertEqual(str(cm.exception), "Timeout in request queue")
self.triggers.popleft()()
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
- write(b'1234')
- write(b'5678')
+ write(b"1234")
+ write(b"5678")
@gen.coroutine
def async_body_producer(self, write):
- yield write(b'1234')
+ yield write(b"1234")
yield gen.moment
- yield write(b'5678')
+ yield write(b"5678")
def test_sync_body_producer_chunked(self):
- response = self.fetch("/echo_post", method="POST",
- body_producer=self.sync_body_producer)
+ response = self.fetch(
+ "/echo_post", method="POST", body_producer=self.sync_body_producer
+ )
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self):
- response = self.fetch("/echo_post", method="POST",
- body_producer=self.sync_body_producer,
- headers={'Content-Length': '8'})
+ response = self.fetch(
+ "/echo_post",
+ method="POST",
+ body_producer=self.sync_body_producer,
+ headers={"Content-Length": "8"},
+ )
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self):
- response = self.fetch("/echo_post", method="POST",
- body_producer=self.async_body_producer)
+ response = self.fetch(
+ "/echo_post", method="POST", body_producer=self.async_body_producer
+ )
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self):
- response = self.fetch("/echo_post", method="POST",
- body_producer=self.async_body_producer,
- headers={'Content-Length': '8'})
+ response = self.fetch(
+ "/echo_post",
+ method="POST",
+ body_producer=self.async_body_producer,
+ headers={"Content-Length": "8"},
+ )
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_native_body_producer_chunked(self):
async def body_producer(write):
- await write(b'1234')
+ await write(b"1234")
import asyncio
+
await asyncio.sleep(0)
- await write(b'5678')
- response = self.fetch("/echo_post", method="POST",
- body_producer=body_producer)
+ await write(b"5678")
+
+ response = self.fetch("/echo_post", method="POST", body_producer=body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_native_body_producer_content_length(self):
async def body_producer(write):
- await write(b'1234')
+ await write(b"1234")
import asyncio
+
await asyncio.sleep(0)
- await write(b'5678')
- response = self.fetch("/echo_post", method="POST",
- body_producer=body_producer,
- headers={'Content-Length': '8'})
+ await write(b"5678")
+
+ response = self.fetch(
+ "/echo_post",
+ method="POST",
+ body_producer=body_producer,
+ headers={"Content-Length": "8"},
+ )
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
- response = self.fetch("/echo_post", method="POST",
- body=b"1234",
- expect_100_continue=True)
+ response = self.fetch(
+ "/echo_post", method="POST", body=b"1234", expect_100_continue=True
+ )
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self):
def body_producer(write):
raise Exception("should not be called")
- response = self.fetch("/respond_in_prepare", method="POST",
- body_producer=body_producer,
- expect_100_continue=True)
+
+ response = self.fetch(
+ "/respond_in_prepare",
+ method="POST",
+ body_producer=body_producer,
+ expect_100_continue=True,
+ )
self.assertEqual(response.code, 403)
def test_streaming_follow_redirects(self):
# or we have a better framework to skip tests based on curl version.
headers = [] # type: typing.List[str]
chunk_bytes = [] # type: typing.List[bytes]
- self.fetch("/redirect?url=/hello",
- header_callback=headers.append,
- streaming_callback=chunk_bytes.append)
+ self.fetch(
+ "/redirect?url=/hello",
+ header_callback=headers.append,
+ streaming_callback=chunk_bytes.append,
+ )
chunks = list(map(to_unicode, chunk_bytes))
- self.assertEqual(chunks, ['Hello world!'])
+ self.assertEqual(chunks, ["Hello world!"])
# Make sure we only got one set of headers.
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
self.assertEqual(num_start_lines, 1)
self.http_client = self.create_client()
def create_client(self, **kwargs):
- return SimpleAsyncHTTPClient(force_instance=True,
- defaults=dict(validate_cert=False),
- **kwargs)
+ return SimpleAsyncHTTPClient(
+ force_instance=True, defaults=dict(validate_cert=False), **kwargs
+ )
def test_ssl_options(self):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_context(self):
- resp = self.fetch("/hello",
- ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
+ resp = self.fetch("/hello", ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_options_handshake_fail(self):
- with ExpectLog(gen_log, "SSL Error|Uncaught exception",
- required=False):
+ with ExpectLog(gen_log, "SSL Error|Uncaught exception", required=False):
with self.assertRaises(ssl.SSLError):
self.fetch(
- "/hello", ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
- raise_error=True)
+ "/hello",
+ ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
+ raise_error=True,
+ )
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
# No stack traces are logged for SSL errors (in this case,
# failure to validate the testing self-signed cert).
# The SSLError is exposed through ssl.SSLError.
- with ExpectLog(gen_log, '.*') as expect_log:
+ with ExpectLog(gen_log, ".*") as expect_log:
with self.assertRaises(ssl.SSLError):
self.fetch("/", validate_cert=True, raise_error=True)
self.assertFalse(expect_log.logged_stack)
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 10) # type: ignore
- with closing(AsyncHTTPClient(
- max_clients=11, force_instance=True)) as client:
+ with closing(AsyncHTTPClient(max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11) # type: ignore
# Now configure max_clients statically and try overriding it
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 12) # type: ignore
- with closing(AsyncHTTPClient(
- max_clients=13, force_instance=True)) as client:
+ with closing(AsyncHTTPClient(max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13) # type: ignore
- with closing(AsyncHTTPClient(
- max_clients=14, force_instance=True)) as client:
+ with closing(AsyncHTTPClient(max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14) # type: ignore
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
- self.http1 = request.version.startswith('HTTP/1.')
+ self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
- request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
- HTTPHeaders())
+ request.connection.write_headers(
+ ResponseStartLine("", 200, "OK"), HTTPHeaders()
+ )
request.connection.finish()
return
self.request = request
- fut = self.request.connection.stream.write(
- b"HTTP/1.1 100 CONTINUE\r\n\r\n")
+ fut = self.request.connection.stream.write(b"HTTP/1.1 100 CONTINUE\r\n\r\n")
fut.add_done_callback(self.respond_200)
def respond_200(self, fut):
fut.result()
fut = self.request.connection.stream.write(
- b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA")
+ b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA"
+ )
fut.add_done_callback(lambda f: self.request.connection.stream.close())
def get_app(self):
return self.respond_100
def test_100_continue(self):
- res = self.fetch('/')
+ res = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
- self.assertEqual(res.body, b'A')
+ self.assertEqual(res.body, b"A")
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
- self.http1 = request.version.startswith('HTTP/1.')
+ self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
# Close the request cleanly in HTTP/2; it will be skipped anyway.
- request.connection.write_headers(ResponseStartLine('', 200, 'OK'),
- HTTPHeaders())
+ request.connection.write_headers(
+ ResponseStartLine("", 200, "OK"), HTTPHeaders()
+ )
request.connection.finish()
return
return self.respond_204
def test_204_no_content(self):
- resp = self.fetch('/')
+ resp = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
- self.assertEqual(resp.body, b'')
+ self.assertEqual(resp.body, b"")
def test_204_invalid_content_length(self):
# 204 status with non-zero content length is malformed
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
hostname_mapping={
- 'www.example.com': '127.0.0.1',
- ('foo.example.com', 8000): ('127.0.0.1', self.get_http_port()),
- })
+ "www.example.com": "127.0.0.1",
+ ("foo.example.com", 8000): ("127.0.0.1", self.get_http_port()),
+ }
+ )
def get_app(self):
- return Application([url("/hello", HelloWorldHandler), ])
+ return Application([url("/hello", HelloWorldHandler)])
def test_hostname_mapping(self):
- response = self.fetch(
- 'http://www.example.com:%d/hello' % self.get_http_port())
+ response = self.fetch("http://www.example.com:%d/hello" % self.get_http_port())
response.rethrow()
- self.assertEqual(response.body, b'Hello world!')
+ self.assertEqual(response.body, b"Hello world!")
def test_port_mapping(self):
- response = self.fetch('http://foo.example.com:8000/hello')
+ response = self.fetch("http://foo.example.com:8000/hello")
response.rethrow()
- self.assertEqual(response.body, b'Hello world!')
+ self.assertEqual(response.body, b"Hello world!")
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
yield Event().wait()
super(ResolveTimeoutTestCase, self).setUp()
- self.http_client = SimpleAsyncHTTPClient(
- resolver=BadResolver())
+ self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver())
def get_app(self):
- return Application([url("/hello", HelloWorldHandler), ])
+ return Application([url("/hello", HelloWorldHandler)])
def test_resolve_timeout(self):
with self.assertRaises(HTTPTimeoutError):
- self.fetch('/hello', connect_timeout=0.1, raise_error=True)
+ self.fetch("/hello", connect_timeout=0.1, raise_error=True)
class MaxHeaderSizeTest(AsyncHTTPTestCase):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
- return Application([('/small', SmallHeaders),
- ('/large', LargeHeaders)])
+ return Application([("/small", SmallHeaders), ("/large", LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_header_size=1024)
def test_small_headers(self):
- response = self.fetch('/small')
+ response = self.fetch("/small")
response.rethrow()
- self.assertEqual(response.body, b'ok')
+ self.assertEqual(response.body, b"ok")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
with self.assertRaises(UnsatisfiableReadError):
- self.fetch('/large', raise_error=True)
+ self.fetch("/large", raise_error=True)
class MaxBodySizeTest(AsyncHTTPTestCase):
def get(self):
self.write("a" * 1024 * 100)
- return Application([('/small', SmallBody),
- ('/large', LargeBody)])
+ return Application([("/small", SmallBody), ("/large", LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
def test_small_body(self):
- response = self.fetch('/small')
+ response = self.fetch("/small")
response.rethrow()
- self.assertEqual(response.body, b'a' * 1024 * 64)
+ self.assertEqual(response.body, b"a" * 1024 * 64)
def test_large_body(self):
- with ExpectLog(gen_log, "Malformed HTTP message from None: Content-Length too long"):
+ with ExpectLog(
+ gen_log, "Malformed HTTP message from None: Content-Length too long"
+ ):
with self.assertRaises(HTTPStreamClosedError):
- self.fetch('/large', raise_error=True)
+ self.fetch("/large", raise_error=True)
class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_app(self):
-
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
- return Application([('/large', LargeBody)])
+ return Application([("/large", LargeBody)])
def get_http_client(self):
# 100KB body with 64KB buffer
- return SimpleAsyncHTTPClient(max_body_size=1024 * 100, max_buffer_size=1024 * 64)
+ return SimpleAsyncHTTPClient(
+ max_body_size=1024 * 100, max_buffer_size=1024 * 64
+ )
def test_large_body(self):
- response = self.fetch('/large')
+ response = self.fetch("/large")
response.rethrow()
- self.assertEqual(response.body, b'a' * 1024 * 100)
+ self.assertEqual(response.body, b"a" * 1024 * 100)
class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
def get_app(self):
-
class ChunkedWithContentLength(RequestHandler):
def get(self):
# Add an invalid Transfer-Encoding to the response
- self.set_header('Transfer-Encoding', 'chunked')
+ self.set_header("Transfer-Encoding", "chunked")
self.write("Hello world")
- return Application([('/chunkwithcl', ChunkedWithContentLength)])
+ return Application([("/chunkwithcl", ChunkedWithContentLength)])
def get_http_client(self):
return SimpleAsyncHTTPClient()
def test_chunked_with_content_length(self):
# Make sure the invalid headers are detected
- with ExpectLog(gen_log, ("Malformed HTTP message from None: Response "
- "with both Transfer-Encoding and Content-Length")):
+ with ExpectLog(
+ gen_log,
+ (
+ "Malformed HTTP message from None: Response "
+ "with both Transfer-Encoding and Content-Length"
+ ),
+ ):
with self.assertRaises(HTTPStreamClosedError):
- self.fetch('/chunkwithcl', raise_error=True)
+ self.fetch("/chunkwithcl", raise_error=True)
from tornado.gen import TimeoutError
import typing
+
if typing.TYPE_CHECKING:
from tornado.iostream import IOStream # noqa: F401
from typing import List, Dict, Tuple # noqa: F401
super(TestTCPServer, self).__init__()
self.streams = [] # type: List[IOStream]
self.queue = Queue() # type: Queue[IOStream]
- sockets = bind_sockets(0, 'localhost', family)
+ sockets = bind_sockets(0, "localhost", family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
self.client = TCPClient()
def start_server(self, family):
- if family == socket.AF_UNSPEC and 'TRAVIS' in os.environ:
+ if family == socket.AF_UNSPEC and "TRAVIS" in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def skipIfLocalhostV4(self):
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
- addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve('localhost', 80))
+ addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve("localhost", 80))
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host, source_ip=None, source_port=None):
port = self.start_server(family)
- stream = yield self.client.connect(host, port,
- source_ip=source_ip,
- source_port=source_port)
+ stream = yield self.client.connect(
+ host, port, source_ip=source_ip, source_port=source_port
+ )
server_stream = yield self.server.queue.get()
with closing(stream):
stream.write(b"hello")
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
- self.do_test_connect(socket.AF_INET, '127.0.0.1')
+ self.do_test_connect(socket.AF_INET, "127.0.0.1")
def test_connect_ipv4_dual(self):
- self.do_test_connect(socket.AF_INET, 'localhost')
+ self.do_test_connect(socket.AF_INET, "localhost")
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
- self.do_test_connect(socket.AF_INET6, '::1')
+ self.do_test_connect(socket.AF_INET6, "::1")
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
- if Resolver.configured_class().__name__.endswith('TwistedResolver'):
- self.skipTest('TwistedResolver does not support multiple addresses')
- self.do_test_connect(socket.AF_INET6, 'localhost')
+ if Resolver.configured_class().__name__.endswith("TwistedResolver"):
+ self.skipTest("TwistedResolver does not support multiple addresses")
+ self.do_test_connect(socket.AF_INET6, "localhost")
def test_connect_unspec_ipv4(self):
- self.do_test_connect(socket.AF_UNSPEC, '127.0.0.1')
+ self.do_test_connect(socket.AF_UNSPEC, "127.0.0.1")
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
- self.do_test_connect(socket.AF_UNSPEC, '::1')
+ self.do_test_connect(socket.AF_UNSPEC, "::1")
def test_connect_unspec_dual(self):
- self.do_test_connect(socket.AF_UNSPEC, 'localhost')
+ self.do_test_connect(socket.AF_UNSPEC, "localhost")
@gen_test
def test_refused_ipv4(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
- yield self.client.connect('127.0.0.1', port)
+ yield self.client.connect("127.0.0.1", port)
def test_source_ip_fail(self):
- '''
+ """
Fail when trying to use the source IP Address '8.8.8.8'.
- '''
- self.assertRaises(socket.error,
- self.do_test_connect,
- socket.AF_INET,
- '127.0.0.1',
- source_ip='8.8.8.8')
+ """
+ self.assertRaises(
+ socket.error,
+ self.do_test_connect,
+ socket.AF_INET,
+ "127.0.0.1",
+ source_ip="8.8.8.8",
+ )
def test_source_ip_success(self):
- '''
+ """
Success when trying to use the source IP Address '127.0.0.1'
- '''
- self.do_test_connect(socket.AF_INET, '127.0.0.1', source_ip='127.0.0.1')
+ """
+ self.do_test_connect(socket.AF_INET, "127.0.0.1", source_ip="127.0.0.1")
@skipIfNonUnix
def test_source_port_fail(self):
- '''
+ """
Fail when trying to use source port 1.
- '''
- self.assertRaises(socket.error,
- self.do_test_connect,
- socket.AF_INET,
- '127.0.0.1',
- source_port=1)
+ """
+ self.assertRaises(
+ socket.error,
+ self.do_test_connect,
+ socket.AF_INET,
+ "127.0.0.1",
+ source_port=1,
+ )
@gen_test
def test_connect_timeout(self):
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
+
with self.assertRaises(TimeoutError):
yield TCPClient(resolver=TimeoutResolver()).connect(
- '1.2.3.4', 12345, timeout=timeout)
+ "1.2.3.4", 12345, timeout=timeout
+ )
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
- primary, secondary = _Connector.split(
- [(AF1, 'a'),
- (AF1, 'b')])
- self.assertEqual(primary, [(AF1, 'a'),
- (AF1, 'b')])
+ primary, secondary = _Connector.split([(AF1, "a"), (AF1, "b")])
+ self.assertEqual(primary, [(AF1, "a"), (AF1, "b")])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
- [(AF1, 'a'),
- (AF2, 'b'),
- (AF1, 'c'),
- (AF2, 'd')])
- self.assertEqual(primary, [(AF1, 'a'), (AF1, 'c')])
- self.assertEqual(secondary, [(AF2, 'b'), (AF2, 'd')])
+ [(AF1, "a"), (AF2, "b"), (AF1, "c"), (AF2, "d")]
+ )
+ self.assertEqual(primary, [(AF1, "a"), (AF1, "c")])
+ self.assertEqual(secondary, [(AF2, "b"), (AF2, "d")])
class ConnectorTest(AsyncTestCase):
def setUp(self):
super(ConnectorTest, self).setUp()
- self.connect_futures = {} \
- # type: Dict[Tuple[int, Tuple], Future[ConnectorTest.FakeStream]]
+ self.connect_futures = (
+ {}
+ ) # type: Dict[Tuple[int, Tuple], Future[ConnectorTest.FakeStream]]
self.streams = {} # type: Dict[Tuple, ConnectorTest.FakeStream]
- self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
- (AF2, 'c'), (AF2, 'd')]
+ self.addrinfo = [(AF1, "a"), (AF1, "b"), (AF2, "c"), (AF2, "d")]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
- self.assertEqual(list(self.connect_futures.keys()),
- [(AF1, 'a')])
- self.resolve_connect(AF1, 'a', True)
- self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+ self.assertEqual(list(self.connect_futures.keys()), [(AF1, "a")])
+ self.resolve_connect(AF1, "a", True)
+ self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_immediate_failure(self):
# Fail with just one address.
- conn, future = self.start_connect([(AF1, 'a')])
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
+ conn, future = self.start_connect([(AF1, "a")])
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'))
- self.resolve_connect(AF1, 'b', True)
- self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"))
+ self.resolve_connect(AF1, "b", True)
+ self.assertEqual(future.result(), (AF1, "b", self.streams["b"]))
def test_one_family_second_try_failure(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'))
- self.resolve_connect(AF1, 'b', False)
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"))
+ self.resolve_connect(AF1, "b", False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'))
- self.resolve_connect(AF1, 'b', True)
- self.assertEqual(future.result(), (AF1, 'b', self.streams['b']))
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"))
+ self.resolve_connect(AF1, "b", True)
+ self.assertEqual(future.result(), (AF1, "b", self.streams["b"]))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'), (AF2, 'c'))
- self.resolve_connect(AF1, 'b', False)
- self.resolve_connect(AF2, 'c', True)
- self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"), (AF2, "c"))
+ self.resolve_connect(AF1, "b", False)
+ self.resolve_connect(AF2, "c", True)
+ self.assertEqual(future.result(), (AF2, "c", self.streams["c"]))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
+ self.assert_pending((AF1, "a"))
conn.on_timeout()
- self.assert_pending((AF1, 'a'), (AF2, 'c'))
- self.resolve_connect(AF2, 'c', True)
- self.assertEqual(future.result(), (AF2, 'c', self.streams['c']))
+ self.assert_pending((AF1, "a"), (AF2, "c"))
+ self.resolve_connect(AF2, "c", True)
+ self.assertEqual(future.result(), (AF2, "c", self.streams["c"]))
# resolving 'a' after the connection has completed doesn't start 'b'
- self.resolve_connect(AF1, 'a', False)
+ self.resolve_connect(AF1, "a", False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
+ self.assert_pending((AF1, "a"))
conn.on_timeout()
- self.assert_pending((AF1, 'a'), (AF2, 'c'))
- self.resolve_connect(AF1, 'a', True)
- self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+ self.assert_pending((AF1, "a"), (AF2, "c"))
+ self.resolve_connect(AF1, "a", True)
+ self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
# resolving 'c' after completion closes the connection.
- self.resolve_connect(AF2, 'c', True)
- self.assertTrue(self.streams.pop('c').closed)
+ self.resolve_connect(AF2, "c", True)
+ self.assertTrue(self.streams.pop("c").closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
+ self.assert_pending((AF1, "a"))
conn.on_timeout()
- self.assert_pending((AF1, 'a'), (AF2, 'c'))
- self.resolve_connect(AF2, 'c', False)
- self.assert_pending((AF1, 'a'), (AF2, 'd'))
- self.resolve_connect(AF2, 'd', False)
+ self.assert_pending((AF1, "a"), (AF2, "c"))
+ self.resolve_connect(AF2, "c", False)
+ self.assert_pending((AF1, "a"), (AF2, "d"))
+ self.resolve_connect(AF2, "d", False)
# one queue is now empty
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'))
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"))
self.assertFalse(future.done())
- self.resolve_connect(AF1, 'b', False)
+ self.resolve_connect(AF1, "b", False)
self.assertRaises(IOError, future.result)
def test_one_family_timeout_after_connect_timeout(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
conn.on_connect_timeout()
# the connector will close all streams on connect timeout, we
# should explicitly pop the connect_future.
- self.connect_futures.pop((AF1, 'a'))
- self.assertTrue(self.streams.pop('a').closed)
+ self.connect_futures.pop((AF1, "a"))
+ self.assertTrue(self.streams.pop("a").closed)
conn.on_timeout()
# if the future is set with TimeoutError, we will not iterate next
# possible address.
self.assertRaises(TimeoutError, future.result)
def test_one_family_success_before_connect_timeout(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', True)
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", True)
conn.on_connect_timeout()
self.assert_pending()
- self.assertEqual(self.streams['a'].closed, False)
+ self.assertEqual(self.streams["a"].closed, False)
# success stream will be pop
self.assertEqual(len(conn.streams), 0)
# streams in connector should be closed after connect timeout
self.assert_connector_streams_closed(conn)
- self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+ self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_one_family_second_try_after_connect_timeout(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'))
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"))
conn.on_connect_timeout()
- self.connect_futures.pop((AF1, 'b'))
- self.assertTrue(self.streams.pop('b').closed)
+ self.connect_futures.pop((AF1, "b"))
+ self.assertTrue(self.streams.pop("b").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_second_try_failure_before_connect_timeout(self):
- conn, future = self.start_connect([(AF1, 'a'), (AF1, 'b')])
- self.assert_pending((AF1, 'a'))
- self.resolve_connect(AF1, 'a', False)
- self.assert_pending((AF1, 'b'))
- self.resolve_connect(AF1, 'b', False)
+ conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
+ self.assert_pending((AF1, "a"))
+ self.resolve_connect(AF1, "a", False)
+ self.assert_pending((AF1, "b"))
+ self.resolve_connect(AF1, "b", False)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
def test_two_family_timeout_before_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
+ self.assert_pending((AF1, "a"))
conn.on_timeout()
- self.assert_pending((AF1, 'a'), (AF2, 'c'))
+ self.assert_pending((AF1, "a"), (AF2, "c"))
conn.on_connect_timeout()
- self.connect_futures.pop((AF1, 'a'))
- self.assertTrue(self.streams.pop('a').closed)
- self.connect_futures.pop((AF2, 'c'))
- self.assertTrue(self.streams.pop('c').closed)
+ self.connect_futures.pop((AF1, "a"))
+ self.assertTrue(self.streams.pop("a").closed)
+ self.connect_futures.pop((AF2, "c"))
+ self.assertTrue(self.streams.pop("c").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
def test_two_family_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
+ self.assert_pending((AF1, "a"))
conn.on_timeout()
- self.assert_pending((AF1, 'a'), (AF2, 'c'))
- self.resolve_connect(AF1, 'a', True)
+ self.assert_pending((AF1, "a"), (AF2, "c"))
+ self.resolve_connect(AF1, "a", True)
# if one of streams succeed, connector will close all other streams
- self.connect_futures.pop((AF2, 'c'))
- self.assertTrue(self.streams.pop('c').closed)
+ self.connect_futures.pop((AF2, "c"))
+ self.assertTrue(self.streams.pop("c").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
- self.assertEqual(future.result(), (AF1, 'a', self.streams['a']))
+ self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_two_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
- self.assert_pending((AF1, 'a'))
+ self.assert_pending((AF1, "a"))
conn.on_connect_timeout()
- self.connect_futures.pop((AF1, 'a'))
- self.assertTrue(self.streams.pop('a').closed)
+ self.connect_futures.pop((AF1, "a"))
+ self.assertTrue(self.streams.pop("a").closed)
self.assert_pending()
conn.on_timeout()
# if the future is set with TimeoutError, connector will not
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
- yield stream.read_bytes(len(b'hello'))
+ yield stream.read_bytes(len(b"hello"))
stream.close()
1 / 0
server.add_socket(sock)
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
- yield client.connect(('localhost', port))
- yield client.write(b'hello')
+ yield client.connect(("localhost", port))
+ yield client.write(b"hello")
yield client.read_until_close()
yield gen.moment
finally:
class TestServer(TCPServer):
async def handle_stream(self, stream, address):
- stream.write(b'data')
+ stream.write(b"data")
stream.close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
- yield client.connect(('localhost', port))
+ yield client.connect(("localhost", port))
result = yield client.read_until_close()
- self.assertEqual(result, b'data')
+ self.assertEqual(result, b"data")
server.stop()
client.close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
- server_addr = ('localhost', port)
+ server_addr = ("localhost", port)
N = 40
clients = [IOStream(socket.socket()) for i in range(N)]
connected_clients = []
yield [connect(c) for c in clients]
- self.assertGreater(len(connected_clients), 0,
- "all clients failed connecting")
+ self.assertGreater(len(connected_clients), 0, "all clients failed connecting")
try:
if len(connected_clients) == N:
# Ideally we'd make the test deterministic, but we're testing
# for a race condition in combination with the system's TCP stack...
- self.skipTest("at least one client should fail connecting "
- "for the test to be meaningful")
+ self.skipTest(
+ "at least one client should fail connecting "
+ "for the test to be meaningful"
+ )
finally:
for c in connected_clients:
c.close()
# byte, so we don't have to worry about atomicity of the shared
# stdout stream) and then exits.
def run_subproc(self, code):
- proc = subprocess.Popen(sys.executable,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE)
+ proc = subprocess.Popen(
+ sys.executable, stdin=subprocess.PIPE, stdout=subprocess.PIPE
+ )
proc.stdin.write(utf8(code))
proc.stdin.close()
proc.wait()
stdout = proc.stdout.read()
proc.stdout.close()
if proc.returncode != 0:
- raise RuntimeError("Process returned %d. stdout=%r" % (
- proc.returncode, stdout))
+ raise RuntimeError(
+ "Process returned %d. stdout=%r" % (proc.returncode, stdout)
+ )
return to_unicode(stdout)
def test_single(self):
# As a sanity check, run the single-process version through this test
# harness too.
- code = textwrap.dedent("""
+ code = textwrap.dedent(
+ """
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
server.listen(0, address='127.0.0.1')
IOLoop.current().run_sync(lambda: None)
print('012', end='')
- """)
+ """
+ )
out = self.run_subproc(code)
- self.assertEqual(''.join(sorted(out)), "012")
+ self.assertEqual("".join(sorted(out)), "012")
def test_simple(self):
- code = textwrap.dedent("""
+ code = textwrap.dedent(
+ """
from tornado.ioloop import IOLoop
from tornado.process import task_id
from tornado.tcpserver import TCPServer
server.start(3)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
- """)
+ """
+ )
out = self.run_subproc(code)
- self.assertEqual(''.join(sorted(out)), "012")
+ self.assertEqual("".join(sorted(out)), "012")
def test_advanced(self):
- code = textwrap.dedent("""
+ code = textwrap.dedent(
+ """
from tornado.ioloop import IOLoop
from tornado.netutil import bind_sockets
from tornado.process import fork_processes, task_id
server.add_sockets(sockets)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
- """)
+ """
+ )
out = self.run_subproc(code)
- self.assertEqual(''.join(sorted(out)), "012")
+ self.assertEqual("".join(sorted(out)), "012")
class TemplateTest(unittest.TestCase):
def test_simple(self):
template = Template("Hello {{ name }}!")
- self.assertEqual(template.generate(name="Ben"),
- b"Hello Ben!")
+ self.assertEqual(template.generate(name="Ben"), b"Hello Ben!")
def test_bytes(self):
template = Template("Hello {{ name }}!")
- self.assertEqual(template.generate(name=utf8("Ben")),
- b"Hello Ben!")
+ self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!")
def test_expressions(self):
template = Template("2 + 2 = {{ 2 + 2 }}")
def test_comment(self):
template = Template("Hello{# TODO i18n #} {{ name }}!")
- self.assertEqual(template.generate(name=utf8("Ben")),
- b"Hello Ben!")
+ self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!")
def test_include(self):
- loader = DictLoader({
- "index.html": '{% include "header.html" %}\nbody text',
- "header.html": "header text",
- })
- self.assertEqual(loader.load("index.html").generate(),
- b"header text\nbody text")
+ loader = DictLoader(
+ {
+ "index.html": '{% include "header.html" %}\nbody text',
+ "header.html": "header text",
+ }
+ )
+ self.assertEqual(
+ loader.load("index.html").generate(), b"header text\nbody text"
+ )
def test_extends(self):
- loader = DictLoader({
- "base.html": """\
+ loader = DictLoader(
+ {
+ "base.html": """\
<title>{% block title %}default title{% end %}</title>
<body>{% block body %}default body{% end %}</body>
""",
- "page.html": """\
+ "page.html": """\
{% extends "base.html" %}
{% block title %}page title{% end %}
{% block body %}page body{% end %}
""",
- })
- self.assertEqual(loader.load("page.html").generate(),
- b"<title>page title</title>\n<body>page body</body>\n")
+ }
+ )
+ self.assertEqual(
+ loader.load("page.html").generate(),
+ b"<title>page title</title>\n<body>page body</body>\n",
+ )
def test_relative_load(self):
- loader = DictLoader({
- "a/1.html": "{% include '2.html' %}",
- "a/2.html": "{% include '../b/3.html' %}",
- "b/3.html": "ok",
- })
- self.assertEqual(loader.load("a/1.html").generate(),
- b"ok")
+ loader = DictLoader(
+ {
+ "a/1.html": "{% include '2.html' %}",
+ "a/2.html": "{% include '../b/3.html' %}",
+ "b/3.html": "ok",
+ }
+ )
+ self.assertEqual(loader.load("a/1.html").generate(), b"ok")
def test_escaping(self):
self.assertRaises(ParseError, lambda: Template("{{"))
self.assertEqual(Template("{{!").generate(), b"{{")
self.assertEqual(Template("{%!").generate(), b"{%")
self.assertEqual(Template("{#!").generate(), b"{#")
- self.assertEqual(Template("{{ 'expr' }} {{!jquery expr}}").generate(),
- b"expr {{jquery expr}}")
+ self.assertEqual(
+ Template("{{ 'expr' }} {{!jquery expr}}").generate(),
+ b"expr {{jquery expr}}",
+ )
def test_unicode_template(self):
template = Template(utf8(u"\u00e9"))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_custom_namespace(self):
- loader = DictLoader({"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1})
+ loader = DictLoader(
+ {"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1}
+ )
self.assertEqual(loader.load("test.html").generate(), b"6")
def test_apply(self):
def upper(s):
return s.upper()
+
template = Template(utf8("{% apply upper %}foo{% end %}"))
self.assertEqual(template.generate(upper=upper), b"FOO")
def test_unicode_apply(self):
def upper(s):
return to_unicode(s).upper()
+
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_bytes_apply(self):
def upper(s):
return utf8(to_unicode(s).upper())
+
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
self.assertEqual(template.generate(), b"")
def test_try(self):
- template = Template(utf8("""{% try %}
+ template = Template(
+ utf8(
+ """{% try %}
try{% set y = 1/x %}
{% except %}-except
{% else %}-else
{% finally %}-finally
-{% end %}"""))
+{% end %}"""
+ )
+ )
self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n")
self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n")
self.assertEqual(template.generate(), b"foo")
def test_break_continue(self):
- template = Template(utf8("""\
+ template = Template(
+ utf8(
+ """\
{% for i in range(10) %}
{% if i == 2 %}
{% continue %}
{% if i == 6 %}
{% break %}
{% end %}
-{% end %}"""))
+{% end %}"""
+ )
+ )
result = template.generate()
# remove extraneous whitespace
- result = b''.join(result.split())
+ result = b"".join(result.split())
self.assertEqual(result, b"013456")
def test_break_outside_loop(self):
# This test verifies current behavior, although of course it would
# be nice if apply didn't cause seemingly unrelated breakage
try:
- Template(utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}"))
+ Template(
+ utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}")
+ )
raise Exception("Did not get expected exception")
except ParseError:
pass
- @unittest.skip('no testable future imports')
+ @unittest.skip("no testable future imports")
def test_no_inherit_future(self):
# TODO(bdarnell): make a test like this for one of the future
# imports available in python 3. Unfortunately they're harder
# This file has from __future__ import division...
self.assertEqual(1 / 2, 0.5)
# ...but the template doesn't
- template = Template('{{ 1 / 2 }}')
- self.assertEqual(template.generate(), '0')
+ template = Template("{{ 1 / 2 }}")
+ self.assertEqual(template.generate(), "0")
def test_non_ascii_name(self):
loader = DictLoader({u"t\u00e9st.html": "hello"})
class StackTraceTest(unittest.TestCase):
def test_error_line_number_expression(self):
- loader = DictLoader({"test.html": """one
+ loader = DictLoader(
+ {
+ "test.html": """one
two{{1/0}}
three
- """})
+ """
+ }
+ )
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_directive(self):
- loader = DictLoader({"test.html": """one
+ loader = DictLoader(
+ {
+ "test.html": """one
two{%if 1/0%}
three{%end%}
- """})
+ """
+ }
+ )
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
assert loader is not None
return loader.load(path).generate(**kwargs)
- loader = DictLoader({
- "base.html": "{% module Template('sub.html') %}",
- "sub.html": "{{1/0}}",
- }, namespace={"_tt_modules": ObjectDict(Template=load_generate)})
+ loader = DictLoader(
+ {"base.html": "{% module Template('sub.html') %}", "sub.html": "{{1/0}}"},
+ namespace={"_tt_modules": ObjectDict(Template=load_generate)},
+ )
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
- self.assertTrue('# base.html:1' in exc_stack)
- self.assertTrue('# sub.html:1' in exc_stack)
+ self.assertTrue("# base.html:1" in exc_stack)
+ self.assertTrue("# sub.html:1" in exc_stack)
def test_error_line_number_include(self):
- loader = DictLoader({
- "base.html": "{% include 'sub.html' %}",
- "sub.html": "{{1/0}}",
- })
+ loader = DictLoader(
+ {"base.html": "{% include 'sub.html' %}", "sub.html": "{{1/0}}"}
+ )
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
- self.assertTrue("# sub.html:1 (via base.html:1)" in
- traceback.format_exc())
+ self.assertTrue("# sub.html:1 (via base.html:1)" in traceback.format_exc())
def test_error_line_number_extends_base_error(self):
- loader = DictLoader({
- "base.html": "{{1/0}}",
- "sub.html": "{% extends 'base.html' %}",
- })
+ loader = DictLoader(
+ {"base.html": "{{1/0}}", "sub.html": "{% extends 'base.html' %}"}
+ )
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
self.assertTrue("# base.html:1" in exc_stack)
def test_error_line_number_extends_sub_error(self):
- loader = DictLoader({
- "base.html": "{% block 'block' %}{% end %}",
- "sub.html": """
+ loader = DictLoader(
+ {
+ "base.html": "{% block 'block' %}{% end %}",
+ "sub.html": """
{% extends 'base.html' %}
{% block 'block' %}
{{1/0}}
{% end %}
- """})
+ """,
+ }
+ )
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
- self.assertTrue("# sub.html:4 (via base.html:1)" in
- traceback.format_exc())
+ self.assertTrue("# sub.html:4 (via base.html:1)" in traceback.format_exc())
def test_multi_includes(self):
- loader = DictLoader({
- "a.html": "{% include 'b.html' %}",
- "b.html": "{% include 'c.html' %}",
- "c.html": "{{1/0}}",
- })
+ loader = DictLoader(
+ {
+ "a.html": "{% include 'b.html' %}",
+ "b.html": "{% include 'c.html' %}",
+ "c.html": "{{1/0}}",
+ }
+ )
try:
loader.load("a.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
- self.assertTrue("# c.html:1 (via b.html:1, a.html:1)" in
- traceback.format_exc())
+ self.assertTrue(
+ "# c.html:1 (via b.html:1, a.html:1)" in traceback.format_exc()
+ )
class ParseErrorDetailTest(unittest.TestCase):
def test_details(self):
- loader = DictLoader({
- "foo.html": "\n\n{{",
- })
+ loader = DictLoader({"foo.html": "\n\n{{"})
with self.assertRaises(ParseError) as cm:
loader.load("foo.html")
- self.assertEqual("Missing end expression }} at foo.html:3",
- str(cm.exception))
+ self.assertEqual("Missing end expression }} at foo.html:3", str(cm.exception))
self.assertEqual("foo.html", cm.exception.filename)
self.assertEqual(3, cm.exception.lineno)
"escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
"unescaped.html": "{% autoescape None %}{{ name }}",
"default.html": "{{ name }}",
-
"include.html": """\
escaped: {% include 'escaped.html' %}
unescaped: {% include 'unescaped.html' %}
default: {% include 'default.html' %}
""",
-
"escaped_block.html": """\
{% autoescape xhtml_escape %}\
{% block name %}base: {{ name }}{% end %}""",
"unescaped_block.html": """\
{% autoescape None %}\
{% block name %}base: {{ name }}{% end %}""",
-
# Extend a base template with different autoescape policy,
# with and without overriding the base's blocks
"escaped_extends_unescaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
-
"raw_expression.html": """\
{% autoescape xhtml_escape %}\
expr: {{ name }}
def test_default_off(self):
loader = DictLoader(self.templates, autoescape=None)
name = "Bobby <table>s"
- self.assertEqual(loader.load("escaped.html").generate(name=name),
- b"Bobby <table>s")
- self.assertEqual(loader.load("unescaped.html").generate(name=name),
- b"Bobby <table>s")
- self.assertEqual(loader.load("default.html").generate(name=name),
- b"Bobby <table>s")
-
- self.assertEqual(loader.load("include.html").generate(name=name),
- b"escaped: Bobby <table>s\n"
- b"unescaped: Bobby <table>s\n"
- b"default: Bobby <table>s\n")
+ self.assertEqual(
+ loader.load("escaped.html").generate(name=name), b"Bobby <table>s"
+ )
+ self.assertEqual(
+ loader.load("unescaped.html").generate(name=name), b"Bobby <table>s"
+ )
+ self.assertEqual(
+ loader.load("default.html").generate(name=name), b"Bobby <table>s"
+ )
+
+ self.assertEqual(
+ loader.load("include.html").generate(name=name),
+ b"escaped: Bobby <table>s\n"
+ b"unescaped: Bobby <table>s\n"
+ b"default: Bobby <table>s\n",
+ )
def test_default_on(self):
loader = DictLoader(self.templates, autoescape="xhtml_escape")
name = "Bobby <table>s"
- self.assertEqual(loader.load("escaped.html").generate(name=name),
- b"Bobby <table>s")
- self.assertEqual(loader.load("unescaped.html").generate(name=name),
- b"Bobby <table>s")
- self.assertEqual(loader.load("default.html").generate(name=name),
- b"Bobby <table>s")
-
- self.assertEqual(loader.load("include.html").generate(name=name),
- b"escaped: Bobby <table>s\n"
- b"unescaped: Bobby <table>s\n"
- b"default: Bobby <table>s\n")
+ self.assertEqual(
+ loader.load("escaped.html").generate(name=name), b"Bobby <table>s"
+ )
+ self.assertEqual(
+ loader.load("unescaped.html").generate(name=name), b"Bobby <table>s"
+ )
+ self.assertEqual(
+ loader.load("default.html").generate(name=name), b"Bobby <table>s"
+ )
+
+ self.assertEqual(
+ loader.load("include.html").generate(name=name),
+ b"escaped: Bobby <table>s\n"
+ b"unescaped: Bobby <table>s\n"
+ b"default: Bobby <table>s\n",
+ )
def test_unextended_block(self):
loader = DictLoader(self.templates)
name = "<script>"
- self.assertEqual(loader.load("escaped_block.html").generate(name=name),
- b"base: <script>")
- self.assertEqual(loader.load("unescaped_block.html").generate(name=name),
- b"base: <script>")
+ self.assertEqual(
+ loader.load("escaped_block.html").generate(name=name),
+ b"base: <script>",
+ )
+ self.assertEqual(
+ loader.load("unescaped_block.html").generate(name=name), b"base: <script>"
+ )
def test_extended_block(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name="<script>")
- self.assertEqual(render("escaped_extends_unescaped.html"),
- b"base: <script>")
- self.assertEqual(render("escaped_overrides_unescaped.html"),
- b"extended: <script>")
- self.assertEqual(render("unescaped_extends_escaped.html"),
- b"base: <script>")
- self.assertEqual(render("unescaped_overrides_escaped.html"),
- b"extended: <script>")
+ self.assertEqual(render("escaped_extends_unescaped.html"), b"base: <script>")
+ self.assertEqual(
+ render("escaped_overrides_unescaped.html"), b"extended: <script>"
+ )
+
+ self.assertEqual(
+ render("unescaped_extends_escaped.html"), b"base: <script>"
+ )
+ self.assertEqual(
+ render("unescaped_overrides_escaped.html"), b"extended: <script>"
+ )
def test_raw_expression(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name='<>&"')
- self.assertEqual(render("raw_expression.html"),
- b"expr: <>&"\n"
- b"raw: <>&\"")
+
+ self.assertEqual(
+ render("raw_expression.html"), b"expr: <>&"\n" b'raw: <>&"'
+ )
def test_custom_escape(self):
- loader = DictLoader({"foo.py":
- "{% autoescape py_escape %}s = {{ name }}\n"})
+ loader = DictLoader({"foo.py": "{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes)
return repr(native_str(s))
def render(template, name):
- return loader.load(template).generate(py_escape=py_escape,
- name=name)
- self.assertEqual(render("foo.py", "<html>"),
- b"s = '<html>'\n")
- self.assertEqual(render("foo.py", "';sys.exit()"),
- b"""s = "';sys.exit()"\n""")
- self.assertEqual(render("foo.py", ["not a string"]),
- b"""s = "['not a string']"\n""")
+ return loader.load(template).generate(py_escape=py_escape, name=name)
+
+ self.assertEqual(render("foo.py", "<html>"), b"s = '<html>'\n")
+ self.assertEqual(render("foo.py", "';sys.exit()"), b"""s = "';sys.exit()"\n""")
+ self.assertEqual(
+ render("foo.py", ["not a string"]), b"""s = "['not a string']"\n"""
+ )
def test_manual_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
- loader = DictLoader({'foo.txt': """\
+ loader = DictLoader(
+ {
+ "foo.txt": """\
{% for i in items
%}{% if i > 0 %}, {% end %}{#
#}{{i
}}{% end
-%}""",
- })
- self.assertEqual(loader.load("foo.txt").generate(items=range(5)),
- b"0, 1, 2, 3, 4")
+%}"""
+ }
+ )
+ self.assertEqual(
+ loader.load("foo.txt").generate(items=range(5)), b"0, 1, 2, 3, 4"
+ )
def test_whitespace_by_filename(self):
# Default whitespace handling depends on the template filename.
- loader = DictLoader({
- "foo.html": " \n\t\n asdf\t ",
- "bar.js": " \n\n\n\t qwer ",
- "baz.txt": "\t zxcv\n\n",
- "include.html": " {% include baz.txt %} \n ",
- "include.txt": "\t\t{% include foo.html %} ",
- })
+ loader = DictLoader(
+ {
+ "foo.html": " \n\t\n asdf\t ",
+ "bar.js": " \n\n\n\t qwer ",
+ "baz.txt": "\t zxcv\n\n",
+ "include.html": " {% include baz.txt %} \n ",
+ "include.txt": "\t\t{% include foo.html %} ",
+ }
+ )
# HTML and JS files have whitespace compressed by default.
- self.assertEqual(loader.load("foo.html").generate(),
- b"\nasdf ")
- self.assertEqual(loader.load("bar.js").generate(),
- b"\nqwer ")
+ self.assertEqual(loader.load("foo.html").generate(), b"\nasdf ")
+ self.assertEqual(loader.load("bar.js").generate(), b"\nqwer ")
# TXT files do not.
- self.assertEqual(loader.load("baz.txt").generate(),
- b"\t zxcv\n\n")
+ self.assertEqual(loader.load("baz.txt").generate(), b"\t zxcv\n\n")
# Each file maintains its own status even when included in
# a file of the other type.
- self.assertEqual(loader.load("include.html").generate(),
- b" \t zxcv\n\n\n")
- self.assertEqual(loader.load("include.txt").generate(),
- b"\t\t\nasdf ")
+ self.assertEqual(loader.load("include.html").generate(), b" \t zxcv\n\n\n")
+ self.assertEqual(loader.load("include.txt").generate(), b"\t\t\nasdf ")
def test_whitespace_by_loader(self):
- templates = {
- "foo.html": "\t\tfoo\n\n",
- "bar.txt": "\t\tbar\n\n",
- }
- loader = DictLoader(templates, whitespace='all')
+ templates = {"foo.html": "\t\tfoo\n\n", "bar.txt": "\t\tbar\n\n"}
+ loader = DictLoader(templates, whitespace="all")
self.assertEqual(loader.load("foo.html").generate(), b"\t\tfoo\n\n")
self.assertEqual(loader.load("bar.txt").generate(), b"\t\tbar\n\n")
- loader = DictLoader(templates, whitespace='single')
+ loader = DictLoader(templates, whitespace="single")
self.assertEqual(loader.load("foo.html").generate(), b" foo\n")
self.assertEqual(loader.load("bar.txt").generate(), b" bar\n")
- loader = DictLoader(templates, whitespace='oneline')
+ loader = DictLoader(templates, whitespace="oneline")
self.assertEqual(loader.load("foo.html").generate(), b" foo ")
self.assertEqual(loader.load("bar.txt").generate(), b" bar ")
def test_whitespace_directive(self):
- loader = DictLoader({
- "foo.html": """\
+ loader = DictLoader(
+ {
+ "foo.html": """\
{% whitespace oneline %}
{% for i in range(3) %}
{{ i }}
{% end %}
{% whitespace all %}
pre\tformatted
-"""})
- self.assertEqual(loader.load("foo.html").generate(),
- b" 0 1 2 \n pre\tformatted\n")
+"""
+ }
+ )
+ self.assertEqual(
+ loader.load("foo.html").generate(), b" 0 1 2 \n pre\tformatted\n"
+ )
class TemplateLoaderTest(unittest.TestCase):
# Timeout set with environment variable
self.io_loop.add_timeout(time() + 1, self.stop)
- with set_environ('ASYNC_TEST_TIMEOUT', '0.01'):
+ with set_environ("ASYNC_TEST_TIMEOUT", "0.01"):
with self.assertRaises(self.failureException):
self.wait()
return Application()
def test_fetch_segment(self):
- path = '/path'
+ path = "/path"
response = self.fetch(path)
self.assertEqual(response.request.url, self.get_url(path))
def test_fetch_full_http_url(self):
# Ensure that self.fetch() recognizes absolute urls and does
# not transform them into references to our main test server.
- path = 'http://localhost:%d/path' % self.second_port
+ path = "http://localhost:%d/path" % self.second_port
response = self.fetch(path)
self.assertEqual(response.request.url, path)
class Test(AsyncTestCase):
def test_gen(self):
yield
- test = Test('test_gen')
+
+ test = Test("test_gen")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
- @unittest.skipIf(platform.python_implementation() == 'PyPy',
- 'pypy destructor warnings cannot be silenced')
+ @unittest.skipIf(
+ platform.python_implementation() == "PyPy",
+ "pypy destructor warnings cannot be silenced",
+ )
def test_undecorated_coroutine(self):
class Test(AsyncTestCase):
async def test_coro(self):
pass
- test = Test('test_coro')
+ test = Test("test_coro")
result = unittest.TestResult()
# Silence "RuntimeWarning: coroutine 'test_coro' was never awaited".
with warnings.catch_warnings():
- warnings.simplefilter('ignore')
+ warnings.simplefilter("ignore")
test.run(result)
self.assertEqual(len(result.errors), 1)
@unittest.skip("don't run this")
def test_gen(self):
yield
- test = Test('test_gen')
+
+ test = Test("test_gen")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 0)
class Test(AsyncTestCase):
def test_other_return(self):
return 42
- test = Test('test_other_return')
+
+ test = Test("test_other_return")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
class SetUpTearDown(unittest.TestCase):
def setUp(self):
- events.append('setUp')
+ events.append("setUp")
def tearDown(self):
- events.append('tearDown')
+ events.append("tearDown")
class InheritBoth(AsyncTestCase, SetUpTearDown):
def test(self):
- events.append('test')
+ events.append("test")
- InheritBoth('test').run(result)
- expected = ['setUp', 'test', 'tearDown']
+ InheritBoth("test").run(result)
+ expected = ["setUp", "test", "tearDown"]
self.assertEqual(expected, events)
except ioloop.TimeoutError:
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
- self.assertIn(
- "gen.sleep(1)",
- traceback.format_exc())
+ self.assertIn("gen.sleep(1)", traceback.format_exc())
self.finished = True
yield gen.sleep(0.25)
# Uses provided timeout of 0.5 seconds, doesn't time out.
- with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
+ with set_environ("ASYNC_TEST_TIMEOUT", "0.1"):
test_long_timeout(self)
self.finished = True
yield gen.sleep(1)
# Uses environment-variable timeout of 0.1, times out.
- with set_environ('ASYNC_TEST_TIMEOUT', '0.1'):
+ with set_environ("ASYNC_TEST_TIMEOUT", "0.1"):
with self.assertRaises(ioloop.TimeoutError):
test_short_timeout(self)
def test_with_method_args(self):
@gen_test
def test_with_args(self, *args):
- self.assertEqual(args, ('test',))
+ self.assertEqual(args, ("test",))
yield gen.moment
- test_with_args(self, 'test')
+ test_with_args(self, "test")
self.finished = True
def test_with_method_kwargs(self):
@gen_test
def test_with_kwargs(self, **kwargs):
- self.assertDictEqual(kwargs, {'test': 'test'})
+ self.assertDictEqual(kwargs, {"test": "test"})
yield gen.moment
- test_with_kwargs(self, test='test')
+ test_with_kwargs(self, test="test")
self.finished = True
def test_native_coroutine(self):
@gen_test
async def test(self):
self.finished = True
+
test(self)
def test_native_coroutine_timeout(self):
self.assertIs(self.io_loop.asyncio_loop, self.new_loop)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
from tornado.web import RequestHandler, Application
try:
- from twisted.internet.defer import Deferred, inlineCallbacks, returnValue # type: ignore
+ from twisted.internet.defer import ( # type: ignore
+ Deferred,
+ inlineCallbacks,
+ returnValue,
+ )
from twisted.internet.protocol import Protocol # type: ignore
from twisted.internet.asyncioreactor import AsyncioSelectorReactor # type: ignore
from twisted.web.client import Agent, readBody # type: ignore
# Not used directly but needed for `yield deferred` to work.
import tornado.platform.twisted # noqa: F401
-skipIfNoTwisted = unittest.skipUnless(have_twisted,
- "twisted module not present")
+skipIfNoTwisted = unittest.skipUnless(have_twisted, "twisted module not present")
def save_signal_handlers():
def render_GET(self, request):
return b"Hello from twisted!"
+
site = Site(HelloResource())
- port = self.reactor.listenTCP(0, site, interface='127.0.0.1')
+ port = self.reactor.listenTCP(0, site, interface="127.0.0.1")
self.twisted_port = port.getHost().port
def start_tornado_server(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello from tornado!")
- app = Application([('/', HelloHandler)],
- log_function=lambda x: None)
+
+ app = Application([("/", HelloHandler)], log_function=lambda x: None)
server = HTTPServer(app)
sock, self.tornado_port = bind_unused_port()
server.add_sockets([sock])
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
- d = client.request(b'GET', utf8(url))
+ d = client.request(b"GET", utf8(url))
class Accumulator(Protocol):
def __init__(self, finished):
finished = Deferred()
response.deliverBody(Accumulator(finished))
return finished
+
d.addCallback(callback)
def shutdown(failure):
- if hasattr(self, 'stop_loop'):
+ if hasattr(self, "stop_loop"):
self.stop_loop()
elif failure is not None:
# loop hasn't been initialized yet; try our best to
try:
failure.raiseException()
except:
- logging.error('exception before starting loop', exc_info=True)
+ logging.error("exception before starting loop", exc_info=True)
+
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
- return b''.join(chunks)
+ return b"".join(chunks)
def twisted_coroutine_fetch(self, url, runner):
body = [None]
# by reading the body in one blob instead of streaming it with
# a Protocol.
client = Agent(self.reactor)
- response = yield client.request(b'GET', utf8(url))
+ response = yield client.request(b"GET", utf8(url))
with warnings.catch_warnings():
# readBody has a buggy DeprecationWarning in Twisted 15.0:
# https://twistedmatrix.com/trac/changeset/43379
- warnings.simplefilter('ignore', category=DeprecationWarning)
+ warnings.simplefilter("ignore", category=DeprecationWarning)
body[0] = yield readBody(response)
self.stop_loop()
+
self.io_loop.add_callback(f)
runner()
return body[0]
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
- 'http://127.0.0.1:%d' % self.twisted_port, self.run_reactor)
- self.assertEqual(response.body, b'Hello from twisted!')
+ "http://127.0.0.1:%d" % self.twisted_port, self.run_reactor
+ )
+ self.assertEqual(response.body, b"Hello from twisted!")
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
- 'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
- self.assertEqual(response, b'Hello from tornado!')
+ "http://127.0.0.1:%d" % self.tornado_port, self.run_reactor
+ )
+ self.assertEqual(response, b"Hello from tornado!")
def testTornadoServerTwistedCoroutineClientReactor(self):
self.start_tornado_server()
response = self.twisted_coroutine_fetch(
- 'http://127.0.0.1:%d' % self.tornado_port, self.run_reactor)
- self.assertEqual(response, b'Hello from tornado!')
+ "http://127.0.0.1:%d" % self.tornado_port, self.run_reactor
+ )
+ self.assertEqual(response, b"Hello from tornado!")
@skipIfNoTwisted
# must have a yield even if it's unreachable.
yield
returnValue(42)
+
res = yield fn()
self.assertEqual(res, 42)
if False:
yield
1 / 0
+
with self.assertRaises(ZeroDivisionError):
yield fn()
from tornado.testing import bind_unused_port
-skipIfNonUnix = unittest.skipIf(os.name != 'posix' or sys.platform == 'cygwin',
- "non-unix platform")
+skipIfNonUnix = unittest.skipIf(
+ os.name != "posix" or sys.platform == "cygwin", "non-unix platform"
+)
# travis-ci.org runs our tests in an overworked virtual machine, which makes
# timing-related tests unreliable.
-skipOnTravis = unittest.skipIf('TRAVIS' in os.environ,
- 'timing tests unreliable on travis')
+skipOnTravis = unittest.skipIf(
+ "TRAVIS" in os.environ, "timing tests unreliable on travis"
+)
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
-skipIfNoNetwork = unittest.skipIf('NO_NETWORK' in os.environ,
- 'network access disabled')
+skipIfNoNetwork = unittest.skipIf("NO_NETWORK" in os.environ, "network access disabled")
-skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
- 'Not CPython implementation')
+skipNotCPython = unittest.skipIf(
+ platform.python_implementation() != "CPython", "Not CPython implementation"
+)
# Used for tests affected by
# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in
# TODO: remove this after pypy3 5.8 is obsolete.
-skipPypy3V58 = unittest.skipIf(platform.python_implementation() == 'PyPy' and
- sys.version_info > (3,) and
- sys.pypy_version_info < (5, 9), # type: ignore
- 'pypy3 5.8 has buggy ssl module')
+skipPypy3V58 = unittest.skipIf(
+ platform.python_implementation() == "PyPy"
+ and sys.version_info > (3,)
+ and sys.pypy_version_info < (5, 9), # type: ignore
+ "pypy3 5.8 has buggy ssl module",
+)
def _detect_ipv6():
sock = None
try:
sock = socket.socket(socket.AF_INET6)
- sock.bind(('::1', 0))
+ sock.bind(("::1", 0))
except socket.error:
return False
finally:
return True
-skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), 'ipv6 support not present')
+skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), "ipv6 support not present")
def refusing_port():
def ignore_deprecation():
"""Context manager to ignore deprecation warnings."""
with warnings.catch_warnings():
- warnings.simplefilter('ignore', DeprecationWarning)
+ warnings.simplefilter("ignore", DeprecationWarning)
yield
import tornado.escape
from tornado.escape import utf8
from tornado.util import (
- raise_exc_info, Configurable, exec_in, ArgReplacer,
- timedelta_to_seconds, import_object, re_unescape, is_finalizing
+ raise_exc_info,
+ Configurable,
+ exec_in,
+ ArgReplacer,
+ timedelta_to_seconds,
+ import_object,
+ re_unescape,
+ is_finalizing,
)
import typing
self.checkSubclasses()
def test_config_str(self):
- TestConfigurable.configure('tornado.test.util_test.TestConfig2')
+ TestConfigurable.configure("tornado.test.util_test.TestConfig2")
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
class UnicodeLiteralTest(unittest.TestCase):
def test_unicode_escapes(self):
- self.assertEqual(utf8(u'\u00e9'), b'\xc3\xa9')
+ self.assertEqual(utf8(u"\u00e9"), b"\xc3\xa9")
class ExecInTest(unittest.TestCase):
# TODO(bdarnell): make a version of this test for one of the new
# future imports available in python 3.
- @unittest.skip('no testable future imports')
+ @unittest.skip("no testable future imports")
def test_no_inherit_future(self):
# This file has from __future__ import print_function...
f = StringIO()
- print('hello', file=f)
+ print("hello", file=f)
# ...but the template doesn't
exec_in('print >> f, "world"', dict(f=f))
- self.assertEqual(f.getvalue(), 'hello\nworld\n')
+ self.assertEqual(f.getvalue(), "hello\nworld\n")
class ArgReplacerTest(unittest.TestCase):
def setUp(self):
def function(x, y, callback=None, z=None):
pass
- self.replacer = ArgReplacer(function, 'callback')
+
+ self.replacer = ArgReplacer(function, "callback")
def test_omitted(self):
args = (1, 2)
kwargs = dict() # type: Dict[str, Any]
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
- self.assertEqual(self.replacer.replace('new', args, kwargs),
- (None, (1, 2), dict(callback='new')))
+ self.assertEqual(
+ self.replacer.replace("new", args, kwargs),
+ (None, (1, 2), dict(callback="new")),
+ )
def test_position(self):
- args = (1, 2, 'old', 3)
+ args = (1, 2, "old", 3)
kwargs = dict() # type: Dict[str, Any]
- self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
- self.assertEqual(self.replacer.replace('new', args, kwargs),
- ('old', [1, 2, 'new', 3], dict()))
+ self.assertEqual(self.replacer.get_old_value(args, kwargs), "old")
+ self.assertEqual(
+ self.replacer.replace("new", args, kwargs),
+ ("old", [1, 2, "new", 3], dict()),
+ )
def test_keyword(self):
args = (1,)
- kwargs = dict(y=2, callback='old', z=3)
- self.assertEqual(self.replacer.get_old_value(args, kwargs), 'old')
- self.assertEqual(self.replacer.replace('new', args, kwargs),
- ('old', (1,), dict(y=2, callback='new', z=3)))
+ kwargs = dict(y=2, callback="old", z=3)
+ self.assertEqual(self.replacer.get_old_value(args, kwargs), "old")
+ self.assertEqual(
+ self.replacer.replace("new", args, kwargs),
+ ("old", (1,), dict(y=2, callback="new", z=3)),
+ )
class TimedeltaToSecondsTest(unittest.TestCase):
class ImportObjectTest(unittest.TestCase):
def test_import_member(self):
- self.assertIs(import_object('tornado.escape.utf8'), utf8)
+ self.assertIs(import_object("tornado.escape.utf8"), utf8)
def test_import_member_unicode(self):
- self.assertIs(import_object(u'tornado.escape.utf8'), utf8)
+ self.assertIs(import_object(u"tornado.escape.utf8"), utf8)
def test_import_module(self):
- self.assertIs(import_object('tornado.escape'), tornado.escape)
+ self.assertIs(import_object("tornado.escape"), tornado.escape)
def test_import_module_unicode(self):
# The internal implementation of __import__ differs depending on
# whether the thing being imported is a module or not.
# This variant requires a byte string in python 2.
- self.assertIs(import_object(u'tornado.escape'), tornado.escape)
+ self.assertIs(import_object(u"tornado.escape"), tornado.escape)
class ReUnescapeTest(unittest.TestCase):
def test_re_unescape(self):
- test_strings = (
- '/favicon.ico',
- 'index.html',
- 'Hello, World!',
- '!$@#%;',
- )
+ test_strings = ("/favicon.ico", "index.html", "Hello, World!", "!$@#%;")
for string in test_strings:
self.assertEqual(string, re_unescape(re.escape(string)))
def test_re_unescape_raises_error_on_invalid_input(self):
with self.assertRaises(ValueError):
- re_unescape('\\d')
+ re_unescape("\\d")
with self.assertRaises(ValueError):
- re_unescape('\\b')
+ re_unescape("\\b")
with self.assertRaises(ValueError):
- re_unescape('\\Z')
+ re_unescape("\\Z")
class IsFinalizingTest(unittest.TestCase):
from tornado.concurrent import Future
from tornado import gen
-from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring # noqa: E501
+from tornado.escape import (
+ json_decode,
+ utf8,
+ to_unicode,
+ recursive_unicode,
+ native_str,
+ to_basestring,
+) # noqa: E501
from tornado.httpclient import HTTPClientError
from tornado.httputil import format_timestamp
from tornado.iostream import IOStream
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
from tornado.util import ObjectDict, unicode_type
from tornado.web import (
- Application, RequestHandler, StaticFileHandler, RedirectHandler as WebRedirectHandler,
- HTTPError, MissingArgumentError, ErrorHandler, authenticated, url,
- _create_signature_v1, create_signed_value, decode_signed_value, get_signature_key_version,
- UIModule, Finish, stream_request_body, removeslash, addslash, GZipContentEncoding,
+ Application,
+ RequestHandler,
+ StaticFileHandler,
+ RedirectHandler as WebRedirectHandler,
+ HTTPError,
+ MissingArgumentError,
+ ErrorHandler,
+ authenticated,
+ url,
+ _create_signature_v1,
+ create_signed_value,
+ decode_signed_value,
+ get_signature_key_version,
+ UIModule,
+ Finish,
+ stream_request_body,
+ removeslash,
+ addslash,
+ GZipContentEncoding,
)
import binascii
Override get_handlers and get_app_kwargs instead of get_app.
This class is deprecated since WSGI mode is no longer supported.
"""
+
def get_app(self):
self.app = Application(self.get_handlers(), **self.get_app_kwargs())
return self.app
To use, define a nested class named ``Handler``.
"""
+
def get_handlers(self):
- return [('/', self.Handler)]
+ return [("/", self.Handler)]
class HelloHandler(RequestHandler):
def get(self):
- self.write('hello')
+ self.write("hello")
class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
- def __init__(self, cookie_secret='0123456789', key_version=None):
+ def __init__(self, cookie_secret="0123456789", key_version=None):
# don't call super.__init__
self._cookies = {} # type: typing.Dict[str, bytes]
if key_version is None:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
else:
- self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
- key_version=key_version))
+ self.application = ObjectDict(
+ settings=dict(cookie_secret=cookie_secret, key_version=key_version)
+ )
def get_cookie(self, name):
return self._cookies.get(name)
class SecureCookieV1Test(unittest.TestCase):
def test_round_trip(self):
handler = CookieTestRequestHandler()
- handler.set_secure_cookie('foo', b'bar', version=1)
- self.assertEqual(handler.get_secure_cookie('foo', min_version=1),
- b'bar')
+ handler.set_secure_cookie("foo", b"bar", version=1)
+ self.assertEqual(handler.get_secure_cookie("foo", min_version=1), b"bar")
def test_cookie_tampering_future_timestamp(self):
handler = CookieTestRequestHandler()
# this string base64-encodes to '12345678'
- handler.set_secure_cookie('foo', binascii.a2b_hex(b'd76df8e7aefc'),
- version=1)
- cookie = handler._cookies['foo']
- match = re.match(br'12345678\|([0-9]+)\|([0-9a-f]+)', cookie)
+ handler.set_secure_cookie("foo", binascii.a2b_hex(b"d76df8e7aefc"), version=1)
+ cookie = handler._cookies["foo"]
+ match = re.match(br"12345678\|([0-9]+)\|([0-9a-f]+)", cookie)
assert match is not None
timestamp = match.group(1)
sig = match.group(2)
self.assertEqual(
- _create_signature_v1(handler.application.settings["cookie_secret"],
- 'foo', '12345678', timestamp),
- sig)
+ _create_signature_v1(
+ handler.application.settings["cookie_secret"],
+ "foo",
+ "12345678",
+ timestamp,
+ ),
+ sig,
+ )
# shifting digits from payload to timestamp doesn't alter signature
# (this is not desirable behavior, just confirming that that's how it
# works)
self.assertEqual(
- _create_signature_v1(handler.application.settings["cookie_secret"],
- 'foo', '1234', b'5678' + timestamp),
- sig)
+ _create_signature_v1(
+ handler.application.settings["cookie_secret"],
+ "foo",
+ "1234",
+ b"5678" + timestamp,
+ ),
+ sig,
+ )
# tamper with the cookie
- handler._cookies['foo'] = utf8('1234|5678%s|%s' % (
- to_basestring(timestamp), to_basestring(sig)))
+ handler._cookies["foo"] = utf8(
+ "1234|5678%s|%s" % (to_basestring(timestamp), to_basestring(sig))
+ )
# it gets rejected
with ExpectLog(gen_log, "Cookie timestamp in future"):
- self.assertTrue(
- handler.get_secure_cookie('foo', min_version=1) is None)
+ self.assertTrue(handler.get_secure_cookie("foo", min_version=1) is None)
def test_arbitrary_bytes(self):
# Secure cookies accept arbitrary data (which is base64 encoded).
# Note that normal cookies accept only a subset of ascii.
handler = CookieTestRequestHandler()
- handler.set_secure_cookie('foo', b'\xe9', version=1)
- self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')
+ handler.set_secure_cookie("foo", b"\xe9", version=1)
+ self.assertEqual(handler.get_secure_cookie("foo", min_version=1), b"\xe9")
# See SignedValueTest below for more.
class SecureCookieV2Test(unittest.TestCase):
- KEY_VERSIONS = {
- 0: 'ajklasdf0ojaisdf',
- 1: 'aslkjasaolwkjsdf'
- }
+ KEY_VERSIONS = {0: "ajklasdf0ojaisdf", 1: "aslkjasaolwkjsdf"}
def test_round_trip(self):
handler = CookieTestRequestHandler()
- handler.set_secure_cookie('foo', b'bar', version=2)
- self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')
+ handler.set_secure_cookie("foo", b"bar", version=2)
+ self.assertEqual(handler.get_secure_cookie("foo", min_version=2), b"bar")
def test_key_version_roundtrip(self):
- handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
- key_version=0)
- handler.set_secure_cookie('foo', b'bar')
- self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
+ handler = CookieTestRequestHandler(
+ cookie_secret=self.KEY_VERSIONS, key_version=0
+ )
+ handler.set_secure_cookie("foo", b"bar")
+ self.assertEqual(handler.get_secure_cookie("foo"), b"bar")
def test_key_version_roundtrip_differing_version(self):
- handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
- key_version=1)
- handler.set_secure_cookie('foo', b'bar')
- self.assertEqual(handler.get_secure_cookie('foo'), b'bar')
+ handler = CookieTestRequestHandler(
+ cookie_secret=self.KEY_VERSIONS, key_version=1
+ )
+ handler.set_secure_cookie("foo", b"bar")
+ self.assertEqual(handler.get_secure_cookie("foo"), b"bar")
def test_key_version_increment_version(self):
- handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
- key_version=0)
- handler.set_secure_cookie('foo', b'bar')
- new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
- key_version=1)
+ handler = CookieTestRequestHandler(
+ cookie_secret=self.KEY_VERSIONS, key_version=0
+ )
+ handler.set_secure_cookie("foo", b"bar")
+ new_handler = CookieTestRequestHandler(
+ cookie_secret=self.KEY_VERSIONS, key_version=1
+ )
new_handler._cookies = handler._cookies
- self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')
+ self.assertEqual(new_handler.get_secure_cookie("foo"), b"bar")
def test_key_version_invalidate_version(self):
- handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
- key_version=0)
- handler.set_secure_cookie('foo', b'bar')
+ handler = CookieTestRequestHandler(
+ cookie_secret=self.KEY_VERSIONS, key_version=0
+ )
+ handler.set_secure_cookie("foo", b"bar")
new_key_versions = self.KEY_VERSIONS.copy()
new_key_versions.pop(0)
- new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
- key_version=1)
+ new_handler = CookieTestRequestHandler(
+ cookie_secret=new_key_versions, key_version=1
+ )
new_handler._cookies = handler._cookies
- self.assertEqual(new_handler.get_secure_cookie('foo'), None)
+ self.assertEqual(new_handler.get_secure_cookie("foo"), None)
class FinalReturnTest(WebTestCase):
class RenderHandler(RequestHandler):
def create_template_loader(self, path):
- return DictLoader({'foo.html': 'hi'})
+ return DictLoader({"foo.html": "hi"})
@gen.coroutine
def get(self):
- test.final_return = self.render('foo.html')
+ test.final_return = self.render("foo.html")
- return [("/finish", FinishHandler),
- ("/render", RenderHandler)]
+ return [("/finish", FinishHandler), ("/render", RenderHandler)]
def get_app_kwargs(self):
- return dict(template_path='FinalReturnTest')
+ return dict(template_path="FinalReturnTest")
def test_finish_method_return_future(self):
- response = self.fetch(self.get_url('/finish'))
+ response = self.fetch(self.get_url("/finish"))
self.assertEqual(response.code, 200)
self.assertIsInstance(self.final_return, Future)
self.assertTrue(self.final_return.done())
def test_render_method_return_future(self):
- response = self.fetch(self.get_url('/render'))
+ response = self.fetch(self.get_url("/render"))
self.assertEqual(response.code, 200)
self.assertIsInstance(self.final_return, Future)
def get(self):
# unicode domain and path arguments shouldn't break things
# either (see bug #285)
- self.set_cookie("unicode_args", "blah", domain=u"foo.com",
- path=u"/foo")
+ self.set_cookie("unicode_args", "blah", domain=u"foo.com", path=u"/foo")
class SetCookieSpecialCharHandler(RequestHandler):
def get(self):
self.set_cookie("c", "1", httponly=True)
self.set_cookie("d", "1", httponly=False)
- return [("/set", SetCookieHandler),
- ("/get", GetCookieHandler),
- ("/set_domain", SetCookieDomainHandler),
- ("/special_char", SetCookieSpecialCharHandler),
- ("/set_overwrite", SetCookieOverwriteHandler),
- ("/set_max_age", SetCookieMaxAgeHandler),
- ("/set_expires_days", SetCookieExpiresDaysHandler),
- ("/set_falsy_flags", SetCookieFalsyFlags)
- ]
+ return [
+ ("/set", SetCookieHandler),
+ ("/get", GetCookieHandler),
+ ("/set_domain", SetCookieDomainHandler),
+ ("/special_char", SetCookieSpecialCharHandler),
+ ("/set_overwrite", SetCookieOverwriteHandler),
+ ("/set_max_age", SetCookieMaxAgeHandler),
+ ("/set_expires_days", SetCookieExpiresDaysHandler),
+ ("/set_falsy_flags", SetCookieFalsyFlags),
+ ]
def test_set_cookie(self):
response = self.fetch("/set")
- self.assertEqual(sorted(response.headers.get_list("Set-Cookie")),
- ["bytes=zxcv; Path=/",
- "str=asdf; Path=/",
- "unicode=qwer; Path=/",
- ])
+ self.assertEqual(
+ sorted(response.headers.get_list("Set-Cookie")),
+ ["bytes=zxcv; Path=/", "str=asdf; Path=/", "unicode=qwer; Path=/"],
+ )
def test_get_cookie(self):
response = self.fetch("/get", headers={"Cookie": "foo=bar"})
def test_set_cookie_domain(self):
response = self.fetch("/set_domain")
- self.assertEqual(response.headers.get_list("Set-Cookie"),
- ["unicode_args=blah; Domain=foo.com; Path=/foo"])
+ self.assertEqual(
+ response.headers.get_list("Set-Cookie"),
+ ["unicode_args=blah; Domain=foo.com; Path=/foo"],
+ )
def test_cookie_special_char(self):
response = self.fetch("/special_char")
self.assertEqual(headers[0], 'equals="a=b"; Path=/')
self.assertEqual(headers[1], 'quote="a\\"b"; Path=/')
# python 2.7 octal-escapes the semicolon; older versions leave it alone
- self.assertTrue(headers[2] in ('semicolon="a;b"; Path=/',
- 'semicolon="a\\073b"; Path=/'),
- headers[2])
-
- data = [('foo=a=b', 'a=b'),
- ('foo="a=b"', 'a=b'),
- ('foo="a;b"', '"a'), # even quoted, ";" is a delimiter
- ('foo=a\\073b', 'a\\073b'), # escapes only decoded in quotes
- ('foo="a\\073b"', 'a;b'),
- ('foo="a\\"b"', 'a"b'),
- ]
+ self.assertTrue(
+ headers[2] in ('semicolon="a;b"; Path=/', 'semicolon="a\\073b"; Path=/'),
+ headers[2],
+ )
+
+ data = [
+ ("foo=a=b", "a=b"),
+ ('foo="a=b"', "a=b"),
+ ('foo="a;b"', '"a'), # even quoted, ";" is a delimiter
+ ("foo=a\\073b", "a\\073b"), # escapes only decoded in quotes
+ ('foo="a\\073b"', "a;b"),
+ ('foo="a\\"b"', 'a"b'),
+ ]
for header, expected in data:
logging.debug("trying %r", header)
response = self.fetch("/get", headers={"Cookie": header})
def test_set_cookie_overwrite(self):
response = self.fetch("/set_overwrite")
headers = response.headers.get_list("Set-Cookie")
- self.assertEqual(sorted(headers),
- ["a=e; Path=/", "c=d; Domain=example.com; Path=/"])
+ self.assertEqual(
+ sorted(headers), ["a=e; Path=/", "c=d; Domain=example.com; Path=/"]
+ )
def test_set_cookie_max_age(self):
response = self.fetch("/set_max_age")
headers = response.headers.get_list("Set-Cookie")
- self.assertEqual(sorted(headers),
- ["foo=bar; Max-Age=10; Path=/"])
+ self.assertEqual(sorted(headers), ["foo=bar; Max-Age=10; Path=/"])
def test_set_cookie_expires_days(self):
response = self.fetch("/set_expires_days")
headers = sorted(response.headers.get_list("Set-Cookie"))
# The secure and httponly headers are capitalized in py35 and
# lowercase in older versions.
- self.assertEqual(headers[0].lower(), 'a=1; path=/; secure')
- self.assertEqual(headers[1].lower(), 'b=1; path=/')
- self.assertEqual(headers[2].lower(), 'c=1; httponly; path=/')
- self.assertEqual(headers[3].lower(), 'd=1; path=/')
+ self.assertEqual(headers[0].lower(), "a=1; path=/; secure")
+ self.assertEqual(headers[1].lower(), "b=1; path=/")
+ self.assertEqual(headers[2].lower(), "c=1; httponly; path=/")
+ self.assertEqual(headers[3].lower(), "d=1; path=/")
class AuthRedirectRequestHandler(RequestHandler):
class AuthRedirectTest(WebTestCase):
def get_handlers(self):
- return [('/relative', AuthRedirectRequestHandler,
- dict(login_url='/login')),
- ('/absolute', AuthRedirectRequestHandler,
- dict(login_url='http://example.com/login'))]
+ return [
+ ("/relative", AuthRedirectRequestHandler, dict(login_url="/login")),
+ (
+ "/absolute",
+ AuthRedirectRequestHandler,
+ dict(login_url="http://example.com/login"),
+ ),
+ ]
def test_relative_auth_redirect(self):
- response = self.fetch(self.get_url('/relative'),
- follow_redirects=False)
+ response = self.fetch(self.get_url("/relative"), follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertEqual(response.headers['Location'], '/login?next=%2Frelative')
+ self.assertEqual(response.headers["Location"], "/login?next=%2Frelative")
def test_absolute_auth_redirect(self):
- response = self.fetch(self.get_url('/absolute'),
- follow_redirects=False)
+ response = self.fetch(self.get_url("/absolute"), follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertTrue(re.match(
- 'http://example.com/login\?next=http%3A%2F%2F127.0.0.1%3A[0-9]+%2Fabsolute',
- response.headers['Location']), response.headers['Location'])
+ self.assertTrue(
+ re.match(
+ "http://example.com/login\?next=http%3A%2F%2F127.0.0.1%3A[0-9]+%2Fabsolute",
+ response.headers["Location"],
+ ),
+ response.headers["Location"],
+ )
class ConnectionCloseHandler(RequestHandler):
class ConnectionCloseTest(WebTestCase):
def get_handlers(self):
- return [('/', ConnectionCloseHandler, dict(test=self))]
+ return [("/", ConnectionCloseHandler, dict(test=self))]
def test_connection_close(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.wait()
def on_handler_waiting(self):
- logging.debug('handler waiting')
+ logging.debug("handler waiting")
self.stream.close()
def on_connection_close(self):
- logging.debug('connection closed')
+ logging.debug("connection closed")
self.stop()
raise Exception("incorrect type for key: %r" % type(key))
for value in self.request.arguments[key]:
if type(value) != bytes:
- raise Exception("incorrect type for value: %r" %
- type(value))
+ raise Exception("incorrect type for value: %r" % type(value))
for value in self.get_arguments(key):
if type(value) != unicode_type:
- raise Exception("incorrect type for value: %r" %
- type(value))
+ raise Exception("incorrect type for value: %r" % type(value))
for arg in path_args:
if type(arg) != unicode_type:
raise Exception("incorrect type for path arg: %r" % type(arg))
- self.write(dict(path=self.request.path,
- path_args=path_args,
- args=recursive_unicode(self.request.arguments)))
+ self.write(
+ dict(
+ path=self.request.path,
+ path_args=path_args,
+ args=recursive_unicode(self.request.arguments),
+ )
+ )
class RequestEncodingTest(WebTestCase):
def get_handlers(self):
- return [("/group/(.*)", EchoHandler),
- ("/slashes/([^/]*)/([^/]*)", EchoHandler),
- ]
+ return [("/group/(.*)", EchoHandler), ("/slashes/([^/]*)/([^/]*)", EchoHandler)]
def fetch_json(self, path):
return json_decode(self.fetch(path).body)
def test_group_question_mark(self):
# Ensure that url-encoded question marks are handled properly
- self.assertEqual(self.fetch_json('/group/%3F'),
- dict(path='/group/%3F', path_args=['?'], args={}))
- self.assertEqual(self.fetch_json('/group/%3F?%3F=%3F'),
- dict(path='/group/%3F', path_args=['?'], args={'?': ['?']}))
+ self.assertEqual(
+ self.fetch_json("/group/%3F"),
+ dict(path="/group/%3F", path_args=["?"], args={}),
+ )
+ self.assertEqual(
+ self.fetch_json("/group/%3F?%3F=%3F"),
+ dict(path="/group/%3F", path_args=["?"], args={"?": ["?"]}),
+ )
def test_group_encoding(self):
# Path components and query arguments should be decoded the same way
- self.assertEqual(self.fetch_json('/group/%C3%A9?arg=%C3%A9'),
- {u"path": u"/group/%C3%A9",
- u"path_args": [u"\u00e9"],
- u"args": {u"arg": [u"\u00e9"]}})
+ self.assertEqual(
+ self.fetch_json("/group/%C3%A9?arg=%C3%A9"),
+ {
+ u"path": u"/group/%C3%A9",
+ u"path_args": [u"\u00e9"],
+ u"args": {u"arg": [u"\u00e9"]},
+ },
+ )
def test_slashes(self):
# Slashes may be escaped to appear as a single "directory" in the path,
# but they are then unescaped when passed to the get() method.
- self.assertEqual(self.fetch_json('/slashes/foo/bar'),
- dict(path="/slashes/foo/bar",
- path_args=["foo", "bar"],
- args={}))
- self.assertEqual(self.fetch_json('/slashes/a%2Fb/c%2Fd'),
- dict(path="/slashes/a%2Fb/c%2Fd",
- path_args=["a/b", "c/d"],
- args={}))
+ self.assertEqual(
+ self.fetch_json("/slashes/foo/bar"),
+ dict(path="/slashes/foo/bar", path_args=["foo", "bar"], args={}),
+ )
+ self.assertEqual(
+ self.fetch_json("/slashes/a%2Fb/c%2Fd"),
+ dict(path="/slashes/a%2Fb/c%2Fd", path_args=["a/b", "c/d"], args={}),
+ )
def test_error(self):
# Percent signs (encoded as %25) should not mess up printf-style
def prepare(self):
self.errors = {} # type: typing.Dict[str, str]
- self.check_type('status', self.get_status(), int)
+ self.check_type("status", self.get_status(), int)
# get_argument is an exception from the general rule of using
# type str for non-body data mainly for historical reasons.
- self.check_type('argument', self.get_argument('foo'), unicode_type)
- self.check_type('cookie_key', list(self.cookies.keys())[0], str)
- self.check_type('cookie_value', list(self.cookies.values())[0].value, str)
+ self.check_type("argument", self.get_argument("foo"), unicode_type)
+ self.check_type("cookie_key", list(self.cookies.keys())[0], str)
+ self.check_type("cookie_value", list(self.cookies.values())[0].value, str)
# Secure cookies return bytes because they can contain arbitrary
# data, but regular cookies are native strings.
- if list(self.cookies.keys()) != ['asdf']:
- raise Exception("unexpected values for cookie keys: %r" %
- self.cookies.keys())
- self.check_type('get_secure_cookie', self.get_secure_cookie('asdf'), bytes)
- self.check_type('get_cookie', self.get_cookie('asdf'), str)
+ if list(self.cookies.keys()) != ["asdf"]:
+ raise Exception(
+ "unexpected values for cookie keys: %r" % self.cookies.keys()
+ )
+ self.check_type("get_secure_cookie", self.get_secure_cookie("asdf"), bytes)
+ self.check_type("get_cookie", self.get_cookie("asdf"), str)
- self.check_type('xsrf_token', self.xsrf_token, bytes)
- self.check_type('xsrf_form_html', self.xsrf_form_html(), str)
+ self.check_type("xsrf_token", self.xsrf_token, bytes)
+ self.check_type("xsrf_form_html", self.xsrf_form_html(), str)
- self.check_type('reverse_url', self.reverse_url('typecheck', 'foo'), str)
+ self.check_type("reverse_url", self.reverse_url("typecheck", "foo"), str)
- self.check_type('request_summary', self._request_summary(), str)
+ self.check_type("request_summary", self._request_summary(), str)
def get(self, path_component):
# path_component uses type unicode instead of str for consistency
# with get_argument()
- self.check_type('path_component', path_component, unicode_type)
+ self.check_type("path_component", path_component, unicode_type)
self.write(self.errors)
def post(self, path_component):
- self.check_type('path_component', path_component, unicode_type)
+ self.check_type("path_component", path_component, unicode_type)
self.write(self.errors)
def check_type(self, name, obj, expected_type):
actual_type = type(obj)
if expected_type != actual_type:
- self.errors[name] = "expected %s, got %s" % (expected_type,
- actual_type)
+ self.errors[name] = "expected %s, got %s" % (expected_type, actual_type)
class DecodeArgHandler(RequestHandler):
if type(value) != bytes:
raise Exception("unexpected type for value: %r" % type(value))
# use self.request.arguments directly to avoid recursion
- if 'encoding' in self.request.arguments:
- return value.decode(to_unicode(self.request.arguments['encoding'][0]))
+ if "encoding" in self.request.arguments:
+ return value.decode(to_unicode(self.request.arguments["encoding"][0]))
else:
return value
elif type(s) == unicode_type:
return ["unicode", s]
raise Exception("unknown type")
- self.write({'path': describe(arg),
- 'query': describe(self.get_argument("foo")),
- })
+
+ self.write({"path": describe(arg), "query": describe(self.get_argument("foo"))})
class LinkifyHandler(RequestHandler):
class RedirectHandler(RequestHandler):
def get(self):
- if self.get_argument('permanent', None) is not None:
- self.redirect('/', permanent=int(self.get_argument('permanent')))
- elif self.get_argument('status', None) is not None:
- self.redirect('/', status=int(self.get_argument('status')))
+ if self.get_argument("permanent", None) is not None:
+ self.redirect("/", permanent=int(self.get_argument("permanent")))
+ elif self.get_argument("status", None) is not None:
+ self.redirect("/", status=int(self.get_argument("status")))
else:
raise Exception("didn't get permanent or status arguments")
class GetArgumentHandler(RequestHandler):
def prepare(self):
- if self.get_argument('source', None) == 'query':
+ if self.get_argument("source", None) == "query":
method = self.get_query_argument
- elif self.get_argument('source', None) == 'body':
+ elif self.get_argument("source", None) == "body":
method = self.get_body_argument
else:
method = self.get_argument
class GetArgumentsHandler(RequestHandler):
def prepare(self):
- self.finish(dict(default=self.get_arguments("foo"),
- query=self.get_query_arguments("foo"),
- body=self.get_body_arguments("foo")))
+ self.finish(
+ dict(
+ default=self.get_arguments("foo"),
+ query=self.get_query_arguments("foo"),
+ body=self.get_body_arguments("foo"),
+ )
+ )
# This test was shared with wsgi_test.py; now the name is meaningless.
COOKIE_SECRET = "WebTest.COOKIE_SECRET"
def get_app_kwargs(self):
- loader = DictLoader({
- "linkify.html": "{% module linkify(message) %}",
- "page.html": """\
+ loader = DictLoader(
+ {
+ "linkify.html": "{% module linkify(message) %}",
+ "page.html": """\
<html><head></head><body>
{% for e in entries %}
{% module Template("entry.html", entry=e) %}
{% end %}
</body></html>""",
- "entry.html": """\
+ "entry.html": """\
{{ set_resources(embedded_css=".entry { margin-bottom: 1em; }",
embedded_javascript="js_embed()",
css_files=["/base.css", "/foo.css"],
html_head="<meta>",
html_body='<script src="/analytics.js"/>') }}
<div class="entry">...</div>""",
- })
- return dict(template_loader=loader,
- autoescape="xhtml_escape",
- cookie_secret=self.COOKIE_SECRET)
+ }
+ )
+ return dict(
+ template_loader=loader,
+ autoescape="xhtml_escape",
+ cookie_secret=self.COOKIE_SECRET,
+ )
def tearDown(self):
super(WSGISafeWebTest, self).tearDown()
def get_handlers(self):
urls = [
- url("/typecheck/(.*)", TypeCheckHandler, name='typecheck'),
- url("/decode_arg/(.*)", DecodeArgHandler, name='decode_arg'),
+ url("/typecheck/(.*)", TypeCheckHandler, name="typecheck"),
+ url("/decode_arg/(.*)", DecodeArgHandler, name="decode_arg"),
url("/decode_arg_kw/(?P<arg>.*)", DecodeArgHandler),
url("/linkify", LinkifyHandler),
url("/uimodule_resources", UIModuleResourceHandler),
url("/optional_path/(.+)?", OptionalPathHandler),
url("/multi_header", MultiHeaderHandler),
url("/redirect", RedirectHandler),
- url("/web_redirect_permanent", WebRedirectHandler, {"url": "/web_redirect_newpath"}),
- url("/web_redirect", WebRedirectHandler,
- {"url": "/web_redirect_newpath", "permanent": False}),
- url("//web_redirect_double_slash", WebRedirectHandler,
- {"url": '/web_redirect_newpath'}),
+ url(
+ "/web_redirect_permanent",
+ WebRedirectHandler,
+ {"url": "/web_redirect_newpath"},
+ ),
+ url(
+ "/web_redirect",
+ WebRedirectHandler,
+ {"url": "/web_redirect_newpath", "permanent": False},
+ ),
+ url(
+ "//web_redirect_double_slash",
+ WebRedirectHandler,
+ {"url": "/web_redirect_newpath"},
+ ),
url("/header_injection", HeaderInjectionHandler),
url("/get_argument", GetArgumentHandler),
url("/get_arguments", GetArgumentsHandler),
return json_decode(response.body)
def test_types(self):
- cookie_value = to_unicode(create_signed_value(self.COOKIE_SECRET,
- "asdf", "qwer"))
- response = self.fetch("/typecheck/asdf?foo=bar",
- headers={"Cookie": "asdf=" + cookie_value})
+ cookie_value = to_unicode(
+ create_signed_value(self.COOKIE_SECRET, "asdf", "qwer")
+ )
+ response = self.fetch(
+ "/typecheck/asdf?foo=bar", headers={"Cookie": "asdf=" + cookie_value}
+ )
data = json_decode(response.body)
self.assertEqual(data, {})
- response = self.fetch("/typecheck/asdf?foo=bar", method="POST",
- headers={"Cookie": "asdf=" + cookie_value},
- body="foo=bar")
+ response = self.fetch(
+ "/typecheck/asdf?foo=bar",
+ method="POST",
+ headers={"Cookie": "asdf=" + cookie_value},
+ body="foo=bar",
+ )
def test_decode_argument(self):
# These urls all decode to the same thing
- urls = ["/decode_arg/%C3%A9?foo=%C3%A9&encoding=utf-8",
- "/decode_arg/%E9?foo=%E9&encoding=latin1",
- "/decode_arg_kw/%E9?foo=%E9&encoding=latin1",
- ]
+ urls = [
+ "/decode_arg/%C3%A9?foo=%C3%A9&encoding=utf-8",
+ "/decode_arg/%E9?foo=%E9&encoding=latin1",
+ "/decode_arg_kw/%E9?foo=%E9&encoding=latin1",
+ ]
for req_url in urls:
response = self.fetch(req_url)
response.rethrow()
data = json_decode(response.body)
- self.assertEqual(data, {u'path': [u'unicode', u'\u00e9'],
- u'query': [u'unicode', u'\u00e9'],
- })
+ self.assertEqual(
+ data,
+ {u"path": [u"unicode", u"\u00e9"], u"query": [u"unicode", u"\u00e9"]},
+ )
response = self.fetch("/decode_arg/%C3%A9?foo=%C3%A9")
response.rethrow()
data = json_decode(response.body)
- self.assertEqual(data, {u'path': [u'bytes', u'c3a9'],
- u'query': [u'bytes', u'c3a9'],
- })
+ self.assertEqual(
+ data, {u"path": [u"bytes", u"c3a9"], u"query": [u"bytes", u"c3a9"]}
+ )
def test_decode_argument_invalid_unicode(self):
# test that invalid unicode in URLs causes 400, not 500
def test_decode_argument_plus(self):
# These urls are all equivalent.
- urls = ["/decode_arg/1%20%2B%201?foo=1%20%2B%201&encoding=utf-8",
- "/decode_arg/1%20+%201?foo=1+%2B+1&encoding=utf-8"]
+ urls = [
+ "/decode_arg/1%20%2B%201?foo=1%20%2B%201&encoding=utf-8",
+ "/decode_arg/1%20+%201?foo=1+%2B+1&encoding=utf-8",
+ ]
for req_url in urls:
response = self.fetch(req_url)
response.rethrow()
data = json_decode(response.body)
- self.assertEqual(data, {u'path': [u'unicode', u'1 + 1'],
- u'query': [u'unicode', u'1 + 1'],
- })
+ self.assertEqual(
+ data,
+ {u"path": [u"unicode", u"1 + 1"], u"query": [u"unicode", u"1 + 1"]},
+ )
def test_reverse_url(self):
- self.assertEqual(self.app.reverse_url('decode_arg', 'foo'),
- '/decode_arg/foo')
- self.assertEqual(self.app.reverse_url('decode_arg', 42),
- '/decode_arg/42')
- self.assertEqual(self.app.reverse_url('decode_arg', b'\xe9'),
- '/decode_arg/%E9')
- self.assertEqual(self.app.reverse_url('decode_arg', u'\u00e9'),
- '/decode_arg/%C3%A9')
- self.assertEqual(self.app.reverse_url('decode_arg', '1 + 1'),
- '/decode_arg/1%20%2B%201')
+ self.assertEqual(self.app.reverse_url("decode_arg", "foo"), "/decode_arg/foo")
+ self.assertEqual(self.app.reverse_url("decode_arg", 42), "/decode_arg/42")
+ self.assertEqual(self.app.reverse_url("decode_arg", b"\xe9"), "/decode_arg/%E9")
+ self.assertEqual(
+ self.app.reverse_url("decode_arg", u"\u00e9"), "/decode_arg/%C3%A9"
+ )
+ self.assertEqual(
+ self.app.reverse_url("decode_arg", "1 + 1"), "/decode_arg/1%20%2B%201"
+ )
def test_uimodule_unescaped(self):
response = self.fetch("/linkify")
- self.assertEqual(response.body,
- b"<a href=\"http://example.com\">http://example.com</a>")
+ self.assertEqual(
+ response.body, b'<a href="http://example.com">http://example.com</a>'
+ )
def test_uimodule_resources(self):
response = self.fetch("/uimodule_resources")
- self.assertEqual(response.body, b"""\
+ self.assertEqual(
+ response.body,
+ b"""\
<html><head><link href="/base.css" type="text/css" rel="stylesheet"/><link href="/foo.css" type="text/css" rel="stylesheet"/>
<style type="text/css">
.entry { margin-bottom: 1em; }
//]]>
</script>
<script src="/analytics.js"/>
-</body></html>""") # noqa: E501
+</body></html>""", # noqa: E501
+ )
def test_optional_path(self):
- self.assertEqual(self.fetch_json("/optional_path/foo"),
- {u"path": u"foo"})
- self.assertEqual(self.fetch_json("/optional_path/"),
- {u"path": None})
+ self.assertEqual(self.fetch_json("/optional_path/foo"), {u"path": u"foo"})
+ self.assertEqual(self.fetch_json("/optional_path/"), {u"path": None})
def test_multi_header(self):
response = self.fetch("/multi_header")
def test_web_redirect(self):
response = self.fetch("/web_redirect_permanent", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
+ self.assertEqual(response.headers["Location"], "/web_redirect_newpath")
response = self.fetch("/web_redirect", follow_redirects=False)
self.assertEqual(response.code, 302)
- self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
+ self.assertEqual(response.headers["Location"], "/web_redirect_newpath")
def test_web_redirect_double_slash(self):
response = self.fetch("//web_redirect_double_slash", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], '/web_redirect_newpath')
+ self.assertEqual(response.headers["Location"], "/web_redirect_newpath")
def test_header_injection(self):
response = self.fetch("/header_injection")
response = self.fetch("/get_argument?foo=bar", method="POST", body=body)
self.assertEqual(response.body, b"hello")
# In plural methods they are merged.
- response = self.fetch("/get_arguments?foo=bar",
- method="POST", body=body)
- self.assertEqual(json_decode(response.body),
- dict(default=['bar', 'hello'],
- query=['bar'],
- body=['hello']))
+ response = self.fetch("/get_arguments?foo=bar", method="POST", body=body)
+ self.assertEqual(
+ json_decode(response.body),
+ dict(default=["bar", "hello"], query=["bar"], body=["hello"]),
+ )
def test_get_query_arguments(self):
# send as a post so we can ensure the separation between query
# string and body arguments.
body = urllib.parse.urlencode(dict(foo="hello"))
- response = self.fetch("/get_argument?source=query&foo=bar",
- method="POST", body=body)
+ response = self.fetch(
+ "/get_argument?source=query&foo=bar", method="POST", body=body
+ )
self.assertEqual(response.body, b"bar")
- response = self.fetch("/get_argument?source=query&foo=",
- method="POST", body=body)
+ response = self.fetch(
+ "/get_argument?source=query&foo=", method="POST", body=body
+ )
self.assertEqual(response.body, b"")
- response = self.fetch("/get_argument?source=query",
- method="POST", body=body)
+ response = self.fetch("/get_argument?source=query", method="POST", body=body)
self.assertEqual(response.body, b"default")
def test_get_body_arguments(self):
body = urllib.parse.urlencode(dict(foo="bar"))
- response = self.fetch("/get_argument?source=body&foo=hello",
- method="POST", body=body)
+ response = self.fetch(
+ "/get_argument?source=body&foo=hello", method="POST", body=body
+ )
self.assertEqual(response.body, b"bar")
body = urllib.parse.urlencode(dict(foo=""))
- response = self.fetch("/get_argument?source=body&foo=hello",
- method="POST", body=body)
+ response = self.fetch(
+ "/get_argument?source=body&foo=hello", method="POST", body=body
+ )
self.assertEqual(response.body, b"")
body = urllib.parse.urlencode(dict())
- response = self.fetch("/get_argument?source=body&foo=hello",
- method="POST", body=body)
+ response = self.fetch(
+ "/get_argument?source=body&foo=hello", method="POST", body=body
+ )
self.assertEqual(response.body, b"default")
def test_no_gzip(self):
- response = self.fetch('/get_argument')
- self.assertNotIn('Accept-Encoding', response.headers.get('Vary', ''))
- self.assertNotIn('gzip', response.headers.get('Content-Encoding', ''))
+ response = self.fetch("/get_argument")
+ self.assertNotIn("Accept-Encoding", response.headers.get("Vary", ""))
+ self.assertNotIn("gzip", response.headers.get("Content-Encoding", ""))
class NonWSGIWebTests(WebTestCase):
def get_handlers(self):
- return [("/empty_flush", EmptyFlushCallbackHandler),
- ]
+ return [("/empty_flush", EmptyFlushCallbackHandler)]
def test_empty_flush(self):
response = self.fetch("/empty_flush")
def write_error(self, status_code, **kwargs):
raise Exception("exception in write_error")
- return [url("/default", DefaultHandler),
- url("/write_error", WriteErrorHandler),
- url("/failed_write_error", FailedWriteErrorHandler),
- ]
+ return [
+ url("/default", DefaultHandler),
+ url("/write_error", WriteErrorHandler),
+ url("/failed_write_error", FailedWriteErrorHandler),
+ ]
def test_default(self):
with ExpectLog(app_log, "Uncaught exception"):
# The expected MD5 hash of robots.txt, used in tests that call
# StaticFileHandler.get_version
robots_txt_hash = b"f71d20196d4caf35b6a670db8c70b03d"
- static_dir = os.path.join(os.path.dirname(__file__), 'static')
+ static_dir = os.path.join(os.path.dirname(__file__), "static")
def get_handlers(self):
class StaticUrlHandler(RequestHandler):
def get(self, path):
- with_v = int(self.get_argument('include_version', 1))
+ with_v = int(self.get_argument("include_version", 1))
self.write(self.static_url(path, include_version=with_v))
class AbsoluteStaticUrlHandler(StaticUrlHandler):
check_override = override_url.find(protocol, 0, protocol_length)
if do_include:
- result = (check_override == 0 and check_regular == -1)
+ result = check_override == 0 and check_regular == -1
else:
- result = (check_override == -1 and check_regular == 0)
+ result = check_override == -1 and check_regular == 0
self.write(str(result))
- return [('/static_url/(.*)', StaticUrlHandler),
- ('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
- ('/override_static_url/(.*)', OverrideStaticUrlHandler),
- ('/root_static/(.*)', StaticFileHandler, dict(path='/'))]
+ return [
+ ("/static_url/(.*)", StaticUrlHandler),
+ ("/abs_static_url/(.*)", AbsoluteStaticUrlHandler),
+ ("/override_static_url/(.*)", OverrideStaticUrlHandler),
+ ("/root_static/(.*)", StaticFileHandler, dict(path="/")),
+ ]
def get_app_kwargs(self):
- return dict(static_path=relpath('static'))
+ return dict(static_path=relpath("static"))
def test_static_files(self):
- response = self.fetch('/robots.txt')
+ response = self.fetch("/robots.txt")
self.assertTrue(b"Disallow: /" in response.body)
- response = self.fetch('/static/robots.txt')
+ response = self.fetch("/static/robots.txt")
self.assertTrue(b"Disallow: /" in response.body)
self.assertEqual(response.headers.get("Content-Type"), "text/plain")
def test_static_compressed_files(self):
response = self.fetch("/static/sample.xml.gz")
- self.assertEqual(response.headers.get("Content-Type"),
- "application/gzip")
+ self.assertEqual(response.headers.get("Content-Type"), "application/gzip")
response = self.fetch("/static/sample.xml.bz2")
- self.assertEqual(response.headers.get("Content-Type"),
- "application/octet-stream")
+ self.assertEqual(
+ response.headers.get("Content-Type"), "application/octet-stream"
+ )
# make sure the uncompressed file still has the correct type
response = self.fetch("/static/sample.xml")
- self.assertTrue(response.headers.get("Content-Type")
- in set(("text/xml", "application/xml")))
+ self.assertTrue(
+ response.headers.get("Content-Type") in set(("text/xml", "application/xml"))
+ )
def test_static_url(self):
response = self.fetch("/static_url/robots.txt")
- self.assertEqual(response.body,
- b"/static/robots.txt?v=" + self.robots_txt_hash)
+ self.assertEqual(response.body, b"/static/robots.txt?v=" + self.robots_txt_hash)
def test_absolute_static_url(self):
response = self.fetch("/abs_static_url/robots.txt")
- self.assertEqual(response.body, (
- utf8(self.get_url("/")) +
- b"static/robots.txt?v=" +
- self.robots_txt_hash
- ))
+ self.assertEqual(
+ response.body,
+ (utf8(self.get_url("/")) + b"static/robots.txt?v=" + self.robots_txt_hash),
+ )
def test_relative_version_exclusion(self):
response = self.fetch("/static_url/robots.txt?include_version=0")
def test_absolute_version_exclusion(self):
response = self.fetch("/abs_static_url/robots.txt?include_version=0")
- self.assertEqual(response.body,
- utf8(self.get_url("/") + "static/robots.txt"))
+ self.assertEqual(response.body, utf8(self.get_url("/") + "static/robots.txt"))
def test_include_host_override(self):
self._trigger_include_host_check(False)
get_response = self.fetch(*args, method="GET", **kwargs)
content_headers = set()
for h in itertools.chain(head_response.headers, get_response.headers):
- if h.startswith('Content-'):
+ if h.startswith("Content-"):
content_headers.add(h)
for h in content_headers:
- self.assertEqual(head_response.headers.get(h),
- get_response.headers.get(h),
- "%s differs between GET (%s) and HEAD (%s)" %
- (h, head_response.headers.get(h),
- get_response.headers.get(h)))
+ self.assertEqual(
+ head_response.headers.get(h),
+ get_response.headers.get(h),
+ "%s differs between GET (%s) and HEAD (%s)"
+ % (h, head_response.headers.get(h), get_response.headers.get(h)),
+ )
return get_response
def test_static_304_if_modified_since(self):
response1 = self.get_and_head("/static/robots.txt")
- response2 = self.get_and_head("/static/robots.txt", headers={
- 'If-Modified-Since': response1.headers['Last-Modified']})
+ response2 = self.get_and_head(
+ "/static/robots.txt",
+ headers={"If-Modified-Since": response1.headers["Last-Modified"]},
+ )
self.assertEqual(response2.code, 304)
- self.assertTrue('Content-Length' not in response2.headers)
- self.assertTrue('Last-Modified' not in response2.headers)
+ self.assertTrue("Content-Length" not in response2.headers)
+ self.assertTrue("Last-Modified" not in response2.headers)
def test_static_304_if_none_match(self):
response1 = self.get_and_head("/static/robots.txt")
- response2 = self.get_and_head("/static/robots.txt", headers={
- 'If-None-Match': response1.headers['Etag']})
+ response2 = self.get_and_head(
+ "/static/robots.txt", headers={"If-None-Match": response1.headers["Etag"]}
+ )
self.assertEqual(response2.code, 304)
def test_static_304_etag_modified_bug(self):
response1 = self.get_and_head("/static/robots.txt")
- response2 = self.get_and_head("/static/robots.txt", headers={
- 'If-None-Match': '"MISMATCH"',
- 'If-Modified-Since': response1.headers['Last-Modified']})
+ response2 = self.get_and_head(
+ "/static/robots.txt",
+ headers={
+ "If-None-Match": '"MISMATCH"',
+ "If-Modified-Since": response1.headers["Last-Modified"],
+ },
+ )
self.assertEqual(response2.code, 200)
def test_static_if_modified_since_pre_epoch(self):
# On windows, the functions that work with time_t do not accept
# negative values, and at least one client (processing.js) seems
# to use if-modified-since 1/1/1960 as a cache-busting technique.
- response = self.get_and_head("/static/robots.txt", headers={
- 'If-Modified-Since': 'Fri, 01 Jan 1960 00:00:00 GMT'})
+ response = self.get_and_head(
+ "/static/robots.txt",
+ headers={"If-Modified-Since": "Fri, 01 Jan 1960 00:00:00 GMT"},
+ )
self.assertEqual(response.code, 200)
def test_static_if_modified_since_time_zone(self):
# chosen just before and after the known modification time
# of the file to ensure that the right time zone is being used
# when parsing If-Modified-Since.
- stat = os.stat(relpath('static/robots.txt'))
+ stat = os.stat(relpath("static/robots.txt"))
- response = self.get_and_head('/static/robots.txt', headers={
- 'If-Modified-Since': format_timestamp(stat.st_mtime - 1)})
+ response = self.get_and_head(
+ "/static/robots.txt",
+ headers={"If-Modified-Since": format_timestamp(stat.st_mtime - 1)},
+ )
self.assertEqual(response.code, 200)
- response = self.get_and_head('/static/robots.txt', headers={
- 'If-Modified-Since': format_timestamp(stat.st_mtime + 1)})
+ response = self.get_and_head(
+ "/static/robots.txt",
+ headers={"If-Modified-Since": format_timestamp(stat.st_mtime + 1)},
+ )
self.assertEqual(response.code, 304)
def test_static_etag(self):
- response = self.get_and_head('/static/robots.txt')
- self.assertEqual(utf8(response.headers.get("Etag")),
- b'"' + self.robots_txt_hash + b'"')
+ response = self.get_and_head("/static/robots.txt")
+ self.assertEqual(
+ utf8(response.headers.get("Etag")), b'"' + self.robots_txt_hash + b'"'
+ )
def test_static_with_range(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=0-9'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=0-9"}
+ )
self.assertEqual(response.code, 206)
self.assertEqual(response.body, b"User-agent")
- self.assertEqual(utf8(response.headers.get("Etag")),
- b'"' + self.robots_txt_hash + b'"')
+ self.assertEqual(
+ utf8(response.headers.get("Etag")), b'"' + self.robots_txt_hash + b'"'
+ )
self.assertEqual(response.headers.get("Content-Length"), "10")
- self.assertEqual(response.headers.get("Content-Range"),
- "bytes 0-9/26")
+ self.assertEqual(response.headers.get("Content-Range"), "bytes 0-9/26")
def test_static_with_range_full_file(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=0-'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=0-"}
+ )
# Note: Chrome refuses to play audio if it gets an HTTP 206 in response
# to ``Range: bytes=0-`` :(
self.assertEqual(response.code, 200)
self.assertEqual(response.headers.get("Content-Range"), None)
def test_static_with_range_full_past_end(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=0-10000000'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=0-10000000"}
+ )
self.assertEqual(response.code, 200)
robots_file_path = os.path.join(self.static_dir, "robots.txt")
with open(robots_file_path) as f:
self.assertEqual(response.headers.get("Content-Range"), None)
def test_static_with_range_partial_past_end(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=1-10000000'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=1-10000000"}
+ )
self.assertEqual(response.code, 206)
robots_file_path = os.path.join(self.static_dir, "robots.txt")
with open(robots_file_path) as f:
self.assertEqual(response.headers.get("Content-Range"), "bytes 1-25/26")
def test_static_with_range_end_edge(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=22-'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=22-"}
+ )
self.assertEqual(response.body, b": /\n")
self.assertEqual(response.headers.get("Content-Length"), "4")
- self.assertEqual(response.headers.get("Content-Range"),
- "bytes 22-25/26")
+ self.assertEqual(response.headers.get("Content-Range"), "bytes 22-25/26")
def test_static_with_range_neg_end(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=-4'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=-4"}
+ )
self.assertEqual(response.body, b": /\n")
self.assertEqual(response.headers.get("Content-Length"), "4")
- self.assertEqual(response.headers.get("Content-Range"),
- "bytes 22-25/26")
+ self.assertEqual(response.headers.get("Content-Range"), "bytes 22-25/26")
def test_static_invalid_range(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'asdf'})
+ response = self.get_and_head("/static/robots.txt", headers={"Range": "asdf"})
self.assertEqual(response.code, 200)
def test_static_unsatisfiable_range_zero_suffix(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=-0'})
- self.assertEqual(response.headers.get("Content-Range"),
- "bytes */26")
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=-0"}
+ )
+ self.assertEqual(response.headers.get("Content-Range"), "bytes */26")
self.assertEqual(response.code, 416)
def test_static_unsatisfiable_range_invalid_start(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=26'})
+ response = self.get_and_head(
+ "/static/robots.txt", headers={"Range": "bytes=26"}
+ )
self.assertEqual(response.code, 416)
- self.assertEqual(response.headers.get("Content-Range"),
- "bytes */26")
+ self.assertEqual(response.headers.get("Content-Range"), "bytes */26")
def test_static_head(self):
- response = self.fetch('/static/robots.txt', method='HEAD')
+ response = self.fetch("/static/robots.txt", method="HEAD")
self.assertEqual(response.code, 200)
# No body was returned, but we did get the right content length.
- self.assertEqual(response.body, b'')
- self.assertEqual(response.headers['Content-Length'], '26')
- self.assertEqual(utf8(response.headers['Etag']),
- b'"' + self.robots_txt_hash + b'"')
+ self.assertEqual(response.body, b"")
+ self.assertEqual(response.headers["Content-Length"], "26")
+ self.assertEqual(
+ utf8(response.headers["Etag"]), b'"' + self.robots_txt_hash + b'"'
+ )
def test_static_head_range(self):
- response = self.fetch('/static/robots.txt', method='HEAD',
- headers={'Range': 'bytes=1-4'})
+ response = self.fetch(
+ "/static/robots.txt", method="HEAD", headers={"Range": "bytes=1-4"}
+ )
self.assertEqual(response.code, 206)
- self.assertEqual(response.body, b'')
- self.assertEqual(response.headers['Content-Length'], '4')
- self.assertEqual(utf8(response.headers['Etag']),
- b'"' + self.robots_txt_hash + b'"')
+ self.assertEqual(response.body, b"")
+ self.assertEqual(response.headers["Content-Length"], "4")
+ self.assertEqual(
+ utf8(response.headers["Etag"]), b'"' + self.robots_txt_hash + b'"'
+ )
def test_static_range_if_none_match(self):
- response = self.get_and_head('/static/robots.txt', headers={
- 'Range': 'bytes=1-4',
- 'If-None-Match': b'"' + self.robots_txt_hash + b'"'})
+ response = self.get_and_head(
+ "/static/robots.txt",
+ headers={
+ "Range": "bytes=1-4",
+ "If-None-Match": b'"' + self.robots_txt_hash + b'"',
+ },
+ )
self.assertEqual(response.code, 304)
- self.assertEqual(response.body, b'')
- self.assertTrue('Content-Length' not in response.headers)
- self.assertEqual(utf8(response.headers['Etag']),
- b'"' + self.robots_txt_hash + b'"')
+ self.assertEqual(response.body, b"")
+ self.assertTrue("Content-Length" not in response.headers)
+ self.assertEqual(
+ utf8(response.headers["Etag"]), b'"' + self.robots_txt_hash + b'"'
+ )
def test_static_404(self):
- response = self.get_and_head('/static/blarg')
+ response = self.get_and_head("/static/blarg")
self.assertEqual(response.code, 404)
def test_path_traversal_protection(self):
self.http_client.close()
self.http_client = SimpleAsyncHTTPClient()
with ExpectLog(gen_log, ".*not in root static directory"):
- response = self.get_and_head('/static/../static_foo.txt')
+ response = self.get_and_head("/static/../static_foo.txt")
# Attempted path traversal should result in 403, not 200
# (which means the check failed and the file was served)
# or 404 (which means that the file didn't exist and
# is probably a packaging error).
self.assertEqual(response.code, 403)
- @unittest.skipIf(os.name != 'posix', 'non-posix OS')
+ @unittest.skipIf(os.name != "posix", "non-posix OS")
def test_root_static_path(self):
# Sometimes people set the StaticFileHandler's path to '/'
# to disable Tornado's path validation (in conjunction with
# their own validation in get_absolute_path). Make sure
# that the stricter validation in 4.2.1 doesn't break them.
- path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
- 'static/robots.txt')
- response = self.get_and_head('/root_static' + urllib.parse.quote(path))
+ path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "static/robots.txt"
+ )
+ response = self.get_and_head("/root_static" + urllib.parse.quote(path))
self.assertEqual(response.code, 200)
class StaticDefaultFilenameTest(WebTestCase):
def get_app_kwargs(self):
- return dict(static_path=relpath('static'),
- static_handler_args=dict(default_filename='index.html'))
+ return dict(
+ static_path=relpath("static"),
+ static_handler_args=dict(default_filename="index.html"),
+ )
def get_handlers(self):
return []
def test_static_default_filename(self):
- response = self.fetch('/static/dir/', follow_redirects=False)
+ response = self.fetch("/static/dir/", follow_redirects=False)
self.assertEqual(response.code, 200)
- self.assertEqual(b'this is the index\n', response.body)
+ self.assertEqual(b"this is the index\n", response.body)
def test_static_default_redirect(self):
- response = self.fetch('/static/dir', follow_redirects=False)
+ response = self.fetch("/static/dir", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertTrue(response.headers['Location'].endswith('/static/dir/'))
+ self.assertTrue(response.headers["Location"].endswith("/static/dir/"))
class StaticFileWithPathTest(WebTestCase):
def get_app_kwargs(self):
- return dict(static_path=relpath('static'),
- static_handler_args=dict(default_filename='index.html'))
+ return dict(
+ static_path=relpath("static"),
+ static_handler_args=dict(default_filename="index.html"),
+ )
def get_handlers(self):
- return [("/foo/(.*)", StaticFileHandler, {
- "path": relpath("templates/"),
- })]
+ return [("/foo/(.*)", StaticFileHandler, {"path": relpath("templates/")})]
def test_serve(self):
response = self.fetch("/foo/utf8.html")
@classmethod
def make_static_url(cls, settings, path):
version_hash = cls.get_version(settings, path)
- extension_index = path.rindex('.')
+ extension_index = path.rindex(".")
before_version = path[:extension_index]
- after_version = path[(extension_index + 1):]
- return '/static/%s.%s.%s' % (before_version, version_hash,
- after_version)
+ after_version = path[(extension_index + 1) :]
+ return "/static/%s.%s.%s" % (
+ before_version,
+ version_hash,
+ after_version,
+ )
def parse_url_path(self, url_path):
- extension_index = url_path.rindex('.')
- version_index = url_path.rindex('.', 0, extension_index)
- return '%s%s' % (url_path[:version_index],
- url_path[extension_index:])
+ extension_index = url_path.rindex(".")
+ version_index = url_path.rindex(".", 0, extension_index)
+ return "%s%s" % (url_path[:version_index], url_path[extension_index:])
@classmethod
def get_absolute_path(cls, settings, path):
- return 'CustomStaticFileTest:' + path
+ return "CustomStaticFileTest:" + path
def validate_absolute_path(self, root, absolute_path):
return absolute_path
@classmethod
def get_content(self, path, start=None, end=None):
assert start is None and end is None
- if path == 'CustomStaticFileTest:foo.txt':
- return b'bar'
+ if path == "CustomStaticFileTest:foo.txt":
+ return b"bar"
raise Exception("unexpected path %r" % path)
def get_content_size(self):
- if self.absolute_path == 'CustomStaticFileTest:foo.txt':
+ if self.absolute_path == "CustomStaticFileTest:foo.txt":
return 3
raise Exception("unexpected path %r" % self.absolute_path)
return [("/static_url/(.*)", StaticUrlHandler)]
def get_app_kwargs(self):
- return dict(static_path="dummy",
- static_handler_class=self.static_handler_class)
+ return dict(static_path="dummy", static_handler_class=self.static_handler_class)
def test_serve(self):
response = self.fetch("/static/foo.42.txt")
return [("/foo", HostMatchingTest.Handler, {"reply": "wildcard"})]
def test_host_matching(self):
- self.app.add_handlers("www.example.com",
- [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})])
- self.app.add_handlers(r"www\.example\.com",
- [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
- self.app.add_handlers("www.example.com",
- [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
- self.app.add_handlers("www.e.*e.com",
- [("/baz", HostMatchingTest.Handler, {"reply": "[3]"})])
+ self.app.add_handlers(
+ "www.example.com", [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})]
+ )
+ self.app.add_handlers(
+ r"www\.example\.com", [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})]
+ )
+ self.app.add_handlers(
+ "www.example.com", [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})]
+ )
+ self.app.add_handlers(
+ "www.e.*e.com", [("/baz", HostMatchingTest.Handler, {"reply": "[3]"})]
+ )
response = self.fetch("/foo")
self.assertEqual(response.body, b"wildcard")
response = self.fetch("/baz")
self.assertEqual(response.code, 404)
- response = self.fetch("/foo", headers={'Host': 'www.example.com'})
+ response = self.fetch("/foo", headers={"Host": "www.example.com"})
self.assertEqual(response.body, b"[0]")
- response = self.fetch("/bar", headers={'Host': 'www.example.com'})
+ response = self.fetch("/bar", headers={"Host": "www.example.com"})
self.assertEqual(response.body, b"[1]")
- response = self.fetch("/baz", headers={'Host': 'www.example.com'})
+ response = self.fetch("/baz", headers={"Host": "www.example.com"})
self.assertEqual(response.body, b"[2]")
- response = self.fetch("/baz", headers={'Host': 'www.exe.com'})
+ response = self.fetch("/baz", headers={"Host": "www.exe.com"})
self.assertEqual(response.body, b"[3]")
return []
def get_app_kwargs(self):
- return {'default_host': "www.example.com"}
+ return {"default_host": "www.example.com"}
def test_default_host_matching(self):
- self.app.add_handlers("www.example.com",
- [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})])
- self.app.add_handlers(r"www\.example\.com",
- [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
- self.app.add_handlers("www.test.com",
- [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
+ self.app.add_handlers(
+ "www.example.com", [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})]
+ )
+ self.app.add_handlers(
+ r"www\.example\.com", [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})]
+ )
+ self.app.add_handlers(
+ "www.test.com", [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})]
+ )
response = self.fetch("/foo")
self.assertEqual(response.body, b"[0]")
def get(self, path):
self.write(path)
- return [("/str/(?P<path>.*)", EchoHandler),
- (u"/unicode/(?P<path>.*)", EchoHandler)]
+ return [
+ ("/str/(?P<path>.*)", EchoHandler),
+ (u"/unicode/(?P<path>.*)", EchoHandler),
+ ]
def test_named_urlspec_groups(self):
response = self.fetch("/str/foo")
self.finish()
def test_204_headers(self):
- response = self.fetch('/')
+ response = self.fetch("/")
self.assertEqual(response.code, 204)
self.assertNotIn("Content-Length", response.headers)
self.assertNotIn("Transfer-Encoding", response.headers)
self.write("hello")
def test_304_headers(self):
- response1 = self.fetch('/')
+ response1 = self.fetch("/")
self.assertEqual(response1.headers["Content-Length"], "5")
self.assertEqual(response1.headers["Content-Language"], "en_US")
- response2 = self.fetch('/', headers={
- 'If-None-Match': response1.headers["Etag"]})
+ response2 = self.fetch(
+ "/", headers={"If-None-Match": response1.headers["Etag"]}
+ )
self.assertEqual(response2.code, 304)
self.assertTrue("Content-Length" not in response2.headers)
self.assertTrue("Content-Language" not in response2.headers)
class StatusReasonTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
- reason = self.request.arguments.get('reason', [])
- self.set_status(int(self.get_argument('code')),
- reason=reason[0] if reason else None)
+ reason = self.request.arguments.get("reason", [])
+ self.set_status(
+ int(self.get_argument("code")), reason=reason[0] if reason else None
+ )
def get_http_client(self):
# simple_httpclient only: curl doesn't expose the reason string
self.write("hello")
def test_date_header(self):
- response = self.fetch('/')
- parsed = email.utils.parsedate(response.headers['Date'])
+ response = self.fetch("/")
+ parsed = email.utils.parsedate(response.headers["Date"])
assert parsed is not None
header_date = datetime.datetime(*parsed[:6])
- self.assertTrue(header_date - datetime.datetime.utcnow() <
- datetime.timedelta(seconds=2))
+ self.assertTrue(
+ header_date - datetime.datetime.utcnow() < datetime.timedelta(seconds=2)
+ )
class RaiseWithReasonTest(SimpleHandlerTestCase):
response = self.fetch("/")
self.assertEqual(response.code, 682)
self.assertEqual(response.reason, "Foo")
- self.assertIn(b'682: Foo', response.body)
+ self.assertIn(b"682: Foo", response.body)
def test_httperror_str(self):
self.assertEqual(str(HTTPError(682, reason="Foo")), "HTTP 682: Foo")
# note that if the handlers list is empty we get the default_host
# redirect fallback instead of a 404, so test with both an
# explicitly defined error handler and an implicit 404.
- return [('/error', ErrorHandler, dict(status_code=417))]
+ return [("/error", ErrorHandler, dict(status_code=417))]
def get_app_kwargs(self):
return dict(xsrf_cookies=True)
def test_error_xsrf(self):
- response = self.fetch('/error', method='POST', body='')
+ response = self.fetch("/error", method="POST", body="")
self.assertEqual(response.code, 417)
def test_404_xsrf(self):
- response = self.fetch('/404', method='POST', body='')
+ response = self.fetch("/404", method="POST", body="")
self.assertEqual(response.code, 404)
class GzipTestCase(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
- for v in self.get_arguments('vary'):
- self.add_header('Vary', v)
+ for v in self.get_arguments("vary"):
+ self.add_header("Vary", v)
# Must write at least MIN_LENGTH bytes to activate compression.
- self.write('hello world' + ('!' * GZipContentEncoding.MIN_LENGTH))
+ self.write("hello world" + ("!" * GZipContentEncoding.MIN_LENGTH))
def get_app_kwargs(self):
return dict(
- gzip=True,
- static_path=os.path.join(os.path.dirname(__file__), 'static'))
+ gzip=True, static_path=os.path.join(os.path.dirname(__file__), "static")
+ )
def assert_compressed(self, response):
# simple_httpclient renames the content-encoding header;
# curl_httpclient doesn't.
self.assertEqual(
response.headers.get(
- 'Content-Encoding',
- response.headers.get('X-Consumed-Content-Encoding')),
- 'gzip')
+ "Content-Encoding", response.headers.get("X-Consumed-Content-Encoding")
+ ),
+ "gzip",
+ )
def test_gzip(self):
- response = self.fetch('/')
+ response = self.fetch("/")
self.assert_compressed(response)
- self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
+ self.assertEqual(response.headers["Vary"], "Accept-Encoding")
def test_gzip_static(self):
# The streaming responses in StaticFileHandler have subtle
# interactions with the gzip output so test this case separately.
- response = self.fetch('/robots.txt')
+ response = self.fetch("/robots.txt")
self.assert_compressed(response)
- self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
+ self.assertEqual(response.headers["Vary"], "Accept-Encoding")
def test_gzip_not_requested(self):
- response = self.fetch('/', use_gzip=False)
- self.assertNotIn('Content-Encoding', response.headers)
- self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
+ response = self.fetch("/", use_gzip=False)
+ self.assertNotIn("Content-Encoding", response.headers)
+ self.assertEqual(response.headers["Vary"], "Accept-Encoding")
def test_vary_already_present(self):
- response = self.fetch('/?vary=Accept-Language')
+ response = self.fetch("/?vary=Accept-Language")
self.assert_compressed(response)
- self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
- ['Accept-Language', 'Accept-Encoding'])
+ self.assertEqual(
+ [s.strip() for s in response.headers["Vary"].split(",")],
+ ["Accept-Language", "Accept-Encoding"],
+ )
def test_vary_already_present_multiple(self):
# Regression test for https://github.com/tornadoweb/tornado/issues/1670
- response = self.fetch('/?vary=Accept-Language&vary=Cookie')
+ response = self.fetch("/?vary=Accept-Language&vary=Cookie")
self.assert_compressed(response)
- self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
- ['Accept-Language', 'Cookie', 'Accept-Encoding'])
+ self.assertEqual(
+ [s.strip() for s in response.headers["Vary"].split(",")],
+ ["Accept-Language", "Cookie", "Accept-Encoding"],
+ )
class PathArgsInPrepareTest(WebTestCase):
self.write(dict(args=self.path_args, kwargs=self.path_kwargs))
def get(self, path):
- assert path == 'foo'
+ assert path == "foo"
self.finish()
def get_handlers(self):
- return [('/pos/(.*)', self.Handler),
- ('/kw/(?P<path>.*)', self.Handler)]
+ return [("/pos/(.*)", self.Handler), ("/kw/(?P<path>.*)", self.Handler)]
def test_pos(self):
- response = self.fetch('/pos/foo')
+ response = self.fetch("/pos/foo")
response.rethrow()
data = json_decode(response.body)
- self.assertEqual(data, {'args': ['foo'], 'kwargs': {}})
+ self.assertEqual(data, {"args": ["foo"], "kwargs": {}})
def test_kw(self):
- response = self.fetch('/kw/foo')
+ response = self.fetch("/kw/foo")
response.rethrow()
data = json_decode(response.body)
- self.assertEqual(data, {'args': [], 'kwargs': {'path': 'foo'}})
+ self.assertEqual(data, {"args": [], "kwargs": {"path": "foo"}})
class ClearAllCookiesTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
self.clear_all_cookies()
- self.write('ok')
+ self.write("ok")
def test_clear_all_cookies(self):
- response = self.fetch('/', headers={'Cookie': 'foo=bar; baz=xyzzy'})
- set_cookies = sorted(response.headers.get_list('Set-Cookie'))
+ response = self.fetch("/", headers={"Cookie": "foo=bar; baz=xyzzy"})
+ set_cookies = sorted(response.headers.get_list("Set-Cookie"))
# Python 3.5 sends 'baz="";'; older versions use 'baz=;'
- self.assertTrue(set_cookies[0].startswith('baz=;') or
- set_cookies[0].startswith('baz="";'))
- self.assertTrue(set_cookies[1].startswith('foo=;') or
- set_cookies[1].startswith('foo="";'))
+ self.assertTrue(
+ set_cookies[0].startswith("baz=;") or set_cookies[0].startswith('baz="";')
+ )
+ self.assertTrue(
+ set_cookies[1].startswith("foo=;") or set_cookies[1].startswith('foo="";')
+ )
class PermissionError(Exception):
class ExceptionHandlerTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
- exc = self.get_argument('exc')
- if exc == 'http':
+ exc = self.get_argument("exc")
+ if exc == "http":
raise HTTPError(410, "no longer here")
- elif exc == 'zero':
+ elif exc == "zero":
1 / 0
- elif exc == 'permission':
- raise PermissionError('not allowed')
+ elif exc == "permission":
+ raise PermissionError("not allowed")
def write_error(self, status_code, **kwargs):
- if 'exc_info' in kwargs:
- typ, value, tb = kwargs['exc_info']
+ if "exc_info" in kwargs:
+ typ, value, tb = kwargs["exc_info"]
if isinstance(value, PermissionError):
self.set_status(403)
- self.write('PermissionError')
+ self.write("PermissionError")
return
RequestHandler.write_error(self, status_code, **kwargs)
def log_exception(self, typ, value, tb):
if isinstance(value, PermissionError):
- app_log.warning('custom logging for PermissionError: %s',
- value.args[0])
+ app_log.warning("custom logging for PermissionError: %s", value.args[0])
else:
RequestHandler.log_exception(self, typ, value, tb)
def test_http_error(self):
# HTTPErrors are logged as warnings with no stack trace.
# TODO: extend ExpectLog to test this more precisely
- with ExpectLog(gen_log, '.*no longer here'):
- response = self.fetch('/?exc=http')
+ with ExpectLog(gen_log, ".*no longer here"):
+ response = self.fetch("/?exc=http")
self.assertEqual(response.code, 410)
def test_unknown_error(self):
# Unknown errors are logged as errors with a stack trace.
- with ExpectLog(app_log, 'Uncaught exception'):
- response = self.fetch('/?exc=zero')
+ with ExpectLog(app_log, "Uncaught exception"):
+ response = self.fetch("/?exc=zero")
self.assertEqual(response.code, 500)
def test_known_error(self):
# log_exception can override logging behavior, and write_error
# can override the response.
- with ExpectLog(app_log,
- 'custom logging for PermissionError: not allowed'):
- response = self.fetch('/?exc=permission')
+ with ExpectLog(app_log, "custom logging for PermissionError: not allowed"):
+ response = self.fetch("/?exc=permission")
self.assertEqual(response.code, 403)
def test_buggy_log_exception(self):
# Something gets logged even though the application's
# logger is broken.
- with ExpectLog(app_log, '.*'):
- self.fetch('/')
+ with ExpectLog(app_log, ".*"):
+ self.fetch("/")
class UIMethodUIModuleTest(SimpleHandlerTestCase):
"""Test that UI methods and modules are created correctly and
associated with the handler.
"""
+
class Handler(RequestHandler):
def get(self):
- self.render('foo.html')
+ self.render("foo.html")
def value(self):
return self.get_argument("value")
def get_app_kwargs(self):
def my_ui_method(handler, x):
- return "In my_ui_method(%s) with handler value %s." % (
- x, handler.value())
+ return "In my_ui_method(%s) with handler value %s." % (x, handler.value())
class MyModule(UIModule):
def render(self, x):
return "In MyModule(%s) with handler value %s." % (
- x, self.handler.value())
+ x,
+ self.handler.value(),
+ )
- loader = DictLoader({
- 'foo.html': '{{ my_ui_method(42) }} {% module MyModule(123) %}',
- })
- return dict(template_loader=loader,
- ui_methods={'my_ui_method': my_ui_method},
- ui_modules={'MyModule': MyModule})
+ loader = DictLoader(
+ {"foo.html": "{{ my_ui_method(42) }} {% module MyModule(123) %}"}
+ )
+ return dict(
+ template_loader=loader,
+ ui_methods={"my_ui_method": my_ui_method},
+ ui_modules={"MyModule": MyModule},
+ )
def tearDown(self):
super(UIMethodUIModuleTest, self).tearDown()
RequestHandler._template_loaders.clear()
def test_ui_method(self):
- response = self.fetch('/?value=asdf')
- self.assertEqual(response.body,
- b'In my_ui_method(42) with handler value asdf. '
- b'In MyModule(123) with handler value asdf.')
+ response = self.fetch("/?value=asdf")
+ self.assertEqual(
+ response.body,
+ b"In my_ui_method(42) with handler value asdf. "
+ b"In MyModule(123) with handler value asdf.",
+ )
class GetArgumentErrorTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
try:
- self.get_argument('foo')
+ self.get_argument("foo")
self.write({})
except MissingArgumentError as e:
- self.write({'arg_name': e.arg_name,
- 'log_message': e.log_message})
+ self.write({"arg_name": e.arg_name, "log_message": e.log_message})
def test_catch_error(self):
- response = self.fetch('/')
- self.assertEqual(json_decode(response.body),
- {'arg_name': 'foo',
- 'log_message': 'Missing argument foo'})
+ response = self.fetch("/")
+ self.assertEqual(
+ json_decode(response.body),
+ {"arg_name": "foo", "log_message": "Missing argument foo"},
+ )
class SetLazyPropertiesTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def prepare(self):
- self.current_user = 'Ben'
- self.locale = locale.get('en_US')
+ self.current_user = "Ben"
+ self.locale = locale.get("en_US")
def get_user_locale(self):
raise NotImplementedError()
raise NotImplementedError()
def get(self):
- self.write('Hello %s (%s)' % (self.current_user, self.locale.code))
+ self.write("Hello %s (%s)" % (self.current_user, self.locale.code))
def test_set_properties(self):
# Ensure that current_user can be assigned to normally for apps
# that want to forgo the lazy get_current_user property
- response = self.fetch('/')
- self.assertEqual(response.body, b'Hello Ben (en_US)')
+ response = self.fetch("/")
+ self.assertEqual(response.body, b"Hello Ben (en_US)")
class GetCurrentUserTest(WebTestCase):
def get_app_kwargs(self):
class WithoutUserModule(UIModule):
def render(self):
- return ''
+ return ""
class WithUserModule(UIModule):
def render(self):
return str(self.current_user)
- loader = DictLoader({
- 'without_user.html': '',
- 'with_user.html': '{{ current_user }}',
- 'without_user_module.html': '{% module WithoutUserModule() %}',
- 'with_user_module.html': '{% module WithUserModule() %}',
- })
- return dict(template_loader=loader,
- ui_modules={'WithUserModule': WithUserModule,
- 'WithoutUserModule': WithoutUserModule})
+ loader = DictLoader(
+ {
+ "without_user.html": "",
+ "with_user.html": "{{ current_user }}",
+ "without_user_module.html": "{% module WithoutUserModule() %}",
+ "with_user_module.html": "{% module WithUserModule() %}",
+ }
+ )
+ return dict(
+ template_loader=loader,
+ ui_modules={
+ "WithUserModule": WithUserModule,
+ "WithoutUserModule": WithoutUserModule,
+ },
+ )
def tearDown(self):
super(GetCurrentUserTest, self).tearDown()
def get_current_user(self):
self.has_loaded_current_user = True
- return ''
+ return ""
class WithoutUserHandler(CurrentUserHandler):
def get(self):
- self.render_string('without_user.html')
+ self.render_string("without_user.html")
self.finish(str(self.has_loaded_current_user))
class WithUserHandler(CurrentUserHandler):
def get(self):
- self.render_string('with_user.html')
+ self.render_string("with_user.html")
self.finish(str(self.has_loaded_current_user))
class CurrentUserModuleHandler(CurrentUserHandler):
class WithoutUserModuleHandler(CurrentUserModuleHandler):
def get(self):
- self.render_string('without_user_module.html')
+ self.render_string("without_user_module.html")
self.finish(str(self.has_loaded_current_user))
class WithUserModuleHandler(CurrentUserModuleHandler):
def get(self):
- self.render_string('with_user_module.html')
+ self.render_string("with_user_module.html")
self.finish(str(self.has_loaded_current_user))
- return [('/without_user', WithoutUserHandler),
- ('/with_user', WithUserHandler),
- ('/without_user_module', WithoutUserModuleHandler),
- ('/with_user_module', WithUserModuleHandler)]
+ return [
+ ("/without_user", WithoutUserHandler),
+ ("/with_user", WithUserHandler),
+ ("/without_user_module", WithoutUserModuleHandler),
+ ("/with_user_module", WithUserModuleHandler),
+ ]
- @unittest.skip('needs fix')
+ @unittest.skip("needs fix")
def test_get_current_user_is_lazy(self):
# TODO: Make this test pass. See #820.
- response = self.fetch('/without_user')
- self.assertEqual(response.body, b'False')
+ response = self.fetch("/without_user")
+ self.assertEqual(response.body, b"False")
def test_get_current_user_works(self):
- response = self.fetch('/with_user')
- self.assertEqual(response.body, b'True')
+ response = self.fetch("/with_user")
+ self.assertEqual(response.body, b"True")
def test_get_current_user_from_ui_module_is_lazy(self):
- response = self.fetch('/without_user_module')
- self.assertEqual(response.body, b'False')
+ response = self.fetch("/without_user_module")
+ self.assertEqual(response.body, b"False")
def test_get_current_user_from_ui_module_works(self):
- response = self.fetch('/with_user_module')
- self.assertEqual(response.body, b'True')
+ response = self.fetch("/with_user_module")
+ self.assertEqual(response.body, b"True")
class UnimplementedHTTPMethodsTest(SimpleHandlerTestCase):
pass
def test_unimplemented_standard_methods(self):
- for method in ['HEAD', 'GET', 'DELETE', 'OPTIONS']:
- response = self.fetch('/', method=method)
+ for method in ["HEAD", "GET", "DELETE", "OPTIONS"]:
+ response = self.fetch("/", method=method)
self.assertEqual(response.code, 405)
- for method in ['POST', 'PUT']:
- response = self.fetch('/', method=method, body=b'')
+ for method in ["POST", "PUT"]:
+ response = self.fetch("/", method=method, body=b"")
self.assertEqual(response.code, 405)
def other(self):
# Even though this method exists, it won't get called automatically
# because it is not in SUPPORTED_METHODS.
- self.write('other')
+ self.write("other")
def test_unimplemented_patch(self):
# PATCH is recently standardized; Tornado supports it by default
# but wsgiref.validate doesn't like it.
- response = self.fetch('/', method='PATCH', body=b'')
+ response = self.fetch("/", method="PATCH", body=b"")
self.assertEqual(response.code, 405)
def test_unimplemented_other(self):
- response = self.fetch('/', method='OTHER',
- allow_nonstandard_methods=True)
+ response = self.fetch("/", method="OTHER", allow_nonstandard_methods=True)
self.assertEqual(response.code, 405)
get = delete = options = post = put = method # type: ignore
def test_standard_methods(self):
- response = self.fetch('/', method='HEAD')
- self.assertEqual(response.body, b'')
- for method in ['GET', 'DELETE', 'OPTIONS']:
- response = self.fetch('/', method=method)
+ response = self.fetch("/", method="HEAD")
+ self.assertEqual(response.body, b"")
+ for method in ["GET", "DELETE", "OPTIONS"]:
+ response = self.fetch("/", method=method)
self.assertEqual(response.body, utf8(method))
- for method in ['POST', 'PUT']:
- response = self.fetch('/', method=method, body=b'')
+ for method in ["POST", "PUT"]:
+ response = self.fetch("/", method=method, body=b"")
self.assertEqual(response.body, utf8(method))
class PatchMethodTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
- SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ('OTHER',) # type: ignore
+ SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ( # type: ignore
+ "OTHER",
+ )
def patch(self):
- self.write('patch')
+ self.write("patch")
def other(self):
- self.write('other')
+ self.write("other")
def test_patch(self):
- response = self.fetch('/', method='PATCH', body=b'')
- self.assertEqual(response.body, b'patch')
+ response = self.fetch("/", method="PATCH", body=b"")
+ self.assertEqual(response.body, b"patch")
def test_other(self):
- response = self.fetch('/', method='OTHER',
- allow_nonstandard_methods=True)
- self.assertEqual(response.body, b'other')
+ response = self.fetch("/", method="OTHER", allow_nonstandard_methods=True)
+ self.assertEqual(response.body, b"other")
class FinishInPrepareTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def prepare(self):
- self.finish('done')
+ self.finish("done")
def get(self):
# It's difficult to assert for certain that a method did not
# or will not be called in an asynchronous context, but this
# will be logged noisily if it is reached.
- raise Exception('should not reach this method')
+ raise Exception("should not reach this method")
def test_finish_in_prepare(self):
- response = self.fetch('/')
- self.assertEqual(response.body, b'done')
+ response = self.fetch("/")
+ self.assertEqual(response.body, b"done")
class Default404Test(WebTestCase):
def get_handlers(self):
# If there are no handlers at all a default redirect handler gets added.
- return [('/foo', RequestHandler)]
+ return [("/foo", RequestHandler)]
def test_404(self):
- response = self.fetch('/')
+ response = self.fetch("/")
self.assertEqual(response.code, 404)
- self.assertEqual(response.body,
- b'<html><title>404: Not Found</title>'
- b'<body>404: Not Found</body></html>')
+ self.assertEqual(
+ response.body,
+ b"<html><title>404: Not Found</title>"
+ b"<body>404: Not Found</body></html>",
+ )
class Custom404Test(WebTestCase):
def get_handlers(self):
- return [('/foo', RequestHandler)]
+ return [("/foo", RequestHandler)]
def get_app_kwargs(self):
class Custom404Handler(RequestHandler):
def get(self):
self.set_status(404)
- self.write('custom 404 response')
+ self.write("custom 404 response")
return dict(default_handler_class=Custom404Handler)
def test_404(self):
- response = self.fetch('/')
+ response = self.fetch("/")
self.assertEqual(response.code, 404)
- self.assertEqual(response.body, b'custom 404 response')
+ self.assertEqual(response.body, b"custom 404 response")
class DefaultHandlerArgumentsTest(WebTestCase):
def get_handlers(self):
- return [('/foo', RequestHandler)]
+ return [("/foo", RequestHandler)]
def get_app_kwargs(self):
- return dict(default_handler_class=ErrorHandler,
- default_handler_args=dict(status_code=403))
+ return dict(
+ default_handler_class=ErrorHandler,
+ default_handler_args=dict(status_code=403),
+ )
def test_403(self):
- response = self.fetch('/')
+ response = self.fetch("/")
self.assertEqual(response.code, 403)
class HandlerByNameTest(WebTestCase):
def get_handlers(self):
# All three are equivalent.
- return [('/hello1', HelloHandler),
- ('/hello2', 'tornado.test.web_test.HelloHandler'),
- url('/hello3', 'tornado.test.web_test.HelloHandler'),
- ]
+ return [
+ ("/hello1", HelloHandler),
+ ("/hello2", "tornado.test.web_test.HelloHandler"),
+ url("/hello3", "tornado.test.web_test.HelloHandler"),
+ ]
def test_handler_by_name(self):
- resp = self.fetch('/hello1')
- self.assertEqual(resp.body, b'hello')
- resp = self.fetch('/hello2')
- self.assertEqual(resp.body, b'hello')
- resp = self.fetch('/hello3')
- self.assertEqual(resp.body, b'hello')
+ resp = self.fetch("/hello1")
+ self.assertEqual(resp.body, b"hello")
+ resp = self.fetch("/hello2")
+ self.assertEqual(resp.body, b"hello")
+ resp = self.fetch("/hello3")
+ self.assertEqual(resp.body, b"hello")
class StreamingRequestBodyTest(WebTestCase):
super(CloseDetectionHandler, self).on_connection_close()
self.test.close_future.set_result(None)
- return [('/stream_body', StreamingBodyHandler, dict(test=self)),
- ('/early_return', EarlyReturnHandler),
- ('/close_detection', CloseDetectionHandler, dict(test=self))]
+ return [
+ ("/stream_body", StreamingBodyHandler, dict(test=self)),
+ ("/early_return", EarlyReturnHandler),
+ ("/close_detection", CloseDetectionHandler, dict(test=self)),
+ ]
def connect(self, url, connection_close):
# Use a raw connection so we can control the sending of data.
@contextlib.contextmanager
def in_method(self, method):
if self.method is not None:
- self.test.fail("entered method %s while in %s" %
- (method, self.method))
+ self.test.fail("entered method %s while in %s" % (method, self.method))
self.method = method
self.methods.append(method)
try:
def prepare(self):
# Note that asynchronous prepare() does not block data_received,
# so we don't use in_method here.
- self.methods.append('prepare')
+ self.methods.append("prepare")
yield gen.moment
@gen.coroutine
def post(self):
- with self.in_method('post'):
+ with self.in_method("post"):
yield gen.moment
self.write(dict(methods=self.methods))
# Test all the slightly different code paths for fixed, chunked, etc bodies.
def test_flow_control_fixed_body(self):
- response = self.fetch('/', body='abcdefghijklmnopqrstuvwxyz',
- method='POST')
+ response = self.fetch("/", body="abcdefghijklmnopqrstuvwxyz", method="POST")
response.rethrow()
- self.assertEqual(json_decode(response.body),
- dict(methods=['prepare', 'data_received',
- 'data_received', 'data_received',
- 'post']))
+ self.assertEqual(
+ json_decode(response.body),
+ dict(
+ methods=[
+ "prepare",
+ "data_received",
+ "data_received",
+ "data_received",
+ "post",
+ ]
+ ),
+ )
def test_flow_control_chunked_body(self):
- chunks = [b'abcd', b'efgh', b'ijkl']
+ chunks = [b"abcd", b"efgh", b"ijkl"]
@gen.coroutine
def body_producer(write):
for i in chunks:
yield write(i)
- response = self.fetch('/', body_producer=body_producer, method='POST')
+
+ response = self.fetch("/", body_producer=body_producer, method="POST")
response.rethrow()
- self.assertEqual(json_decode(response.body),
- dict(methods=['prepare', 'data_received',
- 'data_received', 'data_received',
- 'post']))
+ self.assertEqual(
+ json_decode(response.body),
+ dict(
+ methods=[
+ "prepare",
+ "data_received",
+ "data_received",
+ "data_received",
+ "post",
+ ]
+ ),
+ )
def test_flow_control_compressed_body(self):
bytesio = BytesIO()
- gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
- gzip_file.write(b'abcdefghijklmnopqrstuvwxyz')
+ gzip_file = gzip.GzipFile(mode="w", fileobj=bytesio)
+ gzip_file.write(b"abcdefghijklmnopqrstuvwxyz")
gzip_file.close()
compressed_body = bytesio.getvalue()
- response = self.fetch('/', body=compressed_body, method='POST',
- headers={'Content-Encoding': 'gzip'})
+ response = self.fetch(
+ "/",
+ body=compressed_body,
+ method="POST",
+ headers={"Content-Encoding": "gzip"},
+ )
response.rethrow()
- self.assertEqual(json_decode(response.body),
- dict(methods=['prepare', 'data_received',
- 'data_received', 'data_received',
- 'post']))
+ self.assertEqual(
+ json_decode(response.body),
+ dict(
+ methods=[
+ "prepare",
+ "data_received",
+ "data_received",
+ "data_received",
+ "post",
+ ]
+ ),
+ )
class DecoratedStreamingRequestFlowControlTest(
- BaseStreamingRequestFlowControlTest,
- WebTestCase):
+ BaseStreamingRequestFlowControlTest, WebTestCase
+):
def get_handlers(self):
class DecoratedFlowControlHandler(BaseFlowControlHandler):
@gen.coroutine
def data_received(self, data):
- with self.in_method('data_received'):
+ with self.in_method("data_received"):
yield gen.moment
- return [('/', DecoratedFlowControlHandler, dict(test=self))]
+
+ return [("/", DecoratedFlowControlHandler, dict(test=self))]
class NativeStreamingRequestFlowControlTest(
- BaseStreamingRequestFlowControlTest,
- WebTestCase):
+ BaseStreamingRequestFlowControlTest, WebTestCase
+):
def get_handlers(self):
class NativeFlowControlHandler(BaseFlowControlHandler):
async def data_received(self, data):
- with self.in_method('data_received'):
+ with self.in_method("data_received"):
import asyncio
+
await asyncio.sleep(0)
- return [('/', NativeFlowControlHandler, dict(test=self))]
+
+ return [("/", NativeFlowControlHandler, dict(test=self))]
class IncorrectContentLengthTest(SimpleHandlerTestCase):
test.server_error = e
raise
- return [('/high', TooHigh),
- ('/low', TooLow)]
+ return [("/high", TooHigh), ("/low", TooLow)]
def test_content_length_too_high(self):
# When the content-length is too high, the connection is simply
# closed without completing the response. An error is logged on
# the server.
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
- with ExpectLog(gen_log,
- "(Cannot send error response after headers written"
- "|Failed to flush partial response)"):
+ with ExpectLog(
+ gen_log,
+ "(Cannot send error response after headers written"
+ "|Failed to flush partial response)",
+ ):
with self.assertRaises(HTTPClientError):
self.fetch("/high", raise_error=True)
- self.assertEqual(str(self.server_error),
- "Tried to write 40 bytes less than Content-Length")
+ self.assertEqual(
+ str(self.server_error), "Tried to write 40 bytes less than Content-Length"
+ )
def test_content_length_too_low(self):
# When the content-length is too low, the connection is closed
# without writing the last chunk, so the client never sees the request
# complete (which would be a framing error).
with ExpectLog(app_log, "(Uncaught exception|Exception in callback)"):
- with ExpectLog(gen_log,
- "(Cannot send error response after headers written"
- "|Failed to flush partial response)"):
+ with ExpectLog(
+ gen_log,
+ "(Cannot send error response after headers written"
+ "|Failed to flush partial response)",
+ ):
with self.assertRaises(HTTPClientError):
self.fetch("/low", raise_error=True)
- self.assertEqual(str(self.server_error),
- "Tried to write more data than Content-Length")
+ self.assertEqual(
+ str(self.server_error), "Tried to write more data than Content-Length"
+ )
class ClientCloseTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
- if self.request.version.startswith('HTTP/1'):
+ if self.request.version.startswith("HTTP/1"):
# Simulate a connection closed by the client during
# request processing. The client will see an error, but the
# server should respond gracefully (without logging errors
# because we were unable to write out as many bytes as
# Content-Length said we would)
self.request.connection.stream.close()
- self.write('hello')
+ self.write("hello")
else:
# TODO: add a HTTP2-compatible version of this test.
- self.write('requires HTTP/1.x')
+ self.write("requires HTTP/1.x")
def test_client_close(self):
with self.assertRaises((HTTPClientError, unittest.SkipTest)):
- response = self.fetch('/', raise_error=True)
- if response.body == b'requires HTTP/1.x':
- self.skipTest('requires HTTP/1.x')
+ response = self.fetch("/", raise_error=True)
+ if response.body == b"requires HTTP/1.x":
+ self.skipTest("requires HTTP/1.x")
self.assertEqual(response.code, 599)
return 1300000000
def test_known_values(self):
- signed_v1 = create_signed_value(SignedValueTest.SECRET, "key", "value",
- version=1, clock=self.present)
+ signed_v1 = create_signed_value(
+ SignedValueTest.SECRET, "key", "value", version=1, clock=self.present
+ )
self.assertEqual(
- signed_v1,
- b"dmFsdWU=|1300000000|31c934969f53e48164c50768b40cbd7e2daaaa4f")
+ signed_v1, b"dmFsdWU=|1300000000|31c934969f53e48164c50768b40cbd7e2daaaa4f"
+ )
- signed_v2 = create_signed_value(SignedValueTest.SECRET, "key", "value",
- version=2, clock=self.present)
+ signed_v2 = create_signed_value(
+ SignedValueTest.SECRET, "key", "value", version=2, clock=self.present
+ )
self.assertEqual(
signed_v2,
b"2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
- b"3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152")
+ b"3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152",
+ )
- signed_default = create_signed_value(SignedValueTest.SECRET,
- "key", "value", clock=self.present)
+ signed_default = create_signed_value(
+ SignedValueTest.SECRET, "key", "value", clock=self.present
+ )
self.assertEqual(signed_default, signed_v2)
- decoded_v1 = decode_signed_value(SignedValueTest.SECRET, "key",
- signed_v1, min_version=1,
- clock=self.present)
+ decoded_v1 = decode_signed_value(
+ SignedValueTest.SECRET, "key", signed_v1, min_version=1, clock=self.present
+ )
self.assertEqual(decoded_v1, b"value")
- decoded_v2 = decode_signed_value(SignedValueTest.SECRET, "key",
- signed_v2, min_version=2,
- clock=self.present)
+ decoded_v2 = decode_signed_value(
+ SignedValueTest.SECRET, "key", signed_v2, min_version=2, clock=self.present
+ )
self.assertEqual(decoded_v2, b"value")
def test_name_swap(self):
- signed1 = create_signed_value(SignedValueTest.SECRET, "key1", "value",
- clock=self.present)
- signed2 = create_signed_value(SignedValueTest.SECRET, "key2", "value",
- clock=self.present)
+ signed1 = create_signed_value(
+ SignedValueTest.SECRET, "key1", "value", clock=self.present
+ )
+ signed2 = create_signed_value(
+ SignedValueTest.SECRET, "key2", "value", clock=self.present
+ )
# Try decoding each string with the other's "name"
- decoded1 = decode_signed_value(SignedValueTest.SECRET, "key2", signed1,
- clock=self.present)
+ decoded1 = decode_signed_value(
+ SignedValueTest.SECRET, "key2", signed1, clock=self.present
+ )
self.assertIs(decoded1, None)
- decoded2 = decode_signed_value(SignedValueTest.SECRET, "key1", signed2,
- clock=self.present)
+ decoded2 = decode_signed_value(
+ SignedValueTest.SECRET, "key1", signed2, clock=self.present
+ )
self.assertIs(decoded2, None)
def test_expired(self):
- signed = create_signed_value(SignedValueTest.SECRET, "key1", "value",
- clock=self.past)
- decoded_past = decode_signed_value(SignedValueTest.SECRET, "key1",
- signed, clock=self.past)
+ signed = create_signed_value(
+ SignedValueTest.SECRET, "key1", "value", clock=self.past
+ )
+ decoded_past = decode_signed_value(
+ SignedValueTest.SECRET, "key1", signed, clock=self.past
+ )
self.assertEqual(decoded_past, b"value")
- decoded_present = decode_signed_value(SignedValueTest.SECRET, "key1",
- signed, clock=self.present)
+ decoded_present = decode_signed_value(
+ SignedValueTest.SECRET, "key1", signed, clock=self.present
+ )
self.assertIs(decoded_present, None)
def test_payload_tampering(self):
sig = "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"
def validate(prefix):
- return (b'value' ==
- decode_signed_value(SignedValueTest.SECRET, "key",
- prefix + sig, clock=self.present))
+ return b"value" == decode_signed_value(
+ SignedValueTest.SECRET, "key", prefix + sig, clock=self.present
+ )
+
self.assertTrue(validate("2|1:0|10:1300000000|3:key|8:dmFsdWU=|"))
# Change key version
self.assertFalse(validate("2|1:1|10:1300000000|3:key|8:dmFsdWU=|"))
prefix = "2|1:0|10:1300000000|3:key|8:dmFsdWU=|"
def validate(sig):
- return (b'value' ==
- decode_signed_value(SignedValueTest.SECRET, "key",
- prefix + sig, clock=self.present))
- self.assertTrue(validate(
- "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"))
+ return b"value" == decode_signed_value(
+ SignedValueTest.SECRET, "key", prefix + sig, clock=self.present
+ )
+
+ self.assertTrue(
+ validate("3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152")
+ )
# All zeros
self.assertFalse(validate("0" * 32))
# Change one character
- self.assertFalse(validate(
- "4d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152"))
+ self.assertFalse(
+ validate("4d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e152")
+ )
# Change another character
- self.assertFalse(validate(
- "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e153"))
+ self.assertFalse(
+ validate("3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e153")
+ )
# Truncate
- self.assertFalse(validate(
- "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e15"))
+ self.assertFalse(
+ validate("3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e15")
+ )
# Lengthen
- self.assertFalse(validate(
- "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e1538"))
+ self.assertFalse(
+ validate(
+ "3d4e60b996ff9c5d5788e333a0cba6f238a22c6c0f94788870e1a9ecd482e1538"
+ )
+ )
def test_non_ascii(self):
value = b"\xe9"
- signed = create_signed_value(SignedValueTest.SECRET, "key", value,
- clock=self.present)
- decoded = decode_signed_value(SignedValueTest.SECRET, "key", signed,
- clock=self.present)
+ signed = create_signed_value(
+ SignedValueTest.SECRET, "key", value, clock=self.present
+ )
+ decoded = decode_signed_value(
+ SignedValueTest.SECRET, "key", signed, clock=self.present
+ )
self.assertEqual(value, decoded)
def test_key_versioning_read_write_default_key(self):
value = b"\xe9"
- signed = create_signed_value(SignedValueTest.SECRET_DICT,
- "key", value, clock=self.present,
- key_version=0)
- decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
- "key", signed, clock=self.present)
+ signed = create_signed_value(
+ SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=0
+ )
+ decoded = decode_signed_value(
+ SignedValueTest.SECRET_DICT, "key", signed, clock=self.present
+ )
self.assertEqual(value, decoded)
def test_key_versioning_read_write_non_default_key(self):
value = b"\xe9"
- signed = create_signed_value(SignedValueTest.SECRET_DICT,
- "key", value, clock=self.present,
- key_version=1)
- decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
- "key", signed, clock=self.present)
+ signed = create_signed_value(
+ SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=1
+ )
+ decoded = decode_signed_value(
+ SignedValueTest.SECRET_DICT, "key", signed, clock=self.present
+ )
self.assertEqual(value, decoded)
def test_key_versioning_invalid_key(self):
value = b"\xe9"
- signed = create_signed_value(SignedValueTest.SECRET_DICT,
- "key", value, clock=self.present,
- key_version=0)
+ signed = create_signed_value(
+ SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=0
+ )
newkeys = SignedValueTest.SECRET_DICT.copy()
newkeys.pop(0)
- decoded = decode_signed_value(newkeys,
- "key", signed, clock=self.present)
+ decoded = decode_signed_value(newkeys, "key", signed, clock=self.present)
self.assertEqual(None, decoded)
def test_key_version_retrieval(self):
value = b"\xe9"
- signed = create_signed_value(SignedValueTest.SECRET_DICT,
- "key", value, clock=self.present,
- key_version=1)
+ signed = create_signed_value(
+ SignedValueTest.SECRET_DICT, "key", value, clock=self.present, key_version=1
+ )
key_version = get_signature_key_version(signed)
self.assertEqual(1, key_version)
else:
headers = None
response = self.fetch(
- "/" if version is None else ("/?version=%d" % version),
- headers=headers)
+ "/" if version is None else ("/?version=%d" % version), headers=headers
+ )
response.rethrow()
return native_str(response.body)
def test_xsrf_fail_body_no_cookie(self):
with ExpectLog(gen_log, ".*XSRF cookie does not match POST"):
response = self.fetch(
- "/", method="POST",
- body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)))
+ "/",
+ method="POST",
+ body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)),
+ )
self.assertEqual(response.code, 403)
def test_xsrf_fail_argument_invalid_format(self):
with ExpectLog(gen_log, ".*'_xsrf' argument has invalid format"):
response = self.fetch(
- "/", method="POST",
+ "/",
+ method="POST",
headers=self.cookie_headers(),
- body=urllib.parse.urlencode(dict(_xsrf='3|')))
+ body=urllib.parse.urlencode(dict(_xsrf="3|")),
+ )
self.assertEqual(response.code, 403)
def test_xsrf_fail_cookie_invalid_format(self):
with ExpectLog(gen_log, ".*XSRF cookie does not match POST"):
response = self.fetch(
- "/", method="POST",
- headers=self.cookie_headers(token='3|'),
- body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)))
+ "/",
+ method="POST",
+ headers=self.cookie_headers(token="3|"),
+ body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)),
+ )
self.assertEqual(response.code, 403)
def test_xsrf_fail_cookie_no_body(self):
with ExpectLog(gen_log, ".*'_xsrf' argument missing"):
response = self.fetch(
- "/", method="POST", body=b"",
- headers=self.cookie_headers())
+ "/", method="POST", body=b"", headers=self.cookie_headers()
+ )
self.assertEqual(response.code, 403)
def test_xsrf_success_short_token(self):
response = self.fetch(
- "/", method="POST",
- body=urllib.parse.urlencode(dict(_xsrf='deadbeef')),
- headers=self.cookie_headers(token='deadbeef'))
+ "/",
+ method="POST",
+ body=urllib.parse.urlencode(dict(_xsrf="deadbeef")),
+ headers=self.cookie_headers(token="deadbeef"),
+ )
self.assertEqual(response.code, 200)
def test_xsrf_success_non_hex_token(self):
response = self.fetch(
- "/", method="POST",
- body=urllib.parse.urlencode(dict(_xsrf='xoxo')),
- headers=self.cookie_headers(token='xoxo'))
+ "/",
+ method="POST",
+ body=urllib.parse.urlencode(dict(_xsrf="xoxo")),
+ headers=self.cookie_headers(token="xoxo"),
+ )
self.assertEqual(response.code, 200)
def test_xsrf_success_post_body(self):
response = self.fetch(
- "/", method="POST",
+ "/",
+ method="POST",
body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)),
- headers=self.cookie_headers())
+ headers=self.cookie_headers(),
+ )
self.assertEqual(response.code, 200)
def test_xsrf_success_query_string(self):
response = self.fetch(
"/?" + urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)),
- method="POST", body=b"",
- headers=self.cookie_headers())
+ method="POST",
+ body=b"",
+ headers=self.cookie_headers(),
+ )
self.assertEqual(response.code, 200)
def test_xsrf_success_header(self):
- response = self.fetch("/", method="POST", body=b"",
- headers=dict({"X-Xsrftoken": self.xsrf_token}, # type: ignore
- **self.cookie_headers()))
+ response = self.fetch(
+ "/",
+ method="POST",
+ body=b"",
+ headers=dict(
+ {"X-Xsrftoken": self.xsrf_token}, # type: ignore
+ **self.cookie_headers()
+ ),
+ )
self.assertEqual(response.code, 200)
def test_distinct_tokens(self):
# Each token can be used to authenticate its own request.
for token in (self.xsrf_token, token2):
response = self.fetch(
- "/", method="POST",
+ "/",
+ method="POST",
body=urllib.parse.urlencode(dict(_xsrf=token)),
- headers=self.cookie_headers(token))
+ headers=self.cookie_headers(token),
+ )
self.assertEqual(response.code, 200)
# Sending one in the cookie and the other in the body is not allowed.
- for cookie_token, body_token in ((self.xsrf_token, token2),
- (token2, self.xsrf_token)):
- with ExpectLog(gen_log, '.*XSRF cookie does not match POST'):
+ for cookie_token, body_token in (
+ (self.xsrf_token, token2),
+ (token2, self.xsrf_token),
+ ):
+ with ExpectLog(gen_log, ".*XSRF cookie does not match POST"):
response = self.fetch(
- "/", method="POST",
+ "/",
+ method="POST",
body=urllib.parse.urlencode(dict(_xsrf=body_token)),
- headers=self.cookie_headers(cookie_token))
+ headers=self.cookie_headers(cookie_token),
+ )
self.assertEqual(response.code, 403)
def test_refresh_token(self):
# Tokens are encoded uniquely each time
tokens_seen.add(token)
response = self.fetch(
- "/", method="POST",
+ "/",
+ method="POST",
body=urllib.parse.urlencode(dict(_xsrf=self.xsrf_token)),
- headers=self.cookie_headers(token))
+ headers=self.cookie_headers(token),
+ )
self.assertEqual(response.code, 200)
self.assertEqual(len(tokens_seen), 6)
def test_versioning(self):
# Version 1 still produces distinct tokens per request.
- self.assertNotEqual(self.get_token(version=1),
- self.get_token(version=1))
+ self.assertNotEqual(self.get_token(version=1), self.get_token(version=1))
# Refreshed v1 tokens are all identical.
v1_token = self.get_token(version=1)
self.assertNotEqual(v2_token, self.get_token(v1_token))
# The tokens are cross-compatible.
- for cookie_token, body_token in ((v1_token, v2_token),
- (v2_token, v1_token)):
+ for cookie_token, body_token in ((v1_token, v2_token), (v2_token, v1_token)):
response = self.fetch(
- "/", method="POST",
+ "/",
+ method="POST",
body=urllib.parse.urlencode(dict(_xsrf=body_token)),
- headers=self.cookie_headers(cookie_token))
+ headers=self.cookie_headers(cookie_token),
+ )
self.assertEqual(response.code, 200)
self.write(self.xsrf_token)
def get_app_kwargs(self):
- return dict(xsrf_cookies=True,
- xsrf_cookie_kwargs=dict(httponly=True))
+ return dict(xsrf_cookies=True, xsrf_cookie_kwargs=dict(httponly=True))
def test_xsrf_httponly(self):
response = self.fetch("/")
- self.assertIn('httponly;', response.headers['Set-Cookie'].lower())
+ self.assertIn("httponly;", response.headers["Set-Cookie"].lower())
class FinishExceptionTest(SimpleHandlerTestCase):
class Handler(RequestHandler):
def get(self):
self.set_status(401)
- self.set_header('WWW-Authenticate', 'Basic realm="something"')
- if self.get_argument('finish_value', ''):
- raise Finish('authentication required')
+ self.set_header("WWW-Authenticate", 'Basic realm="something"')
+ if self.get_argument("finish_value", ""):
+ raise Finish("authentication required")
else:
- self.write('authentication required')
+ self.write("authentication required")
raise Finish()
def test_finish_exception(self):
- for u in ['/', '/?finish_value=1']:
+ for u in ["/", "/?finish_value=1"]:
response = self.fetch(u)
self.assertEqual(response.code, 401)
- self.assertEqual('Basic realm="something"',
- response.headers.get('WWW-Authenticate'))
- self.assertEqual(b'authentication required', response.body)
+ self.assertEqual(
+ 'Basic realm="something"', response.headers.get("WWW-Authenticate")
+ )
+ self.assertEqual(b"authentication required", response.body)
class DecoratorTest(WebTestCase):
def get(self):
pass
- return [("/removeslash/", RemoveSlashHandler),
- ("/addslash", AddSlashHandler),
- ]
+ return [("/removeslash/", RemoveSlashHandler), ("/addslash", AddSlashHandler)]
def test_removeslash(self):
response = self.fetch("/removeslash/", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], "/removeslash")
+ self.assertEqual(response.headers["Location"], "/removeslash")
response = self.fetch("/removeslash/?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], "/removeslash?foo=bar")
+ self.assertEqual(response.headers["Location"], "/removeslash?foo=bar")
def test_addslash(self):
response = self.fetch("/addslash", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], "/addslash/")
+ self.assertEqual(response.headers["Location"], "/addslash/")
response = self.fetch("/addslash?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], "/addslash/?foo=bar")
+ self.assertEqual(response.headers["Location"], "/addslash/?foo=bar")
class CacheTest(WebTestCase):
def compute_etag(self):
return self._write_buffer[0]
- return [
- ('/etag/(.*)', EtagHandler)
- ]
+ return [("/etag/(.*)", EtagHandler)]
def test_wildcard_etag(self):
computed_etag = '"xyzzy"'
- etags = '*'
+ etags = "*"
self._test_etag(computed_etag, etags, 304)
def test_strong_etag_match(self):
def _test_etag(self, computed_etag, etags, status_code):
response = self.fetch(
- '/etag/' + computed_etag,
- headers={'If-None-Match': etags}
+ "/etag/" + computed_etag, headers={"If-None-Match": etags}
)
self.assertEqual(response.code, status_code)
class ApplicationTest(AsyncTestCase):
def test_listen(self):
app = Application([])
- server = app.listen(0, address='127.0.0.1')
+ server = app.listen(0, address="127.0.0.1")
server.stop()
class URLSpecReverseTest(unittest.TestCase):
def test_reverse(self):
- self.assertEqual('/favicon.ico', url(r'/favicon\.ico', None).reverse())
- self.assertEqual('/favicon.ico', url(r'^/favicon\.ico$', None).reverse())
+ self.assertEqual("/favicon.ico", url(r"/favicon\.ico", None).reverse())
+ self.assertEqual("/favicon.ico", url(r"^/favicon\.ico$", None).reverse())
def test_non_reversible(self):
# URLSpecs are non-reversible if they include non-constant
# regex features outside capturing groups. Currently, this is
# only strictly enforced for backslash-escaped character
# classes.
- paths = [
- r'^/api/v\d+/foo/(\w+)$',
- ]
+ paths = [r"^/api/v\d+/foo/(\w+)$"]
for path in paths:
# A URLSpec can still be created even if it cannot be reversed.
url_spec = url(path, None)
try:
result = url_spec.reverse()
- self.fail("did not get expected exception when reversing %s. "
- "result: %s" % (path, result))
+ self.fail(
+ "did not get expected exception when reversing %s. "
+ "result: %s" % (path, result)
+ )
except ValueError:
pass
def test_reverse_arguments(self):
- self.assertEqual('/api/v1/foo/bar',
- url(r'^/api/v1/foo/(\w+)$', None).reverse('bar'))
+ self.assertEqual(
+ "/api/v1/foo/bar", url(r"^/api/v1/foo/(\w+)$", None).reverse("bar")
+ )
class RedirectHandlerTest(WebTestCase):
def get_handlers(self):
return [
- ('/src', WebRedirectHandler, {'url': '/dst'}),
- ('/src2', WebRedirectHandler, {'url': '/dst2?foo=bar'}),
- (r'/(.*?)/(.*?)/(.*)', WebRedirectHandler, {'url': '/{1}/{0}/{2}'})]
+ ("/src", WebRedirectHandler, {"url": "/dst"}),
+ ("/src2", WebRedirectHandler, {"url": "/dst2?foo=bar"}),
+ (r"/(.*?)/(.*?)/(.*)", WebRedirectHandler, {"url": "/{1}/{0}/{2}"}),
+ ]
def test_basic_redirect(self):
- response = self.fetch('/src', follow_redirects=False)
+ response = self.fetch("/src", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], '/dst')
+ self.assertEqual(response.headers["Location"], "/dst")
def test_redirect_with_argument(self):
- response = self.fetch('/src?foo=bar', follow_redirects=False)
+ response = self.fetch("/src?foo=bar", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], '/dst?foo=bar')
+ self.assertEqual(response.headers["Location"], "/dst?foo=bar")
def test_redirect_with_appending_argument(self):
- response = self.fetch('/src2?foo2=bar2', follow_redirects=False)
+ response = self.fetch("/src2?foo2=bar2", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], '/dst2?foo=bar&foo2=bar2')
+ self.assertEqual(response.headers["Location"], "/dst2?foo=bar&foo2=bar2")
def test_redirect_pattern(self):
- response = self.fetch('/a/b/c', follow_redirects=False)
+ response = self.fetch("/a/b/c", follow_redirects=False)
self.assertEqual(response.code, 301)
- self.assertEqual(response.headers['Location'], '/b/a/c')
+ self.assertEqual(response.headers["Location"], "/b/a/c")
raise
from tornado.websocket import (
- WebSocketHandler, websocket_connect, WebSocketError, WebSocketClosedError,
+ WebSocketHandler,
+ websocket_connect,
+ WebSocketError,
+ WebSocketClosedError,
)
try:
This allows for deterministic cleanup of the associated socket.
"""
+
def initialize(self, close_future, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options
class HeaderHandler(TestWebSocketHandler):
def open(self):
methods_to_test = [
- functools.partial(self.write, 'This should not work'),
- functools.partial(self.redirect, 'http://localhost/elsewhere'),
- functools.partial(self.set_header, 'X-Test', ''),
- functools.partial(self.set_cookie, 'Chocolate', 'Chip'),
+ functools.partial(self.write, "This should not work"),
+ functools.partial(self.redirect, "http://localhost/elsewhere"),
+ functools.partial(self.set_header, "X-Test", ""),
+ functools.partial(self.set_cookie, "Chocolate", "Chip"),
functools.partial(self.set_status, 503),
self.flush,
self.finish,
raise Exception("did not get expected exception")
except RuntimeError:
pass
- self.write_message(self.request.headers.get('X-Test', ''))
+ self.write_message(self.request.headers.get("X-Test", ""))
class HeaderEchoHandler(TestWebSocketHandler):
def prepare(self):
for k, v in self.request.headers.get_all():
- if k.lower().startswith('x-test'):
+ if k.lower().startswith("x-test"):
self.set_header(k, v)
class NonWebSocketHandler(RequestHandler):
def get(self):
- self.write('ok')
+ self.write("ok")
class CloseReasonHandler(TestWebSocketHandler):
class CoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, close_future, compression_options=None):
- super(CoroutineOnMessageHandler, self).initialize(close_future,
- compression_options)
+ super(CoroutineOnMessageHandler, self).initialize(
+ close_future, compression_options
+ )
self.sleeping = 0
@gen.coroutine
def on_message(self, message):
if self.sleeping > 0:
- self.write_message('another coroutine is already sleeping')
+ self.write_message("another coroutine is already sleeping")
self.sleeping += 1
yield gen.sleep(0.01)
self.sleeping -= 1
class RenderMessageHandler(TestWebSocketHandler):
def on_message(self, message):
- self.write_message(self.render_string('message.html', message=message))
+ self.write_message(self.render_string("message.html", message=message))
class SubprotocolHandler(TestWebSocketHandler):
if self.select_subprotocol_called:
raise Exception("select_subprotocol called twice")
self.select_subprotocol_called = True
- if 'goodproto' in subprotocols:
- return 'goodproto'
+ if "goodproto" in subprotocols:
+ return "goodproto"
return None
def open(self):
def on_message(self, message):
if not self.open_finished:
- raise Exception('on_message called before open finished')
- self.write_message('ok')
+ raise Exception("on_message called before open finished")
+ self.write_message("ok")
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
- 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
- **kwargs)
+ "ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
+ )
raise gen.Return(ws)
@gen.coroutine
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
- return Application([
- ('/echo', EchoHandler, dict(close_future=self.close_future)),
- ('/non_ws', NonWebSocketHandler),
- ('/header', HeaderHandler, dict(close_future=self.close_future)),
- ('/header_echo', HeaderEchoHandler,
- dict(close_future=self.close_future)),
- ('/close_reason', CloseReasonHandler,
- dict(close_future=self.close_future)),
- ('/error_in_on_message', ErrorInOnMessageHandler,
- dict(close_future=self.close_future)),
- ('/async_prepare', AsyncPrepareHandler,
- dict(close_future=self.close_future)),
- ('/path_args/(.*)', PathArgsHandler,
- dict(close_future=self.close_future)),
- ('/coroutine', CoroutineOnMessageHandler,
- dict(close_future=self.close_future)),
- ('/render', RenderMessageHandler,
- dict(close_future=self.close_future)),
- ('/subprotocol', SubprotocolHandler,
- dict(close_future=self.close_future)),
- ('/open_coroutine', OpenCoroutineHandler,
- dict(close_future=self.close_future, test=self)),
- ], template_loader=DictLoader({
- 'message.html': '<b>{{ message }}</b>',
- }))
+ return Application(
+ [
+ ("/echo", EchoHandler, dict(close_future=self.close_future)),
+ ("/non_ws", NonWebSocketHandler),
+ ("/header", HeaderHandler, dict(close_future=self.close_future)),
+ (
+ "/header_echo",
+ HeaderEchoHandler,
+ dict(close_future=self.close_future),
+ ),
+ (
+ "/close_reason",
+ CloseReasonHandler,
+ dict(close_future=self.close_future),
+ ),
+ (
+ "/error_in_on_message",
+ ErrorInOnMessageHandler,
+ dict(close_future=self.close_future),
+ ),
+ (
+ "/async_prepare",
+ AsyncPrepareHandler,
+ dict(close_future=self.close_future),
+ ),
+ (
+ "/path_args/(.*)",
+ PathArgsHandler,
+ dict(close_future=self.close_future),
+ ),
+ (
+ "/coroutine",
+ CoroutineOnMessageHandler,
+ dict(close_future=self.close_future),
+ ),
+ ("/render", RenderMessageHandler, dict(close_future=self.close_future)),
+ (
+ "/subprotocol",
+ SubprotocolHandler,
+ dict(close_future=self.close_future),
+ ),
+ (
+ "/open_coroutine",
+ OpenCoroutineHandler,
+ dict(close_future=self.close_future, test=self),
+ ),
+ ],
+ template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
+ )
def get_http_client(self):
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
def test_http_request(self):
# WS server, HTTP client.
- response = self.fetch('/echo')
+ response = self.fetch("/echo")
self.assertEqual(response.code, 400)
def test_missing_websocket_key(self):
- response = self.fetch('/echo',
- headers={'Connection': 'Upgrade',
- 'Upgrade': 'WebSocket',
- 'Sec-WebSocket-Version': '13'})
+ response = self.fetch(
+ "/echo",
+ headers={
+ "Connection": "Upgrade",
+ "Upgrade": "WebSocket",
+ "Sec-WebSocket-Version": "13",
+ },
+ )
self.assertEqual(response.code, 400)
def test_bad_websocket_version(self):
- response = self.fetch('/echo',
- headers={'Connection': 'Upgrade',
- 'Upgrade': 'WebSocket',
- 'Sec-WebSocket-Version': '12'})
+ response = self.fetch(
+ "/echo",
+ headers={
+ "Connection": "Upgrade",
+ "Upgrade": "WebSocket",
+ "Sec-WebSocket-Version": "12",
+ },
+ )
self.assertEqual(response.code, 426)
@gen_test
def test_websocket_gen(self):
- ws = yield self.ws_connect('/echo')
- yield ws.write_message('hello')
+ ws = yield self.ws_connect("/echo")
+ yield ws.write_message("hello")
response = yield ws.read_message()
- self.assertEqual(response, 'hello')
+ self.assertEqual(response, "hello")
yield self.close(ws)
def test_websocket_callbacks(self):
websocket_connect(
- 'ws://127.0.0.1:%d/echo' % self.get_http_port(),
- callback=self.stop)
+ "ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
+ )
ws = self.wait().result()
- ws.write_message('hello')
+ ws.write_message("hello")
ws.read_message(self.stop)
response = self.wait().result()
- self.assertEqual(response, 'hello')
+ self.assertEqual(response, "hello")
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
- ws = yield self.ws_connect('/echo')
- ws.write_message(b'hello \xe9', binary=True)
+ ws = yield self.ws_connect("/echo")
+ ws.write_message(b"hello \xe9", binary=True)
response = yield ws.read_message()
- self.assertEqual(response, b'hello \xe9')
+ self.assertEqual(response, b"hello \xe9")
yield self.close(ws)
@gen_test
def test_unicode_message(self):
- ws = yield self.ws_connect('/echo')
- ws.write_message(u'hello \u00e9')
+ ws = yield self.ws_connect("/echo")
+ ws.write_message(u"hello \u00e9")
response = yield ws.read_message()
- self.assertEqual(response, u'hello \u00e9')
+ self.assertEqual(response, u"hello \u00e9")
yield self.close(ws)
@gen_test
def test_render_message(self):
- ws = yield self.ws_connect('/render')
- ws.write_message('hello')
+ ws = yield self.ws_connect("/render")
+ ws.write_message("hello")
response = yield ws.read_message()
- self.assertEqual(response, '<b>hello</b>')
+ self.assertEqual(response, "<b>hello</b>")
yield self.close(ws)
@gen_test
def test_error_in_on_message(self):
- ws = yield self.ws_connect('/error_in_on_message')
- ws.write_message('hello')
+ ws = yield self.ws_connect("/error_in_on_message")
+ ws.write_message("hello")
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
- yield self.ws_connect('/notfound')
+ yield self.ws_connect("/notfound")
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
- yield self.ws_connect('/non_ws')
+ yield self.ws_connect("/non_ws")
@gen_test
def test_websocket_network_fail(self):
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
- 'ws://127.0.0.1:%d/' % port,
- connect_timeout=3600)
+ "ws://127.0.0.1:%d/" % port, connect_timeout=3600
+ )
@gen_test
def test_websocket_close_buffered_data(self):
- ws = yield websocket_connect(
- 'ws://127.0.0.1:%d/echo' % self.get_http_port())
- ws.write_message('hello')
- ws.write_message('world')
+ ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
+ ws.write_message("hello")
+ ws.write_message("world")
# Close the underlying stream.
ws.stream.close()
yield self.close_future
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
- HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
- headers={'X-Test': 'hello'}))
+ HTTPRequest(
+ "ws://127.0.0.1:%d/header" % self.get_http_port(),
+ headers={"X-Test": "hello"},
+ )
+ )
response = yield ws.read_message()
- self.assertEqual(response, 'hello')
+ self.assertEqual(response, "hello")
yield self.close(ws)
@gen_test
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
ws = yield websocket_connect(
- HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
- headers={'X-Test-Hello': 'hello'}))
- self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
- self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
+ HTTPRequest(
+ "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
+ headers={"X-Test-Hello": "hello"},
+ )
+ )
+ self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
+ self.assertEqual(
+ ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
+ )
yield self.close(ws)
@gen_test
def test_server_close_reason(self):
- ws = yield self.ws_connect('/close_reason')
+ ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
@gen_test
def test_client_close_reason(self):
- ws = yield self.ws_connect('/echo')
- ws.close(1001, 'goodbye')
+ ws = yield self.ws_connect("/echo")
+ ws.close(1001, "goodbye")
code, reason = yield self.close_future
self.assertEqual(code, 1001)
- self.assertEqual(reason, 'goodbye')
+ self.assertEqual(reason, "goodbye")
@gen_test
def test_write_after_close(self):
- ws = yield self.ws_connect('/close_reason')
+ ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
self.assertIs(msg, None)
with self.assertRaises(WebSocketClosedError):
- ws.write_message('hello')
+ ws.write_message("hello")
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
- ws = yield self.ws_connect('/async_prepare')
- ws.write_message('hello')
+ ws = yield self.ws_connect("/async_prepare")
+ ws.write_message("hello")
res = yield ws.read_message()
- self.assertEqual(res, 'hello')
+ self.assertEqual(res, "hello")
@gen_test
def test_path_args(self):
- ws = yield self.ws_connect('/path_args/hello')
+ ws = yield self.ws_connect("/path_args/hello")
res = yield ws.read_message()
- self.assertEqual(res, 'hello')
+ self.assertEqual(res, "hello")
@gen_test
def test_coroutine(self):
- ws = yield self.ws_connect('/coroutine')
+ ws = yield self.ws_connect("/coroutine")
# Send both messages immediately, coroutine must process one at a time.
- yield ws.write_message('hello1')
- yield ws.write_message('hello2')
+ yield ws.write_message("hello1")
+ yield ws.write_message("hello2")
res = yield ws.read_message()
- self.assertEqual(res, 'hello1')
+ self.assertEqual(res, "hello1")
res = yield ws.read_message()
- self.assertEqual(res, 'hello2')
+ self.assertEqual(res, "hello2")
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
- url = 'ws://127.0.0.1:%d/echo' % port
- headers = {'Origin': 'http://127.0.0.1:%d' % port}
+ url = "ws://127.0.0.1:%d/echo" % port
+ headers = {"Origin": "http://127.0.0.1:%d" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
- ws.write_message('hello')
+ ws.write_message("hello")
response = yield ws.read_message()
- self.assertEqual(response, 'hello')
+ self.assertEqual(response, "hello")
yield self.close(ws)
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
- url = 'ws://127.0.0.1:%d/echo' % port
- headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
+ url = "ws://127.0.0.1:%d/echo" % port
+ headers = {"Origin": "http://127.0.0.1:%d/something" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
- ws.write_message('hello')
+ ws.write_message("hello")
response = yield ws.read_message()
- self.assertEqual(response, 'hello')
+ self.assertEqual(response, "hello")
yield self.close(ws)
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
- url = 'ws://127.0.0.1:%d/echo' % port
- headers = {'Origin': '127.0.0.1:%d' % port}
+ url = "ws://127.0.0.1:%d/echo" % port
+ headers = {"Origin": "127.0.0.1:%d" % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
def test_check_origin_invalid(self):
port = self.get_http_port()
- url = 'ws://127.0.0.1:%d/echo' % port
+ url = "ws://127.0.0.1:%d/echo" % port
# Host is 127.0.0.1, which should not be accessible from some other
# domain
- headers = {'Origin': 'http://somewhereelse.com'}
+ headers = {"Origin": "http://somewhereelse.com"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
- url = 'ws://localhost:%d/echo' % port
+ url = "ws://localhost:%d/echo" % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
- headers = {'Origin': 'http://subtenant.localhost'}
+ headers = {"Origin": "http://subtenant.localhost"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
@gen_test
def test_subprotocols(self):
- ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
- self.assertEqual(ws.selected_subprotocol, 'goodproto')
+ ws = yield self.ws_connect(
+ "/subprotocol", subprotocols=["badproto", "goodproto"]
+ )
+ self.assertEqual(ws.selected_subprotocol, "goodproto")
res = yield ws.read_message()
- self.assertEqual(res, 'subprotocol=goodproto')
+ self.assertEqual(res, "subprotocol=goodproto")
yield self.close(ws)
@gen_test
def test_subprotocols_not_offered(self):
- ws = yield self.ws_connect('/subprotocol')
+ ws = yield self.ws_connect("/subprotocol")
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
- self.assertEqual(res, 'subprotocol=None')
+ self.assertEqual(res, "subprotocol=None")
yield self.close(ws)
@gen_test
def test_open_coroutine(self):
self.message_sent = Event()
- ws = yield self.ws_connect('/open_coroutine')
- yield ws.write_message('hello')
+ ws = yield self.ws_connect("/open_coroutine")
+ yield ws.write_message("hello")
self.message_sent.set()
res = yield ws.read_message()
- self.assertEqual(res, 'ok')
+ self.assertEqual(res, "ok")
yield self.close(ws)
async def on_message(self, message):
if self.sleeping > 0:
- self.write_message('another coroutine is already sleeping')
+ self.write_message("another coroutine is already sleeping")
self.sleeping += 1
await gen.sleep(0.01)
self.sleeping -= 1
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
- return Application([
- ('/native', NativeCoroutineOnMessageHandler,
- dict(close_future=self.close_future))])
+ return Application(
+ [
+ (
+ "/native",
+ NativeCoroutineOnMessageHandler,
+ dict(close_future=self.close_future),
+ )
+ ]
+ )
@gen_test
def test_native_coroutine(self):
- ws = yield self.ws_connect('/native')
+ ws = yield self.ws_connect("/native")
# Send both messages immediately, coroutine must process one at a time.
- yield ws.write_message('hello1')
- yield ws.write_message('hello2')
+ yield ws.write_message("hello1")
+ yield ws.write_message("hello2")
res = yield ws.read_message()
- self.assertEqual(res, 'hello1')
+ self.assertEqual(res, "hello1")
res = yield ws.read_message()
- self.assertEqual(res, 'hello2')
+ self.assertEqual(res, "hello2")
class CompressionTestMixin(object):
- MESSAGE = 'Hello world. Testing 123 123'
+ MESSAGE = "Hello world. Testing 123 123"
def get_app(self):
self.close_future = Future() # type: Future[None]
def on_message(self, message):
self.write_message(str(len(message)))
- return Application([
- ('/echo', EchoHandler, dict(
- close_future=self.close_future,
- compression_options=self.get_server_compression_options())),
- ('/limited', LimitedHandler, dict(
- close_future=self.close_future,
- compression_options=self.get_server_compression_options())),
- ])
+ return Application(
+ [
+ (
+ "/echo",
+ EchoHandler,
+ dict(
+ close_future=self.close_future,
+ compression_options=self.get_server_compression_options(),
+ ),
+ ),
+ (
+ "/limited",
+ LimitedHandler,
+ dict(
+ close_future=self.close_future,
+ compression_options=self.get_server_compression_options(),
+ ),
+ ),
+ ]
+ )
def get_server_compression_options(self):
return None
@gen_test
def test_message_sizes(self):
ws = yield self.ws_connect(
- '/echo',
- compression_options=self.get_client_compression_options())
+ "/echo", compression_options=self.get_client_compression_options()
+ )
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
- self.verify_wire_bytes(ws.protocol._wire_bytes_in,
- ws.protocol._wire_bytes_out)
+ self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
yield self.close(ws)
@gen_test
def test_size_limit(self):
ws = yield self.ws_connect(
- '/limited',
- compression_options=self.get_client_compression_options())
+ "/limited", compression_options=self.get_client_compression_options()
+ )
# Small messages pass through.
- ws.write_message('a' * 128)
+ ws.write_message("a" * 128)
response = yield ws.read_message()
- self.assertEqual(response, '128')
+ self.assertEqual(response, "128")
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
- ws.write_message('a' * 2048)
+ ws.write_message("a" * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
yield self.close(ws)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
+
def verify_wire_bytes(self, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
- self.assertEqual(self.mask(b'abcd', b''), b'')
- self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
- self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
- self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
+ self.assertEqual(self.mask(b"abcd", b""), b"")
+ self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
+ self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
+ self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
- self.assertEqual(self.mask(b'\x00\x01\x02\x03',
- b'\xff\xfb\xfd\xfc\xfe\xfa'),
- b'\xff\xfa\xff\xff\xfe\xfb')
- self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
- b'\x00\x01\x02\x03\x04\x05'),
- b'\xff\xfa\xff\xff\xfb\xfe')
+ self.assertEqual(
+ self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
+ b"\xff\xfa\xff\xff\xfe\xfb",
+ )
+ self.assertEqual(
+ self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
+ b"\xff\xfa\xff\xff\xfb\xfe",
+ )
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
self.write_message("got pong")
self.close_future = Future() # type: Future[None]
- return Application([
- ('/', PingHandler, dict(close_future=self.close_future)),
- ], websocket_ping_interval=0.01)
+ return Application(
+ [("/", PingHandler, dict(close_future=self.close_future))],
+ websocket_ping_interval=0.01,
+ )
@gen_test
def test_server_ping(self):
- ws = yield self.ws_connect('/')
+ ws = yield self.ws_connect("/")
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got pong")
self.write_message("got ping")
self.close_future = Future() # type: Future[None]
- return Application([
- ('/', PingHandler, dict(close_future=self.close_future)),
- ])
+ return Application([("/", PingHandler, dict(close_future=self.close_future))])
@gen_test
def test_client_ping(self):
- ws = yield self.ws_connect('/', ping_interval=0.01)
+ ws = yield self.ws_connect("/", ping_interval=0.01)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
self.write_message(data, binary=isinstance(data, bytes))
self.close_future = Future() # type: Future[None]
- return Application([
- ('/', PingHandler, dict(close_future=self.close_future)),
- ])
+ return Application([("/", PingHandler, dict(close_future=self.close_future))])
@gen_test
def test_manual_ping(self):
- ws = yield self.ws_connect('/')
+ ws = yield self.ws_connect("/")
- self.assertRaises(ValueError, ws.ping, 'a' * 126)
+ self.assertRaises(ValueError, ws.ping, "a" * 126)
- ws.ping('hello')
+ ws.ping("hello")
resp = yield ws.read_message()
# on_ping always sees bytes.
- self.assertEqual(resp, b'hello')
+ self.assertEqual(resp, b"hello")
- ws.ping(b'binary hello')
+ ws.ping(b"binary hello")
resp = yield ws.read_message()
- self.assertEqual(resp, b'binary hello')
+ self.assertEqual(resp, b"binary hello")
yield self.close(ws)
class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
- return Application([
- ('/', EchoHandler, dict(close_future=self.close_future)),
- ], websocket_max_message_size=1024)
+ return Application(
+ [("/", EchoHandler, dict(close_future=self.close_future))],
+ websocket_max_message_size=1024,
+ )
@gen_test
def test_large_message(self):
- ws = yield self.ws_connect('/')
+ ws = yield self.ws_connect("/")
# Write a message that is allowed.
- msg = 'a' * 1024
+ msg = "a" * 1024
ws.write_message(msg)
resp = yield ws.read_message()
self.assertEqual(resp, msg)
# Write a message that is too large.
- ws.write_message(msg + 'b')
+ ws.write_message(msg + "b")
resp = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(resp, None)
from tornado.platform.auto import set_close_exec
-skipIfNonWindows = unittest.skipIf(os.name != 'nt', 'non-windows platform')
+skipIfNonWindows = unittest.skipIf(os.name != "nt", "non-windows platform")
@skipIfNonWindows
from types import TracebackType
if typing.TYPE_CHECKING:
- _ExcInfoTuple = Tuple[Optional[Type[BaseException]], Optional[BaseException],
- Optional[TracebackType]]
+ _ExcInfoTuple = Tuple[
+ Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]
+ ]
_NON_OWNED_IOLOOPS = AsyncIOMainLoop
-def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]:
+def bind_unused_port(reuse_port: bool = False) -> Tuple[socket.socket, int]:
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
Always binds to ``127.0.0.1`` without resolving the name
``localhost``.
"""
- sock = netutil.bind_sockets(0, '127.0.0.1', family=socket.AF_INET,
- reuse_port=reuse_port)[0]
+ sock = netutil.bind_sockets(
+ 0, "127.0.0.1", family=socket.AF_INET, reuse_port=reuse_port
+ )[0]
port = sock.getsockname()[1]
return sock, port
.. versionadded:: 3.1
"""
- env = os.environ.get('ASYNC_TEST_TIMEOUT')
+ env = os.environ.get("ASYNC_TEST_TIMEOUT")
if env is not None:
try:
return float(env)
necessarily errors, but we alert anyway since there is no good
reason to return a value from a test).
"""
+
def __init__(self, orig_method: Callable) -> None:
self.orig_method = orig_method
def __call__(self, *args: Any, **kwargs: Any) -> None:
result = self.orig_method(*args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
- raise TypeError("Generator and coroutine test methods should be"
- " decorated with tornado.testing.gen_test")
+ raise TypeError(
+ "Generator and coroutine test methods should be"
+ " decorated with tornado.testing.gen_test"
+ )
elif result is not None:
- raise ValueError("Return value from test method ignored: %r" %
- result)
+ raise ValueError("Return value from test method ignored: %r" % result)
def __getattr__(self, name: str) -> Any:
"""Proxy all unknown attributes to the original method.
# Test contents of response
self.assertIn("FriendFeed", response.body)
"""
- def __init__(self, methodName: str='runTest') -> None:
+
+ def __init__(self, methodName: str = "runTest") -> None:
super(AsyncTestCase, self).__init__(methodName)
self.__stopped = False
self.__running = False
"""
return IOLoop()
- def _handle_exception(self, typ: Type[Exception], value: Exception, tb: TracebackType) -> bool:
+ def _handle_exception(
+ self, typ: Type[Exception], value: Exception, tb: TracebackType
+ ) -> bool:
if self.__failure is None:
self.__failure = (typ, value, tb)
else:
- app_log.error("multiple unhandled exceptions in test",
- exc_info=(typ, value, tb))
+ app_log.error(
+ "multiple unhandled exceptions in test", exc_info=(typ, value, tb)
+ )
self.stop()
return True
self.__failure = None
raise_exc_info(failure)
- def run(self, result: unittest.TestResult=None) -> unittest.TestCase:
+ def run(self, result: unittest.TestResult = None) -> unittest.TestCase:
ret = super(AsyncTestCase, self).run(result)
# As a last resort, if an exception escaped super.run() and wasn't
# re-raised in tearDown, raise it here. This will cause the
self.__rethrow()
return ret
- def stop(self, _arg: Any=None, **kwargs: Any) -> None:
+ def stop(self, _arg: Any = None, **kwargs: Any) -> None:
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
to return.
self.__running = False
self.__stopped = True
- def wait(self, condition: Callable[..., bool]=None, timeout: float=None) -> None:
+ def wait(
+ self, condition: Callable[..., bool] = None, timeout: float = None
+ ) -> None:
"""Runs the `.IOLoop` until stop is called or timeout has passed.
In the event of a timeout, an exception will be thrown. The
if not self.__stopped:
if timeout:
+
def timeout_func() -> None:
try:
raise self.failureException(
- 'Async operation timed out after %s seconds' %
- timeout)
+ "Async operation timed out after %s seconds" % timeout
+ )
except Exception:
self.__failure = sys.exc_info()
self.stop()
- self.__timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
- timeout_func)
+
+ self.__timeout = self.io_loop.add_timeout(
+ self.io_loop.time() + timeout, timeout_func
+ )
while True:
self.__running = True
self.io_loop.start()
- if (self.__failure is not None or
- condition is None or condition()):
+ if self.__failure is not None or condition is None or condition():
break
if self.__timeout is not None:
self.io_loop.remove_timeout(self.__timeout)
to do other asynchronous operations in tests, you'll probably need to use
``stop()`` and ``wait()`` yourself.
"""
+
def setUp(self) -> None:
super(AsyncHTTPTestCase, self).setUp()
sock, port = bind_unused_port()
"""
raise NotImplementedError()
- def fetch(self, path: str, raise_error: bool=False, **kwargs: Any) -> HTTPResponse:
+ def fetch(
+ self, path: str, raise_error: bool = False, **kwargs: Any
+ ) -> HTTPResponse:
"""Convenience method to synchronously fetch a URL.
The given path will be appended to the local server's host and
response codes.
"""
- if path.lower().startswith(('http://', 'https://')):
+ if path.lower().startswith(("http://", "https://")):
url = path
else:
url = self.get_url(path)
return self.io_loop.run_sync(
lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
- timeout=get_async_test_timeout())
+ timeout=get_async_test_timeout(),
+ )
def get_httpserver_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to return additional
return self.__port
def get_protocol(self) -> str:
- return 'http'
+ return "http"
def get_url(self, path: str) -> str:
"""Returns an absolute url for the given path on the test server."""
- return '%s://127.0.0.1:%s%s' % (self.get_protocol(),
- self.get_http_port(), path)
+ return "%s://127.0.0.1:%s%s" % (self.get_protocol(), self.get_http_port(), path)
def tearDown(self) -> None:
self.http_server.stop()
- self.io_loop.run_sync(self.http_server.close_all_connections,
- timeout=get_async_test_timeout())
+ self.io_loop.run_sync(
+ self.http_server.close_all_connections, timeout=get_async_test_timeout()
+ )
self.http_client.close()
super(AsyncHTTPTestCase, self).tearDown()
Interface is generally the same as `AsyncHTTPTestCase`.
"""
+
def get_http_client(self) -> AsyncHTTPClient:
- return AsyncHTTPClient(force_instance=True,
- defaults=dict(validate_cert=False))
+ return AsyncHTTPClient(force_instance=True, defaults=dict(validate_cert=False))
def get_httpserver_options(self) -> Dict[str, Any]:
return dict(ssl_options=self.get_ssl_options())
# -out tornado/test/test.crt -nodes -days 3650 -x509
module_dir = os.path.dirname(__file__)
return dict(
- certfile=os.path.join(module_dir, 'test', 'test.crt'),
- keyfile=os.path.join(module_dir, 'test', 'test.key'))
+ certfile=os.path.join(module_dir, "test", "test.crt"),
+ keyfile=os.path.join(module_dir, "test", "test.key"),
+ )
def get_protocol(self) -> str:
- return 'https'
+ return "https"
@typing.overload
-def gen_test(*, timeout: float=None) -> Callable[[Callable[..., Union[Generator, Coroutine]]],
- Callable[..., None]]:
+def gen_test(
+ *, timeout: float = None
+) -> Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]:
pass
def gen_test( # noqa: F811
- func: Callable[..., Union[Generator, Coroutine]]=None, timeout: float=None,
-) -> Union[Callable[..., None],
- Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]]:
+ func: Callable[..., Union[Generator, Coroutine]] = None, timeout: float = None
+) -> Union[
+ Callable[..., None],
+ Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]],
+]:
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
# type: (AsyncTestCase, *Any, **Any) -> None
try:
return self.io_loop.run_sync(
- functools.partial(coro, self, *args, **kwargs),
- timeout=timeout)
+ functools.partial(coro, self, *args, **kwargs), timeout=timeout
+ )
except TimeoutError as e:
# run_sync raises an error with an unhelpful traceback.
# If the underlying generator is still running, we can throw the
# point where the test is stopped. The only reason the generator
# would not be running would be if it were cancelled, which means
# a native coroutine, so we can rely on the cr_running attribute.
- if (self._test_generator is not None and
- getattr(self._test_generator, 'cr_running', True)):
+ if self._test_generator is not None and getattr(
+ self._test_generator, "cr_running", True
+ ):
self._test_generator.throw(type(e), e)
# In case the test contains an overly broad except
# clause, we may get back here.
# Coroutine was stopped or didn't raise a useful stack trace,
# so re-raise the original exception which is better than nothing.
raise
+
return post_coroutine
if func is not None:
.. versionchanged:: 4.3
Added the ``logged_stack`` attribute.
"""
- def __init__(self, logger: Union[logging.Logger, basestring_type], regex: str,
- required: bool=True) -> None:
+
+ def __init__(
+ self,
+ logger: Union[logging.Logger, basestring_type],
+ regex: str,
+ required: bool = True,
+ ) -> None:
"""Constructs an ExpectLog context manager.
:param logger: Logger object (or name of logger) to watch. Pass
return False
return True
- def __enter__(self) -> 'ExpectLog':
+ def __enter__(self) -> "ExpectLog":
self.logger.addFilter(self)
return self
- def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException],
- tb: Optional[TracebackType]) -> None:
+ def __exit__(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
self.logger.removeFilter(self)
if not typ and self.required and not self.matched:
raise Exception("did not get expected log message")
"""
from tornado.options import define, options, parse_command_line
- define('exception_on_interrupt', type=bool, default=True,
- help=("If true (default), ctrl-c raises a KeyboardInterrupt "
- "exception. This prints a stack trace but cannot interrupt "
- "certain operations. If false, the process is more reliably "
- "killed, but does not print a stack trace."))
+ define(
+ "exception_on_interrupt",
+ type=bool,
+ default=True,
+ help=(
+ "If true (default), ctrl-c raises a KeyboardInterrupt "
+ "exception. This prints a stack trace but cannot interrupt "
+ "certain operations. If false, the process is more reliably "
+ "killed, but does not print a stack trace."
+ ),
+ )
# support the same options as unittest's command-line interface
- define('verbose', type=bool)
- define('quiet', type=bool)
- define('failfast', type=bool)
- define('catch', type=bool)
- define('buffer', type=bool)
+ define("verbose", type=bool)
+ define("quiet", type=bool)
+ define("failfast", type=bool)
+ define("catch", type=bool)
+ define("buffer", type=bool)
argv = [sys.argv[0]] + parse_command_line(sys.argv)
signal.signal(signal.SIGINT, signal.SIG_DFL)
if options.verbose is not None:
- kwargs['verbosity'] = 2
+ kwargs["verbosity"] = 2
if options.quiet is not None:
- kwargs['verbosity'] = 0
+ kwargs["verbosity"] = 0
if options.failfast is not None:
- kwargs['failfast'] = True
+ kwargs["failfast"] = True
if options.catch is not None:
- kwargs['catchbreak'] = True
+ kwargs["catchbreak"] = True
if options.buffer is not None:
- kwargs['buffer'] = True
+ kwargs["buffer"] = True
- if __name__ == '__main__' and len(argv) == 1:
+ if __name__ == "__main__" and len(argv) == 1:
print("No tests specified", file=sys.stderr)
sys.exit(1)
# In order to be able to run tests by their fully-qualified name
unittest.main(defaultTest="all", argv=argv, **kwargs)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
import zlib
from typing import (
- Any, Optional, Dict, Mapping, List, Tuple, Match, Callable, Type, Sequence
+ Any,
+ Optional,
+ Dict,
+ Mapping,
+ List,
+ Tuple,
+ Match,
+ Callable,
+ Type,
+ Sequence,
)
if typing.TYPE_CHECKING:
class ObjectDict(Dict[str, Any]):
"""Makes a dictionary behave like an object, with attribute-style access.
"""
+
def __getattr__(self, name: str) -> Any:
try:
return self[name]
The interface is like that of `zlib.decompressobj` (without some of the
optional arguments, but it understands gzip headers and checksums.
"""
+
def __init__(self) -> None:
# Magic parameter makes zlib module understand gzip header
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
- def decompress(self, value: bytes, max_length: int=0) -> bytes:
+ def decompress(self, value: bytes, max_length: int = 0) -> bytes:
"""Decompress a chunk, returning newly-available data.
Some data may be buffered for later processing; `flush` must
...
ImportError: No module named missing_module
"""
- if name.count('.') == 0:
+ if name.count(".") == 0:
return __import__(name)
- parts = name.split('.')
- obj = __import__('.'.join(parts[:-1]), fromlist=[parts[-1]])
+ parts = name.split(".")
+ obj = __import__(".".join(parts[:-1]), fromlist=[parts[-1]])
try:
return getattr(obj, parts[-1])
except AttributeError:
raise ImportError("No module named %s" % parts[-1])
-def exec_in(code: Any, glob: Dict[str, Any], loc: Mapping[str, Any]=None) -> None:
+def exec_in(code: Any, glob: Dict[str, Any], loc: Mapping[str, Any] = None) -> None:
if isinstance(code, str):
# exec(string) inherits the caller's future imports; compile
# the string first to prevent that.
- code = compile(code, '<string>', 'exec', dont_inherit=True)
+ code = compile(code, "<string>", "exec", dont_inherit=True)
exec(code, glob, loc)
def raise_exc_info(
- exc_info, # type: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]]
+ exc_info, # type: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]]
):
# type: (...) -> typing.NoReturn
#
errno.
"""
- if hasattr(e, 'errno'):
+ if hasattr(e, "errno"):
return e.errno # type: ignore
elif e.args:
return e.args[0]
return None
-_alphanum = frozenset(
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
+_alphanum = frozenset("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
def _re_unescape_replacement(match: Match[str]) -> str:
return group
-_re_unescape_pattern = re.compile(r'\\(.)', re.DOTALL)
+_re_unescape_pattern = re.compile(r"\\(.)", re.DOTALL)
def re_unescape(s: str) -> str:
multiple levels of a class hierarchy.
"""
+
# Type annotations on this class are mostly done with comments
# because they need to refer to Configurable, which isn't defined
# until after the class definition block. These can use regular
# Manually mangle the private name to see whether this base
# has been configured (and not another base higher in the
# hierarchy).
- if base.__dict__.get('_Configurable__impl_class') is None:
+ if base.__dict__.get("_Configurable__impl_class") is None:
base.__impl_class = cls.configurable_default()
if base.__impl_class is not None:
return base.__impl_class
whether it is passed by position or keyword. For use in decorators
and similar wrappers.
"""
+
def __init__(self, func: Callable, name: str) -> None:
self.name = name
try:
try:
return getfullargspec(func).args
except TypeError:
- if hasattr(func, 'func_code'):
+ if hasattr(func, "func_code"):
# Cython-generated code has all the attributes needed
# by inspect.getfullargspec, but the inspect module only
# works with ordinary functions. Inline the portion of
# functions the @cython.binding(True) decorator must
# be used (for methods it works out of the box).
code = func.func_code # type: ignore
- return code.co_varnames[:code.co_argcount]
+ return code.co_varnames[: code.co_argcount]
raise
- def get_old_value(self, args: Sequence[Any], kwargs: Dict[str, Any], default: Any=None) -> Any:
+ def get_old_value(
+ self, args: Sequence[Any], kwargs: Dict[str, Any], default: Any = None
+ ) -> Any:
"""Returns the old value of the named argument without replacing it.
Returns ``default`` if the argument is not present.
else:
return kwargs.get(self.name, default)
- def replace(self, new_value: Any, args: Sequence[Any],
- kwargs: Dict[str, Any]) -> Tuple[Any, Sequence[Any], Dict[str, Any]]:
+ def replace(
+ self, new_value: Any, args: Sequence[Any], kwargs: Dict[str, Any]
+ ) -> Tuple[Any, Sequence[Any], Dict[str, Any]]:
"""Replace the named argument in ``args, kwargs`` with ``new_value``.
Returns ``(old_value, args, kwargs)``. The returned ``args`` and
return unmasked_arr.tobytes()
-if (os.environ.get('TORNADO_NO_EXTENSION') or
- os.environ.get('TORNADO_EXTENSION') == '0'):
+if os.environ.get("TORNADO_NO_EXTENSION") or os.environ.get("TORNADO_EXTENSION") == "0":
# These environment variables exist to make it easier to do performance
# comparisons; they are not guaranteed to remain supported in the future.
_websocket_mask = _websocket_mask_python
try:
from tornado.speedups import websocket_mask as _websocket_mask
except ImportError:
- if os.environ.get('TORNADO_EXTENSION') == '1':
+ if os.environ.get("TORNADO_EXTENSION") == "1":
raise
_websocket_mask = _websocket_mask_python
def doctests():
# type: () -> unittest.TestSuite
import doctest
+
return doctest.DocTestSuite()
from tornado.log import access_log, app_log, gen_log
from tornado import template
from tornado.escape import utf8, _unicode
-from tornado.routing import (AnyMatches, DefaultHostMatches, HostMatches,
- ReversibleRouter, Rule, ReversibleRuleRouter,
- URLSpec, _RuleList)
+from tornado.routing import (
+ AnyMatches,
+ DefaultHostMatches,
+ HostMatches,
+ ReversibleRouter,
+ Rule,
+ ReversibleRuleRouter,
+ URLSpec,
+ _RuleList,
+)
from tornado.util import ObjectDict, unicode_type, _websocket_mask
url = URLSpec
-from typing import (Dict, Any, Union, Optional, Awaitable, Tuple, List, Callable, Iterable,
- Generator, Type, cast, overload)
+from typing import (
+ Dict,
+ Any,
+ Union,
+ Optional,
+ Awaitable,
+ Tuple,
+ List,
+ Callable,
+ Iterable,
+ Generator,
+ Type,
+ cast,
+ overload,
+)
from types import TracebackType
import typing
+
if typing.TYPE_CHECKING:
from typing import Set # noqa: F401
# The following types are accepted by RequestHandler.set_header
# and related methods.
-_HeaderTypes = Union[bytes, unicode_type,
- int, numbers.Integral, datetime.datetime]
+_HeaderTypes = Union[bytes, unicode_type, int, numbers.Integral, datetime.datetime]
-_CookieSecretTypes = Union[str, bytes,
- Dict[int, str],
- Dict[int, bytes]]
+_CookieSecretTypes = Union[str, bytes, Dict[int, str], Dict[int, bytes]]
MIN_SUPPORTED_SIGNED_VALUE_VERSION = 1
Subclasses must define at least one of the methods defined in the
"Entry points" section below.
"""
- SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT",
- "OPTIONS")
+
+ SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT", "OPTIONS")
_template_loaders = {} # type: Dict[str, template.BaseLoader]
_template_loader_lock = threading.Lock()
path_args = None # type: List[str]
path_kwargs = None # type: Dict[str, str]
- def __init__(self, application: 'Application', request: httputil.HTTPServerRequest,
- **kwargs: Any) -> None:
+ def __init__(
+ self,
+ application: "Application",
+ request: httputil.HTTPServerRequest,
+ **kwargs: Any
+ ) -> None:
super(RequestHandler, self).__init__()
self.application = application
self._finished = False
self._auto_finish = True
self._prepared_future = None
- self.ui = ObjectDict((n, self._ui_method(m)) for n, m in
- application.ui_methods.items())
+ self.ui = ObjectDict(
+ (n, self._ui_method(m)) for n, m in application.ui_methods.items()
+ )
# UIModules are available as both `modules` and `_tt_modules` in the
# template namespace. Historically only `modules` was available
# but could be clobbered by user additions to the namespace.
# The template {% module %} directive looks in `_tt_modules` to avoid
# possible conflicts.
- self.ui["_tt_modules"] = _UIModuleNamespace(self,
- application.ui_modules)
+ self.ui["_tt_modules"] = _UIModuleNamespace(self, application.ui_modules)
self.ui["modules"] = self.ui["_tt_modules"]
self.clear()
assert self.request.connection is not None
# TODO: need to add set_close_callback to HTTPConnection interface
- self.request.connection.set_close_callback(self.on_connection_close) # type: ignore
+ self.request.connection.set_close_callback( # type: ignore
+ self.on_connection_close
+ )
self.initialize(**kwargs) # type: ignore
def initialize(self) -> None:
def clear(self) -> None:
"""Resets all headers and content for this response."""
- self._headers = httputil.HTTPHeaders({
- "Server": "TornadoServer/%s" % tornado.version,
- "Content-Type": "text/html; charset=UTF-8",
- "Date": httputil.format_timestamp(time.time()),
- })
+ self._headers = httputil.HTTPHeaders(
+ {
+ "Server": "TornadoServer/%s" % tornado.version,
+ "Content-Type": "text/html; charset=UTF-8",
+ "Date": httputil.format_timestamp(time.time()),
+ }
+ )
self.set_default_headers()
self._write_buffer = [] # type: List[bytes]
self._status_code = 200
"""
pass
- def set_status(self, status_code: int, reason: str=None) -> None:
+ def set_status(self, status_code: int, reason: str = None) -> None:
"""Sets the status code for our response.
:arg int status_code: Response status code.
elif isinstance(value, bytes): # py3
# Non-ascii characters in headers are not well supported,
# but if you pass bytes, use latin1 so they pass through as-is.
- retval = value.decode('latin1')
+ retval = value.decode("latin1")
elif isinstance(value, unicode_type): # py2
# TODO: This is inconsistent with the use of latin1 above,
# but it's been that way for a long time. Should it change?
return retval
@overload
- def get_argument(self, name: str, default: str, strip: bool=True) -> str:
+ def get_argument(self, name: str, default: str, strip: bool = True) -> str:
pass
@overload # noqa: F811
- def get_argument(self, name: str, default: _ArgDefaultMarker=_ARG_DEFAULT,
- strip: bool=True) -> str:
+ def get_argument(
+ self, name: str, default: _ArgDefaultMarker = _ARG_DEFAULT, strip: bool = True
+ ) -> str:
pass
@overload # noqa: F811
- def get_argument(self, name: str, default: None, strip: bool=True) -> Optional[str]:
+ def get_argument(
+ self, name: str, default: None, strip: bool = True
+ ) -> Optional[str]:
pass
- def get_argument(self, name: str, # noqa: F811
- default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT,
- strip: bool=True) -> Optional[str]:
+ def get_argument( # noqa: F811
+ self,
+ name: str,
+ default: Union[None, str, _ArgDefaultMarker] = _ARG_DEFAULT,
+ strip: bool = True,
+ ) -> Optional[str]:
"""Returns the value of the argument with the given name.
If default is not provided, the argument is considered to be
"""
return self._get_argument(name, default, self.request.arguments, strip)
- def get_arguments(self, name: str, strip: bool=True) -> List[str]:
+ def get_arguments(self, name: str, strip: bool = True) -> List[str]:
"""Returns a list of the arguments with the given name.
If the argument is not present, returns an empty list.
return self._get_arguments(name, self.request.arguments, strip)
- def get_body_argument(self, name: str,
- default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT,
- strip: bool=True) -> Optional[str]:
+ def get_body_argument(
+ self,
+ name: str,
+ default: Union[None, str, _ArgDefaultMarker] = _ARG_DEFAULT,
+ strip: bool = True,
+ ) -> Optional[str]:
"""Returns the value of the argument with the given name
from the request body.
.. versionadded:: 3.2
"""
- return self._get_argument(name, default, self.request.body_arguments,
- strip)
+ return self._get_argument(name, default, self.request.body_arguments, strip)
- def get_body_arguments(self, name: str, strip: bool=True) -> List[str]:
+ def get_body_arguments(self, name: str, strip: bool = True) -> List[str]:
"""Returns a list of the body arguments with the given name.
If the argument is not present, returns an empty list.
"""
return self._get_arguments(name, self.request.body_arguments, strip)
- def get_query_argument(self, name: str,
- default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT,
- strip: bool=True) -> Optional[str]:
+ def get_query_argument(
+ self,
+ name: str,
+ default: Union[None, str, _ArgDefaultMarker] = _ARG_DEFAULT,
+ strip: bool = True,
+ ) -> Optional[str]:
"""Returns the value of the argument with the given name
from the request query string.
.. versionadded:: 3.2
"""
- return self._get_argument(name, default,
- self.request.query_arguments, strip)
+ return self._get_argument(name, default, self.request.query_arguments, strip)
- def get_query_arguments(self, name: str, strip: bool=True) -> List[str]:
+ def get_query_arguments(self, name: str, strip: bool = True) -> List[str]:
"""Returns a list of the query arguments with the given name.
If the argument is not present, returns an empty list.
"""
return self._get_arguments(name, self.request.query_arguments, strip)
- def _get_argument(self, name: str, default: Union[None, str, _ArgDefaultMarker],
- source: Dict[str, List[bytes]], strip: bool=True) -> Optional[str]:
+ def _get_argument(
+ self,
+ name: str,
+ default: Union[None, str, _ArgDefaultMarker],
+ source: Dict[str, List[bytes]],
+ strip: bool = True,
+ ) -> Optional[str]:
args = self._get_arguments(name, source, strip=strip)
if not args:
if isinstance(default, _ArgDefaultMarker):
return default
return args[-1]
- def _get_arguments(self, name: str, source: Dict[str, List[bytes]],
- strip: bool=True) -> List[str]:
+ def _get_arguments(
+ self, name: str, source: Dict[str, List[bytes]], strip: bool = True
+ ) -> List[str]:
values = []
for v in source.get(name, []):
s = self.decode_argument(v, name=name)
values.append(s)
return values
- def decode_argument(self, value: bytes, name: str=None) -> str:
+ def decode_argument(self, value: bytes, name: str = None) -> str:
"""Decodes an argument from the request.
The argument has been percent-decoded and is now a byte string.
try:
return _unicode(value)
except UnicodeDecodeError:
- raise HTTPError(400, "Invalid unicode in %s: %r" %
- (name or "url", value[:40]))
+ raise HTTPError(
+ 400, "Invalid unicode in %s: %r" % (name or "url", value[:40])
+ )
@property
def cookies(self) -> Dict[str, http.cookies.Morsel]:
`self.request.cookies <.httputil.HTTPServerRequest.cookies>`."""
return self.request.cookies
- def get_cookie(self, name: str, default: str=None) -> Optional[str]:
+ def get_cookie(self, name: str, default: str = None) -> Optional[str]:
"""Returns the value of the request cookie with the given name.
If the named cookie is not present, returns ``default``.
return self.request.cookies[name].value
return default
- def set_cookie(self, name: str, value: Union[str, bytes], domain: str=None,
- expires: Union[float, Tuple, datetime.datetime]=None,
- path: str="/",
- expires_days: int=None, **kwargs: Any) -> None:
+ def set_cookie(
+ self,
+ name: str,
+ value: Union[str, bytes],
+ domain: str = None,
+ expires: Union[float, Tuple, datetime.datetime] = None,
+ path: str = "/",
+ expires_days: int = None,
+ **kwargs: Any
+ ) -> None:
"""Sets an outgoing cookie name/value with the given options.
Newly-set cookies are not immediately visible via `get_cookie`;
if domain:
morsel["domain"] = domain
if expires_days is not None and not expires:
- expires = datetime.datetime.utcnow() + datetime.timedelta(
- days=expires_days)
+ expires = datetime.datetime.utcnow() + datetime.timedelta(days=expires_days)
if expires:
morsel["expires"] = httputil.format_timestamp(expires)
if path:
morsel["path"] = path
for k, v in kwargs.items():
- if k == 'max_age':
- k = 'max-age'
+ if k == "max_age":
+ k = "max-age"
# skip falsy values for httponly and secure flags because
# SimpleCookie sets them regardless
- if k in ['httponly', 'secure'] and not v:
+ if k in ["httponly", "secure"] and not v:
continue
morsel[k] = v
- def clear_cookie(self, name: str, path: str="/", domain: str=None) -> None:
+ def clear_cookie(self, name: str, path: str = "/", domain: str = None) -> None:
"""Deletes the cookie with the given name.
Due to limitations of the cookie protocol, you must pass the same
seen until the following request.
"""
expires = datetime.datetime.utcnow() - datetime.timedelta(days=365)
- self.set_cookie(name, value="", path=path, expires=expires,
- domain=domain)
+ self.set_cookie(name, value="", path=path, expires=expires, domain=domain)
- def clear_all_cookies(self, path: str="/", domain: str=None) -> None:
+ def clear_all_cookies(self, path: str = "/", domain: str = None) -> None:
"""Deletes all the cookies the user sent with this request.
See `clear_cookie` for more information on the path and domain
for name in self.request.cookies:
self.clear_cookie(name, path=path, domain=domain)
- def set_secure_cookie(self, name: str, value: Union[str, bytes], expires_days: int=30,
- version: int=None, **kwargs: Any) -> None:
+ def set_secure_cookie(
+ self,
+ name: str,
+ value: Union[str, bytes],
+ expires_days: int = 30,
+ version: int = None,
+ **kwargs: Any
+ ) -> None:
"""Signs and timestamps a cookie so it cannot be forged.
You must specify the ``cookie_secret`` setting in your Application
Added the ``version`` argument. Introduced cookie version 2
and made it the default.
"""
- self.set_cookie(name, self.create_signed_value(name, value,
- version=version),
- expires_days=expires_days, **kwargs)
+ self.set_cookie(
+ name,
+ self.create_signed_value(name, value, version=version),
+ expires_days=expires_days,
+ **kwargs
+ )
- def create_signed_value(self, name: str, value: Union[str, bytes], version: int=None) -> bytes:
+ def create_signed_value(
+ self, name: str, value: Union[str, bytes], version: int = None
+ ) -> bytes:
"""Signs and timestamps a string so it cannot be forged.
Normally used via set_secure_cookie, but provided as a separate
raise Exception("key_version setting must be used for secret_key dicts")
key_version = self.application.settings["key_version"]
- return create_signed_value(secret, name, value, version=version,
- key_version=key_version)
+ return create_signed_value(
+ secret, name, value, version=version, key_version=key_version
+ )
- def get_secure_cookie(self, name: str, value: str=None, max_age_days: int=31,
- min_version: int=None) -> Optional[bytes]:
+ def get_secure_cookie(
+ self,
+ name: str,
+ value: str = None,
+ max_age_days: int = 31,
+ min_version: int = None,
+ ) -> Optional[bytes]:
"""Returns the given signed cookie if it validates, or None.
The decoded cookie value is returned as a byte string (unlike
self.require_setting("cookie_secret", "secure cookies")
if value is None:
value = self.get_cookie(name)
- return decode_signed_value(self.application.settings["cookie_secret"],
- name, value, max_age_days=max_age_days,
- min_version=min_version)
+ return decode_signed_value(
+ self.application.settings["cookie_secret"],
+ name,
+ value,
+ max_age_days=max_age_days,
+ min_version=min_version,
+ )
- def get_secure_cookie_key_version(self, name: str, value: str=None) -> Optional[int]:
+ def get_secure_cookie_key_version(
+ self, name: str, value: str = None
+ ) -> Optional[int]:
"""Returns the signing key version of the secure cookie.
The version is returned as int.
return None
return get_signature_key_version(value)
- def redirect(self, url: str, permanent: bool=False, status: int=None) -> None:
+ def redirect(self, url: str, permanent: bool = False, status: int = None) -> None:
"""Sends a redirect to the given (optionally relative) URL.
If the ``status`` argument is specified, that value is used as the
if not isinstance(chunk, (bytes, unicode_type, dict)):
message = "write() only accepts bytes, unicode, and dict objects"
if isinstance(chunk, list):
- message += ". Lists not accepted for security reasons; see " + \
- "http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write"
+ message += (
+ ". Lists not accepted for security reasons; see "
+ + "http://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write" # noqa: E501
+ )
raise TypeError(message)
if isinstance(chunk, dict):
chunk = escape.json_encode(chunk)
chunk = utf8(chunk)
self._write_buffer.append(chunk)
- def render(self, template_name: str, **kwargs: Any) -> 'Future[None]':
+ def render(self, template_name: str, **kwargs: Any) -> "Future[None]":
"""Renders the template with the given arguments as the response.
``render()`` calls ``finish()``, so no other output methods can be called
if js_files:
# Maintain order of JavaScript files given by modules
js = self.render_linked_js(js_files)
- sloc = html.rindex(b'</body>')
- html = html[:sloc] + utf8(js) + b'\n' + html[sloc:]
+ sloc = html.rindex(b"</body>")
+ html = html[:sloc] + utf8(js) + b"\n" + html[sloc:]
if js_embed:
js_bytes = self.render_embed_js(js_embed)
- sloc = html.rindex(b'</body>')
- html = html[:sloc] + js_bytes + b'\n' + html[sloc:]
+ sloc = html.rindex(b"</body>")
+ html = html[:sloc] + js_bytes + b"\n" + html[sloc:]
if css_files:
css = self.render_linked_css(css_files)
- hloc = html.index(b'</head>')
- html = html[:hloc] + utf8(css) + b'\n' + html[hloc:]
+ hloc = html.index(b"</head>")
+ html = html[:hloc] + utf8(css) + b"\n" + html[hloc:]
if css_embed:
css_bytes = self.render_embed_css(css_embed)
- hloc = html.index(b'</head>')
- html = html[:hloc] + css_bytes + b'\n' + html[hloc:]
+ hloc = html.index(b"</head>")
+ html = html[:hloc] + css_bytes + b"\n" + html[hloc:]
if html_heads:
- hloc = html.index(b'</head>')
- html = html[:hloc] + b''.join(html_heads) + b'\n' + html[hloc:]
+ hloc = html.index(b"</head>")
+ html = html[:hloc] + b"".join(html_heads) + b"\n" + html[hloc:]
if html_bodies:
- hloc = html.index(b'</body>')
- html = html[:hloc] + b''.join(html_bodies) + b'\n' + html[hloc:]
+ hloc = html.index(b"</body>")
+ html = html[:hloc] + b"".join(html_bodies) + b"\n" + html[hloc:]
return self.finish(html)
def render_linked_js(self, js_files: Iterable[str]) -> str:
paths.append(path)
unique_paths.add(path)
- return ''.join('<script src="' + escape.xhtml_escape(p) +
- '" type="text/javascript"></script>'
- for p in paths)
+ return "".join(
+ '<script src="'
+ + escape.xhtml_escape(p)
+ + '" type="text/javascript"></script>'
+ for p in paths
+ )
def render_embed_js(self, js_embed: Iterable[bytes]) -> bytes:
"""Default method used to render the final embedded js for the
Override this method in a sub-classed controller to change the output.
"""
- return b'<script type="text/javascript">\n//<![CDATA[\n' + \
- b'\n'.join(js_embed) + b'\n//]]>\n</script>'
+ return (
+ b'<script type="text/javascript">\n//<![CDATA[\n'
+ + b"\n".join(js_embed)
+ + b"\n//]]>\n</script>"
+ )
def render_linked_css(self, css_files: Iterable[str]) -> str:
"""Default method used to render the final css links for the
paths.append(path)
unique_paths.add(path)
- return ''.join('<link href="' + escape.xhtml_escape(p) + '" '
- 'type="text/css" rel="stylesheet"/>'
- for p in paths)
+ return "".join(
+ '<link href="' + escape.xhtml_escape(p) + '" '
+ 'type="text/css" rel="stylesheet"/>'
+ for p in paths
+ )
def render_embed_css(self, css_embed: Iterable[bytes]) -> bytes:
"""Default method used to render the final embedded css for the
Override this method in a sub-classed controller to change the output.
"""
- return b'<style type="text/css">\n' + b'\n'.join(css_embed) + \
- b'\n</style>'
+ return b'<style type="text/css">\n' + b"\n".join(css_embed) + b"\n</style>"
def render_string(self, template_name: str, **kwargs: Any) -> bytes:
"""Generate the given template with the given arguments.
pgettext=self.locale.pgettext,
static_url=self.static_url,
xsrf_form_html=self.xsrf_form_html,
- reverse_url=self.reverse_url
+ reverse_url=self.reverse_url,
)
namespace.update(self.ui)
return namespace
kwargs["whitespace"] = settings["template_whitespace"]
return template.Loader(template_path, **kwargs)
- def flush(self, include_footers: bool=False) -> 'Future[None]':
+ def flush(self, include_footers: bool = False) -> "Future[None]":
"""Flushes the current output buffer to the network.
The ``callback`` argument, if given, can be used for flow control:
self._headers_written = True
for transform in self._transforms:
assert chunk is not None
- self._status_code, self._headers, chunk = \
- transform.transform_first_chunk(
- self._status_code, self._headers,
- chunk, include_footers)
+ self._status_code, self._headers, chunk = transform.transform_first_chunk(
+ self._status_code, self._headers, chunk, include_footers
+ )
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
- chunk = b''
+ chunk = b""
# Finalize the cookie headers (which have been stored in a side
# object so an outgoing cookie could be overwritten before it
for cookie in self._new_cookie.values():
self.add_header("Set-Cookie", cookie.OutputString(None))
- start_line = httputil.ResponseStartLine('',
- self._status_code,
- self._reason)
+ start_line = httputil.ResponseStartLine("", self._status_code, self._reason)
return self.request.connection.write_headers(
- start_line, self._headers, chunk)
+ start_line, self._headers, chunk
+ )
else:
for transform in self._transforms:
chunk = transform.transform_chunk(chunk, include_footers)
future.set_result(None)
return future
- def finish(self, chunk: Union[str, bytes, dict]=None) -> 'Future[None]':
+ def finish(self, chunk: Union[str, bytes, dict] = None) -> "Future[None]":
"""Finishes this response, ending the HTTP request.
Passing a ``chunk`` to ``finish()`` is equivalent to passing that
# Automatically support ETags and add the Content-Length header if
# we have not flushed any content yet.
if not self._headers_written:
- if (self._status_code == 200 and
- self.request.method in ("GET", "HEAD") and
- "Etag" not in self._headers):
+ if (
+ self._status_code == 200
+ and self.request.method in ("GET", "HEAD")
+ and "Etag" not in self._headers
+ ):
self.set_etag_header()
if self.check_etag_header():
self._write_buffer = []
self.set_status(304)
- if (self._status_code in (204, 304) or
- (self._status_code >= 100 and self._status_code < 200)):
- assert not self._write_buffer, "Cannot send body with %s" % self._status_code
+ if self._status_code in (204, 304) or (
+ self._status_code >= 100 and self._status_code < 200
+ ):
+ assert not self._write_buffer, (
+ "Cannot send body with %s" % self._status_code
+ )
self._clear_headers_for_304()
elif "Content-Length" not in self._headers:
content_length = sum(len(part) for part in self._write_buffer)
# _ui_module closures to allow for faster GC on CPython.
self.ui = None # type: ignore
- def send_error(self, status_code: int=500, **kwargs: Any) -> None:
+ def send_error(self, status_code: int = 500, **kwargs: Any) -> None:
"""Sends the given HTTP error code to the browser.
If `flush()` has already been called, it is not possible to send
try:
self.finish()
except Exception:
- gen_log.error("Failed to flush partial response",
- exc_info=True)
+ gen_log.error("Failed to flush partial response", exc_info=True)
return
self.clear()
- reason = kwargs.get('reason')
- if 'exc_info' in kwargs:
- exception = kwargs['exc_info'][1]
+ reason = kwargs.get("reason")
+ if "exc_info" in kwargs:
+ exception = kwargs["exc_info"][1]
if isinstance(exception, HTTPError) and exception.reason:
reason = exception.reason
self.set_status(status_code, reason=reason)
"""
if self.settings.get("serve_traceback") and "exc_info" in kwargs:
# in debug mode, try to send a traceback
- self.set_header('Content-Type', 'text/plain')
+ self.set_header("Content-Type", "text/plain")
for line in traceback.format_exception(*kwargs["exc_info"]):
self.write(line)
self.finish()
else:
- self.finish("<html><title>%(code)d: %(message)s</title>"
- "<body>%(code)d: %(message)s</body></html>" % {
- "code": status_code,
- "message": self._reason,
- })
+ self.finish(
+ "<html><title>%(code)d: %(message)s</title>"
+ "<body>%(code)d: %(message)s</body></html>"
+ % {"code": status_code, "message": self._reason}
+ )
@property
def locale(self) -> tornado.locale.Locale:
"""
return None
- def get_browser_locale(self, default: str="en_US") -> tornado.locale.Locale:
+ def get_browser_locale(self, default: str = "en_US") -> tornado.locale.Locale:
"""Determines the user's locale from ``Accept-Language`` header.
See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4
self._xsrf_token = binascii.b2a_hex(token)
elif output_version == 2:
mask = os.urandom(4)
- self._xsrf_token = b"|".join([
- b"2",
- binascii.b2a_hex(mask),
- binascii.b2a_hex(_websocket_mask(mask, token)),
- utf8(str(int(timestamp)))])
+ self._xsrf_token = b"|".join(
+ [
+ b"2",
+ binascii.b2a_hex(mask),
+ binascii.b2a_hex(_websocket_mask(mask, token)),
+ utf8(str(int(timestamp))),
+ ]
+ )
else:
- raise ValueError("unknown xsrf cookie version %d",
- output_version)
+ raise ValueError("unknown xsrf cookie version %d", output_version)
if version is None:
expires_days = 30 if self.current_user else None
- self.set_cookie("_xsrf", self._xsrf_token,
- expires_days=expires_days,
- **cookie_kwargs)
+ self.set_cookie(
+ "_xsrf",
+ self._xsrf_token,
+ expires_days=expires_days,
+ **cookie_kwargs
+ )
return self._xsrf_token
def _get_raw_xsrf_token(self) -> Tuple[Optional[int], bytes, float]:
* timestamp: the time this token was generated (will not be accurate
for version 1 cookies)
"""
- if not hasattr(self, '_raw_xsrf_token'):
+ if not hasattr(self, "_raw_xsrf_token"):
cookie = self.get_cookie("_xsrf")
if cookie:
version, token, timestamp = self._decode_xsrf_token(cookie)
self._raw_xsrf_token = (version, token, timestamp)
return self._raw_xsrf_token
- def _decode_xsrf_token(self, cookie: str) -> Tuple[
- Optional[int], Optional[bytes], Optional[float]]:
+ def _decode_xsrf_token(
+ self, cookie: str
+ ) -> Tuple[Optional[int], Optional[bytes], Optional[float]]:
"""Convert a cookie string into a the tuple form returned by
_get_raw_xsrf_token.
"""
_, mask_str, masked_token, timestamp_str = cookie.split("|")
mask = binascii.a2b_hex(utf8(mask_str))
- token = _websocket_mask(
- mask, binascii.a2b_hex(utf8(masked_token)))
+ token = _websocket_mask(mask, binascii.a2b_hex(utf8(masked_token)))
timestamp = int(timestamp_str)
return version, token, timestamp
else:
return (version, token, timestamp)
except Exception:
# Catch exceptions and return nothing instead of failing.
- gen_log.debug("Uncaught exception in _decode_xsrf_token",
- exc_info=True)
+ gen_log.debug("Uncaught exception in _decode_xsrf_token", exc_info=True)
return None, None, None
def check_xsrf_cookie(self) -> None:
Added support for cookie version 2. Both versions 1 and 2 are
supported.
"""
- token = (self.get_argument("_xsrf", None) or
- self.request.headers.get("X-Xsrftoken") or
- self.request.headers.get("X-Csrftoken"))
+ token = (
+ self.get_argument("_xsrf", None)
+ or self.request.headers.get("X-Xsrftoken")
+ or self.request.headers.get("X-Csrftoken")
+ )
if not token:
raise HTTPError(403, "'_xsrf' argument missing from POST")
_, token, _ = self._decode_xsrf_token(token)
See `check_xsrf_cookie()` above for more information.
"""
- return '<input type="hidden" name="_xsrf" value="' + \
- escape.xhtml_escape(self.xsrf_token) + '"/>'
+ return (
+ '<input type="hidden" name="_xsrf" value="'
+ + escape.xhtml_escape(self.xsrf_token)
+ + '"/>'
+ )
- def static_url(self, path: str, include_host: bool=None, **kwargs: Any) -> str:
+ def static_url(self, path: str, include_host: bool = None, **kwargs: Any) -> str:
"""Returns a static URL for the given relative static file path.
This method requires you set the ``static_path`` setting in your
"""
self.require_setting("static_path", "static_url")
- get_url = self.settings.get("static_handler_class",
- StaticFileHandler).make_static_url
+ get_url = self.settings.get(
+ "static_handler_class", StaticFileHandler
+ ).make_static_url
if include_host is None:
include_host = getattr(self, "include_host", False)
return base + get_url(self.settings, path, **kwargs)
- def require_setting(self, name: str, feature: str="this feature") -> None:
+ def require_setting(self, name: str, feature: str = "this feature") -> None:
"""Raises an exception if the given app setting is not defined."""
if not self.application.settings.get(name):
- raise Exception("You must define the '%s' setting in your "
- "application to use %s" % (name, feature))
+ raise Exception(
+ "You must define the '%s' setting in your "
+ "application to use %s" % (name, feature)
+ )
def reverse_url(self, name: str, *args: Any) -> str:
"""Alias for `Application.reverse_url`."""
# Find all weak and strong etag values from If-None-Match header
# because RFC 7232 allows multiple etag values in a single header.
etags = re.findall(
- br'\*|(?:W/)?"[^"]*"',
- utf8(self.request.headers.get("If-None-Match", ""))
+ br'\*|(?:W/)?"[^"]*"', utf8(self.request.headers.get("If-None-Match", ""))
)
if not computed_etag or not etags:
return False
match = False
- if etags[0] == b'*':
+ if etags[0] == b"*":
match = True
else:
# Use a weak comparison when comparing entity-tags.
def val(x: bytes) -> bytes:
- return x[2:] if x.startswith(b'W/') else x
+ return x[2:] if x.startswith(b"W/") else x
for etag in etags:
if val(etag) == val(computed_etag):
return match
@gen.coroutine
- def _execute(self, transforms: List['OutputTransform'], *args: bytes,
- **kwargs: bytes) -> Generator[Any, Any, None]:
+ def _execute(
+ self, transforms: List["OutputTransform"], *args: bytes, **kwargs: bytes
+ ) -> Generator[Any, Any, None]:
"""Executes this request with the given output transforms."""
self._transforms = transforms
try:
if self.request.method not in self.SUPPORTED_METHODS:
raise HTTPError(405)
self.path_args = [self.decode_argument(arg) for arg in args]
- self.path_kwargs = dict((k, self.decode_argument(v, name=k))
- for (k, v) in kwargs.items())
+ self.path_kwargs = dict(
+ (k, self.decode_argument(v, name=k)) for (k, v) in kwargs.items()
+ )
# If XSRF cookies are turned on, reject form submissions without
# the proper cookie
- if self.request.method not in ("GET", "HEAD", "OPTIONS") and \
- self.application.settings.get("xsrf_cookies"):
+ if self.request.method not in (
+ "GET",
+ "HEAD",
+ "OPTIONS",
+ ) and self.application.settings.get("xsrf_cookies"):
self.check_xsrf_cookie()
result = self.prepare()
finally:
# Unset result to avoid circular references
result = None
- if (self._prepared_future is not None and
- not self._prepared_future.done()):
+ if self._prepared_future is not None and not self._prepared_future.done():
# In case we failed before setting _prepared_future, do it
# now (to unblock the HTTP server). Note that this is not
# in a finally block to avoid GC issues prior to Python 3.4.
self.application.log_request(self)
def _request_summary(self) -> str:
- return "%s %s (%s)" % (self.request.method, self.request.uri,
- self.request.remote_ip)
+ return "%s %s (%s)" % (
+ self.request.method,
+ self.request.uri,
+ self.request.remote_ip,
+ )
def _handle_request_exception(self, e: BaseException) -> None:
if isinstance(e, Finish):
else:
self.send_error(500, exc_info=sys.exc_info())
- def log_exception(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[TracebackType]) -> None:
+ def log_exception(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
"""Override to customize logging of uncaught exceptions.
By default logs instances of `HTTPError` as warnings without
if isinstance(value, HTTPError):
if value.log_message:
format = "%d %s: " + value.log_message
- args = ([value.status_code, self._request_summary()] +
- list(value.args))
+ args = [value.status_code, self._request_summary()] + list(value.args)
gen_log.warning(format, *args)
else:
- app_log.error("Uncaught exception %s\n%r", self._request_summary(), # type: ignore
- self.request, exc_info=(typ, value, tb))
-
- def _ui_module(self, name: str, module: Type['UIModule']) -> Callable[..., str]:
+ app_log.error( # type: ignore
+ "Uncaught exception %s\n%r",
+ self._request_summary(),
+ self.request,
+ exc_info=(typ, value, tb),
+ )
+
+ def _ui_module(self, name: str, module: Type["UIModule"]) -> Callable[..., str]:
def render(*args: Any, **kwargs: Any) -> str:
if not hasattr(self, "_active_modules"):
self._active_modules = {} # type: Dict[str, UIModule]
self._active_modules[name] = module(self)
rendered = self._active_modules[name].render(*args, **kwargs)
return rendered
+
return render
def _ui_method(self, method: Callable[..., str]) -> Callable[..., str]:
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.1)
# not explicitly allowed by
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
- headers = ["Allow", "Content-Encoding", "Content-Language",
- "Content-Length", "Content-MD5", "Content-Range",
- "Content-Type", "Last-Modified"]
+ headers = [
+ "Allow",
+ "Content-Encoding",
+ "Content-Language",
+ "Content-Length",
+ "Content-MD5",
+ "Content-Range",
+ "Content-Type",
+ "Last-Modified",
+ ]
for h in headers:
self.clear_header(h)
def removeslash(
- method: Callable[..., Optional[Awaitable[None]]]
+ method: Callable[..., Optional[Awaitable[None]]]
) -> Callable[..., Optional[Awaitable[None]]]:
"""Use this decorator to remove trailing slashes from the request path.
decorator. Your request handler mapping should use a regular expression
like ``r'/foo/*'`` in conjunction with using the decorator.
"""
+
@functools.wraps(method)
- def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]:
+ def wrapper(
+ self: RequestHandler, *args: Any, **kwargs: Any
+ ) -> Optional[Awaitable[None]]:
if self.request.path.endswith("/"):
if self.request.method in ("GET", "HEAD"):
uri = self.request.path.rstrip("/")
else:
raise HTTPError(404)
return method(self, *args, **kwargs)
+
return wrapper
def addslash(
- method: Callable[..., Optional[Awaitable[None]]]
+ method: Callable[..., Optional[Awaitable[None]]]
) -> Callable[..., Optional[Awaitable[None]]]:
"""Use this decorator to add a missing trailing slash to the request path.
decorator. Your request handler mapping should use a regular expression
like ``r'/foo/?'`` in conjunction with using the decorator.
"""
+
@functools.wraps(method)
- def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]:
+ def wrapper(
+ self: RequestHandler, *args: Any, **kwargs: Any
+ ) -> Optional[Awaitable[None]]:
if not self.request.path.endswith("/"):
if self.request.method in ("GET", "HEAD"):
uri = self.request.path + "/"
return None
raise HTTPError(404)
return method(self, *args, **kwargs)
+
return wrapper
`_ApplicationRouter` instance.
"""
- def __init__(self, application: 'Application', rules: _RuleList=None) -> None:
+ def __init__(self, application: "Application", rules: _RuleList = None) -> None:
assert isinstance(application, Application)
self.application = application
super(_ApplicationRouter, self).__init__(rules)
rule = super(_ApplicationRouter, self).process_rule(rule)
if isinstance(rule.target, (list, tuple)):
- rule.target = _ApplicationRouter(self.application, rule.target) # type: ignore
+ rule.target = _ApplicationRouter( # type: ignore
+ self.application, rule.target
+ )
return rule
- def get_target_delegate(self, target: Any, request: httputil.HTTPServerRequest,
- **target_params: Any) -> Optional[httputil.HTTPMessageDelegate]:
+ def get_target_delegate(
+ self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any
+ ) -> Optional[httputil.HTTPMessageDelegate]:
if isclass(target) and issubclass(target, RequestHandler):
- return self.application.get_handler_delegate(request, target, **target_params)
+ return self.application.get_handler_delegate(
+ request, target, **target_params
+ )
- return super(_ApplicationRouter, self).get_target_delegate(target, request, **target_params)
+ return super(_ApplicationRouter, self).get_target_delegate(
+ target, request, **target_params
+ )
class Application(ReversibleRouter):
Integration with the new `tornado.routing` module.
"""
- def __init__(self, handlers: _RuleList=None, default_host: str=None,
- transforms: List[Type['OutputTransform']]=None, **settings: Any) -> None:
+
+ def __init__(
+ self,
+ handlers: _RuleList = None,
+ default_host: str = None,
+ transforms: List[Type["OutputTransform"]] = None,
+ **settings: Any
+ ) -> None:
if transforms is None:
self.transforms = [] # type: List[Type[OutputTransform]]
if settings.get("compress_response") or settings.get("gzip"):
self.transforms = transforms
self.default_host = default_host
self.settings = settings
- self.ui_modules = {'linkify': _linkify,
- 'xsrf_form_html': _xsrf_form_html,
- 'Template': TemplateModule,
- }
+ self.ui_modules = {
+ "linkify": _linkify,
+ "xsrf_form_html": _xsrf_form_html,
+ "Template": TemplateModule,
+ }
self.ui_methods = {} # type: Dict[str, Callable[..., str]]
self._load_ui_modules(settings.get("ui_modules", {}))
self._load_ui_methods(settings.get("ui_methods", {}))
if self.settings.get("static_path"):
path = self.settings["static_path"]
handlers = list(handlers or [])
- static_url_prefix = settings.get("static_url_prefix",
- "/static/")
- static_handler_class = settings.get("static_handler_class",
- StaticFileHandler)
+ static_url_prefix = settings.get("static_url_prefix", "/static/")
+ static_handler_class = settings.get(
+ "static_handler_class", StaticFileHandler
+ )
static_handler_args = settings.get("static_handler_args", {})
- static_handler_args['path'] = path
- for pattern in [re.escape(static_url_prefix) + r"(.*)",
- r"/(favicon\.ico)", r"/(robots\.txt)"]:
- handlers.insert(0, (pattern, static_handler_class,
- static_handler_args))
-
- if self.settings.get('debug'):
- self.settings.setdefault('autoreload', True)
- self.settings.setdefault('compiled_template_cache', False)
- self.settings.setdefault('static_hash_cache', False)
- self.settings.setdefault('serve_traceback', True)
+ static_handler_args["path"] = path
+ for pattern in [
+ re.escape(static_url_prefix) + r"(.*)",
+ r"/(favicon\.ico)",
+ r"/(robots\.txt)",
+ ]:
+ handlers.insert(0, (pattern, static_handler_class, static_handler_args))
+
+ if self.settings.get("debug"):
+ self.settings.setdefault("autoreload", True)
+ self.settings.setdefault("compiled_template_cache", False)
+ self.settings.setdefault("static_hash_cache", False)
+ self.settings.setdefault("serve_traceback", True)
self.wildcard_router = _ApplicationRouter(self, handlers)
- self.default_router = _ApplicationRouter(self, [
- Rule(AnyMatches(), self.wildcard_router)
- ])
+ self.default_router = _ApplicationRouter(
+ self, [Rule(AnyMatches(), self.wildcard_router)]
+ )
# Automatically reload modified modules
- if self.settings.get('autoreload'):
+ if self.settings.get("autoreload"):
from tornado import autoreload
+
autoreload.start()
- def listen(self, port: int, address: str="", **kwargs: Any) -> HTTPServer:
+ def listen(self, port: int, address: str = "", **kwargs: Any) -> HTTPServer:
"""Starts an HTTP server for this application on the given port.
This is a convenience alias for creating an `.HTTPServer`
self.default_router.rules.insert(-1, rule)
if self.default_host is not None:
- self.wildcard_router.add_rules([(
- DefaultHostMatches(self, host_matcher.host_pattern),
- host_handlers
- )])
+ self.wildcard_router.add_rules(
+ [(DefaultHostMatches(self, host_matcher.host_pattern), host_handlers)]
+ )
- def add_transform(self, transform_class: Type['OutputTransform']) -> None:
+ def add_transform(self, transform_class: Type["OutputTransform"]) -> None:
self.transforms.append(transform_class)
def _load_ui_methods(self, methods: Any) -> None:
if isinstance(methods, types.ModuleType):
- self._load_ui_methods(dict((n, getattr(methods, n))
- for n in dir(methods)))
+ self._load_ui_methods(dict((n, getattr(methods, n)) for n in dir(methods)))
elif isinstance(methods, list):
for m in methods:
self._load_ui_methods(m)
else:
for name, fn in methods.items():
- if not name.startswith("_") and hasattr(fn, "__call__") \
- and name[0].lower() == name[0]:
+ if (
+ not name.startswith("_")
+ and hasattr(fn, "__call__")
+ and name[0].lower() == name[0]
+ ):
self.ui_methods[name] = fn
def _load_ui_modules(self, modules: Any) -> None:
if isinstance(modules, types.ModuleType):
- self._load_ui_modules(dict((n, getattr(modules, n))
- for n in dir(modules)))
+ self._load_ui_modules(dict((n, getattr(modules, n)) for n in dir(modules)))
elif isinstance(modules, list):
for m in modules:
self._load_ui_modules(m)
except TypeError:
pass
- def __call__(self, request: httputil.HTTPServerRequest) -> Optional[Awaitable[None]]:
+ def __call__(
+ self, request: httputil.HTTPServerRequest
+ ) -> Optional[Awaitable[None]]:
# Legacy HTTPServer interface
dispatcher = self.find_handler(request)
return dispatcher.execute()
- def find_handler(self, request: httputil.HTTPServerRequest,
- **kwargs: Any) -> '_HandlerDelegate':
+ def find_handler(
+ self, request: httputil.HTTPServerRequest, **kwargs: Any
+ ) -> "_HandlerDelegate":
route = self.default_router.find_handler(request)
if route is not None:
- return cast('_HandlerDelegate', route)
+ return cast("_HandlerDelegate", route)
- if self.settings.get('default_handler_class'):
+ if self.settings.get("default_handler_class"):
return self.get_handler_delegate(
request,
- self.settings['default_handler_class'],
- self.settings.get('default_handler_args', {}))
-
- return self.get_handler_delegate(
- request, ErrorHandler, {'status_code': 404})
-
- def get_handler_delegate(self, request: httputil.HTTPServerRequest,
- target_class: Type[RequestHandler],
- target_kwargs: Dict[str, Any]=None,
- path_args: List[bytes]=None,
- path_kwargs: Dict[str, bytes]=None) -> '_HandlerDelegate':
+ self.settings["default_handler_class"],
+ self.settings.get("default_handler_args", {}),
+ )
+
+ return self.get_handler_delegate(request, ErrorHandler, {"status_code": 404})
+
+ def get_handler_delegate(
+ self,
+ request: httputil.HTTPServerRequest,
+ target_class: Type[RequestHandler],
+ target_kwargs: Dict[str, Any] = None,
+ path_args: List[bytes] = None,
+ path_kwargs: Dict[str, bytes] = None,
+ ) -> "_HandlerDelegate":
"""Returns `~.httputil.HTTPMessageDelegate` that can serve a request
for application and `RequestHandler` subclass.
:arg dict path_kwargs: keyword arguments for ``target_class`` HTTP method.
"""
return _HandlerDelegate(
- self, request, target_class, target_kwargs, path_args, path_kwargs)
+ self, request, target_class, target_kwargs, path_args, path_kwargs
+ )
def reverse_url(self, name: str, *args: Any) -> str:
"""Returns a URL path for handler named ``name``
else:
log_method = access_log.error
request_time = 1000.0 * handler.request.request_time()
- log_method("%d %s %.2fms", handler.get_status(),
- handler._request_summary(), request_time)
+ log_method(
+ "%d %s %.2fms",
+ handler.get_status(),
+ handler._request_summary(),
+ request_time,
+ )
class _HandlerDelegate(httputil.HTTPMessageDelegate):
- def __init__(self, application: Application, request: httputil.HTTPServerRequest,
- handler_class: Type[RequestHandler], handler_kwargs: Optional[Dict[str, Any]],
- path_args: Optional[List[bytes]], path_kwargs: Optional[Dict[str, bytes]]) -> None:
+ def __init__(
+ self,
+ application: Application,
+ request: httputil.HTTPServerRequest,
+ handler_class: Type[RequestHandler],
+ handler_kwargs: Optional[Dict[str, Any]],
+ path_args: Optional[List[bytes]],
+ path_kwargs: Optional[Dict[str, bytes]],
+ ) -> None:
self.application = application
self.connection = request.connection
self.request = request
self.chunks = [] # type: List[bytes]
self.stream_request_body = _has_stream_request_body(self.handler_class)
- def headers_received(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]:
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> Optional[Awaitable[None]]:
if self.stream_request_body:
self.request._body_future = Future()
return self.execute()
if self.stream_request_body:
future_set_result_unless_cancelled(self.request._body_future, None)
else:
- self.request.body = b''.join(self.chunks)
+ self.request.body = b"".join(self.chunks)
self.request._parse_body()
self.execute()
with RequestHandler._template_loader_lock:
for loader in RequestHandler._template_loaders.values():
loader.reset()
- if not self.application.settings.get('static_hash_cache', True):
+ if not self.application.settings.get("static_hash_cache", True):
StaticFileHandler.reset()
- self.handler = self.handler_class(self.application, self.request,
- **self.handler_kwargs)
+ self.handler = self.handler_class(
+ self.application, self.request, **self.handler_kwargs
+ )
transforms = [t(self.request) for t in self.application.transforms]
if self.stream_request_body:
# except handler, and we cannot easily access the IOLoop here to
# call add_future (because of the requirement to remain compatible
# with WSGI)
- self.handler._execute(transforms, *self.path_args,
- **self.path_kwargs)
+ self.handler._execute(transforms, *self.path_args, **self.path_kwargs)
# If we are streaming the request body, then execute() is finished
# when the handler has prepared to receive the body. If not,
# it doesn't matter when execute() finishes (so we return None)
determined automatically from ``status_code``, but can be used
to use a non-standard numeric code.
"""
- def __init__(self, status_code: int=500, log_message: str=None,
- *args: Any, **kwargs: Any) -> None:
+
+ def __init__(
+ self, status_code: int = 500, log_message: str = None, *args: Any, **kwargs: Any
+ ) -> None:
self.status_code = status_code
self.log_message = log_message
self.args = args
- self.reason = kwargs.get('reason', None)
+ self.reason = kwargs.get("reason", None)
if log_message and not args:
- self.log_message = log_message.replace('%', '%%')
+ self.log_message = log_message.replace("%", "%%")
def __str__(self) -> str:
message = "HTTP %d: %s" % (
self.status_code,
- self.reason or httputil.responses.get(self.status_code, 'Unknown'))
+ self.reason or httputil.responses.get(self.status_code, "Unknown"),
+ )
if self.log_message:
return message + " (" + (self.log_message % self.args) + ")"
else:
Arguments passed to ``Finish()`` will be passed on to
`RequestHandler.finish`.
"""
+
pass
.. versionadded:: 3.1
"""
+
def __init__(self, arg_name: str) -> None:
super(MissingArgumentError, self).__init__(
- 400, 'Missing argument %s' % arg_name)
+ 400, "Missing argument %s" % arg_name
+ )
self.arg_name = arg_name
class ErrorHandler(RequestHandler):
"""Generates an error response with ``status_code`` for all requests."""
+
def initialize(self, status_code: int) -> None: # type: ignore
self.set_status(status_code)
If any query arguments are present, they will be copied to the
destination URL.
"""
- def initialize(self, url: str, permanent: bool=True) -> None: # type: ignore
+
+ def initialize(self, url: str, permanent: bool = True) -> None: # type: ignore
self._url = url
self._permanent = permanent
if self.request.query_arguments:
# TODO: figure out typing for the next line.
to_url = httputil.url_concat(
- to_url, list(httputil.qs_to_qsl(self.request.query_arguments))) # type: ignore
+ to_url,
+ list(httputil.qs_to_qsl(self.request.query_arguments)), # type: ignore
+ )
self.redirect(to_url, permanent=self._permanent)
.. versionchanged:: 3.1
Many of the methods for subclasses were added in Tornado 3.1.
"""
+
CACHE_MAX_AGE = 86400 * 365 * 10 # 10 years
_static_hashes = {} # type: Dict[str, Optional[str]]
_lock = threading.Lock() # protects _static_hashes
- def initialize(self, path: str, default_filename: str=None) -> None: # type: ignore
+ def initialize( # type: ignore
+ self, path: str, default_filename: str = None
+ ) -> None:
self.root = path
self.default_filename = default_filename
with cls._lock:
cls._static_hashes = {}
- def head(self, path: str) -> 'Future[None]': # type: ignore
+ def head(self, path: str) -> "Future[None]": # type: ignore
return self.get(path, include_body=False)
@gen.coroutine
- def get(self, path: str, include_body: bool=True) -> Generator[Any, Any, None]:
+ def get(self, path: str, include_body: bool = True) -> Generator[Any, Any, None]:
# Set up our path instance variables.
self.path = self.parse_url_path(path)
del path # make sure we don't refer to path instead of self.path again
absolute_path = self.get_absolute_path(self.root, self.path)
- self.absolute_path = self.validate_absolute_path(
- self.root, absolute_path)
+ self.absolute_path = self.validate_absolute_path(self.root, absolute_path)
if self.absolute_path is None:
return
# content, or when a suffix with length 0 is specified
self.set_status(416) # Range Not Satisfiable
self.set_header("Content-Type", "text/plain")
- self.set_header("Content-Range", "bytes */%s" % (size, ))
+ self.set_header("Content-Range", "bytes */%s" % (size,))
return
if start is not None and start < 0:
start += size
# ``Range: bytes=0-``.
if size != (end or size) - (start or 0):
self.set_status(206) # Partial Content
- self.set_header("Content-Range",
- httputil._get_content_range(start, end, size))
+ self.set_header(
+ "Content-Range", httputil._get_content_range(start, end, size)
+ )
else:
start = end = None
version_hash = self._get_cached_version(self.absolute_path)
if not version_hash:
return None
- return '"%s"' % (version_hash, )
+ return '"%s"' % (version_hash,)
def set_headers(self) -> None:
"""Sets the content and caching headers on the response.
if content_type:
self.set_header("Content-Type", content_type)
- cache_time = self.get_cache_time(self.path, self.modified,
- content_type)
+ cache_time = self.get_cache_time(self.path, self.modified, content_type)
if cache_time > 0:
- self.set_header("Expires", datetime.datetime.utcnow() +
- datetime.timedelta(seconds=cache_time))
+ self.set_header(
+ "Expires",
+ datetime.datetime.utcnow() + datetime.timedelta(seconds=cache_time),
+ )
self.set_header("Cache-Control", "max-age=" + str(cache_time))
self.set_extra_headers(self.path)
.. versionadded:: 3.1
"""
# If client sent If-None-Match, use it, ignore If-Modified-Since
- if self.request.headers.get('If-None-Match'):
+ if self.request.headers.get("If-None-Match"):
return self.check_etag_header()
# Check the If-Modified-Since, and don't send the result if the
# The trailing slash also needs to be temporarily added back
# the requested path so a request to root/ will match.
if not (absolute_path + os.path.sep).startswith(root):
- raise HTTPError(403, "%s is not in root static directory",
- self.path)
- if (os.path.isdir(absolute_path) and
- self.default_filename is not None):
+ raise HTTPError(403, "%s is not in root static directory", self.path)
+ if os.path.isdir(absolute_path) and self.default_filename is not None:
# need to look at the request.path here for when path is empty
# but there is some prefix to the path that was already
# trimmed by the routing
return absolute_path
@classmethod
- def get_content(cls, abspath: str,
- start: int=None, end: int=None) -> Generator[bytes, None, None]:
+ def get_content(
+ cls, abspath: str, start: int = None, end: int = None
+ ) -> Generator[bytes, None, None]:
"""Retrieve the content of the requested resource which is located
at the given absolute path.
def _stat(self) -> os.stat_result:
assert self.absolute_path is not None
- if not hasattr(self, '_stat_result'):
+ if not hasattr(self, "_stat_result"):
self._stat_result = os.stat(self.absolute_path)
return self._stat_result
# consistency with the past (and because we have a unit test
# that relies on this), we truncate the float here, although
# I'm not sure that's the right thing to do.
- modified = datetime.datetime.utcfromtimestamp(
- int(stat_result.st_mtime))
+ modified = datetime.datetime.utcfromtimestamp(int(stat_result.st_mtime))
return modified
def get_content_type(self) -> str:
"""For subclass to add extra headers to the response"""
pass
- def get_cache_time(self, path: str, modified: Optional[datetime.datetime],
- mime_type: str) -> int:
+ def get_cache_time(
+ self, path: str, modified: Optional[datetime.datetime], mime_type: str
+ ) -> int:
"""Override to customize cache control behavior.
Return a positive number of seconds to make the result
return self.CACHE_MAX_AGE if "v" in self.request.arguments else 0
@classmethod
- def make_static_url(cls, settings: Dict[str, Any], path: str,
- include_version: bool=True) -> str:
+ def make_static_url(
+ cls, settings: Dict[str, Any], path: str, include_version: bool = True
+ ) -> str:
"""Constructs a versioned url for the given path.
This method may be overridden in subclasses (but note that it
file corresponding to the given ``path``.
"""
- url = settings.get('static_url_prefix', '/static/') + path
+ url = settings.get("static_url_prefix", "/static/") + path
if not include_version:
return url
if not version_hash:
return url
- return '%s?v=%s' % (url, version_hash)
+ return "%s?v=%s" % (url, version_hash)
def parse_url_path(self, url_path: str) -> str:
"""Converts a static URL path into a filesystem path.
`get_content_version` is now preferred as it allows the base
class to handle caching of the result.
"""
- abs_path = cls.get_absolute_path(settings['static_path'], path)
+ abs_path = cls.get_absolute_path(settings["static_path"], path)
return cls._get_cached_version(abs_path)
@classmethod
(r".*", FallbackHandler, dict(fallback=wsgi_app),
])
"""
- def initialize(self, # type: ignore
- fallback: Callable[[httputil.HTTPServerRequest], None]) -> None:
+
+ def initialize( # type: ignore
+ self, fallback: Callable[[httputil.HTTPServerRequest], None]
+ ) -> None:
self.fallback = fallback
def prepare(self) -> None:
or interact with them directly; the framework chooses which transforms
(if any) to apply.
"""
+
def __init__(self, request: httputil.HTTPServerRequest) -> None:
pass
def transform_first_chunk(
- self, status_code: int, headers: httputil.HTTPHeaders,
- chunk: bytes, finishing: bool
+ self,
+ status_code: int,
+ headers: httputil.HTTPHeaders,
+ chunk: bytes,
+ finishing: bool,
) -> Tuple[int, httputil.HTTPHeaders, bytes]:
return status_code, headers, chunk
of just a whitelist. (the whitelist is still used for certain
non-text mime types).
"""
+
# Whitelist of compressible mime types (in addition to any types
# beginning with "text/").
- CONTENT_TYPES = set(["application/javascript", "application/x-javascript",
- "application/xml", "application/atom+xml",
- "application/json", "application/xhtml+xml",
- "image/svg+xml"])
+ CONTENT_TYPES = set(
+ [
+ "application/javascript",
+ "application/x-javascript",
+ "application/xml",
+ "application/atom+xml",
+ "application/json",
+ "application/xhtml+xml",
+ "image/svg+xml",
+ ]
+ )
# Python's GzipFile defaults to level 9, while most other gzip
# tools (including gzip itself) default to 6, which is probably a
# better CPU/size tradeoff.
self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "")
def _compressible_type(self, ctype: str) -> bool:
- return ctype.startswith('text/') or ctype in self.CONTENT_TYPES
+ return ctype.startswith("text/") or ctype in self.CONTENT_TYPES
def transform_first_chunk(
- self, status_code: int, headers: httputil.HTTPHeaders,
- chunk: bytes, finishing: bool
+ self,
+ status_code: int,
+ headers: httputil.HTTPHeaders,
+ chunk: bytes,
+ finishing: bool,
) -> Tuple[int, httputil.HTTPHeaders, bytes]:
# TODO: can/should this type be inherited from the superclass?
- if 'Vary' in headers:
- headers['Vary'] += ', Accept-Encoding'
+ if "Vary" in headers:
+ headers["Vary"] += ", Accept-Encoding"
else:
- headers['Vary'] = 'Accept-Encoding'
+ headers["Vary"] = "Accept-Encoding"
if self._gzipping:
ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
- self._gzipping = self._compressible_type(ctype) and \
- (not finishing or len(chunk) >= self.MIN_LENGTH) and \
- ("Content-Encoding" not in headers)
+ self._gzipping = (
+ self._compressible_type(ctype)
+ and (not finishing or len(chunk) >= self.MIN_LENGTH)
+ and ("Content-Encoding" not in headers)
+ )
if self._gzipping:
headers["Content-Encoding"] = "gzip"
self._gzip_value = BytesIO()
- self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value,
- compresslevel=self.GZIP_LEVEL)
+ self._gzip_file = gzip.GzipFile(
+ mode="w", fileobj=self._gzip_value, compresslevel=self.GZIP_LEVEL
+ )
chunk = self.transform_chunk(chunk, finishing)
if "Content-Length" in headers:
# The original content length is no longer correct.
def authenticated(
- method: Callable[..., Optional[Awaitable[None]]]
+ method: Callable[..., Optional[Awaitable[None]]]
) -> Callable[..., Optional[Awaitable[None]]]:
"""Decorate methods with this to require that the user be logged in.
will add a `next` parameter so the login page knows where to send
you once you're logged in.
"""
+
@functools.wraps(method)
- def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]:
+ def wrapper(
+ self: RequestHandler, *args: Any, **kwargs: Any
+ ) -> Optional[Awaitable[None]]:
if not self.current_user:
if self.request.method in ("GET", "HEAD"):
url = self.get_login_url()
return None
raise HTTPError(403)
return method(self, *args, **kwargs)
+
return wrapper
Subclasses of UIModule must override the `render` method.
"""
+
def __init__(self, handler: RequestHandler) -> None:
self.handler = handler
self.request = handler.request
per instantiation of the template, so they must not depend on
any arguments to the template.
"""
+
def __init__(self, handler: RequestHandler) -> None:
super(TemplateModule, self).__init__(handler)
# keep resources in both a list and a dict to preserve order
self._resource_dict[path] = kwargs
else:
if self._resource_dict[path] != kwargs:
- raise ValueError("set_resources called with different "
- "resources for the same template")
+ raise ValueError(
+ "set_resources called with different "
+ "resources for the same template"
+ )
return ""
- return self.render_string(path, set_resources=set_resources,
- **kwargs)
+
+ return self.render_string(path, set_resources=set_resources, **kwargs)
def _get_resources(self, key: str) -> Iterable[str]:
return (r[key] for r in self._resource_list if key in r)
class _UIModuleNamespace(object):
"""Lazy namespace which creates UIModule proxies bound to a handler."""
- def __init__(self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]]) -> None:
+
+ def __init__(
+ self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]]
+ ) -> None:
self.handler = handler
self.ui_modules = ui_modules
raise AttributeError(str(e))
-def create_signed_value(secret: _CookieSecretTypes,
- name: str, value: Union[str, bytes],
- version: int=None, clock: Callable[[], float]=None,
- key_version: int=None) -> bytes:
+def create_signed_value(
+ secret: _CookieSecretTypes,
+ name: str,
+ value: Union[str, bytes],
+ version: int = None,
+ clock: Callable[[], float] = None,
+ key_version: int = None,
+) -> bytes:
if version is None:
version = DEFAULT_SIGNED_VALUE_VERSION
if clock is None:
# - signature (hex-encoded; no length prefix)
def format_field(s: Union[str, bytes]) -> bytes:
return utf8("%d:" % len(s)) + utf8(s)
- to_sign = b"|".join([
- b"2",
- format_field(str(key_version or 0)),
- format_field(timestamp),
- format_field(name),
- format_field(value),
- b''])
+
+ to_sign = b"|".join(
+ [
+ b"2",
+ format_field(str(key_version or 0)),
+ format_field(timestamp),
+ format_field(name),
+ format_field(value),
+ b"",
+ ]
+ )
if isinstance(secret, dict):
- assert key_version is not None, 'Key version must be set when sign key dict is used'
- assert version >= 2, 'Version must be at least 2 for key version support'
+ assert (
+ key_version is not None
+ ), "Key version must be set when sign key dict is used"
+ assert version >= 2, "Version must be at least 2 for key version support"
secret = secret[key_version]
signature = _create_signature_v2(secret, to_sign)
return version
-def decode_signed_value(secret: _CookieSecretTypes,
- name: str, value: Union[None, str, bytes], max_age_days: int=31,
- clock: Callable[[], float]=None, min_version: int=None) -> Optional[bytes]:
+def decode_signed_value(
+ secret: _CookieSecretTypes,
+ name: str,
+ value: Union[None, str, bytes],
+ max_age_days: int = 31,
+ clock: Callable[[], float] = None,
+ min_version: int = None,
+) -> Optional[bytes]:
if clock is None:
clock = time.time
if min_version is None:
return None
if version == 1:
assert not isinstance(secret, dict)
- return _decode_signed_value_v1(secret, name, value,
- max_age_days, clock)
+ return _decode_signed_value_v1(secret, name, value, max_age_days, clock)
elif version == 2:
- return _decode_signed_value_v2(secret, name, value,
- max_age_days, clock)
+ return _decode_signed_value_v2(secret, name, value, max_age_days, clock)
else:
return None
-def _decode_signed_value_v1(secret: Union[str, bytes], name: str, value: bytes, max_age_days: int,
- clock: Callable[[], float]) -> Optional[bytes]:
+def _decode_signed_value_v1(
+ secret: Union[str, bytes],
+ name: str,
+ value: bytes,
+ max_age_days: int,
+ clock: Callable[[], float],
+) -> Optional[bytes]:
parts = utf8(value).split(b"|")
if len(parts) != 3:
return None
# digits from the payload to the timestamp without altering the
# signature. For backwards compatibility, sanity-check timestamp
# here instead of modifying _cookie_signature.
- gen_log.warning("Cookie timestamp in future; possible tampering %r",
- value)
+ gen_log.warning("Cookie timestamp in future; possible tampering %r", value)
return None
if parts[1].startswith(b"0"):
gen_log.warning("Tampered cookie %r", value)
def _decode_fields_v2(value: bytes) -> Tuple[int, bytes, bytes, bytes, bytes]:
def _consume_field(s: bytes) -> Tuple[bytes, bytes]:
- length, _, rest = s.partition(b':')
+ length, _, rest = s.partition(b":")
n = int(length)
field_value = rest[:n]
# In python 3, indexing bytes returns small integers; we must
# use a slice to get a byte string as in python 2.
- if rest[n:n + 1] != b'|':
+ if rest[n : n + 1] != b"|":
raise ValueError("malformed v2 signed value field")
- rest = rest[n + 1:]
+ rest = rest[n + 1 :]
return field_value, rest
rest = value[2:] # remove version number
return int(key_version), timestamp, name_field, value_field, passed_sig
-def _decode_signed_value_v2(secret: _CookieSecretTypes,
- name: str, value: bytes, max_age_days: int,
- clock: Callable[[], float]) -> Optional[bytes]:
+def _decode_signed_value_v2(
+ secret: _CookieSecretTypes,
+ name: str,
+ value: bytes,
+ max_age_days: int,
+ clock: Callable[[], float],
+) -> Optional[bytes]:
try:
- key_version, timestamp_bytes, name_field, value_field, passed_sig = _decode_fields_v2(value)
+ key_version, timestamp_bytes, name_field, value_field, passed_sig = _decode_fields_v2(
+ value
+ )
except ValueError:
return None
- signed_string = value[:-len(passed_sig)]
+ signed_string = value[: -len(passed_sig)]
if isinstance(secret, dict):
try:
from tornado.tcpclient import TCPClient
from tornado.util import _websocket_mask
-from typing import (TYPE_CHECKING, cast, Any, Optional, Dict, Union, List, Awaitable,
- Callable, Generator, Tuple, Type)
+from typing import (
+ TYPE_CHECKING,
+ cast,
+ Any,
+ Optional,
+ Dict,
+ Union,
+ List,
+ Awaitable,
+ Callable,
+ Generator,
+ Tuple,
+ Type,
+)
from types import TracebackType
+
if TYPE_CHECKING:
from tornado.iostream import IOStream # noqa: F401
from typing_extensions import Protocol
pass
class _Decompressor(Protocol):
- unconsumed_tail = b'' # type: bytes
+ unconsumed_tail = b"" # type: bytes
def decompress(self, data: bytes, max_length: int) -> bytes:
pass
def close_reason(self, value: Optional[str]) -> None:
pass
- def on_message(self, message: Union[str, bytes]) -> Optional['Awaitable[None]']:
+ def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]:
pass
def on_ping(self, data: bytes) -> None:
def on_pong(self, data: bytes) -> None:
pass
- def log_exception(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[TracebackType]) -> None:
+ def log_exception(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
pass
.. versionadded:: 3.2
"""
+
pass
Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and
``websocket_max_message_size``.
"""
- def __init__(self, application: tornado.web.Application, request: httputil.HTTPServerRequest,
- **kwargs: Any) -> None:
+
+ def __init__(
+ self,
+ application: tornado.web.Application,
+ request: httputil.HTTPServerRequest,
+ **kwargs: Any
+ ) -> None:
super(WebSocketHandler, self).__init__(application, request, **kwargs)
self.ws_connection = None # type: Optional[WebSocketProtocol]
self.close_code = None # type: Optional[int]
self.open_kwargs = kwargs
# Upgrade header should be present and should be equal to WebSocket
- if self.request.headers.get("Upgrade", "").lower() != 'websocket':
+ if self.request.headers.get("Upgrade", "").lower() != "websocket":
self.set_status(400)
- log_msg = "Can \"Upgrade\" only to \"WebSocket\"."
+ log_msg = 'Can "Upgrade" only to "WebSocket".'
self.finish(log_msg)
gen_log.debug(log_msg)
return
# Some proxy servers/load balancers
# might mess with it.
headers = self.request.headers
- connection = map(lambda s: s.strip().lower(),
- headers.get("Connection", "").split(","))
- if 'upgrade' not in connection:
+ connection = map(
+ lambda s: s.strip().lower(), headers.get("Connection", "").split(",")
+ )
+ if "upgrade" not in connection:
self.set_status(400)
- log_msg = "\"Connection\" must be \"Upgrade\"."
+ log_msg = '"Connection" must be "Upgrade".'
self.finish(log_msg)
gen_log.debug(log_msg)
return
Set websocket_ping_interval = 0 to disable pings.
"""
- return self.settings.get('websocket_ping_interval', None)
+ return self.settings.get("websocket_ping_interval", None)
@property
def ping_timeout(self) -> Optional[float]:
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
Default is max of 3 pings or 30 seconds.
"""
- return self.settings.get('websocket_ping_timeout', None)
+ return self.settings.get("websocket_ping_timeout", None)
@property
def max_message_size(self) -> int:
Default is 10MiB.
"""
- return self.settings.get('websocket_max_message_size', _default_max_message_size)
+ return self.settings.get(
+ "websocket_max_message_size", _default_max_message_size
+ )
- def write_message(self, message: Union[bytes, str, Dict[str, Any]],
- binary: bool=False) -> 'Future[None]':
+ def write_message(
+ self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False
+ ) -> "Future[None]":
"""Sends the given message to the client of this Web Socket.
The message may be either a string or a dict (which will be
"""
raise NotImplementedError
- def ping(self, data: Union[str, bytes]=b'') -> None:
+ def ping(self, data: Union[str, bytes] = b"") -> None:
"""Send ping frame to the remote end.
The data argument allows a small amount of data (up to 125
"""
pass
- def close(self, code: int=None, reason: str=None) -> None:
+ def close(self, code: int = None, reason: str = None) -> None:
"""Closes this Web Socket.
Once the close handshake is successful the socket will be closed.
# we can close the connection more gracefully.
self.stream.close()
- def get_websocket_protocol(self) -> Optional['WebSocketProtocol']:
+ def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]:
websocket_version = self.request.headers.get("Sec-WebSocket-Version")
if websocket_version in ("7", "8", "13"):
return WebSocketProtocol13(
- self, compression_options=self.get_compression_options())
+ self, compression_options=self.get_compression_options()
+ )
return None
def _attach_stream(self) -> None:
self.stream = self.detach()
self.stream.set_close_callback(self.on_connection_close)
# disable non-WS methods
- for method in ["write", "redirect", "set_header", "set_cookie",
- "set_status", "flush", "finish"]:
+ for method in [
+ "write",
+ "redirect",
+ "set_header",
+ "set_cookie",
+ "set_status",
+ "flush",
+ "finish",
+ ]:
setattr(self, method, _raise_not_supported_for_websockets)
class WebSocketProtocol(abc.ABC):
"""Base class for WebSocket protocol versions.
"""
- def __init__(self, handler: '_WebSocketConnection') -> None:
+
+ def __init__(self, handler: "_WebSocketConnection") -> None:
self.handler = handler
self.stream = handler.stream
self.client_terminated = False
self.server_terminated = False
- def _run_callback(self, callback: Callable,
- *args: Any, **kwargs: Any) -> Optional['Future[Any]']:
+ def _run_callback(
+ self, callback: Callable, *args: Any, **kwargs: Any
+ ) -> Optional["Future[Any]"]:
"""Runs the given callback with exception handling.
If the callback is a coroutine, returns its Future. On error, aborts the
self.close() # let the subclass cleanup
@abc.abstractmethod
- def close(self, code: int=None, reason: str=None) -> None:
+ def close(self, code: int = None, reason: str = None) -> None:
raise NotImplementedError()
@abc.abstractmethod
raise NotImplementedError()
@abc.abstractmethod
- def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
+ def write_message(
+ self, message: Union[str, bytes], binary: bool = False
+ ) -> "Future[None]":
raise NotImplementedError()
@property
# WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13
# boundary is currently pretty ad-hoc.
@abc.abstractmethod
- def _process_server_headers(self, key: Union[str, bytes],
- headers: httputil.HTTPHeaders) -> None:
+ def _process_server_headers(
+ self, key: Union[str, bytes], headers: httputil.HTTPHeaders
+ ) -> None:
raise NotImplementedError()
@abc.abstractmethod
raise NotImplementedError()
@abc.abstractmethod
- def _receive_frame_loop(self) -> 'Future[None]':
+ def _receive_frame_loop(self) -> "Future[None]":
raise NotImplementedError()
class _PerMessageDeflateCompressor(object):
- def __init__(self, persistent: bool, max_wbits: Optional[int],
- compression_options: Dict[str, Any]=None) -> None:
+ def __init__(
+ self,
+ persistent: bool,
+ max_wbits: Optional[int],
+ compression_options: Dict[str, Any] = None,
+ ) -> None:
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
# There is no symbolic constant for the minimum wbits value.
if not (8 <= max_wbits <= zlib.MAX_WBITS):
- raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
- max_wbits, zlib.MAX_WBITS)
+ raise ValueError(
+ "Invalid max_wbits value %r; allowed range 8-%d",
+ max_wbits,
+ zlib.MAX_WBITS,
+ )
self._max_wbits = max_wbits
- if compression_options is None or 'compression_level' not in compression_options:
+ if (
+ compression_options is None
+ or "compression_level" not in compression_options
+ ):
self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL
else:
- self._compression_level = compression_options['compression_level']
+ self._compression_level = compression_options["compression_level"]
- if compression_options is None or 'mem_level' not in compression_options:
+ if compression_options is None or "mem_level" not in compression_options:
self._mem_level = 8
else:
- self._mem_level = compression_options['mem_level']
+ self._mem_level = compression_options["mem_level"]
if persistent:
self._compressor = self._create_compressor() # type: Optional[_Compressor]
else:
self._compressor = None
- def _create_compressor(self) -> '_Compressor':
- return zlib.compressobj(self._compression_level,
- zlib.DEFLATED, -self._max_wbits, self._mem_level)
+ def _create_compressor(self) -> "_Compressor":
+ return zlib.compressobj(
+ self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level
+ )
def compress(self, data: bytes) -> bytes:
compressor = self._compressor or self._create_compressor()
- data = (compressor.compress(data) +
- compressor.flush(zlib.Z_SYNC_FLUSH))
- assert data.endswith(b'\x00\x00\xff\xff')
+ data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH)
+ assert data.endswith(b"\x00\x00\xff\xff")
return data[:-4]
class _PerMessageDeflateDecompressor(object):
- def __init__(self, persistent: bool, max_wbits: Optional[int], max_message_size: int,
- compression_options: Dict[str, Any]=None) -> None:
+ def __init__(
+ self,
+ persistent: bool,
+ max_wbits: Optional[int],
+ max_message_size: int,
+ compression_options: Dict[str, Any] = None,
+ ) -> None:
self._max_message_size = max_message_size
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
if not (8 <= max_wbits <= zlib.MAX_WBITS):
- raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
- max_wbits, zlib.MAX_WBITS)
+ raise ValueError(
+ "Invalid max_wbits value %r; allowed range 8-%d",
+ max_wbits,
+ zlib.MAX_WBITS,
+ )
self._max_wbits = max_wbits
if persistent:
- self._decompressor = self._create_decompressor() # type: Optional[_Decompressor]
+ self._decompressor = (
+ self._create_decompressor()
+ ) # type: Optional[_Decompressor]
else:
self._decompressor = None
- def _create_decompressor(self) -> '_Decompressor':
+ def _create_decompressor(self) -> "_Decompressor":
return zlib.decompressobj(-self._max_wbits)
def decompress(self, data: bytes) -> bytes:
decompressor = self._decompressor or self._create_decompressor()
- result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
+ result = decompressor.decompress(
+ data + b"\x00\x00\xff\xff", self._max_message_size
+ )
if decompressor.unconsumed_tail:
raise _DecompressTooLargeError()
return result
This class supports versions 7 and 8 of the protocol in addition to the
final version 13.
"""
+
# Bit masks for the first byte of a frame.
FIN = 0x80
RSV1 = 0x40
RSV2 = 0x20
RSV3 = 0x10
RSV_MASK = RSV1 | RSV2 | RSV3
- OPCODE_MASK = 0x0f
+ OPCODE_MASK = 0x0F
stream = None # type: IOStream
- def __init__(self, handler: '_WebSocketConnection', mask_outgoing: bool=False,
- compression_options: Dict[str, Any]=None) -> None:
+ def __init__(
+ self,
+ handler: "_WebSocketConnection",
+ mask_outgoing: bool = False,
+ compression_options: Dict[str, Any] = None,
+ ) -> None:
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
try:
self._accept_connection(handler)
except ValueError:
- gen_log.debug("Malformed WebSocket request received",
- exc_info=True)
+ gen_log.debug("Malformed WebSocket request received", exc_info=True)
self._abort()
return
def _challenge_response(self, handler: WebSocketHandler) -> str:
return WebSocketProtocol13.compute_accept_value(
- cast(str, handler.request.headers.get("Sec-Websocket-Key")))
+ cast(str, handler.request.headers.get("Sec-Websocket-Key"))
+ )
@gen.coroutine
- def _accept_connection(self, handler: WebSocketHandler) -> Generator[Any, Any, None]:
+ def _accept_connection(
+ self, handler: WebSocketHandler
+ ) -> Generator[Any, Any, None]:
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
- subprotocols = [s.strip() for s in subprotocol_header.split(',')]
+ subprotocols = [s.strip() for s in subprotocol_header.split(",")]
else:
subprotocols = []
self.selected_subprotocol = handler.select_subprotocol(subprotocols)
extensions = self._parse_extensions_header(handler.request.headers)
for ext in extensions:
- if (ext[0] == 'permessage-deflate' and
- self._compression_options is not None):
+ if ext[0] == "permessage-deflate" and self._compression_options is not None:
# TODO: negotiate parameters if compression_options
# specifies limits.
- self._create_compressors('server', ext[1], self._compression_options)
- if ('client_max_window_bits' in ext[1] and
- ext[1]['client_max_window_bits'] is None):
+ self._create_compressors("server", ext[1], self._compression_options)
+ if (
+ "client_max_window_bits" in ext[1]
+ and ext[1]["client_max_window_bits"] is None
+ ):
# Don't echo an offered client_max_window_bits
# parameter with no value.
- del ext[1]['client_max_window_bits']
- handler.set_header("Sec-WebSocket-Extensions",
- httputil._encode_header(
- 'permessage-deflate', ext[1]))
+ del ext[1]["client_max_window_bits"]
+ handler.set_header(
+ "Sec-WebSocket-Extensions",
+ httputil._encode_header("permessage-deflate", ext[1]),
+ )
break
handler.clear_header("Content-Type")
self.stream = handler.stream
self.start_pinging()
- open_result = self._run_callback(handler.open, *handler.open_args,
- **handler.open_kwargs)
+ open_result = self._run_callback(
+ handler.open, *handler.open_args, **handler.open_kwargs
+ )
if open_result is not None:
yield open_result
yield self._receive_frame_loop()
def _parse_extensions_header(
- self, headers: httputil.HTTPHeaders
+ self, headers: httputil.HTTPHeaders
) -> List[Tuple[str, Dict[str, str]]]:
- extensions = headers.get("Sec-WebSocket-Extensions", '')
+ extensions = headers.get("Sec-WebSocket-Extensions", "")
if extensions:
- return [httputil._parse_header(e.strip())
- for e in extensions.split(',')]
+ return [httputil._parse_header(e.strip()) for e in extensions.split(",")]
return []
- def _process_server_headers(self, key: Union[str, bytes],
- headers: httputil.HTTPHeaders) -> None:
+ def _process_server_headers(
+ self, key: Union[str, bytes], headers: httputil.HTTPHeaders
+ ) -> None:
"""Process the headers sent by the server to this client connection.
'key' is the websocket handshake challenge/response key.
"""
- assert headers['Upgrade'].lower() == 'websocket'
- assert headers['Connection'].lower() == 'upgrade'
+ assert headers["Upgrade"].lower() == "websocket"
+ assert headers["Connection"].lower() == "upgrade"
accept = self.compute_accept_value(key)
- assert headers['Sec-Websocket-Accept'] == accept
+ assert headers["Sec-Websocket-Accept"] == accept
extensions = self._parse_extensions_header(headers)
for ext in extensions:
- if (ext[0] == 'permessage-deflate' and
- self._compression_options is not None):
- self._create_compressors('client', ext[1])
+ if ext[0] == "permessage-deflate" and self._compression_options is not None:
+ self._create_compressors("client", ext[1])
else:
raise ValueError("unsupported extension %r", ext)
- self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None)
+ self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None)
- def _get_compressor_options(self, side: str, agreed_parameters: Dict[str, Any],
- compression_options: Dict[str, Any]=None) -> Dict[str, Any]:
+ def _get_compressor_options(
+ self,
+ side: str,
+ agreed_parameters: Dict[str, Any],
+ compression_options: Dict[str, Any] = None,
+ ) -> Dict[str, Any]:
"""Converts a websocket agreed_parameters set to keyword arguments
for our compressor objects.
"""
- options = dict(persistent=(side + '_no_context_takeover') not in agreed_parameters) \
- # type: Dict[str, Any]
- wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
+ options = dict(
+ persistent=(side + "_no_context_takeover") not in agreed_parameters
+ ) # type: Dict[str, Any]
+ wbits_header = agreed_parameters.get(side + "_max_window_bits", None)
if wbits_header is None:
- options['max_wbits'] = zlib.MAX_WBITS
+ options["max_wbits"] = zlib.MAX_WBITS
else:
- options['max_wbits'] = int(wbits_header)
- options['compression_options'] = compression_options
+ options["max_wbits"] = int(wbits_header)
+ options["compression_options"] = compression_options
return options
- def _create_compressors(self, side: str, agreed_parameters: Dict[str, Any],
- compression_options: Dict[str, Any]=None) -> None:
+ def _create_compressors(
+ self,
+ side: str,
+ agreed_parameters: Dict[str, Any],
+ compression_options: Dict[str, Any] = None,
+ ) -> None:
# TODO: handle invalid parameters gracefully
- allowed_keys = set(['server_no_context_takeover',
- 'client_no_context_takeover',
- 'server_max_window_bits',
- 'client_max_window_bits'])
+ allowed_keys = set(
+ [
+ "server_no_context_takeover",
+ "client_no_context_takeover",
+ "server_max_window_bits",
+ "client_max_window_bits",
+ ]
+ )
for key in agreed_parameters:
if key not in allowed_keys:
raise ValueError("unsupported compression parameter %r" % key)
- other_side = 'client' if (side == 'server') else 'server'
+ other_side = "client" if (side == "server") else "server"
self._compressor = _PerMessageDeflateCompressor(
- **self._get_compressor_options(side, agreed_parameters, compression_options))
+ **self._get_compressor_options(side, agreed_parameters, compression_options)
+ )
self._decompressor = _PerMessageDeflateDecompressor(
max_message_size=self.handler.max_message_size,
- **self._get_compressor_options(other_side, agreed_parameters, compression_options))
-
- def _write_frame(self, fin: bool, opcode: int, data: bytes, flags: int=0) -> 'Future[None]':
+ **self._get_compressor_options(
+ other_side, agreed_parameters, compression_options
+ )
+ )
+
+ def _write_frame(
+ self, fin: bool, opcode: int, data: bytes, flags: int = 0
+ ) -> "Future[None]":
data_len = len(data)
if opcode & 0x8:
# All control frames MUST have a payload length of 125
self._wire_bytes_out += len(frame)
return self.stream.write(frame)
- def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
+ def write_message(
+ self, message: Union[str, bytes], binary: bool = False
+ ) -> "Future[None]":
"""Sends the given message to the client of this Web Socket."""
if binary:
opcode = 0x2
yield fut
except StreamClosedError:
raise WebSocketClosedError()
+
return wrapper()
def write_ping(self, data: bytes) -> None:
self._abort()
return
is_masked = bool(mask_payloadlen & 0x80)
- payloadlen = mask_payloadlen & 0x7f
+ payloadlen = mask_payloadlen & 0x7F
# Parse and validate the length.
if opcode_is_control and payloadlen >= 126:
if handled_future is not None:
yield handled_future
- def _handle_message(self, opcode: int, data: bytes) -> Optional['Future[None]']:
+ def _handle_message(self, opcode: int, data: bytes) -> Optional["Future[None]"]:
"""Execute on_message, returning its Future if it is a coroutine."""
if self.client_terminated:
return None
# Close
self.client_terminated = True
if len(data) >= 2:
- self.handler.close_code = struct.unpack('>H', data[:2])[0]
+ self.handler.close_code = struct.unpack(">H", data[:2])[0]
if len(data) > 2:
self.handler.close_reason = to_unicode(data[2:])
# Echo the received close code, if any (RFC 6455 section 5.5.1).
self._abort()
return None
- def close(self, code: int=None, reason: str=None) -> None:
+ def close(self, code: int = None, reason: str = None) -> None:
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
if code is None and reason is not None:
code = 1000 # "normal closure" status code
if code is None:
- close_data = b''
+ close_data = b""
else:
- close_data = struct.pack('>H', code)
+ close_data = struct.pack(">H", code)
if reason is not None:
close_data += utf8(reason)
try:
# Give the client a few seconds to complete a clean shutdown,
# otherwise just close the connection.
self._waiting = self.stream.io_loop.add_timeout(
- self.stream.io_loop.time() + 5, self._abort)
+ self.stream.io_loop.time() + 5, self._abort
+ )
def is_closing(self) -> bool:
"""Return true if this connection is closing.
initiated its closing handshake or if the stream has been
shut down uncleanly.
"""
- return (self.stream.closed() or
- self.client_terminated or
- self.server_terminated)
+ return self.stream.closed() or self.client_terminated or self.server_terminated
@property
def ping_interval(self) -> Optional[float]:
if self.ping_interval > 0:
self.last_ping = self.last_pong = IOLoop.current().time()
self.ping_callback = PeriodicCallback(
- self.periodic_ping, self.ping_interval * 1000)
+ self.periodic_ping, self.ping_interval * 1000
+ )
self.ping_callback.start()
def periodic_ping(self) -> None:
since_last_ping = now - self.last_ping
assert self.ping_interval is not None
assert self.ping_timeout is not None
- if (since_last_ping < 2 * self.ping_interval and
- since_last_pong > self.ping_timeout):
+ if (
+ since_last_ping < 2 * self.ping_interval
+ and since_last_pong > self.ping_timeout
+ ):
self.close()
return
- self.write_ping(b'')
+ self.write_ping(b"")
self.last_ping = now
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
+
protocol = None # type: WebSocketProtocol
- def __init__(self, request: httpclient.HTTPRequest,
- on_message_callback: Callable[[Union[None, str, bytes]], None]=None,
- compression_options: Dict[str, Any]=None,
- ping_interval: float=None, ping_timeout: float=None,
- max_message_size: int=_default_max_message_size,
- subprotocols: Optional[List[str]]=[]) -> None:
+ def __init__(
+ self,
+ request: httpclient.HTTPRequest,
+ on_message_callback: Callable[[Union[None, str, bytes]], None] = None,
+ compression_options: Dict[str, Any] = None,
+ ping_interval: float = None,
+ ping_timeout: float = None,
+ max_message_size: int = _default_max_message_size,
+ subprotocols: Optional[List[str]] = [],
+ ) -> None:
self.compression_options = compression_options
self.connect_future = Future() # type: Future[WebSocketClientConnection]
self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]]
self.ping_timeout = ping_timeout
self.max_message_size = max_message_size
- scheme, sep, rest = request.url.partition(':')
- scheme = {'ws': 'http', 'wss': 'https'}[scheme]
+ scheme, sep, rest = request.url.partition(":")
+ scheme = {"ws": "http", "wss": "https"}[scheme]
request.url = scheme + sep + rest
- request.headers.update({
- 'Upgrade': 'websocket',
- 'Connection': 'Upgrade',
- 'Sec-WebSocket-Key': self.key,
- 'Sec-WebSocket-Version': '13',
- })
+ request.headers.update(
+ {
+ "Upgrade": "websocket",
+ "Connection": "Upgrade",
+ "Sec-WebSocket-Key": self.key,
+ "Sec-WebSocket-Version": "13",
+ }
+ )
if subprotocols is not None:
- request.headers['Sec-WebSocket-Protocol'] = ','.join(subprotocols)
+ request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols)
if self.compression_options is not None:
# Always offer to let the server set our max_wbits (and even though
# we don't offer it, we will accept a client_no_context_takeover
# from the server).
# TODO: set server parameters for deflate extension
# if requested in self.compression_options.
- request.headers['Sec-WebSocket-Extensions'] = (
- 'permessage-deflate; client_max_window_bits')
+ request.headers[
+ "Sec-WebSocket-Extensions"
+ ] = "permessage-deflate; client_max_window_bits"
self.tcp_client = TCPClient()
super(WebSocketClientConnection, self).__init__(
- None, request, lambda: None, self._on_http_response,
- 104857600, self.tcp_client, 65536, 104857600)
-
- def close(self, code: int=None, reason: str=None) -> None:
+ None,
+ request,
+ lambda: None,
+ self._on_http_response,
+ 104857600,
+ self.tcp_client,
+ 65536,
+ 104857600,
+ )
+
+ def close(self, code: int = None, reason: str = None) -> None:
"""Closes the websocket connection.
``code`` and ``reason`` are documented under
if response.error:
self.connect_future.set_exception(response.error)
else:
- self.connect_future.set_exception(WebSocketError(
- "Non-websocket response"))
-
- def headers_received(self, start_line: Union[httputil.RequestStartLine,
- httputil.ResponseStartLine],
- headers: httputil.HTTPHeaders) -> None:
+ self.connect_future.set_exception(
+ WebSocketError("Non-websocket response")
+ )
+
+ def headers_received(
+ self,
+ start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders,
+ ) -> None:
assert isinstance(start_line, httputil.ResponseStartLine)
if start_line.code != 101:
return super(WebSocketClientConnection, self).headers_received(
- start_line, headers)
+ start_line, headers
+ )
self.headers = headers
self.protocol = self.get_websocket_protocol()
future_set_result_unless_cancelled(self.connect_future, self)
- def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
+ def write_message(
+ self, message: Union[str, bytes], binary: bool = False
+ ) -> "Future[None]":
"""Sends a message to the WebSocket server.
If the stream is closed, raises `WebSocketClosedError`.
return self.protocol.write_message(message, binary=binary)
def read_message(
- self, callback: Callable[['Future[Union[None, str, bytes]]'], None]=None
- ) -> 'Future[Union[None, str, bytes]]':
+ self, callback: Callable[["Future[Union[None, str, bytes]]"], None] = None
+ ) -> "Future[Union[None, str, bytes]]":
"""Reads a message from the WebSocket server.
If on_message_callback was specified at WebSocket
self.io_loop.add_future(future, callback)
return future
- def on_message(self, message: Union[str, bytes]) -> Optional['Future[None]']:
+ def on_message(self, message: Union[str, bytes]) -> Optional["Future[None]"]:
return self._on_message(message)
- def _on_message(self, message: Union[None, str, bytes]) -> Optional['Future[None]']:
+ def _on_message(self, message: Union[None, str, bytes]) -> Optional["Future[None]"]:
if self._on_message_callback:
self._on_message_callback(message)
return None
else:
return self.read_queue.put(message)
- def ping(self, data: bytes=b'') -> None:
+ def ping(self, data: bytes = b"") -> None:
"""Send ping frame to the remote end.
The data argument allows a small amount of data (up to 125
pass
def get_websocket_protocol(self) -> WebSocketProtocol:
- return WebSocketProtocol13(self, mask_outgoing=True,
- compression_options=self.compression_options)
+ return WebSocketProtocol13(
+ self, mask_outgoing=True, compression_options=self.compression_options
+ )
@property
def selected_subprotocol(self) -> Optional[str]:
"""
return self.protocol.selected_subprotocol
- def log_exception(self, typ: Optional[Type[BaseException]],
- value: Optional[BaseException],
- tb: Optional[TracebackType]) -> None:
+ def log_exception(
+ self,
+ typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
assert typ is not None
assert value is not None
app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb))
def websocket_connect(
- url: Union[str, httpclient.HTTPRequest],
- callback: Callable[['Future[WebSocketClientConnection]'], None]=None,
- connect_timeout: float=None,
- on_message_callback: Callable[[Union[None, str, bytes]], None]=None,
- compression_options: Dict[str, Any]=None,
- ping_interval: float=None, ping_timeout: float=None,
- max_message_size: int=_default_max_message_size, subprotocols: List[str]=None
-) -> 'Future[WebSocketClientConnection]':
+ url: Union[str, httpclient.HTTPRequest],
+ callback: Callable[["Future[WebSocketClientConnection]"], None] = None,
+ connect_timeout: float = None,
+ on_message_callback: Callable[[Union[None, str, bytes]], None] = None,
+ compression_options: Dict[str, Any] = None,
+ ping_interval: float = None,
+ ping_timeout: float = None,
+ max_message_size: int = _default_max_message_size,
+ subprotocols: List[str] = None,
+) -> "Future[WebSocketClientConnection]":
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
request.headers = httputil.HTTPHeaders(request.headers)
else:
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
- request = cast(httpclient.HTTPRequest, httpclient._RequestProxy(
- request, httpclient.HTTPRequest._DEFAULTS))
- conn = WebSocketClientConnection(request,
- on_message_callback=on_message_callback,
- compression_options=compression_options,
- ping_interval=ping_interval,
- ping_timeout=ping_timeout,
- max_message_size=max_message_size,
- subprotocols=subprotocols)
+ request = cast(
+ httpclient.HTTPRequest,
+ httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS),
+ )
+ conn = WebSocketClientConnection(
+ request,
+ on_message_callback=on_message_callback,
+ compression_options=compression_options,
+ ping_interval=ping_interval,
+ ping_timeout=ping_timeout,
+ max_message_size=max_message_size,
+ subprotocols=subprotocols,
+ )
if callback is not None:
IOLoop.current().add_future(conn.connect_future, callback)
return conn.connect_future
from typing import List, Tuple, Optional, Type, Callable, Any, Dict, Text
from types import TracebackType
import typing
+
if typing.TYPE_CHECKING:
from wsgiref.types import WSGIApplication as WSGIAppType # noqa: F401
# here to minimize the temptation to use it in non-wsgi contexts.
def to_wsgi_str(s: bytes) -> str:
assert isinstance(s, bytes)
- return s.decode('latin1')
+ return s.decode("latin1")
class WSGIContainer(object):
Tornado and WSGI apps in the same server. See
https://github.com/bdarnell/django-tornado-demo for a complete example.
"""
- def __init__(self, wsgi_application: 'WSGIAppType') -> None:
+
+ def __init__(self, wsgi_application: "WSGIAppType") -> None:
self.wsgi_application = wsgi_application
def __call__(self, request: httputil.HTTPServerRequest) -> None:
response = [] # type: List[bytes]
def start_response(
- status: str, headers: List[Tuple[str, str]],
- exc_info: Optional[Tuple[Optional[Type[BaseException]],
- Optional[BaseException],
- Optional[TracebackType]]]=None
+ status: str,
+ headers: List[Tuple[str, str]],
+ exc_info: Optional[
+ Tuple[
+ Optional[Type[BaseException]],
+ Optional[BaseException],
+ Optional[TracebackType],
+ ]
+ ] = None,
) -> Callable[[bytes], Any]:
data["status"] = status
data["headers"] = headers
return response.append
+
app_response = self.wsgi_application(
- WSGIContainer.environ(request), start_response)
+ WSGIContainer.environ(request), start_response
+ )
try:
response.extend(app_response)
body = b"".join(response)
if not data:
raise Exception("WSGI app did not call start_response")
- status_code_str, reason = data["status"].split(' ', 1)
+ status_code_str, reason = data["status"].split(" ", 1)
status_code = int(status_code_str)
headers = data["headers"] # type: List[Tuple[str, str]]
header_set = set(k.lower() for (k, v) in headers)
environ = {
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": "",
- "PATH_INFO": to_wsgi_str(escape.url_unescape(
- request.path, encoding=None, plus=False)),
+ "PATH_INFO": to_wsgi_str(
+ escape.url_unescape(request.path, encoding=None, plus=False)
+ ),
"QUERY_STRING": request.query,
"REMOTE_ADDR": request.remote_ip,
"SERVER_NAME": host,
request_time = 1000.0 * request.request_time()
assert request.method is not None
assert request.uri is not None
- summary = request.method + " " + request.uri + " (" + \
- request.remote_ip + ")"
+ summary = request.method + " " + request.uri + " (" + request.remote_ip + ")"
log_method("%d %s %.2fms", status_code, summary, request_time)