]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
httputil: Type-annotate all methods
authorBen Darnell <ben@bendarnell.com>
Sat, 21 Jul 2018 22:07:16 +0000 (18:07 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 21 Jul 2018 23:09:05 +0000 (19:09 -0400)
setup.cfg
tornado/escape.py
tornado/httputil.py

index da6ebfab97bd6d176a8c443964850245fd5d32e6..c5fc02484d2da2ceb5190fa371a6e4c10a1b1063 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -3,3 +3,6 @@ python_version = 3.5
 
 [mypy-tornado.util]
 disallow_untyped_defs = True
+
+[mypy-tornado.httputil]
+disallow_untyped_defs = True
index 382133f5b5920b47759999d9683b79780ce7d5b7..17754b3613438dee0bf113087c458565ff6f44c2 100644 (file)
@@ -142,8 +142,22 @@ def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
 _UTF8_TYPES = (bytes, type(None))
 
 
-def utf8(value):
-    # type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
+@typing.overload
+def utf8(value: bytes) -> bytes:
+    pass
+
+
+@typing.overload  # noqa: F811
+def utf8(value: str) -> bytes:
+    pass
+
+
+@typing.overload  # noqa: F811
+def utf8(value: None) -> None:
+    pass
+
+
+def utf8(value):  # noqa: F811
     """Converts a string argument to a byte string.
 
     If the argument is already a byte string or None, it is returned unchanged.
index 0e498235b03991c979bafcd065c15b3ed49e2c58..72f1eb6b73920967d412ccaea63b44780e1d4695 100644 (file)
@@ -25,7 +25,7 @@ import copy
 import datetime
 import email.utils
 from http.client import responses
-import http.cookies as Cookie
+import http.cookies
 import numbers
 import re
 from ssl import SSLError
@@ -42,7 +42,13 @@ from tornado.util import ObjectDict, unicode_type
 # Reference it so pyflakes doesn't complain.
 responses
 
-import typing  # noqa: F401
+import typing
+from typing import (Tuple, Iterable, List, Mapping, Iterator, Dict, Union, Optional,
+                    Awaitable, Generator)
+
+if typing.TYPE_CHECKING:
+    from typing import Deque  # noqa
+    import unittest  # noqa
 
 
 # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
@@ -61,12 +67,12 @@ class _NormalizedHeaderCache(dict):
     >>> normalized_headers["coNtent-TYPE"]
     'Content-Type'
     """
-    def __init__(self, size):
+    def __init__(self, size: int) -> None:
         super(_NormalizedHeaderCache, self).__init__()
         self.size = size
-        self.queue = collections.deque()
+        self.queue = collections.deque()  # type: Deque[str]
 
-    def __missing__(self, key):
+    def __missing__(self, key: str) -> str:
         normalized = "-".join([w.capitalize() for w in key.split("-")])
         self[key] = normalized
         self.queue.append(key)
@@ -110,7 +116,19 @@ class HTTPHeaders(collections.MutableMapping):
     Set-Cookie: A=B
     Set-Cookie: C=D
     """
-    def __init__(self, *args, **kwargs):
+    @typing.overload
+    def __init__(self, __arg: Mapping[str, List[str]]) -> None:
+        pass
+
+    @typing.overload  # noqa: F811
+    def __init__(self, *args: Tuple[str, str]) -> None:
+        pass
+
+    @typing.overload  # noqa: F811
+    def __init__(self, **kwargs: str) -> None:
+        pass
+
+    def __init__(self, *args: typing.Any, **kwargs: str) -> None:  # noqa: F811
         self._dict = {}  # type: typing.Dict[str, str]
         self._as_list = {}  # type: typing.Dict[str, typing.List[str]]
         self._last_key = None
@@ -125,8 +143,7 @@ class HTTPHeaders(collections.MutableMapping):
 
     # new public methods
 
-    def add(self, name, value):
-        # type: (str, str) -> None
+    def add(self, name: str, value: str) -> None:
         """Adds a new value for the given key."""
         norm_name = _normalized_headers[name]
         self._last_key = norm_name
@@ -137,13 +154,12 @@ class HTTPHeaders(collections.MutableMapping):
         else:
             self[norm_name] = value
 
-    def get_list(self, name):
+    def get_list(self, name: str) -> List[str]:
         """Returns all values for the given header as a list."""
         norm_name = _normalized_headers[name]
         return self._as_list.get(norm_name, [])
 
-    def get_all(self):
-        # type: () -> typing.Iterable[typing.Tuple[str, str]]
+    def get_all(self) -> Iterable[Tuple[str, str]]:
         """Returns an iterable of all (name, value) pairs.
 
         If a header has multiple values, multiple pairs will be
@@ -153,7 +169,7 @@ class HTTPHeaders(collections.MutableMapping):
             for value in values:
                 yield (name, value)
 
-    def parse_line(self, line):
+    def parse_line(self, line: str) -> None:
         """Updates the dictionary with a single header line.
 
         >>> h = HTTPHeaders()
@@ -176,7 +192,7 @@ class HTTPHeaders(collections.MutableMapping):
             self.add(name, value.strip())
 
     @classmethod
-    def parse(cls, headers):
+    def parse(cls, headers: str) -> 'HTTPHeaders':
         """Returns a dictionary from HTTP header text.
 
         >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
@@ -197,27 +213,26 @@ class HTTPHeaders(collections.MutableMapping):
 
     # MutableMapping abstract method implementations.
 
-    def __setitem__(self, name, value):
+    def __setitem__(self, name: str, value: str) -> None:
         norm_name = _normalized_headers[name]
         self._dict[norm_name] = value
         self._as_list[norm_name] = [value]
 
-    def __getitem__(self, name):
-        # type: (str) -> str
+    def __getitem__(self, name: str) -> str:
         return self._dict[_normalized_headers[name]]
 
-    def __delitem__(self, name):
+    def __delitem__(self, name: str) -> None:
         norm_name = _normalized_headers[name]
         del self._dict[norm_name]
         del self._as_list[norm_name]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._dict)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[typing.Any]:
         return iter(self._dict)
 
