]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Add type annotations
authorBen Darnell <ben@bendarnell.com>
Mon, 1 Oct 2018 02:20:00 +0000 (22:20 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 1 Oct 2018 02:20:00 +0000 (22:20 -0400)
This is more invasive than usual because it defines a Protocol to be
shared between WebSocketHandler and WebSocketClientConnection (and
fixes a bug: an uncaught exception in the callback mode of
WebSocketClientConnection would fail due to the missing log_exception
method).

Fixes #2181

setup.cfg
tornado/simple_httpclient.py
tornado/websocket.py

index dfe5d3ad92ad3b435c42b16e61f9e756bef265e7..d24be9b729bf53d002b0bc33dbc19bc65bc442e8 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -7,9 +7,6 @@ python_version = 3.5
 [mypy-tornado.*,tornado.platform.*]
 disallow_untyped_defs = True
 
-[mypy-tornado.websocket]
-disallow_untyped_defs = False
-
 # It's generally too tedious to require type annotations in tests, but
 # we do want to type check them as much as type inference allows.
 [mypy-tornado.test.*]
index 473dd3b5e08bcfa3feb83903cdf7871a5bf75549..15ab6e10c4542ce3d12852c05d6a227de42fc391 100644 (file)
@@ -208,7 +208,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
 class _HTTPConnection(httputil.HTTPMessageDelegate):
     _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
 
-    def __init__(self, client: SimpleAsyncHTTPClient, request: HTTPRequest,
+    def __init__(self, client: Optional[SimpleAsyncHTTPClient], request: HTTPRequest,
                  release_callback: Callable[[], None],
                  final_callback: Callable[[HTTPResponse], None], max_buffer_size: int,
                  tcp_client: TCPClient, max_header_size: int, max_body_size: int) -> None:
index d2994c1c90d5129824516459a40cdd652353f03c..6600d8ca40111c4574922626d7c0710b2c3ac37b 100644 (file)
@@ -16,6 +16,7 @@ the protocol (known as "draft 76") and are not compatible with this module.
    Removed support for the draft 76 protocol version.
 """
 
+import abc
 import base64
 import hashlib
 import os
@@ -31,12 +32,85 @@ from tornado.escape import utf8, native_str, to_unicode
 from tornado import gen, httpclient, httputil
 from tornado.ioloop import IOLoop, PeriodicCallback
 from tornado.iostream import StreamClosedError
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
 from tornado import simple_httpclient
 from tornado.queues import Queue
 from tornado.tcpclient import TCPClient
 from tornado.util import _websocket_mask
 
+from typing import (TYPE_CHECKING, cast, Any, Optional, Dict, Union, List, Awaitable,
+                    Callable, Generator, Tuple, Type)
+from types import TracebackType
+if TYPE_CHECKING:
+    from tornado.iostream import IOStream  # noqa: F401
+    from typing_extensions import Protocol
+
+    # The zlib compressor types aren't actually exposed anywhere
+    # publicly, so declare protocols for the portions we use.
+    class _Compressor(Protocol):
+        def compress(self, data: bytes) -> bytes:
+            pass
+
+        def flush(self, mode: int) -> bytes:
+            pass
+
+    class _Decompressor(Protocol):
+        unconsumed_tail = b''  # type: bytes
+
+        def decompress(self, data: bytes, max_length: int) -> bytes:
+            pass
+
+    class _WebSocketConnection(Protocol):
+        # The common base interface implemented by WebSocketHandler on
+        # the server side and WebSocketClientConnection on the client
+        # side.
+        @property
+        def stream(self) -> Optional[IOStream]:
+            pass
+
+        @property
+        def ping_interval(self) -> Optional[float]:
+            pass
+
+        @property
+        def ping_timeout(self) -> Optional[float]:
+            pass
+
+        @property
+        def max_message_size(self) -> int:
+            pass
+
+        @property
+        def close_code(self) -> Optional[int]:
+            pass
+
+        @close_code.setter
+        def close_code(self, value: Optional[int]) -> None:
+            pass
+
+        @property
+        def close_reason(self) -> Optional[str]:
+            pass
+
+        @close_reason.setter
+        def close_reason(self, value: Optional[str]) -> None:
+            pass
+
+        def on_message(self, message: Union[str, bytes]) -> Optional['Awaitable[None]']:
+            pass
+
+        def on_ping(self, data: bytes) -> None:
+            pass
+
+        def on_pong(self, data: bytes) -> None:
+            pass
+
+        def log_exception(self, typ: Optional[Type[BaseException]],
+                          value: Optional[BaseException],
+                          tb: Optional[TracebackType]) -> None:
+            pass
+
+
 _default_max_message_size = 10 * 1024 * 1024
 
 
@@ -137,15 +211,16 @@ class WebSocketHandler(tornado.web.RequestHandler):
        Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and
        ``websocket_max_message_size``.
     """
-    def __init__(self, application, request, **kwargs):
+    def __init__(self, application: tornado.web.Application, request: httputil.HTTPServerRequest,
+                 **kwargs: Any) -> None:
         super(WebSocketHandler, self).__init__(application, request, **kwargs)
-        self.ws_connection = None
-        self.close_code = None
-        self.close_reason = None
-        self.stream = None
+        self.ws_connection = None  # type: Optional[WebSocketProtocol]
+        self.close_code = None  # type: Optional[int]
+        self.close_reason = None  # type: Optional[str]
+        self.stream = None  # type: Optional[IOStream]
         self._on_close_called = False
 
-    def get(self, *args, **kwargs):
+    def get(self, *args: Any, **kwargs: Any) -> None:
         self.open_args = args
         self.open_kwargs = kwargs
 
@@ -191,7 +266,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
 
         self.ws_connection = self.get_websocket_protocol()
         if self.ws_connection:
-            self.ws_connection.accept_connection()
+            self.ws_connection.accept_connection(self)
         else:
             self.set_status(426, "Upgrade Required")
             self.set_header("Sec-WebSocket-Version", "7, 8, 13")
@@ -200,7 +275,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
     stream = None
 
     @property
-    def ping_interval(self):
+    def ping_interval(self) -> Optional[float]:
         """The interval for websocket keep-alive pings.
 
         Set websocket_ping_interval = 0 to disable pings.
@@ -208,7 +283,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         return self.settings.get('websocket_ping_interval', None)
 
     @property
-    def ping_timeout(self):
+    def ping_timeout(self) -> Optional[float]:
         """If no ping is received in this many seconds,
         close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
         Default is max of 3 pings or 30 seconds.
@@ -216,7 +291,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         return self.settings.get('websocket_ping_timeout', None)
 
     @property
-    def max_message_size(self):
+    def max_message_size(self) -> int:
         """Maximum allowed message size.
 
         If the remote peer sends a message larger than this, the connection
@@ -226,7 +301,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         return self.settings.get('websocket_max_message_size', _default_max_message_size)
 
-    def write_message(self, message, binary=False):
+    def write_message(self, message: Union[bytes, str, Dict[str, Any]],
+                      binary: bool=False) -> 'Future[None]':
         """Sends the given message to the client of this Web Socket.
 
         The message may be either a string or a dict (which will be
@@ -254,7 +330,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
             message = tornado.escape.json_encode(message)
         return self.ws_connection.write_message(message, binary=binary)
 
-    def select_subprotocol(self, subprotocols):
+    def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]:
         """Override to implement subprotocol negotiation.
 
         ``subprotocols`` is a list of strings identifying the
@@ -280,14 +356,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
         return None
 
     @property
-    def selected_subprotocol(self):
+    def selected_subprotocol(self) -> Optional[str]:
         """The subprotocol returned by `select_subprotocol`.
 
         .. versionadded:: 5.1
         """
+        assert self.ws_connection is not None
         return self.ws_connection.selected_subprotocol
 
-    def get_compression_options(self):
+    def get_compression_options(self) -> Optional[Dict[str, Any]]:
         """Override to return compression options for the connection.
 
         If this method returns None (the default), compression will
@@ -311,7 +388,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         # TODO: Add wbits option.
         return None
 
-    def open(self, *args, **kwargs):
+    def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
         """Invoked when a new WebSocket is opened.
 
         The arguments to `open` are extracted from the `tornado.web.URLSpec`
@@ -327,7 +404,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         pass
 
-    def on_message(self, message):
+    def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
         """Handle incoming messages on the WebSocket
 
         This method must be overridden.
@@ -338,7 +415,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         raise NotImplementedError
 
-    def ping(self, data=b''):
+    def ping(self, data: Union[str, bytes]=b'') -> None:
         """Send ping frame to the remote end.
 
         The data argument allows a small amount of data (up to 125
@@ -359,15 +436,15 @@ class WebSocketHandler(tornado.web.RequestHandler):
             raise WebSocketClosedError()
         self.ws_connection.write_ping(data)
 
-    def on_pong(self, data):
+    def on_pong(self, data: bytes) -> None:
         """Invoked when the response to a ping frame is received."""
         pass
 
-    def on_ping(self, data):
+    def on_ping(self, data: bytes) -> None:
         """Invoked when the a ping frame is received."""
         pass
 
-    def on_close(self):
+    def on_close(self) -> None:
         """Invoked when the WebSocket is closed.
 
         If the connection was closed cleanly and a status code or reason
@@ -380,7 +457,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         pass
 
-    def close(self, code=None, reason=None):
+    def close(self, code: int=None, reason: str=None) -> None:
         """Closes this Web Socket.
 
         Once the close handshake is successful the socket will be closed.
@@ -400,7 +477,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.ws_connection.close(code, reason)
             self.ws_connection = None
 
-    def check_origin(self, origin):
+    def check_origin(self, origin: str) -> bool:
         """Override to enable support for allowing alternate origins.
 
         The ``origin`` argument is the value of the ``Origin`` HTTP
@@ -456,7 +533,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         # Check to see that origin matches host directly, including ports
         return origin == host
 
-    def set_nodelay(self, value):
+    def set_nodelay(self, value: bool) -> None:
         """Set the no-delay flag for this stream.
 
         By default, small messages may be delayed and/or combined to minimize
@@ -470,9 +547,10 @@ class WebSocketHandler(tornado.web.RequestHandler):
 
         .. versionadded:: 3.1
         """
+        assert self.stream is not None
         self.stream.set_nodelay(value)
 
-    def on_connection_close(self):
+    def on_connection_close(self) -> None:
         if self.ws_connection:
             self.ws_connection.on_connection_close()
             self.ws_connection = None
@@ -481,7 +559,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.on_close()
             self._break_cycles()
 
-    def _break_cycles(self):
+    def _break_cycles(self) -> None:
         # WebSocketHandlers call finish() early, but we don't want to
         # break up reference cycles (which makes it impossible to call
         # self.render_string) until after we've really closed the
@@ -490,7 +568,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
         if self.get_status() != 101 or self._on_close_called:
             super(WebSocketHandler, self)._break_cycles()
 
-    def send_error(self, *args, **kwargs):
+    def send_error(self, *args: Any, **kwargs: Any) -> None:
         if self.stream is None:
             super(WebSocketHandler, self).send_error(*args, **kwargs)
         else:
@@ -500,13 +578,14 @@ class WebSocketHandler(tornado.web.RequestHandler):
             # we can close the connection more gracefully.
             self.stream.close()
 
-    def get_websocket_protocol(self):
+    def get_websocket_protocol(self) -> Optional['WebSocketProtocol']:
         websocket_version = self.request.headers.get("Sec-WebSocket-Version")
         if websocket_version in ("7", "8", "13"):
             return WebSocketProtocol13(
                 self, compression_options=self.get_compression_options())
+        return None
 
-    def _attach_stream(self):
+    def _attach_stream(self) -> None:
         self.stream = self.detach()
         self.stream.set_close_callback(self.on_connection_close)
         # disable non-WS methods
@@ -515,21 +594,21 @@ class WebSocketHandler(tornado.web.RequestHandler):
             setattr(self, method, _raise_not_supported_for_websockets)
 
 
-def _raise_not_supported_for_websockets(*args, **kwargs):
+def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None:
     raise RuntimeError("Method not supported for Web Sockets")
 
 
-class WebSocketProtocol(object):
+class WebSocketProtocol(abc.ABC):
     """Base class for WebSocket protocol versions.
     """
-    def __init__(self, handler):
+    def __init__(self, handler: '_WebSocketConnection') -> None:
         self.handler = handler
-        self.request = handler.request
         self.stream = handler.stream
         self.client_terminated = False
         self.server_terminated = False
 
-    def _run_callback(self, callback, *args, **kwargs):
+    def _run_callback(self, callback: Callable,
+                      *args: Any, **kwargs: Any) -> Optional['Future[Any]']:
         """Runs the given callback with exception handling.
 
         If the callback is a coroutine, returns its Future. On error, aborts the
@@ -540,25 +619,71 @@ class WebSocketProtocol(object):
         except Exception:
             self.handler.log_exception(*sys.exc_info())
             self._abort()
+            return None
         else:
             if result is not None:
                 result = gen.convert_yielded(result)
+                assert self.stream is not None
                 self.stream.io_loop.add_future(result, lambda f: f.result())
             return result
 
-    def on_connection_close(self):
+    def on_connection_close(self) -> None:
         self._abort()
 
-    def _abort(self):
+    def _abort(self) -> None:
         """Instantly aborts the WebSocket connection by closing the socket"""
         self.client_terminated = True
         self.server_terminated = True
-        self.stream.close()  # forcibly tear down the connection
+        if self.stream is not None:
+            self.stream.close()  # forcibly tear down the connection
         self.close()  # let the subclass cleanup
 
+    @abc.abstractmethod
+    def close(self, code: int=None, reason: str=None) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def is_closing(self) -> bool:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def accept_connection(self, handler: WebSocketHandler) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
+        raise NotImplementedError()
+
+    @property
+    @abc.abstractmethod
+    def selected_subprotocol(self) -> Optional[str]:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def write_ping(self, data: bytes) -> None:
+        raise NotImplementedError()
+
+    # The entry points below are used by WebSocketClientConnection,
+    # which was introduced after we only supported a single version of
+    # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13
+    # boundary is currently pretty ad-hoc.
+    @abc.abstractmethod
+    def _process_server_headers(self, key: Union[str, bytes],
+                                headers: httputil.HTTPHeaders) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def start_pinging(self) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _receive_frame_loop(self) -> 'Future[None]':
+        raise NotImplementedError()
+
 
 class _PerMessageDeflateCompressor(object):
-    def __init__(self, persistent, max_wbits, compression_options=None):
+    def __init__(self, persistent: bool, max_wbits: Optional[int],
+                 compression_options: Dict[str, Any]=None) -> None:
         if max_wbits is None:
             max_wbits = zlib.MAX_WBITS
         # There is no symbolic constant for the minimum wbits value.
@@ -578,15 +703,15 @@ class _PerMessageDeflateCompressor(object):
             self._mem_level = compression_options['mem_level']
 
         if persistent:
-            self._compressor = self._create_compressor()
+            self._compressor = self._create_compressor()  # type: Optional[_Compressor]
         else:
             self._compressor = None
 
-    def _create_compressor(self):
+    def _create_compressor(self) -> '_Compressor':
         return zlib.compressobj(self._compression_level,
                                 zlib.DEFLATED, -self._max_wbits, self._mem_level)
 
-    def compress(self, data):
+    def compress(self, data: bytes) -> bytes:
         compressor = self._compressor or self._create_compressor()
         data = (compressor.compress(data) +
                 compressor.flush(zlib.Z_SYNC_FLUSH))
@@ -595,7 +720,8 @@ class _PerMessageDeflateCompressor(object):
 
 
 class _PerMessageDeflateDecompressor(object):
-    def __init__(self, persistent, max_wbits, max_message_size, compression_options=None):
+    def __init__(self, persistent: bool, max_wbits: Optional[int], max_message_size: int,
+                 compression_options: Dict[str, Any]=None) -> None:
         self._max_message_size = max_message_size
         if max_wbits is None:
             max_wbits = zlib.MAX_WBITS
@@ -604,14 +730,14 @@ class _PerMessageDeflateDecompressor(object):
                              max_wbits, zlib.MAX_WBITS)
         self._max_wbits = max_wbits
         if persistent:
-            self._decompressor = self._create_decompressor()
+            self._decompressor = self._create_decompressor()  # type: Optional[_Decompressor]
         else:
             self._decompressor = None
 
-    def _create_decompressor(self):
+    def _create_decompressor(self) -> '_Decompressor':
         return zlib.decompressobj(-self._max_wbits)
 
-    def decompress(self, data):
+    def decompress(self, data: bytes) -> bytes:
         decompressor = self._decompressor or self._create_decompressor()
         result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
         if decompressor.unconsumed_tail:
@@ -633,22 +759,24 @@ class WebSocketProtocol13(WebSocketProtocol):
     RSV_MASK = RSV1 | RSV2 | RSV3
     OPCODE_MASK = 0x0f
 
-    def __init__(self, handler, mask_outgoing=False,
-                 compression_options=None):
+    stream = None  # type: IOStream
+
+    def __init__(self, handler: '_WebSocketConnection', mask_outgoing: bool=False,
+                 compression_options: Dict[str, Any]=None) -> None:
         WebSocketProtocol.__init__(self, handler)
         self.mask_outgoing = mask_outgoing
         self._final_frame = False
         self._frame_opcode = None
         self._masked_frame = None
-        self._frame_mask = None
+        self._frame_mask = None  # type: Optional[bytes]
         self._frame_length = None
         self._fragmented_message_buffer = None
         self._fragmented_message_opcode = None
-        self._waiting = None
+        self._waiting = None  # type: object
         self._compression_options = compression_options
-        self._decompressor = None
-        self._compressor = None
-        self._frame_compressed = None
+        self._decompressor = None  # type: Optional[_PerMessageDeflateDecompressor]
+        self._compressor = None  # type: Optional[_PerMessageDeflateCompressor]
+        self._frame_compressed = None  # type: Optional[bool]
         # The total uncompressed size of all messages received or sent.
         # Unicode messages are encoded to utf8.
         # Only for testing; subject to change.
@@ -658,40 +786,49 @@ class WebSocketProtocol13(WebSocketProtocol):
         # the effect of compression, frame overhead, and control frames.
         self._wire_bytes_in = 0
         self._wire_bytes_out = 0
-        self.ping_callback = None
-        self.last_ping = 0
-        self.last_pong = 0
+        self.ping_callback = None  # type: Optional[PeriodicCallback]
+        self.last_ping = 0.0
+        self.last_pong = 0.0
 
-    def accept_connection(self):
+    # Use a property for this to satisfy the abc.
+    @property
+    def selected_subprotocol(self) -> Optional[str]:
+        return self._selected_subprotocol
+
+    @selected_subprotocol.setter
+    def selected_subprotocol(self, value: Optional[str]) -> None:
+        self._selected_subprotocol = value
+
+    def accept_connection(self, handler: WebSocketHandler) -> None:
         try:
-            self._handle_websocket_headers()
+            self._handle_websocket_headers(handler)
         except ValueError:
-            self.handler.set_status(400)
+            handler.set_status(400)
             log_msg = "Missing/Invalid WebSocket headers"
-            self.handler.finish(log_msg)
+            handler.finish(log_msg)
             gen_log.debug(log_msg)
             return
 
         try:
-            self._accept_connection()
+            self._accept_connection(handler)
         except ValueError:
             gen_log.debug("Malformed WebSocket request received",
                           exc_info=True)
             self._abort()
             return
 
-    def _handle_websocket_headers(self):
+    def _handle_websocket_headers(self, handler: WebSocketHandler) -> None:
         """Verifies all invariant- and required headers
 
         If a header is missing or have an incorrect value ValueError will be
         raised
         """
         fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
-        if not all(map(lambda f: self.request.headers.get(f), fields)):
+        if not all(map(lambda f: handler.request.headers.get(f), fields)):
             raise ValueError("Missing/Invalid WebSocket headers")
 
     @staticmethod
-    def compute_accept_value(key):
+    def compute_accept_value(key: Union[str, bytes]) -> str:
         """Computes the value for the Sec-WebSocket-Accept header,
         given the value for Sec-WebSocket-Key.
         """
@@ -700,23 +837,23 @@ class WebSocketProtocol13(WebSocketProtocol):
         sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11")  # Magic value
         return native_str(base64.b64encode(sha1.digest()))
 
-    def _challenge_response(self):
+    def _challenge_response(self, handler: WebSocketHandler) -> str:
         return WebSocketProtocol13.compute_accept_value(
-            self.request.headers.get("Sec-Websocket-Key"))
+            cast(str, handler.request.headers.get("Sec-Websocket-Key")))
 
     @gen.coroutine
-    def _accept_connection(self):
-        subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol")
+    def _accept_connection(self, handler: WebSocketHandler) -> Generator[Any, Any, None]:
+        subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
         if subprotocol_header:
             subprotocols = [s.strip() for s in subprotocol_header.split(',')]
         else:
             subprotocols = []
-        self.selected_subprotocol = self.handler.select_subprotocol(subprotocols)
+        self.selected_subprotocol = handler.select_subprotocol(subprotocols)
         if self.selected_subprotocol:
             assert self.selected_subprotocol in subprotocols
-            self.handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
+            handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
 
-        extensions = self._parse_extensions_header(self.request.headers)
+        extensions = self._parse_extensions_header(handler.request.headers)
         for ext in extensions:
             if (ext[0] == 'permessage-deflate' and
                     self._compression_options is not None):
@@ -728,36 +865,40 @@ class WebSocketProtocol13(WebSocketProtocol):
                     # Don't echo an offered client_max_window_bits
                     # parameter with no value.
                     del ext[1]['client_max_window_bits']
-                self.handler.set_header("Sec-WebSocket-Extensions",
-                                        httputil._encode_header(
-                                            'permessage-deflate', ext[1]))
+                handler.set_header("Sec-WebSocket-Extensions",
+                                   httputil._encode_header(
+                                       'permessage-deflate', ext[1]))
                 break
 
-        self.handler.clear_header("Content-Type")
-        self.handler.set_status(101)
-        self.handler.set_header("Upgrade", "websocket")
-        self.handler.set_header("Connection", "Upgrade")
-        self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response())
-        self.handler.finish()
+        handler.clear_header("Content-Type")
+        handler.set_status(101)
+        handler.set_header("Upgrade", "websocket")
+        handler.set_header("Connection", "Upgrade")
+        handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler))
+        handler.finish()
 
-        self.handler._attach_stream()
-        self.stream = self.handler.stream
+        handler._attach_stream()
+        assert handler.stream is not None
+        self.stream = handler.stream
 
         self.start_pinging()
-        open_result = self._run_callback(self.handler.open, *self.handler.open_args,
-                                         **self.handler.open_kwargs)
+        open_result = self._run_callback(handler.open, *handler.open_args,
+                                         **handler.open_kwargs)
         if open_result is not None:
             yield open_result
         yield self._receive_frame_loop()
 
-    def _parse_extensions_header(self, headers):
+    def _parse_extensions_header(
+            self, headers: httputil.HTTPHeaders
+    ) -> List[Tuple[str, Dict[str, str]]]:
         extensions = headers.get("Sec-WebSocket-Extensions", '')
         if extensions:
             return [httputil._parse_header(e.strip())
                     for e in extensions.split(',')]
         return []
 
-    def _process_server_headers(self, key, headers):
+    def _process_server_headers(self, key: Union[str, bytes],
+                                headers: httputil.HTTPHeaders) -> None:
         """Process the headers sent by the server to this client connection.
 
         'key' is the websocket handshake challenge/response key.
@@ -777,12 +918,13 @@ class WebSocketProtocol13(WebSocketProtocol):
 
         self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None)
 
-    def _get_compressor_options(self, side, agreed_parameters, compression_options=None):
+    def _get_compressor_options(self, side: str, agreed_parameters: Dict[str, Any],
+                                compression_options: Dict[str, Any]=None) -> Dict[str, Any]:
         """Converts a websocket agreed_parameters set to keyword arguments
         for our compressor objects.
         """
-        options = dict(
-            persistent=(side + '_no_context_takeover') not in agreed_parameters)
+        options = dict(persistent=(side + '_no_context_takeover') not in agreed_parameters) \
+            # type: Dict[str, Any]
         wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
         if wbits_header is None:
             options['max_wbits'] = zlib.MAX_WBITS
@@ -791,7 +933,8 @@ class WebSocketProtocol13(WebSocketProtocol):
         options['compression_options'] = compression_options
         return options
 
-    def _create_compressors(self, side, agreed_parameters, compression_options=None):
+    def _create_compressors(self, side: str, agreed_parameters: Dict[str, Any],
+                            compression_options: Dict[str, Any]=None) -> None:
         # TODO: handle invalid parameters gracefully
         allowed_keys = set(['server_no_context_takeover',
                             'client_no_context_takeover',
@@ -807,7 +950,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             max_message_size=self.handler.max_message_size,
             **self._get_compressor_options(other_side, agreed_parameters, compression_options))
 
-    def _write_frame(self, fin, opcode, data, flags=0):
+    def _write_frame(self, fin: bool, opcode: int, data: bytes, flags: int=0) -> 'Future[None]':
         data_len = len(data)
         if opcode & 0x8:
             # All control frames MUST have a payload length of 125
@@ -838,7 +981,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._wire_bytes_out += len(frame)
         return self.stream.write(frame)
 
-    def write_message(self, message, binary=False):
+    def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
         """Sends the given message to the client of this Web Socket."""
         if binary:
             opcode = 0x2
@@ -862,32 +1005,32 @@ class WebSocketProtocol13(WebSocketProtocol):
             raise WebSocketClosedError()
 
         @gen.coroutine
-        def wrapper():
+        def wrapper() -> Generator[Any, Any, None]:
             try:
                 yield fut
             except StreamClosedError:
                 raise WebSocketClosedError()
         return wrapper()
 
-    def write_ping(self, data):
+    def write_ping(self, data: bytes) -> None:
         """Send ping frame."""
         assert isinstance(data, bytes)
         self._write_frame(True, 0x9, data)
 
     @gen.coroutine
-    def _receive_frame_loop(self):
+    def _receive_frame_loop(self) -> Generator[Any, Any, None]:
         try:
             while not self.client_terminated:
                 yield self._receive_frame()
         except StreamClosedError:
             self._abort()
 
-    def _read_bytes(self, n):
+    def _read_bytes(self, n: int) -> Awaitable[bytes]:
         self._wire_bytes_in += n
         return self.stream.read_bytes(n)
 
     @gen.coroutine
-    def _receive_frame(self):
+    def _receive_frame(self) -> Generator[Any, Any, None]:
         # Read the frame header.
         data = yield self._read_bytes(2)
         header, mask_payloadlen = struct.unpack("BB", data)
@@ -934,6 +1077,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._frame_mask = yield self._read_bytes(4)
         data = yield self._read_bytes(payloadlen)
         if is_masked:
+            assert self._frame_mask is not None
             data = _websocket_mask(self._frame_mask, data)
 
         # Decide what to do with this frame.
@@ -969,18 +1113,19 @@ class WebSocketProtocol13(WebSocketProtocol):
             if handled_future is not None:
                 yield handled_future
 
-    def _handle_message(self, opcode, data):
+    def _handle_message(self, opcode: int, data: bytes) -> Optional['Future[None]']:
         """Execute on_message, returning its Future if it is a coroutine."""
         if self.client_terminated:
-            return
+            return None
 
         if self._frame_compressed:
+            assert self._decompressor is not None
             try:
                 data = self._decompressor.decompress(data)
             except _DecompressTooLargeError:
                 self.close(1009, "message too big after decompression")
                 self._abort()
-                return
+                return None
 
         if opcode == 0x1:
             # UTF-8 data
@@ -989,7 +1134,7 @@ class WebSocketProtocol13(WebSocketProtocol):
                 decoded = data.decode("utf-8")
             except UnicodeDecodeError:
                 self._abort()
-                return
+                return None
             return self._run_callback(self.handler.on_message, decoded)
         elif opcode == 0x2:
             # Binary data
@@ -1017,8 +1162,9 @@ class WebSocketProtocol13(WebSocketProtocol):
             return self._run_callback(self.handler.on_pong, data)
         else:
             self._abort()
+        return None
 
-    def close(self, code=None, reason=None):
+    def close(self, code: int=None, reason: str=None) -> None:
         """Closes the WebSocket connection."""
         if not self.server_terminated:
             if not self.stream.closed():
@@ -1046,7 +1192,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._waiting = self.stream.io_loop.add_timeout(
                 self.stream.io_loop.time() + 5, self._abort)
 
-    def is_closing(self):
+    def is_closing(self) -> bool:
         """Return true if this connection is closing.
 
         The connection is considered closing if either side has
@@ -1058,28 +1204,30 @@ class WebSocketProtocol13(WebSocketProtocol):
                 self.server_terminated)
 
     @property
-    def ping_interval(self):
+    def ping_interval(self) -> Optional[float]:
         interval = self.handler.ping_interval
         if interval is not None:
             return interval
         return 0
 
     @property
-    def ping_timeout(self):
+    def ping_timeout(self) -> Optional[float]:
         timeout = self.handler.ping_timeout
         if timeout is not None:
             return timeout
+        assert self.ping_interval is not None
         return max(3 * self.ping_interval, 30)
 
-    def start_pinging(self):
+    def start_pinging(self) -> None:
         """Start sending periodic pings to keep the connection alive"""
+        assert self.ping_interval is not None
         if self.ping_interval > 0:
             self.last_ping = self.last_pong = IOLoop.current().time()
             self.ping_callback = PeriodicCallback(
                 self.periodic_ping, self.ping_interval * 1000)
             self.ping_callback.start()
 
-    def periodic_ping(self):
+    def periodic_ping(self) -> None:
         """Send a ping to keep the websocket alive
 
         Called periodically if the websocket_ping_interval is set and non-zero.
@@ -1094,6 +1242,8 @@ class WebSocketProtocol13(WebSocketProtocol):
         now = IOLoop.current().time()
         since_last_pong = now - self.last_pong
         since_last_ping = now - self.last_ping
+        assert self.ping_interval is not None
+        assert self.ping_timeout is not None
         if (since_last_ping < 2 * self.ping_interval and
                 since_last_pong > self.ping_timeout):
             self.close()
@@ -1109,16 +1259,21 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     This class should not be instantiated directly; use the
     `websocket_connect` function instead.
     """
-    def __init__(self, request, on_message_callback=None,
-                 compression_options=None, ping_interval=None, ping_timeout=None,
-                 max_message_size=None, subprotocols=[]):
+    protocol = None  # type: WebSocketProtocol
+
+    def __init__(self, request: httpclient.HTTPRequest,
+                 on_message_callback: Callable[[Union[None, str, bytes]], None]=None,
+                 compression_options: Dict[str, Any]=None,
+                 ping_interval: float=None, ping_timeout: float=None,
+                 max_message_size: int=_default_max_message_size,
+                 subprotocols: Optional[List[str]]=[]) -> None:
         self.compression_options = compression_options
-        self.connect_future = Future()
-        self.protocol = None
-        self.read_queue = Queue(1)
+        self.connect_future = Future()  # type: Future[WebSocketClientConnection]
+        self.read_queue = Queue(1)  # type: Queue[Union[None, str, bytes]]
         self.key = base64.b64encode(os.urandom(16))
         self._on_message_callback = on_message_callback
-        self.close_code = self.close_reason = None
+        self.close_code = None  # type: Optional[int]
+        self.close_reason = None  # type: Optional[str]
         self.ping_interval = ping_interval
         self.ping_timeout = ping_timeout
         self.max_message_size = max_message_size
@@ -1148,7 +1303,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             None, request, lambda: None, self._on_http_response,
             104857600, self.tcp_client, 65536, 104857600)
 
-    def close(self, code=None, reason=None):
+    def close(self, code: int=None, reason: str=None) -> None:
         """Closes the websocket connection.
 
         ``code`` and ``reason`` are documented under
@@ -1162,16 +1317,16 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         """
         if self.protocol is not None:
             self.protocol.close(code, reason)
-            self.protocol = None
+            self.protocol = None  # type: ignore
 
-    def on_connection_close(self):
+    def on_connection_close(self) -> None:
         if not self.connect_future.done():
             self.connect_future.set_exception(StreamClosedError())
-        self.on_message(None)
+        self._on_message(None)
         self.tcp_client.close()
         super(WebSocketClientConnection, self).on_connection_close()
 
-    def _on_http_response(self, response):
+    def _on_http_response(self, response: httpclient.HTTPResponse) -> None:
         if not self.connect_future.done():
             if response.error:
                 self.connect_future.set_exception(response.error)
@@ -1179,7 +1334,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
                 self.connect_future.set_exception(WebSocketError(
                     "Non-websocket response"))
 
-    def headers_received(self, start_line, headers):
+    def headers_received(self, start_line: Union[httputil.RequestStartLine,
+                                                 httputil.ResponseStartLine],
+                         headers: httputil.HTTPHeaders) -> None:
+        assert isinstance(start_line, httputil.ResponseStartLine)
         if start_line.code != 101:
             return super(WebSocketClientConnection, self).headers_received(
                 start_line, headers)
@@ -1200,11 +1358,11 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         # we set on the http request.  This deactivates the error handling
         # in simple_httpclient that would otherwise interfere with our
         # ability to see exceptions.
-        self.final_callback = None
+        self.final_callback = None  # type: ignore
 
         future_set_result_unless_cancelled(self.connect_future, self)
 
-    def write_message(self, message, binary=False):
+    def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
         """Sends a message to the WebSocket server.
 
         If the stream is closed, raises `WebSocketClosedError`.
@@ -1216,7 +1374,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         """
         return self.protocol.write_message(message, binary=binary)
 
-    def read_message(self, callback=None):
+    def read_message(
+            self, callback: Callable[['Future[Union[None, str, bytes]]'], None]=None
+    ) -> 'Future[Union[None, str, bytes]]':
         """Reads a message from the WebSocket server.
 
         If on_message_callback was specified at WebSocket
@@ -1233,13 +1393,17 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             self.io_loop.add_future(future, callback)
         return future
 
-    def on_message(self, message):
+    def on_message(self, message: Union[str, bytes]) -> Optional['Future[None]']:
+        return self._on_message(message)
+
+    def _on_message(self, message: Union[None, str, bytes]) -> Optional['Future[None]']:
         if self._on_message_callback:
             self._on_message_callback(message)
+            return None
         else:
             return self.read_queue.put(message)
 
-    def ping(self, data=b''):
+    def ping(self, data: bytes=b'') -> None:
         """Send ping frame to the remote end.
 
         The data argument allows a small amount of data (up to 125
@@ -1258,29 +1422,41 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             raise WebSocketClosedError()
         self.protocol.write_ping(data)
 
-    def on_pong(self, data):
+    def on_pong(self, data: bytes) -> None:
         pass
 
-    def on_ping(self, data):
+    def on_ping(self, data: bytes) -> None:
         pass
 
-    def get_websocket_protocol(self):
+    def get_websocket_protocol(self) -> WebSocketProtocol:
         return WebSocketProtocol13(self, mask_outgoing=True,
                                    compression_options=self.compression_options)
 
     @property
-    def selected_subprotocol(self):
+    def selected_subprotocol(self) -> Optional[str]:
         """The subprotocol selected by the server.
 
         .. versionadded:: 5.1
         """
         return self.protocol.selected_subprotocol
 
-
-def websocket_connect(url, callback=None, connect_timeout=None,
-                      on_message_callback=None, compression_options=None,
-                      ping_interval=None, ping_timeout=None,
-                      max_message_size=_default_max_message_size, subprotocols=None):
+    def log_exception(self, typ: Optional[Type[BaseException]],
+                      value: Optional[BaseException],
+                      tb: Optional[TracebackType]) -> None:
+        assert typ is not None
+        assert value is not None
+        app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb))
+
+
+def websocket_connect(
+        url: Union[str, httpclient.HTTPRequest],
+        callback: Callable[['Future[WebSocketClientConnection]'], None]=None,
+        connect_timeout: float=None,
+        on_message_callback: Callable[[Union[None, str, bytes]], None]=None,
+        compression_options: Dict[str, Any]=None,
+        ping_interval: float=None, ping_timeout: float=None,
+        max_message_size: int=_default_max_message_size, subprotocols: List[str]=None
+) -> 'Future[WebSocketClientConnection]':
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a
@@ -1332,8 +1508,8 @@ def websocket_connect(url, callback=None, connect_timeout=None,
         request.headers = httputil.HTTPHeaders(request.headers)
     else:
         request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
-    request = httpclient._RequestProxy(
-        request, httpclient.HTTPRequest._DEFAULTS)
+    request = cast(httpclient.HTTPRequest, httpclient._RequestProxy(
+        request, httpclient.HTTPRequest._DEFAULTS))
     conn = WebSocketClientConnection(request,
                                      on_message_callback=on_message_callback,
                                      compression_options=compression_options,