]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
httpclients: Add type annotations
authorBen Darnell <ben@bendarnell.com>
Sat, 29 Sep 2018 02:43:48 +0000 (22:43 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Sep 2018 02:43:48 +0000 (22:43 -0400)
setup.cfg
tornado/curl_httpclient.py
tornado/httpclient.py
tornado/simple_httpclient.py
tornado/test/httpclient_test.py

index 527646a6344f12a302d74e0834f3bd026993cf83..36b6401022e922b0e775ee8e7720b373e5544659 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -7,7 +7,7 @@ python_version = 3.5
 [mypy-tornado.*,tornado.platform.*]
 disallow_untyped_defs = True
 
-[mypy-tornado.auth,tornado.curl_httpclient,tornado.httpclient,tornado.routing,tornado.simple_httpclient,tornado.template,tornado.web,tornado.websocket,tornado.wsgi]
+[mypy-tornado.auth,tornado.routing,tornado.template,tornado.web,tornado.websocket,tornado.wsgi]
 disallow_untyped_defs = False
 
 # It's generally too tedious to require type annotations in tests, but
index 13e1f136c9683e2b8fac588993eed10ad55d4f18..7f13403cdb40f989cd0f331ec3941321f8c79207 100644 (file)
@@ -27,22 +27,30 @@ from tornado import httputil
 from tornado import ioloop
 
 from tornado.escape import utf8, native_str
-from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main
+from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient, main
+from tornado.log import app_log
+
+from typing import Dict, Any, Callable, Union
+import typing
+if typing.TYPE_CHECKING:
+    from typing import Deque, Tuple, Optional  # noqa: F401
 
 curl_log = logging.getLogger('tornado.curl_httpclient')
 
 
 class CurlAsyncHTTPClient(AsyncHTTPClient):
-    def initialize(self, max_clients=10, defaults=None):
+    def initialize(self, max_clients: int=10,  # type: ignore
+                   defaults: Dict[str, Any]=None) -> None:
         super(CurlAsyncHTTPClient, self).initialize(defaults=defaults)
         self._multi = pycurl.CurlMulti()
         self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
         self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
         self._curls = [self._curl_create() for i in range(max_clients)]
         self._free_list = self._curls[:]
-        self._requests = collections.deque()
-        self._fds = {}
-        self._timeout = None
+        self._requests = collections.deque() \
+            # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]]
+        self._fds = {}  # type: Dict[int, int]
+        self._timeout = None  # type: Optional[object]
 
         # libcurl has bugs that sometimes cause it to not report all
         # relevant file descriptors and timeouts to TIMERFUNCTION/
@@ -61,7 +69,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         self._multi.add_handle(dummy_curl_handle)
         self._multi.remove_handle(dummy_curl_handle)
 
-    def close(self):
+    def close(self) -> None:
         self._force_timeout_callback.stop()
         if self._timeout is not None:
             self.io_loop.remove_timeout(self._timeout)
@@ -73,15 +81,15 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         # Set below properties to None to reduce the reference count of current
         # instance, because those properties hold some methods of current
         # instance that will case circular reference.
-        self._force_timeout_callback = None
+        self._force_timeout_callback = None  # type: ignore
         self._multi = None
 
-    def fetch_impl(self, request, callback):
+    def fetch_impl(self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]) -> None:
         self._requests.append((request, callback, self.io_loop.time()))
         self._process_queue()
         self._set_timeout(0)
 
-    def _handle_socket(self, event, fd, multi, data):
+    def _handle_socket(self, event: int, fd: int, multi: Any, data: bytes) -> None:
         """Called by libcurl when it wants to change the file descriptors
         it cares about.
         """
@@ -111,14 +119,14 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                                      ioloop_event)
             self._fds[fd] = ioloop_event
 
-    def _set_timeout(self, msecs):
+    def _set_timeout(self, msecs: int) -> None:
         """Called by libcurl to schedule a timeout."""
         if self._timeout is not None:
             self.io_loop.remove_timeout(self._timeout)
         self._timeout = self.io_loop.add_timeout(
             self.io_loop.time() + msecs / 1000.0, self._handle_timeout)
 
-    def _handle_events(self, fd, events):
+    def _handle_events(self, fd: int, events: int) -> None:
         """Called by IOLoop when there is activity on one of our
         file descriptors.
         """
@@ -136,7 +144,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                 break
         self._finish_pending_requests()
 
-    def _handle_timeout(self):
+    def _handle_timeout(self) -> None:
         """Called by IOLoop when the requested timeout has passed."""
         self._timeout = None
         while True:
@@ -166,7 +174,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         if new_timeout >= 0:
             self._set_timeout(new_timeout)
 
-    def _handle_force_timeout(self):
+    def _handle_force_timeout(self) -> None:
         """Called by IOLoop periodically to ask libcurl to process any
         events it may have forgotten about.
         """