-    def copy(self):
+    def copy(self) -> 'HTTPHeaders':
         # defined in dict but not in MutableMapping.
         return HTTPHeaders(self)
 
@@ -226,7 +241,7 @@ class HTTPHeaders(collections.MutableMapping):
     # the appearance that HTTPHeaders is a single container.
     __copy__ = copy
 
-    def __str__(self):
+    def __str__(self) -> str:
         lines = []
         for name, value in self.get_all():
             lines.append("%s: %s\n" % (name, value))
@@ -327,9 +342,13 @@ class HTTPServerRequest(object):
     .. versionchanged:: 4.0
        Moved from ``tornado.httpserver.HTTPRequest``.
     """
-    def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
-                 body=None, host=None, files=None, connection=None,
-                 start_line=None, server_connection=None):
+    path = None  # type: str
+    query = None  # type: str
+
+    def __init__(self, method: str=None, uri: str=None, version: str="HTTP/1.0",
+                 headers: HTTPHeaders=None, body: bytes=None, host: str=None,
+                 files: Dict[str, 'HTTPFile']=None, connection: 'HTTPConnection'=None,
+                 start_line: 'RequestStartLine'=None, server_connection: object=None) -> None:
         if start_line is not None:
             method, uri, version = start_line
         self.method = method
@@ -351,16 +370,17 @@ class HTTPServerRequest(object):
         self._start_time = time.time()
         self._finish_time = None
 
-        self.path, sep, self.query = uri.partition('?')
+        if uri is not None:
+            self.path, sep, self.query = uri.partition('?')
         self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
         self.query_arguments = copy.deepcopy(self.arguments)
-        self.body_arguments = {}
+        self.body_arguments = {}  # type: Dict[str, List[bytes]]
 
     @property
-    def cookies(self):
-        """A dictionary of Cookie.Morsel objects."""
+    def cookies(self) -> Dict[str, http.cookies.Morsel]:
+        """A dictionary of ``http.cookies.Morsel`` objects."""
         if not hasattr(self, "_cookies"):
-            self._cookies = Cookie.SimpleCookie()
+            self._cookies = http.cookies.SimpleCookie()
             if "Cookie" in self.headers:
                 try:
                     parsed = parse_cookie(self.headers["Cookie"])
@@ -377,18 +397,18 @@ class HTTPServerRequest(object):
                             pass
         return self._cookies
 
-    def full_url(self):
+    def full_url(self) -> str:
         """Reconstructs the full URL for this request."""
         return self.protocol + "://" + self.host + self.uri
 
-    def request_time(self):
+    def request_time(self) -> float:
         """Returns the amount of time it took for this request to execute."""
         if self._finish_time is None:
             return time.time() - self._start_time
         else:
             return self._finish_time - self._start_time
 
-    def get_ssl_certificate(self, binary_form=False):
+    def get_ssl_certificate(self, binary_form: bool=False) -> Union[None, Dict, bytes]:
         """Returns the client's SSL certificate, if any.
 
         To use client certificates, the HTTPServer's
