From: Ben Darnell Date: Sat, 29 Sep 2018 20:50:52 +0000 (-0400) Subject: web,routing: Add type annotations X-Git-Tag: v6.0.0b1~28^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=28c02b5ebfd5943185e25b3c22855de7b527542f;p=thirdparty%2Ftornado.git web,routing: Add type annotations --- diff --git a/setup.cfg b/setup.cfg index b34641c1d..213bc1186 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ python_version = 3.5 [mypy-tornado.*,tornado.platform.*] disallow_untyped_defs = True -[mypy-tornado.auth,tornado.routing,tornado.web,tornado.websocket,tornado.wsgi] +[mypy-tornado.auth,tornado.websocket,tornado.wsgi] disallow_untyped_defs = False # It's generally too tedious to require type annotations in tests, but diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 870cff793..101166635 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -350,7 +350,7 @@ class HTTP1Connection(httputil.HTTPConnection): def write_headers(self, start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], - headers: httputil.HTTPHeaders, chunk: bytes=None) -> Awaitable[None]: + headers: httputil.HTTPHeaders, chunk: bytes=None) -> 'Future[None]': """Implements `.HTTPConnection.write_headers`.""" lines = [] if self.is_client: @@ -438,7 +438,7 @@ class HTTP1Connection(httputil.HTTPConnection): else: return chunk - def write(self, chunk: bytes) -> Awaitable[None]: + def write(self, chunk: bytes) -> 'Future[None]': """Implements `.HTTPConnection.write`. For backwards compatibility it is allowed but deprecated to diff --git a/tornado/httputil.py b/tornado/httputil.py index d86f762f6..c88e862a2 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -43,10 +43,11 @@ responses import typing from typing import (Tuple, Iterable, List, Mapping, Iterator, Dict, Union, Optional, - Awaitable, Generator) + Awaitable, Generator, AnyStr) if typing.TYPE_CHECKING: from typing import Deque # noqa + from asyncio import Future # noqa import unittest # noqa @@ -348,6 +349,9 @@ class HTTPServerRequest(object): path = None # type: str query = None # type: str + # HACK: Used for stream_request_body + _body_future = None # type: Future[None] + def __init__(self, method: str=None, uri: str=None, version: str="HTTP/1.0", headers: HTTPHeaders=None, body: bytes=None, host: str=None, files: Dict[str, List['HTTPFile']]=None, connection: 'HTTPConnection'=None, @@ -503,6 +507,7 @@ class HTTPMessageDelegate(object): .. versionadded:: 4.0 """ + # TODO: genericize this class to avoid exposing the Union. def headers_received(self, start_line: Union['RequestStartLine', 'ResponseStartLine'], headers: HTTPHeaders) -> Optional[Awaitable[None]]: """Called when the HTTP headers have been received and parsed. @@ -545,7 +550,7 @@ class HTTPConnection(object): .. versionadded:: 4.0 """ def write_headers(self, start_line: Union['RequestStartLine', 'ResponseStartLine'], - headers: HTTPHeaders, chunk: bytes=None) -> Awaitable[None]: + headers: HTTPHeaders, chunk: bytes=None) -> 'Future[None]': """Write an HTTP header block. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`. @@ -556,7 +561,7 @@ class HTTPConnection(object): The ``version`` field of ``start_line`` is ignored. - Returns an awaitable for flow control. + Returns a future for flow control. .. versionchanged:: 6.0 @@ -564,10 +569,10 @@ class HTTPConnection(object): """ raise NotImplementedError() - def write(self, chunk: bytes) -> Awaitable[None]: + def write(self, chunk: bytes) -> 'Future[None]': """Writes a chunk of body data. - Returns an awaitable for flow control. + Returns a future for flow control. .. versionchanged:: 6.0 @@ -970,7 +975,7 @@ def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]: return (host, port) -def qs_to_qsl(qs: Dict[str, List[str]]) -> Iterable[Tuple[str, str]]: +def qs_to_qsl(qs: Dict[str, List[AnyStr]]) -> Iterable[Tuple[str, AnyStr]]: """Generator converting a result of ``parse_qs`` back to name-value pairs. .. versionadded:: 5.0 diff --git a/tornado/routing.py b/tornado/routing.py index 2af2fcad5..973b30473 100644 --- a/tornado/routing.py +++ b/tornado/routing.py @@ -184,17 +184,14 @@ from tornado.escape import url_escape, url_unescape, utf8 from tornado.log import app_log from tornado.util import basestring_type, import_object, re_unescape, unicode_type -try: - import typing # noqa -except ImportError: - pass +from typing import Any, Union, Optional, Awaitable, List, Dict, Pattern, Tuple, overload class Router(httputil.HTTPServerConnectionDelegate): """Abstract router interface.""" - def find_handler(self, request, **kwargs): - # type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate + def find_handler(self, request: httputil.HTTPServerRequest, + **kwargs: Any) -> Optional[httputil.HTTPMessageDelegate]: """Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate` that can serve the request. Routing implementations may pass additional kwargs to extend the routing logic. @@ -206,7 +203,8 @@ class Router(httputil.HTTPServerConnectionDelegate): """ raise NotImplementedError() - def start_request(self, server_conn, request_conn): + def start_request(self, server_conn: object, + request_conn: httputil.HTTPConnection) -> httputil.HTTPMessageDelegate: return _RoutingDelegate(self, server_conn, request_conn) @@ -215,7 +213,7 @@ class ReversibleRouter(Router): and support reversing them to original urls. """ - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> Optional[str]: """Returns url string for a given route name and arguments or ``None`` if no match is found. @@ -227,13 +225,17 @@ class ReversibleRouter(Router): class _RoutingDelegate(httputil.HTTPMessageDelegate): - def __init__(self, router, server_conn, request_conn): + def __init__(self, router: Router, server_conn: object, + request_conn: httputil.HTTPConnection) -> None: self.server_conn = server_conn self.request_conn = request_conn - self.delegate = None + self.delegate = None # type: Optional[httputil.HTTPMessageDelegate] self.router = router # type: Router - def headers_received(self, start_line, headers): + def headers_received(self, start_line: Union[httputil.RequestStartLine, + httputil.ResponseStartLine], + headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: + assert isinstance(start_line, httputil.RequestStartLine) request = httputil.HTTPServerRequest( connection=self.request_conn, server_connection=self.server_conn, @@ -247,30 +249,42 @@ class _RoutingDelegate(httputil.HTTPMessageDelegate): return self.delegate.headers_received(start_line, headers) - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: + assert self.delegate is not None return self.delegate.data_received(chunk) - def finish(self): + def finish(self) -> None: + assert self.delegate is not None self.delegate.finish() - def on_connection_close(self): + def on_connection_close(self) -> None: + assert self.delegate is not None self.delegate.on_connection_close() class _DefaultMessageDelegate(httputil.HTTPMessageDelegate): - def __init__(self, connection): + def __init__(self, connection: httputil.HTTPConnection) -> None: self.connection = connection - def finish(self): + def finish(self) -> None: self.connection.write_headers( httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"), httputil.HTTPHeaders()) self.connection.finish() +# _RuleList can either contain pre-constructed Rules or a sequence of +# arguments to be passed to the Rule constructor. +_RuleList = List[Union['Rule', + List[Any], # Can't do detailed typechecking of lists. + Tuple[Union[str, 'Matcher'], Any], + Tuple[Union[str, 'Matcher'], Any, Dict[str, Any]], + Tuple[Union[str, 'Matcher'], Any, Dict[str, Any], str]]] + + class RuleRouter(Router): """Rule-based router implementation.""" - def __init__(self, rules=None): + def __init__(self, rules: _RuleList=None) -> None: """Constructs a router from an ordered list of rules:: RuleRouter([ @@ -297,11 +311,11 @@ class RuleRouter(Router): :arg rules: a list of `Rule` instances or tuples of `Rule` constructor arguments. """ - self.rules = [] # type: typing.List[Rule] + self.rules = [] # type: List[Rule] if rules: self.add_rules(rules) - def add_rules(self, rules): + def add_rules(self, rules: _RuleList) -> None: """Appends new rules to the router. :arg rules: a list of Rule instances (or tuples of arguments, which are @@ -317,7 +331,7 @@ class RuleRouter(Router): self.rules.append(self.process_rule(rule)) - def process_rule(self, rule): + def process_rule(self, rule: 'Rule') -> 'Rule': """Override this method for additional preprocessing of each rule. :arg Rule rule: a rule to be processed. @@ -325,7 +339,8 @@ class RuleRouter(Router): """ return rule - def find_handler(self, request, **kwargs): + def find_handler(self, request: httputil.HTTPServerRequest, + **kwargs: Any) -> Optional[httputil.HTTPMessageDelegate]: for rule in self.rules: target_params = rule.matcher.match(request) if target_params is not None: @@ -340,7 +355,8 @@ class RuleRouter(Router): return None - def get_target_delegate(self, target, request, **target_params): + def get_target_delegate(self, target: Any, request: httputil.HTTPServerRequest, + **target_params: Any) -> Optional[httputil.HTTPMessageDelegate]: """Returns an instance of `~.httputil.HTTPMessageDelegate` for a Rule's target. This method is called by `~.find_handler` and can be extended to provide additional target types. @@ -354,9 +370,11 @@ class RuleRouter(Router): return target.find_handler(request, **target_params) elif isinstance(target, httputil.HTTPServerConnectionDelegate): + assert request.connection is not None return target.start_request(request.server_connection, request.connection) elif callable(target): + assert request.connection is not None return _CallableAdapter( partial(target, **target_params), request.connection ) @@ -372,11 +390,11 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter): in a rule's matcher (see `Matcher.reverse`). """ - def __init__(self, rules=None): - self.named_rules = {} # type: typing.Dict[str, Any] + def __init__(self, rules: _RuleList=None) -> None: + self.named_rules = {} # type: Dict[str, Any] super(ReversibleRuleRouter, self).__init__(rules) - def process_rule(self, rule): + def process_rule(self, rule: 'Rule') -> 'Rule': rule = super(ReversibleRuleRouter, self).process_rule(rule) if rule.name: @@ -388,7 +406,7 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter): return rule - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> Optional[str]: if name in self.named_rules: return self.named_rules[name].matcher.reverse(*args) @@ -404,7 +422,8 @@ class ReversibleRuleRouter(ReversibleRouter, RuleRouter): class Rule(object): """A routing rule.""" - def __init__(self, matcher, target, target_kwargs=None, name=None): + def __init__(self, matcher: 'Matcher', target: Any, + target_kwargs: Dict[str, Any]=None, name: str=None) -> None: """Constructs a Rule instance. :arg Matcher matcher: a `Matcher` instance used for determining @@ -431,10 +450,10 @@ class Rule(object): self.target_kwargs = target_kwargs if target_kwargs else {} self.name = name - def reverse(self, *args): + def reverse(self, *args: Any) -> Optional[str]: return self.matcher.reverse(*args) - def __repr__(self): + def __repr__(self) -> str: return '%s(%r, %s, kwargs=%r, name=%r)' % \ (self.__class__.__name__, self.matcher, self.target, self.target_kwargs, self.name) @@ -443,7 +462,7 @@ class Rule(object): class Matcher(object): """Represents a matcher for request features.""" - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: """Matches current instance against the request. :arg httputil.HTTPServerRequest request: current HTTP request @@ -455,7 +474,7 @@ class Matcher(object): ``None`` must be returned to indicate that there is no match.""" raise NotImplementedError() - def reverse(self, *args): + def reverse(self, *args: Any) -> Optional[str]: """Reconstructs full url from matcher instance and additional arguments.""" return None @@ -463,14 +482,14 @@ class Matcher(object): class AnyMatches(Matcher): """Matches any request.""" - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: return {} class HostMatches(Matcher): """Matches requests from hosts specified by ``host_pattern`` regex.""" - def __init__(self, host_pattern): + def __init__(self, host_pattern: Union[str, Pattern]) -> None: if isinstance(host_pattern, basestring_type): if not host_pattern.endswith("$"): host_pattern += "$" @@ -478,7 +497,7 @@ class HostMatches(Matcher): else: self.host_pattern = host_pattern - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: if self.host_pattern.match(request.host_name): return {} @@ -490,11 +509,11 @@ class DefaultHostMatches(Matcher): Always returns no match if ``X-Real-Ip`` header is present. """ - def __init__(self, application, host_pattern): + def __init__(self, application: Any, host_pattern: Pattern) -> None: self.application = application self.host_pattern = host_pattern - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: # Look for default host if not behind load balancer (for debugging) if "X-Real-Ip" not in request.headers: if self.host_pattern.match(self.application.default_host): @@ -505,7 +524,7 @@ class DefaultHostMatches(Matcher): class PathMatches(Matcher): """Matches requests with paths specified by ``path_pattern`` regex.""" - def __init__(self, path_pattern): + def __init__(self, path_pattern: Union[str, Pattern]) -> None: if isinstance(path_pattern, basestring_type): if not path_pattern.endswith('$'): path_pattern += '$' @@ -519,14 +538,15 @@ class PathMatches(Matcher): self._path, self._group_count = self._find_groups() - def match(self, request): + def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]: match = self.regex.match(request.path) if match is None: return None if not self.regex.groups: return {} - path_args, path_kwargs = [], {} + path_args = [] # type: List[bytes] + path_kwargs = {} # type: Dict[str, bytes] # Pass matched groups to the handler. Since # match.groups() includes both named and @@ -541,7 +561,7 @@ class PathMatches(Matcher): return dict(path_args=path_args, path_kwargs=path_kwargs) - def reverse(self, *args): + def reverse(self, *args: Any) -> Optional[str]: if self._path is None: raise ValueError("Cannot reverse url regex " + self.regex.pattern) assert len(args) == self._group_count, "required number of arguments " \ @@ -555,7 +575,7 @@ class PathMatches(Matcher): converted_args.append(url_escape(utf8(a), plus=False)) return self._path % tuple(converted_args) - def _find_groups(self): + def _find_groups(self) -> Tuple[Optional[str], Optional[int]]: """Returns a tuple (reverse string, group count) for a url. For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method @@ -597,7 +617,8 @@ class URLSpec(Rule): `URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for backwards compatibility. """ - def __init__(self, pattern, handler, kwargs=None, name=None): + def __init__(self, pattern: Union[str, Pattern], handler: Any, + kwargs: Dict[str, Any]=None, name: str=None) -> None: """Parameters: * ``pattern``: Regular expression to be matched. Any capturing @@ -615,19 +636,30 @@ class URLSpec(Rule): `~.web.Application.reverse_url`. """ - super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name) + matcher = PathMatches(pattern) + super(URLSpec, self).__init__(matcher, handler, kwargs, name) - self.regex = self.matcher.regex + self.regex = matcher.regex self.handler_class = self.target self.kwargs = kwargs - def __repr__(self): + def __repr__(self) -> str: return '%s(%r, %s, kwargs=%r, name=%r)' % \ (self.__class__.__name__, self.regex.pattern, self.handler_class, self.kwargs, self.name) -def _unquote_or_none(s): +@overload +def _unquote_or_none(s: str) -> bytes: + pass + + +@overload # noqa: F811 +def _unquote_or_none(s: None) -> None: + pass + + +def _unquote_or_none(s: Optional[str]) -> Optional[bytes]: # noqa: F811 """None-safe wrapper around url_unescape to handle unmatched optional groups correctly. diff --git a/tornado/test/httpclient_test.py b/tornado/test/httpclient_test.py index 1b0a43f8b..d6a519d33 100644 --- a/tornado/test/httpclient_test.py +++ b/tornado/test/httpclient_test.py @@ -109,7 +109,7 @@ class AllMethodsHandler(RequestHandler): def method(self): self.write(self.request.method) - get = post = put = delete = options = patch = other = method + get = post = put = delete = options = patch = other = method # type: ignore class SetHeaderHandler(RequestHandler): diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index 4f7228abb..41ccd6c94 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -1928,7 +1928,7 @@ class AllHTTPMethodsTest(SimpleHandlerTestCase): def method(self): self.write(self.request.method) - get = delete = options = post = put = method + get = delete = options = post = put = method # type: ignore def test_standard_methods(self): response = self.fetch('/', method='HEAD') diff --git a/tornado/web.py b/tornado/web.py index 76337ea83..f54f6f187 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -72,7 +72,6 @@ import mimetypes import numbers import os.path import re -import stat import sys import threading import time @@ -88,26 +87,34 @@ from tornado import gen from tornado.httpserver import HTTPServer from tornado import httputil from tornado import iostream +import tornado.locale from tornado import locale from tornado.log import access_log, app_log, gen_log from tornado import template from tornado.escape import utf8, _unicode from tornado.routing import (AnyMatches, DefaultHostMatches, HostMatches, ReversibleRouter, Rule, ReversibleRuleRouter, - URLSpec) + URLSpec, _RuleList) from tornado.util import ObjectDict, unicode_type, _websocket_mask url = URLSpec -try: - import typing # noqa +from typing import (Dict, Any, Union, Optional, Awaitable, Tuple, List, Callable, Iterable, + Generator, Type, cast) +from types import TracebackType +import typing +if typing.TYPE_CHECKING: + from typing import Set # noqa: F401 - # The following types are accepted by RequestHandler.set_header - # and related methods. - _HeaderTypes = typing.Union[bytes, unicode_type, - numbers.Integral, datetime.datetime] -except ImportError: - pass + +# The following types are accepted by RequestHandler.set_header +# and related methods. +_HeaderTypes = Union[bytes, unicode_type, + int, numbers.Integral, datetime.datetime] + +_CookieSecretTypes = Union[str, bytes, + Dict[int, str], + Dict[int, bytes]] MIN_SUPPORTED_SIGNED_VALUE_VERSION = 1 @@ -143,6 +150,13 @@ May be overridden by passing a ``min_version`` keyword argument. """ +class _ArgDefaultMarker: + pass + + +_ARG_DEFAULT = _ArgDefaultMarker() + + class RequestHandler(object): """Base class for HTTP request handlers. @@ -152,11 +166,19 @@ class RequestHandler(object): SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PATCH", "PUT", "OPTIONS") - _template_loaders = {} # type: typing.Dict[str, template.BaseLoader] + _template_loaders = {} # type: Dict[str, template.BaseLoader] _template_loader_lock = threading.Lock() _remove_control_chars_regex = re.compile(r"[\x00-\x08\x0e-\x1f]") - def __init__(self, application, request, **kwargs): + _stream_request_body = False + + # Will be set in _execute. + _transforms = None # type: List[OutputTransform] + path_args = None # type: List[str] + path_kwargs = None # type: Dict[str, str] + + def __init__(self, application: 'Application', request: httputil.HTTPServerRequest, + **kwargs: Any) -> None: super(RequestHandler, self).__init__() self.application = application @@ -164,11 +186,7 @@ class RequestHandler(object): self._headers_written = False self._finished = False self._auto_finish = True - self._transforms = None # will be set in _execute self._prepared_future = None - self._headers = None # type: httputil.HTTPHeaders - self.path_args = None - self.path_kwargs = None self.ui = ObjectDict((n, self._ui_method(m)) for n, m in application.ui_methods.items()) # UIModules are available as both `modules` and `_tt_modules` in the @@ -180,10 +198,12 @@ class RequestHandler(object): application.ui_modules) self.ui["modules"] = self.ui["_tt_modules"] self.clear() - self.request.connection.set_close_callback(self.on_connection_close) - self.initialize(**kwargs) + assert self.request.connection is not None + # TODO: need to add set_close_callback to HTTPConnection interface + self.request.connection.set_close_callback(self.on_connection_close) # type: ignore + self.initialize(**kwargs) # type: ignore - def initialize(self): + def initialize(self) -> None: """Hook for subclass initialization. Called for each request. A dictionary passed as the third argument of a url spec will be @@ -205,32 +225,32 @@ class RequestHandler(object): pass @property - def settings(self): + def settings(self) -> Dict[str, Any]: """An alias for `self.application.settings `.""" return self.application.settings - def head(self, *args, **kwargs): + def head(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def get(self, *args, **kwargs): + def get(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def post(self, *args, **kwargs): + def post(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def delete(self, *args, **kwargs): + def delete(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def patch(self, *args, **kwargs): + def patch(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def put(self, *args, **kwargs): + def put(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def options(self, *args, **kwargs): + def options(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: raise HTTPError(405) - def prepare(self): + def prepare(self) -> Optional[Awaitable[None]]: """Called at the beginning of a request before `get`/`post`/etc. Override this method to perform common initialization regardless @@ -246,7 +266,7 @@ class RequestHandler(object): """ pass - def on_finish(self): + def on_finish(self) -> None: """Called after the end of a request. Override this method to perform cleanup, logging, etc. @@ -256,7 +276,7 @@ class RequestHandler(object): """ pass - def on_connection_close(self): + def on_connection_close(self) -> None: """Called in async handlers if the client closed the connection. Override this to clean up resources associated with @@ -271,11 +291,11 @@ class RequestHandler(object): connection. """ if _has_stream_request_body(self.__class__): - if not self.request.body.done(): - self.request.body.set_exception(iostream.StreamClosedError()) - self.request.body.exception() + if not self.request._body_future.done(): + self.request._body_future.set_exception(iostream.StreamClosedError()) + self.request._body_future.exception() - def clear(self): + def clear(self) -> None: """Resets all headers and content for this response.""" self._headers = httputil.HTTPHeaders({ "Server": "TornadoServer/%s" % tornado.version, @@ -283,11 +303,11 @@ class RequestHandler(object): "Date": httputil.format_timestamp(time.time()), }) self.set_default_headers() - self._write_buffer = [] + self._write_buffer = [] # type: List[bytes] self._status_code = 200 self._reason = httputil.responses[200] - def set_default_headers(self): + def set_default_headers(self) -> None: """Override this to set HTTP headers at the beginning of the request. For example, this is the place to set a custom ``Server`` header. @@ -297,7 +317,7 @@ class RequestHandler(object): """ pass - def set_status(self, status_code, reason=None): + def set_status(self, status_code: int, reason: str=None) -> None: """Sets the status code for our response. :arg int status_code: Response status code. @@ -316,12 +336,11 @@ class RequestHandler(object): else: self._reason = httputil.responses.get(status_code, "Unknown") - def get_status(self): + def get_status(self) -> int: """Returns the status code for our response.""" return self._status_code - def set_header(self, name, value): - # type: (str, _HeaderTypes) -> None + def set_header(self, name: str, value: _HeaderTypes) -> None: """Sets the given response header name and value. If a datetime is given, we automatically format it according to the @@ -330,8 +349,7 @@ class RequestHandler(object): """ self._headers[name] = self._convert_header_value(value) - def add_header(self, name, value): - # type: (str, _HeaderTypes) -> None + def add_header(self, name: str, value: _HeaderTypes) -> None: """Adds the given response header and value. Unlike `set_header`, `add_header` may be called multiple times @@ -339,7 +357,7 @@ class RequestHandler(object): """ self._headers.add(name, self._convert_header_value(value)) - def clear_header(self, name): + def clear_header(self, name: str) -> None: """Clears an outgoing header, undoing a previous `set_header` call. Note that this method does not apply to multi-valued headers @@ -350,9 +368,7 @@ class RequestHandler(object): _INVALID_HEADER_CHAR_RE = re.compile(r"[\x00-\x1f]") - def _convert_header_value(self, value): - # type: (_HeaderTypes) -> str - + def _convert_header_value(self, value: _HeaderTypes) -> str: # Convert the input value to a str. This type check is a bit # subtle: The bytes case only executes on python 3, and the # unicode case only executes on python 2, because the other @@ -380,9 +396,8 @@ class RequestHandler(object): raise ValueError("Unsafe header value %r", retval) return retval - _ARG_DEFAULT = object() - - def get_argument(self, name, default=_ARG_DEFAULT, strip=True): + def get_argument(self, name: str, default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT, + strip: bool=True) -> Optional[str]: """Returns the value of the argument with the given name. If default is not provided, the argument is considered to be @@ -395,7 +410,7 @@ class RequestHandler(object): """ return self._get_argument(name, default, self.request.arguments, strip) - def get_arguments(self, name, strip=True): + def get_arguments(self, name: str, strip: bool=True) -> List[str]: """Returns a list of the arguments with the given name. If the argument is not present, returns an empty list. @@ -410,7 +425,9 @@ class RequestHandler(object): return self._get_arguments(name, self.request.arguments, strip) - def get_body_argument(self, name, default=_ARG_DEFAULT, strip=True): + def get_body_argument(self, name: str, + default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT, + strip: bool=True) -> Optional[str]: """Returns the value of the argument with the given name from the request body. @@ -427,7 +444,7 @@ class RequestHandler(object): return self._get_argument(name, default, self.request.body_arguments, strip) - def get_body_arguments(self, name, strip=True): + def get_body_arguments(self, name: str, strip: bool=True) -> List[str]: """Returns a list of the body arguments with the given name. If the argument is not present, returns an empty list. @@ -438,7 +455,9 @@ class RequestHandler(object): """ return self._get_arguments(name, self.request.body_arguments, strip) - def get_query_argument(self, name, default=_ARG_DEFAULT, strip=True): + def get_query_argument(self, name: str, + default: Union[None, str, _ArgDefaultMarker]=_ARG_DEFAULT, + strip: bool=True) -> Optional[str]: """Returns the value of the argument with the given name from the request query string. @@ -455,7 +474,7 @@ class RequestHandler(object): return self._get_argument(name, default, self.request.query_arguments, strip) - def get_query_arguments(self, name, strip=True): + def get_query_arguments(self, name: str, strip: bool=True) -> List[str]: """Returns a list of the query arguments with the given name. If the argument is not present, returns an empty list. @@ -466,28 +485,30 @@ class RequestHandler(object): """ return self._get_arguments(name, self.request.query_arguments, strip) - def _get_argument(self, name, default, source, strip=True): + def _get_argument(self, name: str, default: Union[None, str, _ArgDefaultMarker], + source: Dict[str, List[bytes]], strip: bool=True) -> Optional[str]: args = self._get_arguments(name, source, strip=strip) if not args: - if default is self._ARG_DEFAULT: + if isinstance(default, _ArgDefaultMarker): raise MissingArgumentError(name) return default return args[-1] - def _get_arguments(self, name, source, strip=True): + def _get_arguments(self, name: str, source: Dict[str, List[bytes]], + strip: bool=True) -> List[str]: values = [] for v in source.get(name, []): - v = self.decode_argument(v, name=name) - if isinstance(v, unicode_type): + s = self.decode_argument(v, name=name) + if isinstance(s, unicode_type): # Get rid of any weird control chars (unless decoding gave # us bytes, in which case leave it alone) - v = RequestHandler._remove_control_chars_regex.sub(" ", v) + s = RequestHandler._remove_control_chars_regex.sub(" ", s) if strip: - v = v.strip() - values.append(v) + s = s.strip() + values.append(s) return values - def decode_argument(self, value, name=None): + def decode_argument(self, value: bytes, name: str=None) -> str: """Decodes an argument from the request. The argument has been percent-decoded and is now a byte string. @@ -507,12 +528,12 @@ class RequestHandler(object): (name or "url", value[:40])) @property - def cookies(self): + def cookies(self) -> Dict[str, http.cookies.Morsel]: """An alias for `self.request.cookies <.httputil.HTTPServerRequest.cookies>`.""" return self.request.cookies - def get_cookie(self, name, default=None): + def get_cookie(self, name: str, default: str=None) -> Optional[str]: """Returns the value of the request cookie with the given name. If the named cookie is not present, returns ``default``. @@ -525,8 +546,10 @@ class RequestHandler(object): return self.request.cookies[name].value return default - def set_cookie(self, name, value, domain=None, expires=None, path="/", - expires_days=None, **kwargs): + def set_cookie(self, name: str, value: Union[str, bytes], domain: str=None, + expires: Union[float, Tuple, datetime.datetime]=None, + path: str="/", + expires_days: int=None, **kwargs: Any) -> None: """Sets an outgoing cookie name/value with the given options. Newly-set cookies are not immediately visible via `get_cookie`; @@ -573,7 +596,7 @@ class RequestHandler(object): morsel[k] = v - def clear_cookie(self, name, path="/", domain=None): + def clear_cookie(self, name: str, path: str="/", domain: str=None) -> None: """Deletes the cookie with the given name. Due to limitations of the cookie protocol, you must pass the same @@ -588,7 +611,7 @@ class RequestHandler(object): self.set_cookie(name, value="", path=path, expires=expires, domain=domain) - def clear_all_cookies(self, path="/", domain=None): + def clear_all_cookies(self, path: str="/", domain: str=None) -> None: """Deletes all the cookies the user sent with this request. See `clear_cookie` for more information on the path and domain @@ -604,8 +627,8 @@ class RequestHandler(object): for name in self.request.cookies: self.clear_cookie(name, path=path, domain=domain) - def set_secure_cookie(self, name, value, expires_days=30, version=None, - **kwargs): + def set_secure_cookie(self, name: str, value: Union[str, bytes], expires_days: int=30, + version: int=None, **kwargs: Any) -> None: """Signs and timestamps a cookie so it cannot be forged. You must specify the ``cookie_secret`` setting in your Application @@ -633,7 +656,7 @@ class RequestHandler(object): version=version), expires_days=expires_days, **kwargs) - def create_signed_value(self, name, value, version=None): + def create_signed_value(self, name: str, value: Union[str, bytes], version: int=None) -> bytes: """Signs and timestamps a string so it cannot be forged. Normally used via set_secure_cookie, but provided as a separate @@ -656,8 +679,8 @@ class RequestHandler(object): return create_signed_value(secret, name, value, version=version, key_version=key_version) - def get_secure_cookie(self, name, value=None, max_age_days=31, - min_version=None): + def get_secure_cookie(self, name: str, value: str=None, max_age_days: int=31, + min_version: int=None) -> Optional[bytes]: """Returns the given signed cookie if it validates, or None. The decoded cookie value is returned as a byte string (unlike @@ -679,7 +702,7 @@ class RequestHandler(object): name, value, max_age_days=max_age_days, min_version=min_version) - def get_secure_cookie_key_version(self, name, value=None): + def get_secure_cookie_key_version(self, name: str, value: str=None) -> Optional[int]: """Returns the signing key version of the secure cookie. The version is returned as int. @@ -687,9 +710,11 @@ class RequestHandler(object): self.require_setting("cookie_secret", "secure cookies") if value is None: value = self.get_cookie(name) + if value is None: + return None return get_signature_key_version(value) - def redirect(self, url, permanent=False, status=None): + def redirect(self, url: str, permanent: bool=False, status: int=None) -> None: """Sends a redirect to the given (optionally relative) URL. If the ``status`` argument is specified, that value is used as the @@ -707,7 +732,7 @@ class RequestHandler(object): self.set_header("Location", utf8(url)) self.finish() - def write(self, chunk): + def write(self, chunk: Union[str, bytes, dict]) -> None: """Writes the given chunk to the output buffer. To write the output to the network, use the flush() method below. @@ -737,7 +762,7 @@ class RequestHandler(object): chunk = utf8(chunk) self._write_buffer.append(chunk) - def render(self, template_name, **kwargs): + def render(self, template_name: str, **kwargs: Any) -> 'Future[None]': """Renders the template with the given arguments as the response. ``render()`` calls ``finish()``, so no other output methods can be called @@ -768,7 +793,7 @@ class RequestHandler(object): file_part = module.javascript_files() if file_part: if isinstance(file_part, (unicode_type, bytes)): - js_files.append(file_part) + js_files.append(_unicode(file_part)) else: js_files.extend(file_part) embed_part = module.embedded_css() @@ -777,7 +802,7 @@ class RequestHandler(object): file_part = module.css_files() if file_part: if isinstance(file_part, (unicode_type, bytes)): - css_files.append(file_part) + css_files.append(_unicode(file_part)) else: css_files.extend(file_part) head_part = module.html_head() @@ -793,17 +818,17 @@ class RequestHandler(object): sloc = html.rindex(b'') html = html[:sloc] + utf8(js) + b'\n' + html[sloc:] if js_embed: - js = self.render_embed_js(js_embed) + js_bytes = self.render_embed_js(js_embed) sloc = html.rindex(b'') - html = html[:sloc] + js + b'\n' + html[sloc:] + html = html[:sloc] + js_bytes + b'\n' + html[sloc:] if css_files: css = self.render_linked_css(css_files) hloc = html.index(b'') html = html[:hloc] + utf8(css) + b'\n' + html[hloc:] if css_embed: - css = self.render_embed_css(css_embed) + css_bytes = self.render_embed_css(css_embed) hloc = html.index(b'') - html = html[:hloc] + css + b'\n' + html[hloc:] + html = html[:hloc] + css_bytes + b'\n' + html[hloc:] if html_heads: hloc = html.index(b'') html = html[:hloc] + b''.join(html_heads) + b'\n' + html[hloc:] @@ -812,14 +837,14 @@ class RequestHandler(object): html = html[:hloc] + b''.join(html_bodies) + b'\n' + html[hloc:] return self.finish(html) - def render_linked_js(self, js_files): + def render_linked_js(self, js_files: Iterable[str]) -> str: """Default method used to render the final js links for the rendered webpage. Override this method in a sub-classed controller to change the output. """ paths = [] - unique_paths = set() + unique_paths = set() # type: Set[str] for path in js_files: if not is_absolute(path): @@ -832,7 +857,7 @@ class RequestHandler(object): '" type="text/javascript">' for p in paths) - def render_embed_js(self, js_embed): + def render_embed_js(self, js_embed: Iterable[bytes]) -> bytes: """Default method used to render the final embedded js for the rendered webpage. @@ -841,14 +866,14 @@ class RequestHandler(object): return b'' - def render_linked_css(self, css_files): + def render_linked_css(self, css_files: Iterable[str]) -> str: """Default method used to render the final css links for the rendered webpage. Override this method in a sub-classed controller to change the output. """ paths = [] - unique_paths = set() + unique_paths = set() # type: Set[str] for path in css_files: if not is_absolute(path): @@ -861,7 +886,7 @@ class RequestHandler(object): 'type="text/css" rel="stylesheet"/>' for p in paths) - def render_embed_css(self, css_embed): + def render_embed_css(self, css_embed: Iterable[bytes]) -> bytes: """Default method used to render the final embedded css for the rendered webpage. @@ -870,7 +895,7 @@ class RequestHandler(object): return b'' - def render_string(self, template_name, **kwargs): + def render_string(self, template_name: str, **kwargs: Any) -> bytes: """Generate the given template with the given arguments. We return the generated byte string (in utf8). To generate and @@ -883,6 +908,7 @@ class RequestHandler(object): web_file = frame.f_code.co_filename while frame.f_code.co_filename == web_file: frame = frame.f_back + assert frame.f_code.co_filename is not None template_path = os.path.dirname(frame.f_code.co_filename) with RequestHandler._template_loader_lock: if template_path not in RequestHandler._template_loaders: @@ -895,7 +921,7 @@ class RequestHandler(object): namespace.update(kwargs) return t.generate(**namespace) - def get_template_namespace(self): + def get_template_namespace(self) -> Dict[str, Any]: """Returns a dictionary to be used as the default template namespace. May be overridden by subclasses to add or modify values. @@ -918,7 +944,7 @@ class RequestHandler(object): namespace.update(self.ui) return namespace - def create_template_loader(self, template_path): + def create_template_loader(self, template_path: str) -> template.BaseLoader: """Returns a new template loader for the given path. May be overridden by subclasses. By default returns a @@ -939,7 +965,7 @@ class RequestHandler(object): kwargs["whitespace"] = settings["template_whitespace"] return template.Loader(template_path, **kwargs) - def flush(self, include_footers=False): + def flush(self, include_footers: bool=False) -> 'Future[None]': """Flushes the current output buffer to the network. The ``callback`` argument, if given, can be used for flow control: @@ -955,18 +981,20 @@ class RequestHandler(object): The ``callback`` argument was removed. """ + assert self.request.connection is not None chunk = b"".join(self._write_buffer) self._write_buffer = [] if not self._headers_written: self._headers_written = True for transform in self._transforms: + assert chunk is not None self._status_code, self._headers, chunk = \ transform.transform_first_chunk( self._status_code, self._headers, chunk, include_footers) # Ignore the chunk and only write the headers for HEAD requests if self.request.method == "HEAD": - chunk = None + chunk = b'' # Finalize the cookie headers (which have been stored in a side # object so an outgoing cookie could be overwritten before it @@ -987,11 +1015,11 @@ class RequestHandler(object): if self.request.method != "HEAD": return self.request.connection.write(chunk) else: - future = Future() + future = Future() # type: Future[None] future.set_result(None) return future - def finish(self, chunk=None): + def finish(self, chunk: Union[str, bytes, dict]=None) -> 'Future[None]': """Finishes this response, ending the HTTP request. Passing a ``chunk`` to ``finish()`` is equivalent to passing that @@ -1030,12 +1058,12 @@ class RequestHandler(object): content_length = sum(len(part) for part in self._write_buffer) self.set_header("Content-Length", content_length) - if hasattr(self.request, "connection"): - # Now that the request is finished, clear the callback we - # set on the HTTPConnection (which would otherwise prevent the - # garbage collection of the RequestHandler when there - # are keepalive connections) - self.request.connection.set_close_callback(None) + assert self.request.connection is not None + # Now that the request is finished, clear the callback we + # set on the HTTPConnection (which would otherwise prevent the + # garbage collection of the RequestHandler when there + # are keepalive connections) + self.request.connection.set_close_callback(None) # type: ignore future = self.flush(include_footers=True) self.request.connection.finish() @@ -1045,7 +1073,7 @@ class RequestHandler(object): self._break_cycles() return future - def detach(self): + def detach(self) -> iostream.IOStream: """Take control of the underlying stream. Returns the underlying `.IOStream` object and stops all @@ -1057,14 +1085,15 @@ class RequestHandler(object): .. versionadded:: 5.1 """ self._finished = True - return self.request.connection.detach() + # TODO: add detach to HTTPConnection? + return self.request.connection.detach() # type: ignore - def _break_cycles(self): + def _break_cycles(self) -> None: # Break up a reference cycle between this handler and the # _ui_module closures to allow for faster GC on CPython. - self.ui = None + self.ui = None # type: ignore - def send_error(self, status_code=500, **kwargs): + def send_error(self, status_code: int=500, **kwargs: Any) -> None: """Sends the given HTTP error code to the browser. If `flush()` has already been called, it is not possible to send @@ -1103,7 +1132,7 @@ class RequestHandler(object): if not self._finished: self.finish() - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs: Any) -> None: """Override to implement custom error pages. ``write_error`` may call `write`, `render`, `set_header`, etc @@ -1129,7 +1158,7 @@ class RequestHandler(object): }) @property - def locale(self): + def locale(self) -> tornado.locale.Locale: """The locale for the current session. Determined by either `get_user_locale`, which you can override to @@ -1141,17 +1170,19 @@ class RequestHandler(object): Added a property setter. """ if not hasattr(self, "_locale"): - self._locale = self.get_user_locale() - if not self._locale: + loc = self.get_user_locale() + if loc is not None: + self._locale = loc + else: self._locale = self.get_browser_locale() assert self._locale return self._locale @locale.setter - def locale(self, value): + def locale(self, value: tornado.locale.Locale) -> None: self._locale = value - def get_user_locale(self): + def get_user_locale(self) -> Optional[tornado.locale.Locale]: """Override to determine the locale from the authenticated user. If None is returned, we fall back to `get_browser_locale()`. @@ -1161,7 +1192,7 @@ class RequestHandler(object): """ return None - def get_browser_locale(self, default="en_US"): + def get_browser_locale(self, default: str="en_US") -> tornado.locale.Locale: """Determines the user's locale from ``Accept-Language`` header. See http://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.4 @@ -1186,7 +1217,7 @@ class RequestHandler(object): return locale.get(default) @property - def current_user(self): + def current_user(self) -> Any: """The authenticated user for this request. This is set in one of two ways: @@ -1222,17 +1253,17 @@ class RequestHandler(object): return self._current_user @current_user.setter - def current_user(self, value): + def current_user(self, value: Any) -> None: self._current_user = value - def get_current_user(self): + def get_current_user(self) -> Any: """Override to determine the current user from, e.g., a cookie. This method may not be a coroutine. """ return None - def get_login_url(self): + def get_login_url(self) -> str: """Override to customize the login URL based on the request. By default, we use the ``login_url`` application setting. @@ -1240,7 +1271,7 @@ class RequestHandler(object): self.require_setting("login_url", "@tornado.web.authenticated") return self.application.settings["login_url"] - def get_template_path(self): + def get_template_path(self) -> Optional[str]: """Override to customize template path for each handler. By default, we use the ``template_path`` application setting. @@ -1249,7 +1280,7 @@ class RequestHandler(object): return self.application.settings.get("template_path") @property - def xsrf_token(self): + def xsrf_token(self) -> bytes: """The XSRF-prevention token for the current user/session. To prevent cross-site request forgery, we set an '_xsrf' cookie @@ -1304,7 +1335,7 @@ class RequestHandler(object): **cookie_kwargs) return self._xsrf_token - def _get_raw_xsrf_token(self): + def _get_raw_xsrf_token(self) -> Tuple[Optional[int], bytes, float]: """Read or generate the xsrf token in its raw form. The raw_xsrf_token is a tuple containing: @@ -1325,10 +1356,13 @@ class RequestHandler(object): version = None token = os.urandom(16) timestamp = time.time() + assert token is not None + assert timestamp is not None self._raw_xsrf_token = (version, token, timestamp) return self._raw_xsrf_token - def _decode_xsrf_token(self, cookie): + def _decode_xsrf_token(self, cookie: str) -> Tuple[ + Optional[int], Optional[bytes], Optional[float]]: """Convert a cookie string into a the tuple form returned by _get_raw_xsrf_token. """ @@ -1339,12 +1373,12 @@ class RequestHandler(object): if m: version = int(m.group(1)) if version == 2: - _, mask, masked_token, timestamp = cookie.split("|") + _, mask_str, masked_token, timestamp_str = cookie.split("|") - mask = binascii.a2b_hex(utf8(mask)) + mask = binascii.a2b_hex(utf8(mask_str)) token = _websocket_mask( mask, binascii.a2b_hex(utf8(masked_token))) - timestamp = int(timestamp) + timestamp = int(timestamp_str) return version, token, timestamp else: # Treat unknown versions as not present instead of failing. @@ -1364,7 +1398,7 @@ class RequestHandler(object): exc_info=True) return None, None, None - def check_xsrf_cookie(self): + def check_xsrf_cookie(self) -> None: """Verifies that the ``_xsrf`` cookie matches the ``_xsrf`` argument. To prevent cross-site request forgery, we set an ``_xsrf`` @@ -1398,10 +1432,10 @@ class RequestHandler(object): _, expected_token, _ = self._get_raw_xsrf_token() if not token: raise HTTPError(403, "'_xsrf' argument has invalid format") - if not _time_independent_equals(utf8(token), utf8(expected_token)): + if not hmac.compare_digest(utf8(token), utf8(expected_token)): raise HTTPError(403, "XSRF cookie does not match POST argument") - def xsrf_form_html(self): + def xsrf_form_html(self) -> str: """An HTML ```` element to be included with all POST forms. It defines the ``_xsrf`` input value, which we check on all POST @@ -1417,7 +1451,7 @@ class RequestHandler(object): return '' - def static_url(self, path, include_host=None, **kwargs): + def static_url(self, path: str, include_host: bool=None, **kwargs: Any) -> str: """Returns a static URL for the given relative static file path. This method requires you set the ``static_path`` setting in your @@ -1452,17 +1486,17 @@ class RequestHandler(object): return base + get_url(self.settings, path, **kwargs) - def require_setting(self, name, feature="this feature"): + def require_setting(self, name: str, feature: str="this feature") -> None: """Raises an exception if the given app setting is not defined.""" if not self.application.settings.get(name): raise Exception("You must define the '%s' setting in your " "application to use %s" % (name, feature)) - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> str: """Alias for `Application.reverse_url`.""" return self.application.reverse_url(name, *args) - def compute_etag(self): + def compute_etag(self) -> Optional[str]: """Computes the etag header to be used for this request. By default uses a hash of the content written so far. @@ -1475,7 +1509,7 @@ class RequestHandler(object): hasher.update(part) return '"%s"' % hasher.hexdigest() - def set_etag_header(self): + def set_etag_header(self) -> None: """Sets the response's Etag header using ``self.compute_etag()``. Note: no header will be set if ``compute_etag()`` returns ``None``. @@ -1486,7 +1520,7 @@ class RequestHandler(object): if etag is not None: self.set_header("Etag", etag) - def check_etag_header(self): + def check_etag_header(self) -> bool: """Checks the ``Etag`` header against requests's ``If-None-Match``. Returns ``True`` if the request's Etag matches and a 304 should be @@ -1518,7 +1552,7 @@ class RequestHandler(object): match = True else: # Use a weak comparison when comparing entity-tags. - def val(x): + def val(x: bytes) -> bytes: return x[2:] if x.startswith(b'W/') else x for etag in etags: @@ -1528,7 +1562,8 @@ class RequestHandler(object): return match @gen.coroutine - def _execute(self, transforms, *args, **kwargs): + def _execute(self, transforms: List['OutputTransform'], *args: bytes, + **kwargs: bytes) -> Generator[Any, Any, None]: """Executes this request with the given output transforms.""" self._transforms = transforms try: @@ -1559,7 +1594,7 @@ class RequestHandler(object): # result; the data has been passed to self.data_received # instead. try: - yield self.request.body + yield self.request._body_future except iostream.StreamClosedError: return @@ -1584,14 +1619,14 @@ class RequestHandler(object): # in a finally block to avoid GC issues prior to Python 3.4. self._prepared_future.set_result(None) - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: """Implement this method to handle streamed request data. Requires the `.stream_request_body` decorator. """ raise NotImplementedError() - def _log(self): + def _log(self) -> None: """Logs the current request. Sort of deprecated since this functionality was moved to the @@ -1600,11 +1635,11 @@ class RequestHandler(object): """ self.application.log_request(self) - def _request_summary(self): + def _request_summary(self) -> str: return "%s %s (%s)" % (self.request.method, self.request.uri, self.request.remote_ip) - def _handle_request_exception(self, e): + def _handle_request_exception(self, e: BaseException) -> None: if isinstance(e, Finish): # Not an error; just finish the request without logging. if not self._finished: @@ -1626,7 +1661,9 @@ class RequestHandler(object): else: self.send_error(500, exc_info=sys.exc_info()) - def log_exception(self, typ, value, tb): + def log_exception(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType]) -> None: """Override to customize logging of uncaught exceptions. By default logs instances of `HTTPError` as warnings without @@ -1643,23 +1680,23 @@ class RequestHandler(object): list(value.args)) gen_log.warning(format, *args) else: - app_log.error("Uncaught exception %s\n%r", self._request_summary(), + app_log.error("Uncaught exception %s\n%r", self._request_summary(), # type: ignore self.request, exc_info=(typ, value, tb)) - def _ui_module(self, name, module): - def render(*args, **kwargs): + def _ui_module(self, name: str, module: Type['UIModule']) -> Callable[..., str]: + def render(*args: Any, **kwargs: Any) -> str: if not hasattr(self, "_active_modules"): - self._active_modules = {} + self._active_modules = {} # type: Dict[str, UIModule] if name not in self._active_modules: self._active_modules[name] = module(self) rendered = self._active_modules[name].render(*args, **kwargs) return rendered return render - def _ui_method(self, method): + def _ui_method(self, method: Callable[..., str]) -> Callable[..., str]: return lambda *args, **kwargs: method(self, *args, **kwargs) - def _clear_headers_for_304(self): + def _clear_headers_for_304(self) -> None: # 304 responses should not contain entity headers (defined in # http://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.1) # not explicitly allowed by @@ -1671,7 +1708,7 @@ class RequestHandler(object): self.clear_header(h) -def stream_request_body(cls): +def stream_request_body(cls: Type[RequestHandler]) -> Type[RequestHandler]: """Apply to `RequestHandler` subclasses to enable streaming body support. This decorator implies the following changes: @@ -1698,13 +1735,15 @@ def stream_request_body(cls): return cls -def _has_stream_request_body(cls): +def _has_stream_request_body(cls: Type[RequestHandler]) -> bool: if not issubclass(cls, RequestHandler): raise TypeError("expected subclass of RequestHandler, got %r", cls) - return getattr(cls, '_stream_request_body', False) + return cls._stream_request_body -def removeslash(method): +def removeslash( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: """Use this decorator to remove trailing slashes from the request path. For example, a request to ``/foo/`` would redirect to ``/foo`` with this @@ -1712,7 +1751,7 @@ def removeslash(method): like ``r'/foo/*'`` in conjunction with using the decorator. """ @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]: if self.request.path.endswith("/"): if self.request.method in ("GET", "HEAD"): uri = self.request.path.rstrip("/") @@ -1720,14 +1759,16 @@ def removeslash(method): if self.request.query: uri += "?" + self.request.query self.redirect(uri, permanent=True) - return + return None else: raise HTTPError(404) return method(self, *args, **kwargs) return wrapper -def addslash(method): +def addslash( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: """Use this decorator to add a missing trailing slash to the request path. For example, a request to ``/foo`` would redirect to ``/foo/`` with this @@ -1735,14 +1776,14 @@ def addslash(method): like ``r'/foo/?'`` in conjunction with using the decorator. """ @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]: if not self.request.path.endswith("/"): if self.request.method in ("GET", "HEAD"): uri = self.request.path + "/" if self.request.query: uri += "?" + self.request.query self.redirect(uri, permanent=True) - return + return None raise HTTPError(404) return method(self, *args, **kwargs) return wrapper @@ -1759,20 +1800,21 @@ class _ApplicationRouter(ReversibleRuleRouter): `_ApplicationRouter` instance. """ - def __init__(self, application, rules=None): + def __init__(self, application: 'Application', rules: _RuleList=None) -> None: assert isinstance(application, Application) self.application = application super(_ApplicationRouter, self).__init__(rules) - def process_rule(self, rule): + def process_rule(self, rule: Rule) -> Rule: rule = super(_ApplicationRouter, self).process_rule(rule) if isinstance(rule.target, (list, tuple)): - rule.target = _ApplicationRouter(self.application, rule.target) + rule.target = _ApplicationRouter(self.application, rule.target) # type: ignore return rule - def get_target_delegate(self, target, request, **target_params): + def get_target_delegate(self, target: Any, request: httputil.HTTPServerRequest, + **target_params: Any) -> Optional[httputil.HTTPMessageDelegate]: if isclass(target) and issubclass(target, RequestHandler): return self.application.get_handler_delegate(request, target, **target_params) @@ -1862,10 +1904,10 @@ class Application(ReversibleRouter): Integration with the new `tornado.routing` module. """ - def __init__(self, handlers=None, default_host=None, transforms=None, - **settings): + def __init__(self, handlers: _RuleList=None, default_host: str=None, + transforms: List[Type['OutputTransform']]=None, **settings: Any) -> None: if transforms is None: - self.transforms = [] + self.transforms = [] # type: List[Type[OutputTransform]] if settings.get("compress_response") or settings.get("gzip"): self.transforms.append(GZipContentEncoding) else: @@ -1876,7 +1918,7 @@ class Application(ReversibleRouter): 'xsrf_form_html': _xsrf_form_html, 'Template': TemplateModule, } - self.ui_methods = {} + self.ui_methods = {} # type: Dict[str, Callable[..., str]] self._load_ui_modules(settings.get("ui_modules", {})) self._load_ui_methods(settings.get("ui_methods", {})) if self.settings.get("static_path"): @@ -1909,7 +1951,7 @@ class Application(ReversibleRouter): from tornado import autoreload autoreload.start() - def listen(self, port, address="", **kwargs): + def listen(self, port: int, address: str="", **kwargs: Any) -> HTTPServer: """Starts an HTTP server for this application on the given port. This is a convenience alias for creating an `.HTTPServer` @@ -1932,7 +1974,7 @@ class Application(ReversibleRouter): server.listen(port, address) return server - def add_handlers(self, host_pattern, host_handlers): + def add_handlers(self, host_pattern: str, host_handlers: _RuleList) -> None: """Appends the given handlers to our handler list. Host patterns are processed sequentially in the order they were @@ -1949,10 +1991,10 @@ class Application(ReversibleRouter): host_handlers )]) - def add_transform(self, transform_class): + def add_transform(self, transform_class: Type['OutputTransform']) -> None: self.transforms.append(transform_class) - def _load_ui_methods(self, methods): + def _load_ui_methods(self, methods: Any) -> None: if isinstance(methods, types.ModuleType): self._load_ui_methods(dict((n, getattr(methods, n)) for n in dir(methods))) @@ -1965,7 +2007,7 @@ class Application(ReversibleRouter): and name[0].lower() == name[0]: self.ui_methods[name] = fn - def _load_ui_modules(self, modules): + def _load_ui_modules(self, modules: Any) -> None: if isinstance(modules, types.ModuleType): self._load_ui_modules(dict((n, getattr(modules, n)) for n in dir(modules))) @@ -1981,15 +2023,16 @@ class Application(ReversibleRouter): except TypeError: pass - def __call__(self, request): + def __call__(self, request: httputil.HTTPServerRequest) -> Optional[Awaitable[None]]: # Legacy HTTPServer interface dispatcher = self.find_handler(request) return dispatcher.execute() - def find_handler(self, request, **kwargs): + def find_handler(self, request: httputil.HTTPServerRequest, + **kwargs: Any) -> '_HandlerDelegate': route = self.default_router.find_handler(request) if route is not None: - return route + return cast('_HandlerDelegate', route) if self.settings.get('default_handler_class'): return self.get_handler_delegate( @@ -2000,8 +2043,11 @@ class Application(ReversibleRouter): return self.get_handler_delegate( request, ErrorHandler, {'status_code': 404}) - def get_handler_delegate(self, request, target_class, target_kwargs=None, - path_args=None, path_kwargs=None): + def get_handler_delegate(self, request: httputil.HTTPServerRequest, + target_class: Type[RequestHandler], + target_kwargs: Dict[str, Any]=None, + path_args: List[bytes]=None, + path_kwargs: Dict[str, bytes]=None) -> '_HandlerDelegate': """Returns `~.httputil.HTTPMessageDelegate` that can serve a request for application and `RequestHandler` subclass. @@ -2015,7 +2061,7 @@ class Application(ReversibleRouter): return _HandlerDelegate( self, request, target_class, target_kwargs, path_args, path_kwargs) - def reverse_url(self, name, *args): + def reverse_url(self, name: str, *args: Any) -> str: """Returns a URL path for handler named ``name`` The handler must be added to the application as a named `URLSpec`. @@ -2030,7 +2076,7 @@ class Application(ReversibleRouter): raise KeyError("%s not found in named urls" % name) - def log_request(self, handler): + def log_request(self, handler: RequestHandler) -> None: """Writes a completed HTTP request to the logs. By default writes to the python root logger. To change @@ -2053,8 +2099,9 @@ class Application(ReversibleRouter): class _HandlerDelegate(httputil.HTTPMessageDelegate): - def __init__(self, application, request, handler_class, handler_kwargs, - path_args, path_kwargs): + def __init__(self, application: Application, request: httputil.HTTPServerRequest, + handler_class: Type[RequestHandler], handler_kwargs: Optional[Dict[str, Any]], + path_args: Optional[List[bytes]], path_kwargs: Optional[Dict[str, bytes]]) -> None: self.application = application self.connection = request.connection self.request = request @@ -2062,35 +2109,39 @@ class _HandlerDelegate(httputil.HTTPMessageDelegate): self.handler_kwargs = handler_kwargs or {} self.path_args = path_args or [] self.path_kwargs = path_kwargs or {} - self.chunks = [] + self.chunks = [] # type: List[bytes] self.stream_request_body = _has_stream_request_body(self.handler_class) - def headers_received(self, start_line, headers): + def headers_received(self, start_line: Union[httputil.RequestStartLine, + httputil.ResponseStartLine], + headers: httputil.HTTPHeaders) -> Optional[Awaitable[None]]: if self.stream_request_body: - self.request.body = Future() + self.request._body_future = Future() return self.execute() + return None - def data_received(self, data): + def data_received(self, data: bytes) -> Optional[Awaitable[None]]: if self.stream_request_body: return self.handler.data_received(data) else: self.chunks.append(data) + return None - def finish(self): + def finish(self) -> None: if self.stream_request_body: - future_set_result_unless_cancelled(self.request.body, None) + future_set_result_unless_cancelled(self.request._body_future, None) else: self.request.body = b''.join(self.chunks) self.request._parse_body() self.execute() - def on_connection_close(self): + def on_connection_close(self) -> None: if self.stream_request_body: self.handler.on_connection_close() else: - self.chunks = None + self.chunks = None # type: ignore - def execute(self): + def execute(self) -> Optional[Awaitable[None]]: # If template cache is disabled (usually in the debug mode), # re-compile templates and reload static files on every # request so you don't need to restart to see changes @@ -2144,7 +2195,8 @@ class HTTPError(Exception): determined automatically from ``status_code``, but can be used to use a non-standard numeric code. """ - def __init__(self, status_code=500, log_message=None, *args, **kwargs): + def __init__(self, status_code: int=500, log_message: str=None, + *args: Any, **kwargs: Any) -> None: self.status_code = status_code self.log_message = log_message self.args = args @@ -2152,7 +2204,7 @@ class HTTPError(Exception): if log_message and not args: self.log_message = log_message.replace('%', '%%') - def __str__(self): + def __str__(self) -> str: message = "HTTP %d: %s" % ( self.status_code, self.reason or httputil.responses.get(self.status_code, 'Unknown')) @@ -2197,7 +2249,7 @@ class MissingArgumentError(HTTPError): .. versionadded:: 3.1 """ - def __init__(self, arg_name): + def __init__(self, arg_name: str) -> None: super(MissingArgumentError, self).__init__( 400, 'Missing argument %s' % arg_name) self.arg_name = arg_name @@ -2205,13 +2257,13 @@ class MissingArgumentError(HTTPError): class ErrorHandler(RequestHandler): """Generates an error response with ``status_code`` for all requests.""" - def initialize(self, status_code): + def initialize(self, status_code: int) -> None: # type: ignore self.set_status(status_code) - def prepare(self): + def prepare(self) -> None: raise HTTPError(self._status_code) - def check_xsrf_cookie(self): + def check_xsrf_cookie(self) -> None: # POSTs to an ErrorHandler don't actually have side effects, # so we don't need to check the xsrf token. This allows POSTs # to the wrong url to return a 404 instead of 403. @@ -2250,15 +2302,16 @@ class RedirectHandler(RequestHandler): If any query arguments are present, they will be copied to the destination URL. """ - def initialize(self, url, permanent=True): + def initialize(self, url: str, permanent: bool=True) -> None: # type: ignore self._url = url self._permanent = permanent - def get(self, *args): + def get(self, *args: Any) -> None: # type: ignore to_url = self._url.format(*args) if self.request.query_arguments: + # TODO: figure out typing for the next line. to_url = httputil.url_concat( - to_url, list(httputil.qs_to_qsl(self.request.query_arguments))) + to_url, list(httputil.qs_to_qsl(self.request.query_arguments))) # type: ignore self.redirect(to_url, permanent=self._permanent) @@ -2330,23 +2383,23 @@ class StaticFileHandler(RequestHandler): """ CACHE_MAX_AGE = 86400 * 365 * 10 # 10 years - _static_hashes = {} # type: typing.Dict + _static_hashes = {} # type: Dict[str, Optional[str]] _lock = threading.Lock() # protects _static_hashes - def initialize(self, path, default_filename=None): + def initialize(self, path: str, default_filename: str=None) -> None: # type: ignore self.root = path self.default_filename = default_filename @classmethod - def reset(cls): + def reset(cls) -> None: with cls._lock: cls._static_hashes = {} - def head(self, path): + def head(self, path: str) -> 'Future[None]': # type: ignore return self.get(path, include_body=False) @gen.coroutine - def get(self, path, include_body=True): + def get(self, path: str, include_body: bool=True) -> Generator[Any, Any, None]: # Set up our path instance variables. self.path = self.parse_url_path(path) del path # make sure we don't refer to path instead of self.path again @@ -2421,7 +2474,7 @@ class StaticFileHandler(RequestHandler): else: assert self.request.method == "HEAD" - def compute_etag(self): + def compute_etag(self) -> Optional[str]: """Sets the ``Etag`` header based on static url version. This allows efficient ``If-None-Match`` checks against cached @@ -2430,12 +2483,13 @@ class StaticFileHandler(RequestHandler): .. versionadded:: 3.1 """ + assert self.absolute_path is not None version_hash = self._get_cached_version(self.absolute_path) if not version_hash: return None return '"%s"' % (version_hash, ) - def set_headers(self): + def set_headers(self) -> None: """Sets the content and caching headers on the response. .. versionadded:: 3.1 @@ -2459,7 +2513,7 @@ class StaticFileHandler(RequestHandler): self.set_extra_headers(self.path) - def should_return_304(self): + def should_return_304(self) -> bool: """Returns True if the headers indicate that we should return 304. .. versionadded:: 3.1 @@ -2475,13 +2529,14 @@ class StaticFileHandler(RequestHandler): date_tuple = email.utils.parsedate(ims_value) if date_tuple is not None: if_since = datetime.datetime(*date_tuple[:6]) + assert self.modified is not None if if_since >= self.modified: return True return False @classmethod - def get_absolute_path(cls, root, path): + def get_absolute_path(cls, root: str, path: str) -> str: """Returns the absolute location of ``path`` relative to ``root``. ``root`` is the path configured for this `StaticFileHandler` @@ -2497,7 +2552,7 @@ class StaticFileHandler(RequestHandler): abspath = os.path.abspath(os.path.join(root, path)) return abspath - def validate_absolute_path(self, root, absolute_path): + def validate_absolute_path(self, root: str, absolute_path: str) -> Optional[str]: """Validate and return the absolute path. ``root`` is the configured path for the `StaticFileHandler`, @@ -2541,7 +2596,7 @@ class StaticFileHandler(RequestHandler): # trimmed by the routing if not self.request.path.endswith("/"): self.redirect(self.request.path + "/", permanent=True) - return + return None absolute_path = os.path.join(absolute_path, self.default_filename) if not os.path.exists(absolute_path): raise HTTPError(404) @@ -2550,7 +2605,8 @@ class StaticFileHandler(RequestHandler): return absolute_path @classmethod - def get_content(cls, abspath, start=None, end=None): + def get_content(cls, abspath: str, + start: int=None, end: int=None) -> Generator[bytes, None, None]: """Retrieve the content of the requested resource which is located at the given absolute path. @@ -2569,7 +2625,7 @@ class StaticFileHandler(RequestHandler): if start is not None: file.seek(start) if end is not None: - remaining = end - (start or 0) + remaining = end - (start or 0) # type: Optional[int] else: remaining = None while True: @@ -2587,7 +2643,7 @@ class StaticFileHandler(RequestHandler): return @classmethod - def get_content_version(cls, abspath): + def get_content_version(cls, abspath: str) -> str: """Returns a version string for the resource at the given path. This class method may be overridden by subclasses. The @@ -2604,12 +2660,13 @@ class StaticFileHandler(RequestHandler): hasher.update(chunk) return hasher.hexdigest() - def _stat(self): + def _stat(self) -> os.stat_result: + assert self.absolute_path is not None if not hasattr(self, '_stat_result'): self._stat_result = os.stat(self.absolute_path) return self._stat_result - def get_content_size(self): + def get_content_size(self) -> int: """Retrieve the total size of the resource at the given path. This method may be overridden by subclasses. @@ -2621,9 +2678,9 @@ class StaticFileHandler(RequestHandler): partial results are requested. """ stat_result = self._stat() - return stat_result[stat.ST_SIZE] + return stat_result.st_size - def get_modified_time(self): + def get_modified_time(self) -> Optional[datetime.datetime]: """Returns the time that ``self.absolute_path`` was last modified. May be overridden in subclasses. Should return a `~datetime.datetime` @@ -2632,15 +2689,24 @@ class StaticFileHandler(RequestHandler): .. versionadded:: 3.1 """ stat_result = self._stat() + # NOTE: Historically, this used stat_result[stat.ST_MTIME], + # which truncates the fractional portion of the timestamp. It + # was changed from that form to stat_result.st_mtime to + # satisfy mypy (which disallows the bracket operator), but the + # latter form returns a float instead of an int. For + # consistency with the past (and because we have a unit test + # that relies on this), we truncate the float here, although + # I'm not sure that's the right thing to do. modified = datetime.datetime.utcfromtimestamp( - stat_result[stat.ST_MTIME]) + int(stat_result.st_mtime)) return modified - def get_content_type(self): + def get_content_type(self) -> str: """Returns the ``Content-Type`` header to be used for this request. .. versionadded:: 3.1 """ + assert self.absolute_path is not None mime_type, encoding = mimetypes.guess_type(self.absolute_path) # per RFC 6713, use the appropriate type for a gzip compressed file if encoding == "gzip": @@ -2656,11 +2722,12 @@ class StaticFileHandler(RequestHandler): else: return "application/octet-stream" - def set_extra_headers(self, path): + def set_extra_headers(self, path: str) -> None: """For subclass to add extra headers to the response""" pass - def get_cache_time(self, path, modified, mime_type): + def get_cache_time(self, path: str, modified: Optional[datetime.datetime], + mime_type: str) -> int: """Override to customize cache control behavior. Return a positive number of seconds to make the result @@ -2674,7 +2741,8 @@ class StaticFileHandler(RequestHandler): return self.CACHE_MAX_AGE if "v" in self.request.arguments else 0 @classmethod - def make_static_url(cls, settings, path, include_version=True): + def make_static_url(cls, settings: Dict[str, Any], path: str, + include_version: bool=True) -> str: """Constructs a versioned url for the given path. This method may be overridden in subclasses (but note that it @@ -2703,7 +2771,7 @@ class StaticFileHandler(RequestHandler): return '%s?v=%s' % (url, version_hash) - def parse_url_path(self, url_path): + def parse_url_path(self, url_path: str) -> str: """Converts a static URL path into a filesystem path. ``url_path`` is the path component of the URL with @@ -2717,7 +2785,7 @@ class StaticFileHandler(RequestHandler): return url_path @classmethod - def get_version(cls, settings, path): + def get_version(cls, settings: Dict[str, Any], path: str) -> Optional[str]: """Generate the version string to be used in static URLs. ``settings`` is the `Application.settings` dictionary and ``path`` @@ -2734,7 +2802,7 @@ class StaticFileHandler(RequestHandler): return cls._get_cached_version(abs_path) @classmethod - def _get_cached_version(cls, abs_path): + def _get_cached_version(cls, abs_path: str) -> Optional[str]: with cls._lock: hashes = cls._static_hashes if abs_path not in hashes: @@ -2765,10 +2833,11 @@ class FallbackHandler(RequestHandler): (r".*", FallbackHandler, dict(fallback=wsgi_app), ]) """ - def initialize(self, fallback): + def initialize(self, # type: ignore + fallback: Callable[[httputil.HTTPServerRequest], None]) -> None: self.fallback = fallback - def prepare(self): + def prepare(self) -> None: self.fallback(self.request) self._finished = True self.on_finish() @@ -2781,14 +2850,16 @@ class OutputTransform(object): or interact with them directly; the framework chooses which transforms (if any) to apply. """ - def __init__(self, request): + def __init__(self, request: httputil.HTTPServerRequest) -> None: pass - def transform_first_chunk(self, status_code, headers, chunk, finishing): - # type: (int, httputil.HTTPHeaders, bytes, bool) -> typing.Tuple[int, httputil.HTTPHeaders, bytes] # noqa: E501 + def transform_first_chunk( + self, status_code: int, headers: httputil.HTTPHeaders, + chunk: bytes, finishing: bool + ) -> Tuple[int, httputil.HTTPHeaders, bytes]: return status_code, headers, chunk - def transform_chunk(self, chunk, finishing): + def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: return chunk @@ -2819,14 +2890,16 @@ class GZipContentEncoding(OutputTransform): # regardless of size. MIN_LENGTH = 1024 - def __init__(self, request): + def __init__(self, request: httputil.HTTPServerRequest) -> None: self._gzipping = "gzip" in request.headers.get("Accept-Encoding", "") - def _compressible_type(self, ctype): + def _compressible_type(self, ctype: str) -> bool: return ctype.startswith('text/') or ctype in self.CONTENT_TYPES - def transform_first_chunk(self, status_code, headers, chunk, finishing): - # type: (int, httputil.HTTPHeaders, bytes, bool) -> typing.Tuple[int, httputil.HTTPHeaders, bytes] # noqa: E501 + def transform_first_chunk( + self, status_code: int, headers: httputil.HTTPHeaders, + chunk: bytes, finishing: bool + ) -> Tuple[int, httputil.HTTPHeaders, bytes]: # TODO: can/should this type be inherited from the superclass? if 'Vary' in headers: headers['Vary'] += ', Accept-Encoding' @@ -2854,7 +2927,7 @@ class GZipContentEncoding(OutputTransform): del headers["Content-Length"] return status_code, headers, chunk - def transform_chunk(self, chunk, finishing): + def transform_chunk(self, chunk: bytes, finishing: bool) -> bytes: if self._gzipping: self._gzip_file.write(chunk) if finishing: @@ -2867,7 +2940,9 @@ class GZipContentEncoding(OutputTransform): return chunk -def authenticated(method): +def authenticated( + method: Callable[..., Optional[Awaitable[None]]] +) -> Callable[..., Optional[Awaitable[None]]]: """Decorate methods with this to require that the user be logged in. If the user is not logged in, they will be redirected to the configured @@ -2879,7 +2954,7 @@ def authenticated(method): you once you're logged in. """ @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper(self: RequestHandler, *args: Any, **kwargs: Any) -> Optional[Awaitable[None]]: if not self.current_user: if self.request.method in ("GET", "HEAD"): url = self.get_login_url() @@ -2888,10 +2963,11 @@ def authenticated(method): # if login url is absolute, make next absolute too next_url = self.request.full_url() else: + assert self.request.uri is not None next_url = self.request.uri url += "?" + urlencode(dict(next=next_url)) self.redirect(url) - return + return None raise HTTPError(403) return method(self, *args, **kwargs) return wrapper @@ -2906,26 +2982,26 @@ class UIModule(object): Subclasses of UIModule must override the `render` method. """ - def __init__(self, handler): + def __init__(self, handler: RequestHandler) -> None: self.handler = handler self.request = handler.request self.ui = handler.ui self.locale = handler.locale @property - def current_user(self): + def current_user(self) -> Any: return self.handler.current_user - def render(self, *args, **kwargs): + def render(self, *args: Any, **kwargs: Any) -> str: """Override in subclasses to return this module's output.""" raise NotImplementedError() - def embedded_javascript(self): + def embedded_javascript(self) -> Optional[str]: """Override to return a JavaScript string to be embedded in the page.""" return None - def javascript_files(self): + def javascript_files(self) -> Optional[Iterable[str]]: """Override to return a list of JavaScript files needed by this module. If the return values are relative paths, they will be passed to @@ -2933,12 +3009,12 @@ class UIModule(object): """ return None - def embedded_css(self): + def embedded_css(self) -> Optional[str]: """Override to return a CSS string that will be embedded in the page.""" return None - def css_files(self): + def css_files(self) -> Optional[Iterable[str]]: """Override to returns a list of CSS files required by this module. If the return values are relative paths, they will be passed to @@ -2946,30 +3022,30 @@ class UIModule(object): """ return None - def html_head(self): + def html_head(self) -> Optional[str]: """Override to return an HTML string that will be put in the element. """ return None - def html_body(self): + def html_body(self) -> Optional[str]: """Override to return an HTML string that will be put at the end of the element. """ return None - def render_string(self, path, **kwargs): + def render_string(self, path: str, **kwargs: Any) -> bytes: """Renders a template and returns it as a string.""" return self.handler.render_string(path, **kwargs) class _linkify(UIModule): - def render(self, text, **kwargs): + def render(self, text: str, **kwargs: Any) -> str: # type: ignore return escape.linkify(text, **kwargs) class _xsrf_form_html(UIModule): - def render(self): + def render(self) -> str: # type: ignore return self.handler.xsrf_form_html() @@ -2988,14 +3064,14 @@ class TemplateModule(UIModule): per instantiation of the template, so they must not depend on any arguments to the template. """ - def __init__(self, handler): + def __init__(self, handler: RequestHandler) -> None: super(TemplateModule, self).__init__(handler) # keep resources in both a list and a dict to preserve order - self._resource_list = [] - self._resource_dict = {} + self._resource_list = [] # type: List[Dict[str, Any]] + self._resource_dict = {} # type: Dict[str, Dict[str, Any]] - def render(self, path, **kwargs): - def set_resources(**kwargs): + def render(self, path: str, **kwargs: Any) -> bytes: # type: ignore + def set_resources(**kwargs: Any) -> str: if path not in self._resource_dict: self._resource_list.append(kwargs) self._resource_dict[path] = kwargs @@ -3007,13 +3083,13 @@ class TemplateModule(UIModule): return self.render_string(path, set_resources=set_resources, **kwargs) - def _get_resources(self, key): + def _get_resources(self, key: str) -> Iterable[str]: return (r[key] for r in self._resource_list if key in r) - def embedded_javascript(self): + def embedded_javascript(self) -> str: return "\n".join(self._get_resources("embedded_javascript")) - def javascript_files(self): + def javascript_files(self) -> Iterable[str]: result = [] for f in self._get_resources("javascript_files"): if isinstance(f, (unicode_type, bytes)): @@ -3022,10 +3098,10 @@ class TemplateModule(UIModule): result.extend(f) return result - def embedded_css(self): + def embedded_css(self) -> str: return "\n".join(self._get_resources("embedded_css")) - def css_files(self): + def css_files(self) -> Iterable[str]: result = [] for f in self._get_resources("css_files"): if isinstance(f, (unicode_type, bytes)): @@ -3034,47 +3110,33 @@ class TemplateModule(UIModule): result.extend(f) return result - def html_head(self): + def html_head(self) -> str: return "".join(self._get_resources("html_head")) - def html_body(self): + def html_body(self) -> str: return "".join(self._get_resources("html_body")) class _UIModuleNamespace(object): """Lazy namespace which creates UIModule proxies bound to a handler.""" - def __init__(self, handler, ui_modules): + def __init__(self, handler: RequestHandler, ui_modules: Dict[str, Type[UIModule]]) -> None: self.handler = handler self.ui_modules = ui_modules - def __getitem__(self, key): + def __getitem__(self, key: str) -> Callable[..., str]: return self.handler._ui_module(key, self.ui_modules[key]) - def __getattr__(self, key): + def __getattr__(self, key: str) -> Callable[..., str]: try: return self[key] except KeyError as e: raise AttributeError(str(e)) -if hasattr(hmac, 'compare_digest'): # python 3.3 - _time_independent_equals = hmac.compare_digest -else: - def _time_independent_equals(a, b): - if len(a) != len(b): - return False - result = 0 - if isinstance(a[0], int): # python3 byte strings - for x, y in zip(a, b): - result |= x ^ y - else: # python2 - for x, y in zip(a, b): - result |= ord(x) ^ ord(y) - return result == 0 - - -def create_signed_value(secret, name, value, version=None, clock=None, - key_version=None): +def create_signed_value(secret: _CookieSecretTypes, + name: str, value: Union[str, bytes], + version: int=None, clock: Callable[[], float]=None, + key_version: int=None) -> bytes: if version is None: version = DEFAULT_SIGNED_VALUE_VERSION if clock is None: @@ -3083,6 +3145,7 @@ def create_signed_value(secret, name, value, version=None, clock=None, timestamp = utf8(str(int(clock()))) value = base64.b64encode(utf8(value)) if version == 1: + assert not isinstance(secret, dict) signature = _create_signature_v1(secret, name, value, timestamp) value = b"|".join([value, timestamp, signature]) return value @@ -3101,7 +3164,7 @@ def create_signed_value(secret, name, value, version=None, clock=None, # - name (not encoded; assumed to be ~alphanumeric) # - value (base64-encoded) # - signature (hex-encoded; no length prefix) - def format_field(s): + def format_field(s: Union[str, bytes]) -> bytes: return utf8("%d:" % len(s)) + utf8(s) to_sign = b"|".join([ b"2", @@ -3127,7 +3190,7 @@ def create_signed_value(secret, name, value, version=None, clock=None, _signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$") -def _get_version(value): +def _get_version(value: bytes) -> int: # Figures out what version value is. Version 1 did not include an # explicit version field and started with arbitrary base64 data, # which makes this tricky. @@ -3150,8 +3213,9 @@ def _get_version(value): return version -def decode_signed_value(secret, name, value, max_age_days=31, - clock=None, min_version=None): +def decode_signed_value(secret: _CookieSecretTypes, + name: str, value: Union[None, str, bytes], max_age_days: int=31, + clock: Callable[[], float]=None, min_version: int=None) -> Optional[bytes]: if clock is None: clock = time.time if min_version is None: @@ -3167,6 +3231,7 @@ def decode_signed_value(secret, name, value, max_age_days=31, if version < min_version: return None if version == 1: + assert not isinstance(secret, dict) return _decode_signed_value_v1(secret, name, value, max_age_days, clock) elif version == 2: @@ -3176,12 +3241,13 @@ def decode_signed_value(secret, name, value, max_age_days=31, return None -def _decode_signed_value_v1(secret, name, value, max_age_days, clock): +def _decode_signed_value_v1(secret: Union[str, bytes], name: str, value: bytes, max_age_days: int, + clock: Callable[[], float]) -> Optional[bytes]: parts = utf8(value).split(b"|") if len(parts) != 3: return None signature = _create_signature_v1(secret, name, parts[0], parts[1]) - if not _time_independent_equals(parts[2], signature): + if not hmac.compare_digest(parts[2], signature): gen_log.warning("Invalid cookie signature %r", value) return None timestamp = int(parts[1]) @@ -3206,8 +3272,8 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock): return None -def _decode_fields_v2(value): - def _consume_field(s): +def _decode_fields_v2(value: bytes) -> Tuple[int, bytes, bytes, bytes, bytes]: + def _consume_field(s: bytes) -> Tuple[bytes, bytes]: length, _, rest = s.partition(b':') n = int(length) field_value = rest[:n] @@ -3226,9 +3292,11 @@ def _decode_fields_v2(value): return int(key_version), timestamp, name_field, value_field, passed_sig -def _decode_signed_value_v2(secret, name, value, max_age_days, clock): +def _decode_signed_value_v2(secret: _CookieSecretTypes, + name: str, value: bytes, max_age_days: int, + clock: Callable[[], float]) -> Optional[bytes]: try: - key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value) + key_version, timestamp_bytes, name_field, value_field, passed_sig = _decode_fields_v2(value) except ValueError: return None signed_string = value[:-len(passed_sig)] @@ -3240,11 +3308,11 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): return None expected_sig = _create_signature_v2(secret, signed_string) - if not _time_independent_equals(passed_sig, expected_sig): + if not hmac.compare_digest(passed_sig, expected_sig): return None if name_field != utf8(name): return None - timestamp = int(timestamp) + timestamp = int(timestamp_bytes) if timestamp < clock() - max_age_days * 86400: # The signature has expired. return None @@ -3254,7 +3322,7 @@ def _decode_signed_value_v2(secret, name, value, max_age_days, clock): return None -def get_signature_key_version(value): +def get_signature_key_version(value: Union[str, bytes]) -> Optional[int]: value = utf8(value) version = _get_version(value) if version < 2: @@ -3267,18 +3335,18 @@ def get_signature_key_version(value): return key_version -def _create_signature_v1(secret, *parts): +def _create_signature_v1(secret: Union[str, bytes], *parts: Union[str, bytes]) -> bytes: hash = hmac.new(utf8(secret), digestmod=hashlib.sha1) for part in parts: hash.update(utf8(part)) return utf8(hash.hexdigest()) -def _create_signature_v2(secret, s): +def _create_signature_v2(secret: Union[str, bytes], s: bytes) -> bytes: hash = hmac.new(utf8(secret), digestmod=hashlib.sha256) hash.update(utf8(s)) return utf8(hash.hexdigest()) -def is_absolute(path): +def is_absolute(path: str) -> bool: return any(path.startswith(x) for x in ["/", "http:", "https:"])