@@ -179,7 +187,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                 break
         self._finish_pending_requests()
 
-    def _finish_pending_requests(self):
+    def _finish_pending_requests(self) -> None:
         """Process any requests that were completed by the last
         call to multi.socket_action.
         """
@@ -193,7 +201,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                 break
         self._process_queue()
 
-    def _process_queue(self):
+    def _process_queue(self) -> None:
         while True:
             started = 0
             while self._free_list and self._requests:
@@ -233,14 +241,16 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
             if not started:
                 break
 
-    def _finish(self, curl, curl_error=None, curl_message=None):
+    def _finish(self, curl: pycurl.Curl, curl_error: int=None, curl_message: str=None) -> None:
         info = curl.info
         curl.info = None
         self._multi.remove_handle(curl)
         self._free_list.append(curl)
         buffer = info["buffer"]
         if curl_error:
-            error = CurlError(curl_error, curl_message)
+            assert curl_message is not None
+            error = CurlError(curl_error, curl_message)  # type: Optional[CurlError]
+            assert error is not None
             code = error.code
             effective_url = None
             buffer.close()
@@ -273,10 +283,10 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         except Exception:
             self.handle_callback_exception(info["callback"])
 
-    def handle_callback_exception(self, callback):
-        self.io_loop.handle_callback_exception(callback)
+    def handle_callback_exception(self, callback: Any) -> None:
+        app_log.error("Exception in callback %r", callback, exc_info=True)
 
-    def _curl_create(self):
+    def _curl_create(self) -> pycurl.Curl:
         curl = pycurl.Curl()
         if curl_log.isEnabledFor(logging.DEBUG):
             curl.setopt(pycurl.VERBOSE, 1)
@@ -286,7 +296,8 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
             curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
         return curl
 
-    def _curl_setup_request(self, curl, request, buffer, headers):
+    def _curl_setup_request(self, curl: pycurl.Curl, request: HTTPRequest,
+                            buffer: BytesIO, headers: httputil.HTTPHeaders) -> None:
         curl.setopt(pycurl.URL, native_str(request.url))
 
         # libcurl's magic "Expect: 100-continue" behavior causes delays
@@ -312,14 +323,18 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                     functools.partial(self._curl_header_callback,
                                       headers, request.header_callback))
         if request.streaming_callback:
-            def write_function(chunk):
-                self.io_loop.add_callback(request.streaming_callback, chunk)
+            def write_function(b: Union[bytes, bytearray]) -> int:
+                assert request.streaming_callback is not None
+                self.io_loop.add_callback(request.streaming_callback, b)
+                return len(b)
         else:
             write_function = buffer.write
         curl.setopt(pycurl.WRITEFUNCTION, write_function)
         curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
         curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
+        assert request.connect_timeout is not None
         curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
+        assert request.request_timeout is not None
         curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
         if request.user_agent:
             curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
@@ -335,6 +350,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
             curl.setopt(pycurl.PROXY, request.proxy_host)
             curl.setopt(pycurl.PROXYPORT, request.proxy_port)
             if request.proxy_username:
+                assert request.proxy_password is not None
                 credentials = httputil.encode_username_password(request.proxy_username,
                                                                 request.proxy_password)
                 curl.setopt(pycurl.PROXYUSERPWD, credentials)
@@ -416,7 +432,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                 raise ValueError('Body must be None for GET request')
             request_buffer = BytesIO(utf8(request.body or ''))
 
-            def ioctl(cmd):
+            def ioctl(cmd: int) -> None:
                 if cmd == curl.IOCMD_RESTARTREAD:
                     request_buffer.seek(0)
             curl.setopt(pycurl.READFUNCTION, request_buffer.read)
@@ -428,6 +444,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
                 curl.setopt(pycurl.INFILESIZE, len(request.body or ''))
 
         if request.auth_username is not None:
+            assert request.auth_password is not None
             if request.auth_mode is None or request.auth_mode == "basic":
                 curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
             elif request.auth_mode == "digest":
@@ -453,7 +470,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         if request.ssl_options is not None:
             raise ValueError("ssl_options not supported in curl_httpclient")
 
-        if threading.activeCount() > 1:
+        if threading.active_count() > 1:
             # libcurl/pycurl is not thread-safe by default.  When multiple threads
             # are used, signals should be disabled.  This has the side effect
             # of disabling DNS timeouts in some environments (when libcurl is
@@ -466,8 +483,10 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         if request.prepare_curl_callback is not None:
             request.prepare_curl_callback(curl)
 