@@ -408,12 +428,15 @@ class HTTPServerRequest(object):
         http://docs.python.org/library/ssl.html#sslsocket-objects
         """
         try:
-            return self.connection.stream.socket.getpeercert(
+            if self.connection is None:
+                return None
+            # TODO: add a method to HTTPConnection for this so it can work with HTTP/2
+            return self.connection.stream.socket.getpeercert(  # type: ignore
                 binary_form=binary_form)
         except SSLError:
             return None
 
-    def _parse_body(self):
+    def _parse_body(self) -> None:
         parse_body_arguments(
             self.headers.get("Content-Type", ""), self.body,
             self.body_arguments, self.files,
@@ -422,7 +445,7 @@ class HTTPServerRequest(object):
         for k, v in self.body_arguments.items():
             self.arguments.setdefault(k, []).extend(v)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
         args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
         return "%s(%s)" % (self.__class__.__name__, args)
@@ -450,7 +473,8 @@ class HTTPServerConnectionDelegate(object):
 
     .. versionadded:: 4.0
     """
-    def start_request(self, server_conn, request_conn):
+    def start_request(self, server_conn: object,
+                      request_conn: 'HTTPConnection')-> 'HTTPMessageDelegate':
         """This method is called by the server when a new request has started.
 
         :arg server_conn: is an opaque object representing the long-lived
@@ -462,7 +486,7 @@ class HTTPServerConnectionDelegate(object):
         """
         raise NotImplementedError()
 
-    def on_close(self, server_conn):
+    def on_close(self, server_conn: object) -> None:
         """This method is called when a connection has been closed.
 
         :arg server_conn: is a server connection that has previously been
@@ -476,7 +500,8 @@ class HTTPMessageDelegate(object):
 
     .. versionadded:: 4.0
     """
-    def headers_received(self, start_line, headers):
+    def headers_received(self, start_line: Union['RequestStartLine', 'ResponseStartLine'],
+                         headers: HTTPHeaders) -> Optional[Awaitable[None]]:
         """Called when the HTTP headers have been received and parsed.
 
         :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`
@@ -491,18 +516,18 @@ class HTTPMessageDelegate(object):
         """
         pass
 
-    def data_received(self, chunk):
+    def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
         """Called when a chunk of data has been received.
 
         May return a `.Future` for flow control.
         """
         pass
 
-    def finish(self):
+    def finish(self) -> None:
         """Called after the last chunk of data has been received."""
         pass
 
-    def on_connection_close(self):
+    def on_connection_close(self) -> None:
         """Called if the connection is closed without finishing the request.
 
         If ``headers_received`` is called, either ``finish`` or
@@ -516,7 +541,8 @@ class HTTPConnection(object):
 
     .. versionadded:: 4.0
     """
-    def write_headers(self, start_line, headers, chunk=None):
+    def write_headers(self, start_line: Union['RequestStartLine', 'ResponseStartLine'],
+                      headers: HTTPHeaders, chunk: bytes=None) -> Awaitable[None]:
         """Write an HTTP header block.
 
         :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`.
@@ -524,11 +550,10 @@ class HTTPConnection(object):
         :arg chunk: the first (optional) chunk of data.  This is an optimization
             so that small responses can be written in the same call as their
             headers.
-        :arg callback: a callback to be run when the write is complete.
 
         The ``version`` field of ``start_line`` is ignored.
 
-        Returns a `.Future` if no callback is given.
+        Returns an awaitable for flow control.
 
         .. versionchanged:: 6.0
 
@@ -536,11 +561,10 @@ class HTTPConnection(object):
         """
         raise NotImplementedError()
 
-    def write(self, chunk):
+    def write(self, chunk: bytes) -> Awaitable[None]:
         """Writes a chunk of body data.
 
-        The callback will be run when the write is complete.  If no callback
-        is given, returns a Future.
+        Returns an awaitable for flow control.
 
         .. versionchanged:: 6.0
 
@@ -548,13 +572,14 @@ class HTTPConnection(object):
         """
         raise NotImplementedError()
 
-    def finish(self):
+    def finish(self) -> None:
         """Indicates that the last body data has been written.
         """
         raise NotImplementedError()
 
 
