From: Ben Darnell Date: Mon, 1 Oct 2018 02:20:00 +0000 (-0400) Subject: websocket: Add type annotations X-Git-Tag: v6.0.0b1~28^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=efcb7d83c0c362b3079b56db256b0e0a67eed6ad;p=thirdparty%2Ftornado.git websocket: Add type annotations This is more invasive than usual because it defines a Protocol to be shared between WebSocketHandler and WebSocketClientConnection (and fixes a bug: an uncaught exception in the callback mode of WebSocketClientConnection would fail due to the missing log_exception method). Fixes #2181 --- diff --git a/setup.cfg b/setup.cfg index dfe5d3ad9..d24be9b72 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,9 +7,6 @@ python_version = 3.5 [mypy-tornado.*,tornado.platform.*] disallow_untyped_defs = True -[mypy-tornado.websocket] -disallow_untyped_defs = False - # It's generally too tedious to require type annotations in tests, but # we do want to type check them as much as type inference allows. [mypy-tornado.test.*] diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 473dd3b5e..15ab6e10c 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -208,7 +208,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): class _HTTPConnection(httputil.HTTPMessageDelegate): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) - def __init__(self, client: SimpleAsyncHTTPClient, request: HTTPRequest, + def __init__(self, client: Optional[SimpleAsyncHTTPClient], request: HTTPRequest, release_callback: Callable[[], None], final_callback: Callable[[HTTPResponse], None], max_buffer_size: int, tcp_client: TCPClient, max_header_size: int, max_body_size: int) -> None: diff --git a/tornado/websocket.py b/tornado/websocket.py index d2994c1c9..6600d8ca4 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -16,6 +16,7 @@ the protocol (known as "draft 76") and are not compatible with this module. Removed support for the draft 76 protocol version. """ +import abc import base64 import hashlib import os @@ -31,12 +32,85 @@ from tornado.escape import utf8, native_str, to_unicode from tornado import gen, httpclient, httputil from tornado.ioloop import IOLoop, PeriodicCallback from tornado.iostream import StreamClosedError -from tornado.log import gen_log +from tornado.log import gen_log, app_log from tornado import simple_httpclient from tornado.queues import Queue from tornado.tcpclient import TCPClient from tornado.util import _websocket_mask +from typing import (TYPE_CHECKING, cast, Any, Optional, Dict, Union, List, Awaitable, + Callable, Generator, Tuple, Type) +from types import TracebackType +if TYPE_CHECKING: + from tornado.iostream import IOStream # noqa: F401 + from typing_extensions import Protocol + + # The zlib compressor types aren't actually exposed anywhere + # publicly, so declare protocols for the portions we use. + class _Compressor(Protocol): + def compress(self, data: bytes) -> bytes: + pass + + def flush(self, mode: int) -> bytes: + pass + + class _Decompressor(Protocol): + unconsumed_tail = b'' # type: bytes + + def decompress(self, data: bytes, max_length: int) -> bytes: + pass + + class _WebSocketConnection(Protocol): + # The common base interface implemented by WebSocketHandler on + # the server side and WebSocketClientConnection on the client + # side. + @property + def stream(self) -> Optional[IOStream]: + pass + + @property + def ping_interval(self) -> Optional[float]: + pass + + @property + def ping_timeout(self) -> Optional[float]: + pass + + @property + def max_message_size(self) -> int: + pass + + @property + def close_code(self) -> Optional[int]: + pass + + @close_code.setter + def close_code(self, value: Optional[int]) -> None: + pass + + @property + def close_reason(self) -> Optional[str]: + pass + + @close_reason.setter + def close_reason(self, value: Optional[str]) -> None: + pass + + def on_message(self, message: Union[str, bytes]) -> Optional['Awaitable[None]']: + pass + + def on_ping(self, data: bytes) -> None: + pass + + def on_pong(self, data: bytes) -> None: + pass + + def log_exception(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType]) -> None: + pass + + _default_max_message_size = 10 * 1024 * 1024 @@ -137,15 +211,16 @@ class WebSocketHandler(tornado.web.RequestHandler): Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and ``websocket_max_message_size``. """ - def __init__(self, application, request, **kwargs): + def __init__(self, application: tornado.web.Application, request: httputil.HTTPServerRequest, + **kwargs: Any) -> None: super(WebSocketHandler, self).__init__(application, request, **kwargs) - self.ws_connection = None - self.close_code = None - self.close_reason = None - self.stream = None + self.ws_connection = None # type: Optional[WebSocketProtocol] + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] + self.stream = None # type: Optional[IOStream] self._on_close_called = False - def get(self, *args, **kwargs): + def get(self, *args: Any, **kwargs: Any) -> None: self.open_args = args self.open_kwargs = kwargs @@ -191,7 +266,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection = self.get_websocket_protocol() if self.ws_connection: - self.ws_connection.accept_connection() + self.ws_connection.accept_connection(self) else: self.set_status(426, "Upgrade Required") self.set_header("Sec-WebSocket-Version", "7, 8, 13") @@ -200,7 +275,7 @@ class WebSocketHandler(tornado.web.RequestHandler): stream = None @property - def ping_interval(self): + def ping_interval(self) -> Optional[float]: """The interval for websocket keep-alive pings. Set websocket_ping_interval = 0 to disable pings. @@ -208,7 +283,7 @@ class WebSocketHandler(tornado.web.RequestHandler): return self.settings.get('websocket_ping_interval', None) @property - def ping_timeout(self): + def ping_timeout(self) -> Optional[float]: """If no ping is received in this many seconds, close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). Default is max of 3 pings or 30 seconds. @@ -216,7 +291,7 @@ class WebSocketHandler(tornado.web.RequestHandler): return self.settings.get('websocket_ping_timeout', None) @property - def max_message_size(self): + def max_message_size(self) -> int: """Maximum allowed message size. If the remote peer sends a message larger than this, the connection @@ -226,7 +301,8 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return self.settings.get('websocket_max_message_size', _default_max_message_size) - def write_message(self, message, binary=False): + def write_message(self, message: Union[bytes, str, Dict[str, Any]], + binary: bool=False) -> 'Future[None]': """Sends the given message to the client of this Web Socket. The message may be either a string or a dict (which will be @@ -254,7 +330,7 @@ class WebSocketHandler(tornado.web.RequestHandler): message = tornado.escape.json_encode(message) return self.ws_connection.write_message(message, binary=binary) - def select_subprotocol(self, subprotocols): + def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: """Override to implement subprotocol negotiation. ``subprotocols`` is a list of strings identifying the @@ -280,14 +356,15 @@ class WebSocketHandler(tornado.web.RequestHandler): return None @property - def selected_subprotocol(self): + def selected_subprotocol(self) -> Optional[str]: """The subprotocol returned by `select_subprotocol`. .. versionadded:: 5.1 """ + assert self.ws_connection is not None return self.ws_connection.selected_subprotocol - def get_compression_options(self): + def get_compression_options(self) -> Optional[Dict[str, Any]]: """Override to return compression options for the connection. If this method returns None (the default), compression will @@ -311,7 +388,7 @@ class WebSocketHandler(tornado.web.RequestHandler): # TODO: Add wbits option. return None - def open(self, *args, **kwargs): + def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: """Invoked when a new WebSocket is opened. The arguments to `open` are extracted from the `tornado.web.URLSpec` @@ -327,7 +404,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ pass - def on_message(self, message): + def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: """Handle incoming messages on the WebSocket This method must be overridden. @@ -338,7 +415,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ raise NotImplementedError - def ping(self, data=b''): + def ping(self, data: Union[str, bytes]=b'') -> None: """Send ping frame to the remote end. The data argument allows a small amount of data (up to 125 @@ -359,15 +436,15 @@ class WebSocketHandler(tornado.web.RequestHandler): raise WebSocketClosedError() self.ws_connection.write_ping(data) - def on_pong(self, data): + def on_pong(self, data: bytes) -> None: """Invoked when the response to a ping frame is received.""" pass - def on_ping(self, data): + def on_ping(self, data: bytes) -> None: """Invoked when the a ping frame is received.""" pass - def on_close(self): + def on_close(self) -> None: """Invoked when the WebSocket is closed. If the connection was closed cleanly and a status code or reason @@ -380,7 +457,7 @@ class WebSocketHandler(tornado.web.RequestHandler): """ pass - def close(self, code=None, reason=None): + def close(self, code: int=None, reason: str=None) -> None: """Closes this Web Socket. Once the close handshake is successful the socket will be closed. @@ -400,7 +477,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection.close(code, reason) self.ws_connection = None - def check_origin(self, origin): + def check_origin(self, origin: str) -> bool: """Override to enable support for allowing alternate origins. The ``origin`` argument is the value of the ``Origin`` HTTP @@ -456,7 +533,7 @@ class WebSocketHandler(tornado.web.RequestHandler): # Check to see that origin matches host directly, including ports return origin == host - def set_nodelay(self, value): + def set_nodelay(self, value: bool) -> None: """Set the no-delay flag for this stream. By default, small messages may be delayed and/or combined to minimize @@ -470,9 +547,10 @@ class WebSocketHandler(tornado.web.RequestHandler): .. versionadded:: 3.1 """ + assert self.stream is not None self.stream.set_nodelay(value) - def on_connection_close(self): + def on_connection_close(self) -> None: if self.ws_connection: self.ws_connection.on_connection_close() self.ws_connection = None @@ -481,7 +559,7 @@ class WebSocketHandler(tornado.web.RequestHandler): self.on_close() self._break_cycles() - def _break_cycles(self): + def _break_cycles(self) -> None: # WebSocketHandlers call finish() early, but we don't want to # break up reference cycles (which makes it impossible to call # self.render_string) until after we've really closed the @@ -490,7 +568,7 @@ class WebSocketHandler(tornado.web.RequestHandler): if self.get_status() != 101 or self._on_close_called: super(WebSocketHandler, self)._break_cycles() - def send_error(self, *args, **kwargs): + def send_error(self, *args: Any, **kwargs: Any) -> None: if self.stream is None: super(WebSocketHandler, self).send_error(*args, **kwargs) else: @@ -500,13 +578,14 @@ class WebSocketHandler(tornado.web.RequestHandler): # we can close the connection more gracefully. self.stream.close() - def get_websocket_protocol(self): + def get_websocket_protocol(self) -> Optional['WebSocketProtocol']: websocket_version = self.request.headers.get("Sec-WebSocket-Version") if websocket_version in ("7", "8", "13"): return WebSocketProtocol13( self, compression_options=self.get_compression_options()) + return None - def _attach_stream(self): + def _attach_stream(self) -> None: self.stream = self.detach() self.stream.set_close_callback(self.on_connection_close) # disable non-WS methods @@ -515,21 +594,21 @@ class WebSocketHandler(tornado.web.RequestHandler): setattr(self, method, _raise_not_supported_for_websockets) -def _raise_not_supported_for_websockets(*args, **kwargs): +def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: raise RuntimeError("Method not supported for Web Sockets") -class WebSocketProtocol(object): +class WebSocketProtocol(abc.ABC): """Base class for WebSocket protocol versions. """ - def __init__(self, handler): + def __init__(self, handler: '_WebSocketConnection') -> None: self.handler = handler - self.request = handler.request self.stream = handler.stream self.client_terminated = False self.server_terminated = False - def _run_callback(self, callback, *args, **kwargs): + def _run_callback(self, callback: Callable, + *args: Any, **kwargs: Any) -> Optional['Future[Any]']: """Runs the given callback with exception handling. If the callback is a coroutine, returns its Future. On error, aborts the @@ -540,25 +619,71 @@ class WebSocketProtocol(object): except Exception: self.handler.log_exception(*sys.exc_info()) self._abort() + return None else: if result is not None: result = gen.convert_yielded(result) + assert self.stream is not None self.stream.io_loop.add_future(result, lambda f: f.result()) return result - def on_connection_close(self): + def on_connection_close(self) -> None: self._abort() - def _abort(self): + def _abort(self) -> None: """Instantly aborts the WebSocket connection by closing the socket""" self.client_terminated = True self.server_terminated = True - self.stream.close() # forcibly tear down the connection + if self.stream is not None: + self.stream.close() # forcibly tear down the connection self.close() # let the subclass cleanup + @abc.abstractmethod + def close(self, code: int=None, reason: str=None) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def is_closing(self) -> bool: + raise NotImplementedError() + + @abc.abstractmethod + def accept_connection(self, handler: WebSocketHandler) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]': + raise NotImplementedError() + + @property + @abc.abstractmethod + def selected_subprotocol(self) -> Optional[str]: + raise NotImplementedError() + + @abc.abstractmethod + def write_ping(self, data: bytes) -> None: + raise NotImplementedError() + + # The entry points below are used by WebSocketClientConnection, + # which was introduced after we only supported a single version of + # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13 + # boundary is currently pretty ad-hoc. + @abc.abstractmethod + def _process_server_headers(self, key: Union[str, bytes], + headers: httputil.HTTPHeaders) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def start_pinging(self) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def _receive_frame_loop(self) -> 'Future[None]': + raise NotImplementedError() + class _PerMessageDeflateCompressor(object): - def __init__(self, persistent, max_wbits, compression_options=None): + def __init__(self, persistent: bool, max_wbits: Optional[int], + compression_options: Dict[str, Any]=None) -> None: if max_wbits is None: max_wbits = zlib.MAX_WBITS # There is no symbolic constant for the minimum wbits value. @@ -578,15 +703,15 @@ class _PerMessageDeflateCompressor(object): self._mem_level = compression_options['mem_level'] if persistent: - self._compressor = self._create_compressor() + self._compressor = self._create_compressor() # type: Optional[_Compressor] else: self._compressor = None - def _create_compressor(self): + def _create_compressor(self) -> '_Compressor': return zlib.compressobj(self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level) - def compress(self, data): + def compress(self, data: bytes) -> bytes: compressor = self._compressor or self._create_compressor() data = (compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH)) @@ -595,7 +720,8 @@ class _PerMessageDeflateCompressor(object): class _PerMessageDeflateDecompressor(object): - def __init__(self, persistent, max_wbits, max_message_size, compression_options=None): + def __init__(self, persistent: bool, max_wbits: Optional[int], max_message_size: int, + compression_options: Dict[str, Any]=None) -> None: self._max_message_size = max_message_size if max_wbits is None: max_wbits = zlib.MAX_WBITS @@ -604,14 +730,14 @@ class _PerMessageDeflateDecompressor(object): max_wbits, zlib.MAX_WBITS) self._max_wbits = max_wbits if persistent: - self._decompressor = self._create_decompressor() + self._decompressor = self._create_decompressor() # type: Optional[_Decompressor] else: self._decompressor = None - def _create_decompressor(self): + def _create_decompressor(self) -> '_Decompressor': return zlib.decompressobj(-self._max_wbits) - def decompress(self, data): + def decompress(self, data: bytes) -> bytes: decompressor = self._decompressor or self._create_decompressor() result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size) if decompressor.unconsumed_tail: @@ -633,22 +759,24 @@ class WebSocketProtocol13(WebSocketProtocol): RSV_MASK = RSV1 | RSV2 | RSV3 OPCODE_MASK = 0x0f - def __init__(self, handler, mask_outgoing=False, - compression_options=None): + stream = None # type: IOStream + + def __init__(self, handler: '_WebSocketConnection', mask_outgoing: bool=False, + compression_options: Dict[str, Any]=None) -> None: WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing self._final_frame = False self._frame_opcode = None self._masked_frame = None - self._frame_mask = None + self._frame_mask = None # type: Optional[bytes] self._frame_length = None self._fragmented_message_buffer = None self._fragmented_message_opcode = None - self._waiting = None + self._waiting = None # type: object self._compression_options = compression_options - self._decompressor = None - self._compressor = None - self._frame_compressed = None + self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor] + self._compressor = None # type: Optional[_PerMessageDeflateCompressor] + self._frame_compressed = None # type: Optional[bool] # The total uncompressed size of all messages received or sent. # Unicode messages are encoded to utf8. # Only for testing; subject to change. @@ -658,40 +786,49 @@ class WebSocketProtocol13(WebSocketProtocol): # the effect of compression, frame overhead, and control frames. self._wire_bytes_in = 0 self._wire_bytes_out = 0 - self.ping_callback = None - self.last_ping = 0 - self.last_pong = 0 + self.ping_callback = None # type: Optional[PeriodicCallback] + self.last_ping = 0.0 + self.last_pong = 0.0 - def accept_connection(self): + # Use a property for this to satisfy the abc. + @property + def selected_subprotocol(self) -> Optional[str]: + return self._selected_subprotocol + + @selected_subprotocol.setter + def selected_subprotocol(self, value: Optional[str]) -> None: + self._selected_subprotocol = value + + def accept_connection(self, handler: WebSocketHandler) -> None: try: - self._handle_websocket_headers() + self._handle_websocket_headers(handler) except ValueError: - self.handler.set_status(400) + handler.set_status(400) log_msg = "Missing/Invalid WebSocket headers" - self.handler.finish(log_msg) + handler.finish(log_msg) gen_log.debug(log_msg) return try: - self._accept_connection() + self._accept_connection(handler) except ValueError: gen_log.debug("Malformed WebSocket request received", exc_info=True) self._abort() return - def _handle_websocket_headers(self): + def _handle_websocket_headers(self, handler: WebSocketHandler) -> None: """Verifies all invariant- and required headers If a header is missing or have an incorrect value ValueError will be raised """ fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") - if not all(map(lambda f: self.request.headers.get(f), fields)): + if not all(map(lambda f: handler.request.headers.get(f), fields)): raise ValueError("Missing/Invalid WebSocket headers") @staticmethod - def compute_accept_value(key): + def compute_accept_value(key: Union[str, bytes]) -> str: """Computes the value for the Sec-WebSocket-Accept header, given the value for Sec-WebSocket-Key. """ @@ -700,23 +837,23 @@ class WebSocketProtocol13(WebSocketProtocol): sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value return native_str(base64.b64encode(sha1.digest())) - def _challenge_response(self): + def _challenge_response(self, handler: WebSocketHandler) -> str: return WebSocketProtocol13.compute_accept_value( - self.request.headers.get("Sec-Websocket-Key")) + cast(str, handler.request.headers.get("Sec-Websocket-Key"))) @gen.coroutine - def _accept_connection(self): - subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol") + def _accept_connection(self, handler: WebSocketHandler) -> Generator[Any, Any, None]: + subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") if subprotocol_header: subprotocols = [s.strip() for s in subprotocol_header.split(',')] else: subprotocols = [] - self.selected_subprotocol = self.handler.select_subprotocol(subprotocols) + self.selected_subprotocol = handler.select_subprotocol(subprotocols) if self.selected_subprotocol: assert self.selected_subprotocol in subprotocols - self.handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) + handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) - extensions = self._parse_extensions_header(self.request.headers) + extensions = self._parse_extensions_header(handler.request.headers) for ext in extensions: if (ext[0] == 'permessage-deflate' and self._compression_options is not None): @@ -728,36 +865,40 @@ class WebSocketProtocol13(WebSocketProtocol): # Don't echo an offered client_max_window_bits # parameter with no value. del ext[1]['client_max_window_bits'] - self.handler.set_header("Sec-WebSocket-Extensions", - httputil._encode_header( - 'permessage-deflate', ext[1])) + handler.set_header("Sec-WebSocket-Extensions", + httputil._encode_header( + 'permessage-deflate', ext[1])) break - self.handler.clear_header("Content-Type") - self.handler.set_status(101) - self.handler.set_header("Upgrade", "websocket") - self.handler.set_header("Connection", "Upgrade") - self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response()) - self.handler.finish() + handler.clear_header("Content-Type") + handler.set_status(101) + handler.set_header("Upgrade", "websocket") + handler.set_header("Connection", "Upgrade") + handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) + handler.finish() - self.handler._attach_stream() - self.stream = self.handler.stream + handler._attach_stream() + assert handler.stream is not None + self.stream = handler.stream self.start_pinging() - open_result = self._run_callback(self.handler.open, *self.handler.open_args, - **self.handler.open_kwargs) + open_result = self._run_callback(handler.open, *handler.open_args, + **handler.open_kwargs) if open_result is not None: yield open_result yield self._receive_frame_loop() - def _parse_extensions_header(self, headers): + def _parse_extensions_header( + self, headers: httputil.HTTPHeaders + ) -> List[Tuple[str, Dict[str, str]]]: extensions = headers.get("Sec-WebSocket-Extensions", '') if extensions: return [httputil._parse_header(e.strip()) for e in extensions.split(',')] return [] - def _process_server_headers(self, key, headers): + def _process_server_headers(self, key: Union[str, bytes], + headers: httputil.HTTPHeaders) -> None: """Process the headers sent by the server to this client connection. 'key' is the websocket handshake challenge/response key. @@ -777,12 +918,13 @@ class WebSocketProtocol13(WebSocketProtocol): self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None) - def _get_compressor_options(self, side, agreed_parameters, compression_options=None): + def _get_compressor_options(self, side: str, agreed_parameters: Dict[str, Any], + compression_options: Dict[str, Any]=None) -> Dict[str, Any]: """Converts a websocket agreed_parameters set to keyword arguments for our compressor objects. """ - options = dict( - persistent=(side + '_no_context_takeover') not in agreed_parameters) + options = dict(persistent=(side + '_no_context_takeover') not in agreed_parameters) \ + # type: Dict[str, Any] wbits_header = agreed_parameters.get(side + '_max_window_bits', None) if wbits_header is None: options['max_wbits'] = zlib.MAX_WBITS @@ -791,7 +933,8 @@ class WebSocketProtocol13(WebSocketProtocol): options['compression_options'] = compression_options return options - def _create_compressors(self, side, agreed_parameters, compression_options=None): + def _create_compressors(self, side: str, agreed_parameters: Dict[str, Any], + compression_options: Dict[str, Any]=None) -> None: # TODO: handle invalid parameters gracefully allowed_keys = set(['server_no_context_takeover', 'client_no_context_takeover', @@ -807,7 +950,7 @@ class WebSocketProtocol13(WebSocketProtocol): max_message_size=self.handler.max_message_size, **self._get_compressor_options(other_side, agreed_parameters, compression_options)) - def _write_frame(self, fin, opcode, data, flags=0): + def _write_frame(self, fin: bool, opcode: int, data: bytes, flags: int=0) -> 'Future[None]': data_len = len(data) if opcode & 0x8: # All control frames MUST have a payload length of 125 @@ -838,7 +981,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._wire_bytes_out += len(frame) return self.stream.write(frame) - def write_message(self, message, binary=False): + def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]': """Sends the given message to the client of this Web Socket.""" if binary: opcode = 0x2 @@ -862,32 +1005,32 @@ class WebSocketProtocol13(WebSocketProtocol): raise WebSocketClosedError() @gen.coroutine - def wrapper(): + def wrapper() -> Generator[Any, Any, None]: try: yield fut except StreamClosedError: raise WebSocketClosedError() return wrapper() - def write_ping(self, data): + def write_ping(self, data: bytes) -> None: """Send ping frame.""" assert isinstance(data, bytes) self._write_frame(True, 0x9, data) @gen.coroutine - def _receive_frame_loop(self): + def _receive_frame_loop(self) -> Generator[Any, Any, None]: try: while not self.client_terminated: yield self._receive_frame() except StreamClosedError: self._abort() - def _read_bytes(self, n): + def _read_bytes(self, n: int) -> Awaitable[bytes]: self._wire_bytes_in += n return self.stream.read_bytes(n) @gen.coroutine - def _receive_frame(self): + def _receive_frame(self) -> Generator[Any, Any, None]: # Read the frame header. data = yield self._read_bytes(2) header, mask_payloadlen = struct.unpack("BB", data) @@ -934,6 +1077,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._frame_mask = yield self._read_bytes(4) data = yield self._read_bytes(payloadlen) if is_masked: + assert self._frame_mask is not None data = _websocket_mask(self._frame_mask, data) # Decide what to do with this frame. @@ -969,18 +1113,19 @@ class WebSocketProtocol13(WebSocketProtocol): if handled_future is not None: yield handled_future - def _handle_message(self, opcode, data): + def _handle_message(self, opcode: int, data: bytes) -> Optional['Future[None]']: """Execute on_message, returning its Future if it is a coroutine.""" if self.client_terminated: - return + return None if self._frame_compressed: + assert self._decompressor is not None try: data = self._decompressor.decompress(data) except _DecompressTooLargeError: self.close(1009, "message too big after decompression") self._abort() - return + return None if opcode == 0x1: # UTF-8 data @@ -989,7 +1134,7 @@ class WebSocketProtocol13(WebSocketProtocol): decoded = data.decode("utf-8") except UnicodeDecodeError: self._abort() - return + return None return self._run_callback(self.handler.on_message, decoded) elif opcode == 0x2: # Binary data @@ -1017,8 +1162,9 @@ class WebSocketProtocol13(WebSocketProtocol): return self._run_callback(self.handler.on_pong, data) else: self._abort() + return None - def close(self, code=None, reason=None): + def close(self, code: int=None, reason: str=None) -> None: """Closes the WebSocket connection.""" if not self.server_terminated: if not self.stream.closed(): @@ -1046,7 +1192,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._waiting = self.stream.io_loop.add_timeout( self.stream.io_loop.time() + 5, self._abort) - def is_closing(self): + def is_closing(self) -> bool: """Return true if this connection is closing. The connection is considered closing if either side has @@ -1058,28 +1204,30 @@ class WebSocketProtocol13(WebSocketProtocol): self.server_terminated) @property - def ping_interval(self): + def ping_interval(self) -> Optional[float]: interval = self.handler.ping_interval if interval is not None: return interval return 0 @property - def ping_timeout(self): + def ping_timeout(self) -> Optional[float]: timeout = self.handler.ping_timeout if timeout is not None: return timeout + assert self.ping_interval is not None return max(3 * self.ping_interval, 30) - def start_pinging(self): + def start_pinging(self) -> None: """Start sending periodic pings to keep the connection alive""" + assert self.ping_interval is not None if self.ping_interval > 0: self.last_ping = self.last_pong = IOLoop.current().time() self.ping_callback = PeriodicCallback( self.periodic_ping, self.ping_interval * 1000) self.ping_callback.start() - def periodic_ping(self): + def periodic_ping(self) -> None: """Send a ping to keep the websocket alive Called periodically if the websocket_ping_interval is set and non-zero. @@ -1094,6 +1242,8 @@ class WebSocketProtocol13(WebSocketProtocol): now = IOLoop.current().time() since_last_pong = now - self.last_pong since_last_ping = now - self.last_ping + assert self.ping_interval is not None + assert self.ping_timeout is not None if (since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout): self.close() @@ -1109,16 +1259,21 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): This class should not be instantiated directly; use the `websocket_connect` function instead. """ - def __init__(self, request, on_message_callback=None, - compression_options=None, ping_interval=None, ping_timeout=None, - max_message_size=None, subprotocols=[]): + protocol = None # type: WebSocketProtocol + + def __init__(self, request: httpclient.HTTPRequest, + on_message_callback: Callable[[Union[None, str, bytes]], None]=None, + compression_options: Dict[str, Any]=None, + ping_interval: float=None, ping_timeout: float=None, + max_message_size: int=_default_max_message_size, + subprotocols: Optional[List[str]]=[]) -> None: self.compression_options = compression_options - self.connect_future = Future() - self.protocol = None - self.read_queue = Queue(1) + self.connect_future = Future() # type: Future[WebSocketClientConnection] + self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback - self.close_code = self.close_reason = None + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.max_message_size = max_message_size @@ -1148,7 +1303,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): None, request, lambda: None, self._on_http_response, 104857600, self.tcp_client, 65536, 104857600) - def close(self, code=None, reason=None): + def close(self, code: int=None, reason: str=None) -> None: """Closes the websocket connection. ``code`` and ``reason`` are documented under @@ -1162,16 +1317,16 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): """ if self.protocol is not None: self.protocol.close(code, reason) - self.protocol = None + self.protocol = None # type: ignore - def on_connection_close(self): + def on_connection_close(self) -> None: if not self.connect_future.done(): self.connect_future.set_exception(StreamClosedError()) - self.on_message(None) + self._on_message(None) self.tcp_client.close() super(WebSocketClientConnection, self).on_connection_close() - def _on_http_response(self, response): + def _on_http_response(self, response: httpclient.HTTPResponse) -> None: if not self.connect_future.done(): if response.error: self.connect_future.set_exception(response.error) @@ -1179,7 +1334,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.connect_future.set_exception(WebSocketError( "Non-websocket response")) - def headers_received(self, start_line, headers): + def headers_received(self, start_line: Union[httputil.RequestStartLine, + httputil.ResponseStartLine], + headers: httputil.HTTPHeaders) -> None: + assert isinstance(start_line, httputil.ResponseStartLine) if start_line.code != 101: return super(WebSocketClientConnection, self).headers_received( start_line, headers) @@ -1200,11 +1358,11 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): # we set on the http request. This deactivates the error handling # in simple_httpclient that would otherwise interfere with our # ability to see exceptions. - self.final_callback = None + self.final_callback = None # type: ignore future_set_result_unless_cancelled(self.connect_future, self) - def write_message(self, message, binary=False): + def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]': """Sends a message to the WebSocket server. If the stream is closed, raises `WebSocketClosedError`. @@ -1216,7 +1374,9 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): """ return self.protocol.write_message(message, binary=binary) - def read_message(self, callback=None): + def read_message( + self, callback: Callable[['Future[Union[None, str, bytes]]'], None]=None + ) -> 'Future[Union[None, str, bytes]]': """Reads a message from the WebSocket server. If on_message_callback was specified at WebSocket @@ -1233,13 +1393,17 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.io_loop.add_future(future, callback) return future - def on_message(self, message): + def on_message(self, message: Union[str, bytes]) -> Optional['Future[None]']: + return self._on_message(message) + + def _on_message(self, message: Union[None, str, bytes]) -> Optional['Future[None]']: if self._on_message_callback: self._on_message_callback(message) + return None else: return self.read_queue.put(message) - def ping(self, data=b''): + def ping(self, data: bytes=b'') -> None: """Send ping frame to the remote end. The data argument allows a small amount of data (up to 125 @@ -1258,29 +1422,41 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): raise WebSocketClosedError() self.protocol.write_ping(data) - def on_pong(self, data): + def on_pong(self, data: bytes) -> None: pass - def on_ping(self, data): + def on_ping(self, data: bytes) -> None: pass - def get_websocket_protocol(self): + def get_websocket_protocol(self) -> WebSocketProtocol: return WebSocketProtocol13(self, mask_outgoing=True, compression_options=self.compression_options) @property - def selected_subprotocol(self): + def selected_subprotocol(self) -> Optional[str]: """The subprotocol selected by the server. .. versionadded:: 5.1 """ return self.protocol.selected_subprotocol - -def websocket_connect(url, callback=None, connect_timeout=None, - on_message_callback=None, compression_options=None, - ping_interval=None, ping_timeout=None, - max_message_size=_default_max_message_size, subprotocols=None): + def log_exception(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[TracebackType]) -> None: + assert typ is not None + assert value is not None + app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) + + +def websocket_connect( + url: Union[str, httpclient.HTTPRequest], + callback: Callable[['Future[WebSocketClientConnection]'], None]=None, + connect_timeout: float=None, + on_message_callback: Callable[[Union[None, str, bytes]], None]=None, + compression_options: Dict[str, Any]=None, + ping_interval: float=None, ping_timeout: float=None, + max_message_size: int=_default_max_message_size, subprotocols: List[str]=None +) -> 'Future[WebSocketClientConnection]': """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -1332,8 +1508,8 @@ def websocket_connect(url, callback=None, connect_timeout=None, request.headers = httputil.HTTPHeaders(request.headers) else: request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) - request = httpclient._RequestProxy( - request, httpclient.HTTPRequest._DEFAULTS) + request = cast(httpclient.HTTPRequest, httpclient._RequestProxy( + request, httpclient.HTTPRequest._DEFAULTS)) conn = WebSocketClientConnection(request, on_message_callback=on_message_callback, compression_options=compression_options,