-    def _curl_header_callback(self, headers, header_callback, header_line):
-        header_line = native_str(header_line.decode('latin1'))
+    def _curl_header_callback(self, headers: httputil.HTTPHeaders,
+                              header_callback: Callable[[str], None],
+                              header_line_bytes: bytes) -> None:
+        header_line = native_str(header_line_bytes.decode('latin1'))
         if header_callback is not None:
             self.io_loop.add_callback(header_callback, header_line)
         # header_line as returned by curl includes the end-of-line characters.
@@ -484,7 +503,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
             return
         headers.parse_line(header_line)
 
-    def _curl_debug(self, debug_type, debug_msg):
+    def _curl_debug(self, debug_type: int, debug_msg: str) -> None:
         debug_types = ('I', '<', '>', '<', '>')
         if debug_type == 0:
             debug_msg = native_str(debug_msg)
@@ -498,7 +517,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
 
 
 class CurlError(HTTPError):
-    def __init__(self, errno, message):
+    def __init__(self, errno: int, message: str) -> None:
         HTTPError.__init__(self, 599, message)
         self.errno = errno
 
index d3a42cd3784a6166d0d1de3a7cd83e1bc70c9190..1a47ae9e222d1763d6cd78b5645ba5ac764c5c73 100644 (file)
@@ -38,7 +38,10 @@ To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
     AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
 """
 
+import datetime
 import functools
+from io import BytesIO
+import ssl
 import time
 import weakref
 
@@ -48,6 +51,8 @@ from tornado import gen, httputil
 from tornado.ioloop import IOLoop
 from tornado.util import Configurable
 
+from typing import Type, Any, Union, Dict, Callable, Optional, cast
+
 
 class HTTPClient(object):
     """A blocking HTTP client.
@@ -78,7 +83,7 @@ class HTTPClient(object):
        Use `AsyncHTTPClient` instead.
 
     """
-    def __init__(self, async_client_class=None, **kwargs):
+    def __init__(self, async_client_class: Type['AsyncHTTPClient']=None, **kwargs: Any) -> None:
         # Initialize self._closed at the beginning of the constructor
         # so that an exception raised here doesn't lead to confusing
         # failures in __del__.
@@ -86,23 +91,27 @@ class HTTPClient(object):
         self._io_loop = IOLoop(make_current=False)
         if async_client_class is None:
             async_client_class = AsyncHTTPClient
+
         # Create the client while our IOLoop is "current", without
         # clobbering the thread's real current IOLoop (if any).
-        self._async_client = self._io_loop.run_sync(
-            gen.coroutine(lambda: async_client_class(**kwargs)))
+        async def make_client() -> 'AsyncHTTPClient':
+            await gen.sleep(0)
+            assert async_client_class is not None
+            return async_client_class(**kwargs)
+        self._async_client = self._io_loop.run_sync(make_client)
         self._closed = False
 
-    def __del__(self):
+    def __del__(self) -> None:
         self.close()
 
-    def close(self):
+    def close(self) -> None:
         """Closes the HTTPClient, freeing any resources used."""
         if not self._closed:
             self._async_client.close()
             self._io_loop.close()
             self._closed = True
 
-    def fetch(self, request, **kwargs):
+    def fetch(self, request: Union['HTTPRequest', str], **kwargs: Any) -> 'HTTPResponse':
         """Executes a request, returning an `HTTPResponse`.
 
         The request may be either a string URL or an `HTTPRequest` object.
@@ -155,23 +164,26 @@ class AsyncHTTPClient(Configurable):
        The ``io_loop`` argument (deprecated since version 4.1) has been removed.
 
     """
+
+    _instance_cache = None  # type: Dict[IOLoop, AsyncHTTPClient]
+
     @classmethod
-    def configurable_base(cls):
+    def configurable_base(cls) -> Type[Configurable]:
         return AsyncHTTPClient
 
     @classmethod
-    def configurable_default(cls):
+    def configurable_default(cls) -> Type[Configurable]:
         from tornado.simple_httpclient import SimpleAsyncHTTPClient
         return SimpleAsyncHTTPClient
 
     @classmethod
-    def _async_clients(cls):
+    def _async_clients(cls) -> Dict[IOLoop, 'AsyncHTTPClient']:
         attr_name = '_async_client_dict_' + cls.__name__
         if not hasattr(cls, attr_name):
             setattr(cls, attr_name, weakref.WeakKeyDictionary())
         return getattr(cls, attr_name)
 
-    def __new__(cls, force_instance=False, **kwargs):
+    def __new__(cls, force_instance: bool=False, **kwargs: Any) -> 'AsyncHTTPClient':
         io_loop = IOLoop.current()
         if force_instance:
             instance_cache = None
@@ -179,7 +191,7 @@ class AsyncHTTPClient(Configurable):
             instance_cache = cls._async_clients()
         if instance_cache is not None and io_loop in instance_cache:
             return instance_cache[io_loop]
-        instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs)
+        instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs)  # type: ignore
         # Make sure the instance knows which cache to remove itself from.
         # It can't simply call _async_clients() because we may be in
         # __new__(AsyncHTTPClient) but instance.__class__ may be
@@ -189,14 +201,14 @@ class AsyncHTTPClient(Configurable):
             instance_cache[instance.io_loop] = instance
         return instance
 
-    def initialize(self, defaults=None):
+    def initialize(self, defaults: Dict[str, Any]=None) -> None:
         self.io_loop = IOLoop.current()
         self.defaults = dict(HTTPRequest._DEFAULTS)
         if defaults is not None:
             self.defaults.update(defaults)
         self._closed = False
 
-    def close(self):
+    def close(self) -> None:
         """Destroys this HTTP client, freeing any file descriptors used.
 
         This method is **not needed in normal use** due to the way