-def url_concat(url, args):
+def url_concat(url: str, args: Union[Dict[str, str], List[Tuple[str, str]],
+                                     Tuple[Tuple[str, str], ...]]) -> str:
     """Concatenate url and arguments regardless of whether
     url has existing query parameters.
 
@@ -605,7 +630,7 @@ class HTTPFile(ObjectDict):
     pass
 
 
-def _parse_request_range(range_header):
+def _parse_request_range(range_header: str) -> Optional[Tuple[Optional[int], Optional[int]]]:
     """Parses a Range header.
 
     Returns either ``None`` or tuple ``(start, end)``.
@@ -654,7 +679,7 @@ def _parse_request_range(range_header):
     return (start, end)
 
 
-def _get_content_range(start, end, total):
+def _get_content_range(start: Optional[int], end: Optional[int], total: int) -> str:
     """Returns a suitable Content-Range header:
 
     >>> print(_get_content_range(None, 1, 4))
@@ -669,14 +694,15 @@ def _get_content_range(start, end, total):
     return "bytes %s-%s/%s" % (start, end, total)
 
 
-def _int_or_none(val):
+def _int_or_none(val: str) -> Optional[int]:
     val = val.strip()
     if val == "":
         return None
     return int(val)
 
 
-def parse_body_arguments(content_type, body, arguments, files, headers=None):
+def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, List[bytes]],
+                         files: Dict[str, HTTPFile], headers: HTTPHeaders=None) -> None:
     """Parses a form request body.
 
     Supports ``application/x-www-form-urlencoded`` and
@@ -712,7 +738,8 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
             gen_log.warning("Invalid multipart/form-data: %s", e)
 
 
-def parse_multipart_form_data(boundary, data, arguments, files):
+def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str, List[bytes]],
+                              files: Dict[str, HTTPFile]) -> None:
     """Parses a ``multipart/form-data`` body.
 
     The ``boundary`` and ``data`` parameters are both byte strings.
@@ -763,7 +790,7 @@ def parse_multipart_form_data(boundary, data, arguments, files):
             arguments.setdefault(name, []).append(value)
 
 
-def format_timestamp(ts):
+def format_timestamp(ts: Union[numbers.Real, tuple, time.struct_time, datetime.datetime]) -> str:
     """Formats a timestamp in the format used by HTTP.
 
     The argument may be a numeric timestamp as returned by `time.time`,
@@ -774,21 +801,21 @@ def format_timestamp(ts):
     'Sun, 27 Jan 2013 18:43:20 GMT'
     """
     if isinstance(ts, numbers.Real):
-        pass
+        time_float = typing.cast(float, ts)
     elif isinstance(ts, (tuple, time.struct_time)):
-        ts = calendar.timegm(ts)
+        time_float = calendar.timegm(ts)
     elif isinstance(ts, datetime.datetime):
-        ts = calendar.timegm(ts.utctimetuple())
+        time_float = calendar.timegm(ts.utctimetuple())
     else:
         raise TypeError("unknown timestamp type: %r" % ts)
-    return email.utils.formatdate(ts, usegmt=True)
+    return email.utils.formatdate(time_float, usegmt=True)
 
 
 RequestStartLine = collections.namedtuple(
     'RequestStartLine', ['method', 'path', 'version'])
 
 
-def parse_request_start_line(line):
+def parse_request_start_line(line: str) -> RequestStartLine:
     """Returns a (method, path, version) tuple for an HTTP 1.x request line.
 
     The response is a `collections.namedtuple`.
@@ -812,7 +839,7 @@ ResponseStartLine = collections.namedtuple(
     'ResponseStartLine', ['version', 'code', 'reason'])
 
 
-def parse_response_start_line(line):
+def parse_response_start_line(line: str) -> ResponseStartLine:
     """Returns a (version, code, reason) tuple for an HTTP 1.x response line.
 
     The response is a `collections.namedtuple`.
@@ -835,7 +862,7 @@ def parse_response_start_line(line):
 # RFC 2231/5987 format.
 
 
-def _parseparam(s):
+def _parseparam(s: str) -> Generator[str, None, None]:
     while s[:1] == ';':
         s = s[1:]
         end = s.find(';')
@@ -848,7 +875,7 @@ def _parseparam(s):
         s = s[end:]
 
 
-def _parse_header(line):
+def _parse_header(line: str) -> Tuple[str, Dict[str, str]]:
     r"""Parse a Content-type like header.
 
     Return the main content-type and a dictionary of options.
@@ -872,18 +899,18 @@ def _parse_header(line):
             name = p[:i].strip().lower()
             value = p[i + 1:].strip()
             params.append((name, native_str(value)))
-    params = email.utils.decode_params(params)
-    params.pop(0)  # get rid of the dummy again
+    decoded_params = email.utils.decode_params(params)
+    decoded_params.pop(0)  # get rid of the dummy again
     pdict = {}
-    for name, value in params:
-        value = email.utils.collapse_rfc2231_value(value)
+    for name, decoded_value in decoded_params:
+        value = email.utils.collapse_rfc2231_value(decoded_value)
         if len(value) >= 2 and value[0] == '"' and value[-1] == '"':
             value = value[1:-1]
         pdict[name] = value
     return key, pdict
 
 
-def _encode_header(key, pdict):
+def _encode_header(key: str, pdict: Dict[str, str]) -> str:
     """Inverse of _parse_header.
 
     >>> _encode_header('permessage-deflate',
@@ -903,7 +930,7 @@ def _encode_header(key, pdict):
     return '; '.join(out)
 
 
-def encode_username_password(username, password):
+def encode_username_password(username: Union[str, bytes], password: Union[str, bytes]) -> bytes:
     """Encodes a username/password pair in the format used by HTTP auth.
 
     The return value is a byte string in the form ``username:password``.
@@ -918,11 +945,12 @@ def encode_username_password(username, password):
 
 
 def doctests():
+    # type: () -> unittest.TestSuite
     import doctest
     return doctest.DocTestSuite()
 
 
-def split_host_and_port(netloc):
+def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]:
     """Returns ``(host, port)`` tuple from ``netloc``.
 
     Returned ``port`` will be ``None`` if not present.
@@ -932,14 +960,14 @@ def split_host_and_port(netloc):
     match = re.match(r'^(.+):(\d+)$', netloc)
     if match:
         host = match.group(1)
-        port = int(match.group(2))
+        port = int(match.group(2))  # type: Optional[int]
     else:
         host = netloc
         port = None
     return (host, port)
 
 
-def qs_to_qsl(qs):
+def qs_to_qsl(qs: Dict[str, List[str]]) -> Iterable[Tuple[str, str]]:
     """Generator converting a result of ``parse_qs`` back to name-value pairs.
 
     .. versionadded:: 5.0
@@ -954,7 +982,7 @@ _QuotePatt = re.compile(r"[\\].")
 _nulljoin = ''.join
 
 
-def _unquote_cookie(str):
+def _unquote_cookie(s: str) -> str:
     """Handle double quotes and escaping in cookie values.
 
     This method is copied verbatim from the Python 3.5 standard
@@ -963,29 +991,29 @@ def _unquote_cookie(str):
     """
     # If there aren't any doublequotes,
     # then there can't be any special characters.  See RFC 2109.
-    if str is None or len(str) < 2:
-        return str
-    if str[0] != '"' or str[-1] != '"':
-        return str
+    if s is None or len(s) < 2:
+        return s
+    if s[0] != '"' or s[-1] != '"':
+        return s
 
     # We have to assume that we must decode this string.
     # Down to work.
 
     # Remove the "s
-    str = str[1:-1]
+    s = s[1:-1]
 
     # Check for special sequences.  Examples:
     #    \012 --> \n
     #    \"   --> "
     #
     i = 0
-    n = len(str)
+    n = len(s)
     res = []
     while 0 <= i < n:
-        o_match = _OctalPatt.search(str, i)
-        q_match = _QuotePatt.search(str, i)
+        o_match = _OctalPatt.search(s, i)
+        q_match = _QuotePatt.search(s, i)
         if not o_match and not q_match:              # Neither matched
-            res.append(str[i:])
+            res.append(s[i:])
             break
         # else:
         j = k = -1
@@ -994,17 +1022,17 @@ def _unquote_cookie(str):
         if q_match:
             k = q_match.start(0)
         if q_match and (not o_match or k < j):     # QuotePatt matched
-            res.append(str[i:k])
-            res.append(str[k + 1])
+            res.append(s[i:k])
+            res.append(s[k + 1])
             i = k + 2
         else:                                      # OctalPatt matched
-            res.append(str[i:j])
-            res.append(chr(int(str[j + 1:j + 4], 8)))
+            res.append(s[i:j])
+            res.append(chr(int(s[j + 1:j + 4], 8)))
             i = j + 4
     return _nulljoin(res)
 
 
-def parse_cookie(cookie):
+def parse_cookie(cookie: str) -> Dict[str, str]:
     """Parse a ``Cookie`` HTTP header into a dict of name/value pairs.
 
     This function attempts to mimic browser cookie parsing behavior;