@@ -217,7 +229,8 @@ class AsyncHTTPClient(Configurable):
                 raise RuntimeError("inconsistent AsyncHTTPClient cache")
             del self._instance_cache[self.io_loop]
 
-    def fetch(self, request, raise_error=True, **kwargs):
+    def fetch(self, request: Union[str, 'HTTPRequest'],
+              raise_error: bool=True, **kwargs: Any) -> 'Future[HTTPResponse]':
         """Executes a request, asynchronously returning an `HTTPResponse`.
 
         The request may be either a string URL or an `HTTPRequest` object.
@@ -257,23 +270,24 @@ class AsyncHTTPClient(Configurable):
         # so make sure we don't modify the caller's object.  This is also
         # where normal dicts get converted to HTTPHeaders objects.
         request.headers = httputil.HTTPHeaders(request.headers)
-        request = _RequestProxy(request, self.defaults)
-        future = Future()
+        request_proxy = _RequestProxy(request, self.defaults)
+        future = Future()  # type: Future[HTTPResponse]
 
-        def handle_response(response):
+        def handle_response(response: 'HTTPResponse') -> None:
             if response.error:
                 if raise_error or not response._error_is_response_code:
                     future.set_exception(response.error)
                     return
             future_set_result_unless_cancelled(future, response)
-        self.fetch_impl(request, handle_response)
+        self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response)
         return future
 
-    def fetch_impl(self, request, callback):
+    def fetch_impl(self, request: 'HTTPRequest',
+                   callback: Callable[['HTTPResponse'], None]) -> None:
         raise NotImplementedError()
 
     @classmethod
-    def configure(cls, impl, **kwargs):
+    def configure(cls, impl: Union[None, str, Type[Configurable]], **kwargs: Any) -> None:
         """Configures the `AsyncHTTPClient` subclass to use.
 
         ``AsyncHTTPClient()`` actually creates an instance of a subclass.
@@ -297,6 +311,7 @@ class AsyncHTTPClient(Configurable):
 
 class HTTPRequest(object):
     """HTTP client request object."""
+    _headers = None  # type: Union[Dict[str, str], httputil.HTTPHeaders]
 
     # Default values for HTTPRequest parameters.
     # Merged with the values on the request object by AsyncHTTPClient
@@ -311,20 +326,26 @@ class HTTPRequest(object):
         allow_nonstandard_methods=False,
         validate_cert=True)
 
-    def __init__(self, url, method="GET", headers=None, body=None,
-                 auth_username=None, auth_password=None, auth_mode=None,
-                 connect_timeout=None, request_timeout=None,
-                 if_modified_since=None, follow_redirects=None,
-                 max_redirects=None, user_agent=None, use_gzip=None,
-                 network_interface=None, streaming_callback=None,
-                 header_callback=None, prepare_curl_callback=None,
-                 proxy_host=None, proxy_port=None, proxy_username=None,
-                 proxy_password=None, proxy_auth_mode=None,
-                 allow_nonstandard_methods=None, validate_cert=None,
-                 ca_certs=None, allow_ipv6=None, client_key=None,
-                 client_cert=None, body_producer=None,
-                 expect_100_continue=False, decompress_response=None,
-                 ssl_options=None):
+    def __init__(self, url: str, method: str="GET",
+                 headers: Union[Dict[str, str], httputil.HTTPHeaders]=None,
+                 body: Union[bytes, str]=None,
+                 auth_username: str=None, auth_password: str=None, auth_mode: str=None,
+                 connect_timeout: float=None, request_timeout: float=None,
+                 if_modified_since: Union[float, datetime.datetime]=None,
+                 follow_redirects: bool=None,
+                 max_redirects: int=None, user_agent: str=None, use_gzip: bool=None,
+                 network_interface: str=None,
+                 streaming_callback: Callable[[bytes], None]=None,
+                 header_callback: Callable[[str], None]=None,
+                 prepare_curl_callback: Callable[[Any], None]=None,
+                 proxy_host: str=None, proxy_port: int=None, proxy_username: str=None,
+                 proxy_password: str=None, proxy_auth_mode: str=None,
+                 allow_nonstandard_methods: bool=None, validate_cert: bool=None,
+                 ca_certs: str=None, allow_ipv6: bool=None, client_key: str=None,
+                 client_cert: str=None,
+                 body_producer: Callable[[Callable[[bytes], None]], 'Future[None]']=None,
+                 expect_100_continue: bool=False, decompress_response: bool=None,
+                 ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None) -> None:
         r"""All parameters except ``url`` are optional.
 
         :arg str url: URL to fetch
@@ -460,7 +481,7 @@ class HTTPRequest(object):
         self.max_redirects = max_redirects
         self.user_agent = user_agent
         if decompress_response is not None:
-            self.decompress_response = decompress_response
+            self.decompress_response = decompress_response  # type: Optional[bool]
         else:
             self.decompress_response = use_gzip
         self.network_interface = network_interface
@@ -478,22 +499,25 @@ class HTTPRequest(object):
         self.start_time = time.time()
 
     @property
-    def headers(self):
-        return self._headers
+    def headers(self) -> httputil.HTTPHeaders:
+        # TODO: headers may actually be a plain dict until fairly late in
+        # the process (AsyncHTTPClient.fetch), but practically speaking,
+        # whenever the property is used they're already HTTPHeaders.
+        return self._headers  # type: ignore
 
     @headers.setter
-    def headers(self, value):
+    def headers(self, value: Union[Dict[str, str], httputil.HTTPHeaders]) -> None:
         if value is None:
             self._headers = httputil.HTTPHeaders()
         else:
-            self._headers = value
+            self._headers = value  # type: ignore
 
     @property
-    def body(self):
+    def body(self) -> bytes:
         return self._body
 
     @body.setter
-    def body(self, value):
+    def body(self, value: Union[bytes, str]) -> None:
         self._body = utf8(value)
 
 
@@ -545,9 +569,16 @@ class HTTPResponse(object):
        is excluded in both implementations. ``request_time`` is now more accurate for
        ``curl_httpclient`` because it uses a monotonic clock when available.
     """
-    def __init__(self, request, code, headers=None, buffer=None,
-                 effective_url=None, error=None, request_time=None,
-                 time_info=None, reason=None, start_time=None):
+    # I'm not sure why these don't get type-inferred from the references in __init__.
+    error = None  # type: Optional[BaseException]
+    _error_is_response_code = False
+    request = None  # type: HTTPRequest
+
+    def __init__(self, request: HTTPRequest, code: int,
+                 headers: httputil.HTTPHeaders=None, buffer: BytesIO=None,
+                 effective_url: str=None, error: BaseException=None,
+                 request_time: float=None, time_info: Dict[str, float]=None,
+                 reason: str=None, start_time: float=None) -> None:
         if isinstance(request, _RequestProxy):
             self.request = request.request
         else:
@@ -559,7 +590,7 @@ class HTTPResponse(object):
         else:
             self.headers = httputil.HTTPHeaders()
         self.buffer = buffer
-        self._body = None
+        self._body = None  # type: Optional[bytes]
         if effective_url is None:
             self.effective_url = request.url
         else:
@@ -579,7 +610,7 @@ class HTTPResponse(object):
         self.time_info = time_info or {}
 
     @property
-    def body(self):
+    def body(self) -> Optional[bytes]:
         if self.buffer is None:
             return None
         elif self._body is None:
@@ -587,12 +618,12 @@ class HTTPResponse(object):
 
         return self._body
 
-    def rethrow(self):
+    def rethrow(self) -> None:
         """If there was an error on the request, raise an `HTTPError`."""
         if self.error:
             raise self.error
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
         return "%s(%s)" % (self.__class__.__name__, args)
 
@@ -617,13 +648,13 @@ class HTTPClientError(Exception):
        `tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains
        as an alias.
     """
-    def __init__(self, code, message=None, response=None):
+    def __init__(self, code: int, message: str=None, response: HTTPResponse=None) -> None:
         self.code = code
         self.message = message or httputil.responses.get(code, "Unknown")
         self.response = response
         super(HTTPClientError, self).__init__(code, message, response)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "HTTP %d: %s" % (self.code, self.message)
 
     # There is a cyclic reference between self and self.response,
@@ -641,11 +672,11 @@ class _RequestProxy(object):
 
     Used internally by AsyncHTTPClient implementations.
     """
-    def __init__(self, request, defaults):
+    def __init__(self, request: HTTPRequest, defaults: Optional[Dict[str, Any]]) -> None:
         self.request = request
         self.defaults = defaults
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         request_attr = getattr(self.request, name)
         if request_attr is not None:
             return request_attr
@@ -655,7 +686,7 @@ class _RequestProxy(object):
             return None
 
 
-def main():
+def main() -> None:
     from tornado.options import define, options, parse_command_line
     define("print_headers", type=bool, default=False)
     define("print_body", type=bool, default=True)
index d0df55271befa1747092d2e04e192c308488640d..473dd3b5e08bcfa3feb83903cdf7871a5bf75549 100644 (file)
@@ -1,10 +1,11 @@
 from tornado.escape import _unicode
 from tornado import gen
-from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
+from tornado.httpclient import (HTTPResponse, HTTPError, AsyncHTTPClient, main,
+                                _RequestProxy, HTTPRequest)
 from tornado import httputil
 from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
 from tornado.ioloop import IOLoop
-from tornado.iostream import StreamClosedError
+from tornado.iostream import StreamClosedError, IOStream
 from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
 from tornado.log import gen_log
 from tornado.tcpclient import TCPClient
@@ -21,6 +22,12 @@ import time
 from io import BytesIO
 import urllib.parse
 
+from typing import Dict, Any, Generator, Callable, Optional, Type, Union
+from types import TracebackType
+import typing
+if typing.TYPE_CHECKING:
+    from typing import Deque, Tuple, List  # noqa: F401
+
 
 class HTTPTimeoutError(HTTPError):
     """Error raised by SimpleAsyncHTTPClient on timeout.
@@ -30,11 +37,11 @@ class HTTPTimeoutError(HTTPError):
 
     .. versionadded:: 5.1
     """
-    def __init__(self, message):
+    def __init__(self, message: str) -> None:
         super(HTTPTimeoutError, self).__init__(599, message=message)
 
-    def __str__(self):
-        return self.message
+    def __str__(self) -> str:
+        return self.message or "Timeout"
 
 
 class HTTPStreamClosedError(HTTPError):
@@ -48,11 +55,11 @@ class HTTPStreamClosedError(HTTPError):
 
     .. versionadded:: 5.1
     """
-    def __init__(self, message):
+    def __init__(self, message: str) -> None:
         super(HTTPStreamClosedError, self).__init__(599, message=message)
 
-    def __str__(self):
-        return self.message
+    def __str__(self) -> str:
+        return self.message or "Stream closed"
 
 
 class SimpleAsyncHTTPClient(AsyncHTTPClient):
@@ -64,10 +71,10 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
     are not reused, and callers cannot select the network interface to be
     used.
     """
-    def initialize(self, max_clients=10,
-                   hostname_mapping=None, max_buffer_size=104857600,
-                   resolver=None, defaults=None, max_header_size=None,
-                   max_body_size=None):
+    def initialize(self, max_clients: int=10,  # type: ignore
+                   hostname_mapping: Dict[str, str]=None, max_buffer_size: int=104857600,
+                   resolver: Resolver=None, defaults: Dict[str, Any]=None,
+                   max_header_size: int=None, max_body_size: int=None) -> None:
         """Creates a AsyncHTTPClient.
 
         Only a single AsyncHTTPClient instance exists per IOLoop
@@ -102,9 +109,11 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         """
         super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults)
         self.max_clients = max_clients
-        self.queue = collections.deque()
-        self.active = {}
-        self.waiting = {}
+        self.queue = collections.deque() \
+            # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]]
+        self.active = {}  # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]]
+        self.waiting = {} \
+            # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]]
         self.max_buffer_size = max_buffer_size
         self.max_header_size = max_header_size
         self.max_body_size = max_body_size
@@ -121,16 +130,18 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
                                              mapping=hostname_mapping)
         self.tcp_client = TCPClient(resolver=self.resolver)
 
-    def close(self):
+    def close(self) -> None:
         super(SimpleAsyncHTTPClient, self).close()
         if self.own_resolver:
             self.resolver.close()
         self.tcp_client.close()
 
-    def fetch_impl(self, request, callback):
+    def fetch_impl(self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]) -> None:
         key = object()
         self.queue.append((key, request, callback))
         if not len(self.active) < self.max_clients:
+            assert request.connect_timeout is not None
+            assert request.request_timeout is not None
             timeout_handle = self.io_loop.add_timeout(
                 self.io_loop.time() + min(request.connect_timeout,
                                           request.request_timeout),
@@ -144,7 +155,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
                           "%d active, %d queued requests." % (
                               len(self.active), len(self.queue)))
 
-    def _process_queue(self):
+    def _process_queue(self) -> None:
         while self.queue and len(self.active) < self.max_clients:
             key, request, callback = self.queue.popleft()
             if key not in self.waiting:
@@ -154,27 +165,28 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
             release_callback = functools.partial(self._release_fetch, key)
             self._handle_request(request, release_callback, callback)
 
-    def _connection_class(self):
+    def _connection_class(self) -> type:
         return _HTTPConnection
 
-    def _handle_request(self, request, release_callback, final_callback):
+    def _handle_request(self, request: HTTPRequest, release_callback: Callable[[], None],
+                        final_callback: Callable[[HTTPResponse], None]) -> None:
         self._connection_class()(
             self, request, release_callback,
             final_callback, self.max_buffer_size, self.tcp_client,
             self.max_header_size, self.max_body_size)
 
-    def _release_fetch(self, key):
+    def _release_fetch(self, key: object) -> None:
         del self.active[key]
         self._process_queue()
 
-    def _remove_timeout(self, key):
+    def _remove_timeout(self, key: object) -> None:
         if key in self.waiting:
             request, callback, timeout_handle = self.waiting[key]
             if timeout_handle is not None:
                 self.io_loop.remove_timeout(timeout_handle)
             del self.waiting[key]
 
-    def _on_timeout(self, key, info=None):
+    def _on_timeout(self, key: object, info: str=None) -> None:
         """Timeout callback of request.
 
         Construct a timeout HTTPResponse when a timeout occurs.
@@ -196,9 +208,10 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
 class _HTTPConnection(httputil.HTTPMessageDelegate):
     _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
 
-    def __init__(self, client, request, release_callback,
-                 final_callback, max_buffer_size, tcp_client,
-                 max_header_size, max_body_size):
+    def __init__(self, client: SimpleAsyncHTTPClient, request: HTTPRequest,
+                 release_callback: Callable[[], None],
+                 final_callback: Callable[[HTTPResponse], None], max_buffer_size: int,
+                 tcp_client: TCPClient, max_header_size: int, max_body_size: int) -> None:
         self.io_loop = IOLoop.current()
         self.start_time = self.io_loop.time()
         self.start_wall_time = time.time()
@@ -210,17 +223,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         self.tcp_client = tcp_client
         self.max_header_size = max_header_size
         self.max_body_size = max_body_size
-        self.code = None
-        self.headers = None
-        self.chunks = []
+        self.code = None  # type: Optional[int]
+        self.headers = None  # type: Optional[httputil.HTTPHeaders]
+        self.chunks = []  # type: List[bytes]
         self._decompressor = None
         # Timeout handle returned by IOLoop.add_timeout
-        self._timeout = None
+        self._timeout = None  # type: object
         self._sockaddr = None
         IOLoop.current().add_callback(self.run)
 
     @gen.coroutine
-    def run(self):
+    def run(self) -> Generator[Any, Any, None]:
         try:
             self.parsed = urllib.parse.urlsplit(_unicode(self.request.url))
             if self.parsed.scheme not in ("http", "https"):
@@ -292,12 +305,13 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                     username = self.request.auth_username
                     password = self.request.auth_password or ''
                 if username is not None:
+                    assert password is not None
                     if self.request.auth_mode not in (None, "basic"):
                         raise ValueError("unsupported auth_mode %s",
                                          self.request.auth_mode)
                     self.request.headers["Authorization"] = (
-                        b"Basic " + base64.b64encode(
-                            httputil.encode_username_password(username, password)))
+                        "Basic " + _unicode(base64.b64encode(
+                            httputil.encode_username_password(username, password))))
                 if self.request.user_agent:
                     self.request.headers["User-Agent"] = self.request.user_agent
                 if not self.request.allow_nonstandard_methods:
@@ -339,7 +353,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             if not self._handle_exception(*sys.exc_info()):
                 raise
 
-    def _get_ssl_options(self, scheme):
+    def _get_ssl_options(self, scheme: str) -> Union[None, Dict[str, Any], ssl.SSLContext]:
         if scheme == "https":
             if self.request.ssl_options is not None:
                 return self.request.ssl_options
@@ -365,7 +379,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             return ssl_ctx
         return None
 
-    def _on_timeout(self, info=None):
+    def _on_timeout(self, info: str=None) -> None:
         """Timeout callback of _HTTPConnection instance.
 
         Raise a `HTTPTimeoutError` when a timeout occurs.
@@ -378,12 +392,12 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             self._handle_exception(HTTPTimeoutError, HTTPTimeoutError(error_message),
                                    None)
 
-    def _remove_timeout(self):
+    def _remove_timeout(self) -> None:
         if self._timeout is not None:
             self.io_loop.remove_timeout(self._timeout)
             self._timeout = None
 
-    def _create_connection(self, stream):
+    def _create_connection(self, stream: IOStream) -> HTTP1Connection:
         stream.set_nodelay(True)
         connection = HTTP1Connection(
             stream, True,
@@ -391,12 +405,12 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 no_keep_alive=True,
                 max_header_size=self.max_header_size,
                 max_body_size=self.max_body_size,
-                decompress=self.request.decompress_response),
+                decompress=bool(self.request.decompress_response)),
             self._sockaddr)
         return connection
 
     @gen.coroutine
-    def _write_body(self, start_read):
+    def _write_body(self, start_read: bool) -> Generator[Any, Any, None]:
         if self.request.body is not None:
             self.connection.write(self.request.body)
         elif self.request.body_producer is not None:
@@ -411,20 +425,22 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 if not self._handle_exception(*sys.exc_info()):
                     raise
 
-    def _release(self):
+    def _release(self) -> None:
         if self.release_callback is not None:
             release_callback = self.release_callback
-            self.release_callback = None
+            self.release_callback = None  # type: ignore
             release_callback()
 
-    def _run_callback(self, response):
+    def _run_callback(self, response: HTTPResponse) -> None:
         self._release()
         if self.final_callback is not None:
             final_callback = self.final_callback
-            self.final_callback = None
+            self.final_callback = None  # type: ignore
             self.io_loop.add_callback(final_callback, response)
 
-    def _handle_exception(self, typ, value, tb):
+    def _handle_exception(self, typ: Optional[Type[BaseException]],
+                          value: Optional[BaseException],
+                          tb: Optional[TracebackType]) -> bool:
         if self.final_callback:
             self._remove_timeout()
             if isinstance(value, StreamClosedError):
@@ -450,7 +466,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             # pass it along, unless it's just the stream being closed.
             return isinstance(value, StreamClosedError)
 
-    def on_connection_close(self):
+    def on_connection_close(self) -> None:
         if self.final_callback is not None:
             message = "Connection closed"
             if self.stream.error:
@@ -460,7 +476,10 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             except HTTPStreamClosedError:
                 self._handle_exception(*sys.exc_info())
 
-    def headers_received(self, first_line, headers):
+    def headers_received(self, first_line: Union[httputil.ResponseStartLine,
+                                                 httputil.RequestStartLine],
+                         headers: httputil.HTTPHeaders) -> None:
+        assert isinstance(first_line, httputil.ResponseStartLine)
         if self.request.expect_100_continue and first_line.code == 100:
             self._write_body(False)
             return
@@ -478,12 +497,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 self.request.header_callback("%s: %s\r\n" % (k, v))
             self.request.header_callback('\r\n')
 
-    def _should_follow_redirect(self):
-        return (self.request.follow_redirects and
-                self.request.max_redirects > 0 and
-                self.code in (301, 302, 303, 307, 308))
+    def _should_follow_redirect(self) -> bool:
+        if self.request.follow_redirects:
+            assert self.request.max_redirects is not None
+            return (self.code in (301, 302, 303, 307, 308) and
+                    self.request.max_redirects > 0)
+        return False
 
-    def finish(self):
+    def finish(self) -> None:
+        assert self.code is not None
         data = b''.join(self.chunks)
         self._remove_timeout()
         original_request = getattr(self.request, "original_request",
@@ -533,10 +555,10 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         self._run_callback(response)
         self._on_end_request()
 
-    def _on_end_request(self):
+    def _on_end_request(self) -> None:
         self.stream.close()
 
-    def data_received(self, chunk):
+    def data_received(self, chunk: bytes) -> None:
         if self._should_follow_redirect():
             # We're going to follow a redirect so just discard the body.
             return
index 17cec16cf2a39c59d7def3fc91cf809c46d6b8bb..1b0a43f8b46c71353ab24a0f30d40aa0df4f1ce0 100644 (file)
@@ -538,7 +538,7 @@ class RequestProxyTest(unittest.TestCase):
 
 class HTTPResponseTestCase(unittest.TestCase):
     def test_str(self):
-        response = HTTPResponse(HTTPRequest('http://example.com'),
+        response = HTTPResponse(HTTPRequest('http://example.com'),  # type: ignore
                                 200, headers={}, buffer=BytesIO())
         s = str(response)
         self.assertTrue(s.startswith('HTTPResponse('))
@@ -606,12 +606,12 @@ class HTTPRequestTestCase(unittest.TestCase):
 
     def test_headers_setter(self):
         request = HTTPRequest('http://example.com')
-        request.headers = {'bar': 'baz'}
+        request.headers = {'bar': 'baz'}  # type: ignore
         self.assertEqual(request.headers, {'bar': 'baz'})
 
     def test_null_headers_setter(self):
         request = HTTPRequest('http://example.com')
-        request.headers = None
+        request.headers = None  # type: ignore
         self.assertEqual(request.headers, {})
 
     def test_body(self):
@@ -620,7 +620,7 @@ class HTTPRequestTestCase(unittest.TestCase):
 
     def test_body_setter(self):
         request = HTTPRequest('http://example.com')
-        request.body = 'foo'
+        request.body = 'foo'  # type: ignore
         self.assertEqual(request.body, utf8('foo'))
 
     def test_if_modified_since(self):