Removed support for the draft 76 protocol version.
"""
+import abc
import base64
import hashlib
import os
from tornado import gen, httpclient, httputil
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.iostream import StreamClosedError
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
from tornado import simple_httpclient
from tornado.queues import Queue
from tornado.tcpclient import TCPClient
from tornado.util import _websocket_mask
+from typing import (TYPE_CHECKING, cast, Any, Optional, Dict, Union, List, Awaitable,
+ Callable, Generator, Tuple, Type)
+from types import TracebackType
+if TYPE_CHECKING:
+ from tornado.iostream import IOStream # noqa: F401
+ from typing_extensions import Protocol
+
+ # The zlib compressor types aren't actually exposed anywhere
+ # publicly, so declare protocols for the portions we use.
+ class _Compressor(Protocol):
+ def compress(self, data: bytes) -> bytes:
+ pass
+
+ def flush(self, mode: int) -> bytes:
+ pass
+
+ class _Decompressor(Protocol):
+ unconsumed_tail = b'' # type: bytes
+
+ def decompress(self, data: bytes, max_length: int) -> bytes:
+ pass
+
+ class _WebSocketConnection(Protocol):
+ # The common base interface implemented by WebSocketHandler on
+ # the server side and WebSocketClientConnection on the client
+ # side.
+ @property
+ def stream(self) -> Optional[IOStream]:
+ pass
+
+ @property
+ def ping_interval(self) -> Optional[float]:
+ pass
+
+ @property
+ def ping_timeout(self) -> Optional[float]:
+ pass
+
+ @property
+ def max_message_size(self) -> int:
+ pass
+
+ @property
+ def close_code(self) -> Optional[int]:
+ pass
+
+ @close_code.setter
+ def close_code(self, value: Optional[int]) -> None:
+ pass
+
+ @property
+ def close_reason(self) -> Optional[str]:
+ pass
+
+ @close_reason.setter
+ def close_reason(self, value: Optional[str]) -> None:
+ pass
+
+ def on_message(self, message: Union[str, bytes]) -> Optional['Awaitable[None]']:
+ pass
+
+ def on_ping(self, data: bytes) -> None:
+ pass
+
+ def on_pong(self, data: bytes) -> None:
+ pass
+
+ def log_exception(self, typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType]) -> None:
+ pass
+
+
_default_max_message_size = 10 * 1024 * 1024
Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and
``websocket_max_message_size``.
"""
- def __init__(self, application, request, **kwargs):
+ def __init__(self, application: tornado.web.Application, request: httputil.HTTPServerRequest,
+ **kwargs: Any) -> None:
super(WebSocketHandler, self).__init__(application, request, **kwargs)
- self.ws_connection = None
- self.close_code = None
- self.close_reason = None
- self.stream = None
+ self.ws_connection = None # type: Optional[WebSocketProtocol]
+ self.close_code = None # type: Optional[int]
+ self.close_reason = None # type: Optional[str]
+ self.stream = None # type: Optional[IOStream]
self._on_close_called = False
- def get(self, *args, **kwargs):
+ def get(self, *args: Any, **kwargs: Any) -> None:
self.open_args = args
self.open_kwargs = kwargs
self.ws_connection = self.get_websocket_protocol()
if self.ws_connection:
- self.ws_connection.accept_connection()
+ self.ws_connection.accept_connection(self)
else:
self.set_status(426, "Upgrade Required")
self.set_header("Sec-WebSocket-Version", "7, 8, 13")
stream = None
@property
- def ping_interval(self):
+ def ping_interval(self) -> Optional[float]:
"""The interval for websocket keep-alive pings.
Set websocket_ping_interval = 0 to disable pings.
return self.settings.get('websocket_ping_interval', None)
@property
- def ping_timeout(self):
+ def ping_timeout(self) -> Optional[float]:
"""If no ping is received in this many seconds,
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
Default is max of 3 pings or 30 seconds.
return self.settings.get('websocket_ping_timeout', None)
@property
- def max_message_size(self):
+ def max_message_size(self) -> int:
"""Maximum allowed message size.
If the remote peer sends a message larger than this, the connection
"""
return self.settings.get('websocket_max_message_size', _default_max_message_size)
- def write_message(self, message, binary=False):
+ def write_message(self, message: Union[bytes, str, Dict[str, Any]],
+ binary: bool=False) -> 'Future[None]':
"""Sends the given message to the client of this Web Socket.
The message may be either a string or a dict (which will be
message = tornado.escape.json_encode(message)
return self.ws_connection.write_message(message, binary=binary)
- def select_subprotocol(self, subprotocols):
+ def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]:
"""Override to implement subprotocol negotiation.
``subprotocols`` is a list of strings identifying the
return None
@property
- def selected_subprotocol(self):
+ def selected_subprotocol(self) -> Optional[str]:
"""The subprotocol returned by `select_subprotocol`.
.. versionadded:: 5.1
"""
+ assert self.ws_connection is not None
return self.ws_connection.selected_subprotocol
- def get_compression_options(self):
+ def get_compression_options(self) -> Optional[Dict[str, Any]]:
"""Override to return compression options for the connection.
If this method returns None (the default), compression will
# TODO: Add wbits option.
return None
- def open(self, *args, **kwargs):
+ def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]:
"""Invoked when a new WebSocket is opened.
The arguments to `open` are extracted from the `tornado.web.URLSpec`
"""
pass
- def on_message(self, message):
+ def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]:
"""Handle incoming messages on the WebSocket
This method must be overridden.
"""
raise NotImplementedError
- def ping(self, data=b''):
+ def ping(self, data: Union[str, bytes]=b'') -> None:
"""Send ping frame to the remote end.
The data argument allows a small amount of data (up to 125
raise WebSocketClosedError()
self.ws_connection.write_ping(data)
- def on_pong(self, data):
+ def on_pong(self, data: bytes) -> None:
"""Invoked when the response to a ping frame is received."""
pass
- def on_ping(self, data):
+ def on_ping(self, data: bytes) -> None:
"""Invoked when the a ping frame is received."""
pass
- def on_close(self):
+ def on_close(self) -> None:
"""Invoked when the WebSocket is closed.
If the connection was closed cleanly and a status code or reason
"""
pass
- def close(self, code=None, reason=None):
+ def close(self, code: int=None, reason: str=None) -> None:
"""Closes this Web Socket.
Once the close handshake is successful the socket will be closed.
self.ws_connection.close(code, reason)
self.ws_connection = None
- def check_origin(self, origin):
+ def check_origin(self, origin: str) -> bool:
"""Override to enable support for allowing alternate origins.
The ``origin`` argument is the value of the ``Origin`` HTTP
# Check to see that origin matches host directly, including ports
return origin == host
- def set_nodelay(self, value):
+ def set_nodelay(self, value: bool) -> None:
"""Set the no-delay flag for this stream.
By default, small messages may be delayed and/or combined to minimize
.. versionadded:: 3.1
"""
+ assert self.stream is not None
self.stream.set_nodelay(value)
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
if self.ws_connection:
self.ws_connection.on_connection_close()
self.ws_connection = None
self.on_close()
self._break_cycles()
- def _break_cycles(self):
+ def _break_cycles(self) -> None:
# WebSocketHandlers call finish() early, but we don't want to
# break up reference cycles (which makes it impossible to call
# self.render_string) until after we've really closed the
if self.get_status() != 101 or self._on_close_called:
super(WebSocketHandler, self)._break_cycles()
- def send_error(self, *args, **kwargs):
+ def send_error(self, *args: Any, **kwargs: Any) -> None:
if self.stream is None:
super(WebSocketHandler, self).send_error(*args, **kwargs)
else:
# we can close the connection more gracefully.
self.stream.close()
- def get_websocket_protocol(self):
+ def get_websocket_protocol(self) -> Optional['WebSocketProtocol']:
websocket_version = self.request.headers.get("Sec-WebSocket-Version")
if websocket_version in ("7", "8", "13"):
return WebSocketProtocol13(
self, compression_options=self.get_compression_options())
+ return None
- def _attach_stream(self):
+ def _attach_stream(self) -> None:
self.stream = self.detach()
self.stream.set_close_callback(self.on_connection_close)
# disable non-WS methods
setattr(self, method, _raise_not_supported_for_websockets)
-def _raise_not_supported_for_websockets(*args, **kwargs):
+def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None:
raise RuntimeError("Method not supported for Web Sockets")
-class WebSocketProtocol(object):
+class WebSocketProtocol(abc.ABC):
"""Base class for WebSocket protocol versions.
"""
- def __init__(self, handler):
+ def __init__(self, handler: '_WebSocketConnection') -> None:
self.handler = handler
- self.request = handler.request
self.stream = handler.stream
self.client_terminated = False
self.server_terminated = False
- def _run_callback(self, callback, *args, **kwargs):
+ def _run_callback(self, callback: Callable,
+ *args: Any, **kwargs: Any) -> Optional['Future[Any]']:
"""Runs the given callback with exception handling.
If the callback is a coroutine, returns its Future. On error, aborts the
except Exception:
self.handler.log_exception(*sys.exc_info())
self._abort()
+ return None
else:
if result is not None:
result = gen.convert_yielded(result)
+ assert self.stream is not None
self.stream.io_loop.add_future(result, lambda f: f.result())
return result
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
self._abort()
- def _abort(self):
+ def _abort(self) -> None:
"""Instantly aborts the WebSocket connection by closing the socket"""
self.client_terminated = True
self.server_terminated = True
- self.stream.close() # forcibly tear down the connection
+ if self.stream is not None:
+ self.stream.close() # forcibly tear down the connection
self.close() # let the subclass cleanup
+ @abc.abstractmethod
+ def close(self, code: int=None, reason: str=None) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def is_closing(self) -> bool:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def accept_connection(self, handler: WebSocketHandler) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
+ raise NotImplementedError()
+
+ @property
+ @abc.abstractmethod
+ def selected_subprotocol(self) -> Optional[str]:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def write_ping(self, data: bytes) -> None:
+ raise NotImplementedError()
+
+ # The entry points below are used by WebSocketClientConnection,
+ # which was introduced after we only supported a single version of
+ # WebSocketProtocol. The WebSocketProtocol/WebSocketProtocol13
+ # boundary is currently pretty ad-hoc.
+ @abc.abstractmethod
+ def _process_server_headers(self, key: Union[str, bytes],
+ headers: httputil.HTTPHeaders) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def start_pinging(self) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _receive_frame_loop(self) -> 'Future[None]':
+ raise NotImplementedError()
+
class _PerMessageDeflateCompressor(object):
- def __init__(self, persistent, max_wbits, compression_options=None):
+ def __init__(self, persistent: bool, max_wbits: Optional[int],
+ compression_options: Dict[str, Any]=None) -> None:
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
# There is no symbolic constant for the minimum wbits value.
self._mem_level = compression_options['mem_level']
if persistent:
- self._compressor = self._create_compressor()
+ self._compressor = self._create_compressor() # type: Optional[_Compressor]
else:
self._compressor = None
- def _create_compressor(self):
+ def _create_compressor(self) -> '_Compressor':
return zlib.compressobj(self._compression_level,
zlib.DEFLATED, -self._max_wbits, self._mem_level)
- def compress(self, data):
+ def compress(self, data: bytes) -> bytes:
compressor = self._compressor or self._create_compressor()
data = (compressor.compress(data) +
compressor.flush(zlib.Z_SYNC_FLUSH))
class _PerMessageDeflateDecompressor(object):
- def __init__(self, persistent, max_wbits, max_message_size, compression_options=None):
+ def __init__(self, persistent: bool, max_wbits: Optional[int], max_message_size: int,
+ compression_options: Dict[str, Any]=None) -> None:
self._max_message_size = max_message_size
if max_wbits is None:
max_wbits = zlib.MAX_WBITS
max_wbits, zlib.MAX_WBITS)
self._max_wbits = max_wbits
if persistent:
- self._decompressor = self._create_decompressor()
+ self._decompressor = self._create_decompressor() # type: Optional[_Decompressor]
else:
self._decompressor = None
- def _create_decompressor(self):
+ def _create_decompressor(self) -> '_Decompressor':
return zlib.decompressobj(-self._max_wbits)
- def decompress(self, data):
+ def decompress(self, data: bytes) -> bytes:
decompressor = self._decompressor or self._create_decompressor()
result = decompressor.decompress(data + b'\x00\x00\xff\xff', self._max_message_size)
if decompressor.unconsumed_tail:
RSV_MASK = RSV1 | RSV2 | RSV3
OPCODE_MASK = 0x0f
- def __init__(self, handler, mask_outgoing=False,
- compression_options=None):
+ stream = None # type: IOStream
+
+ def __init__(self, handler: '_WebSocketConnection', mask_outgoing: bool=False,
+ compression_options: Dict[str, Any]=None) -> None:
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
self._final_frame = False
self._frame_opcode = None
self._masked_frame = None
- self._frame_mask = None
+ self._frame_mask = None # type: Optional[bytes]
self._frame_length = None
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
- self._waiting = None
+ self._waiting = None # type: object
self._compression_options = compression_options
- self._decompressor = None
- self._compressor = None
- self._frame_compressed = None
+ self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor]
+ self._compressor = None # type: Optional[_PerMessageDeflateCompressor]
+ self._frame_compressed = None # type: Optional[bool]
# The total uncompressed size of all messages received or sent.
# Unicode messages are encoded to utf8.
# Only for testing; subject to change.
# the effect of compression, frame overhead, and control frames.
self._wire_bytes_in = 0
self._wire_bytes_out = 0
- self.ping_callback = None
- self.last_ping = 0
- self.last_pong = 0
+ self.ping_callback = None # type: Optional[PeriodicCallback]
+ self.last_ping = 0.0
+ self.last_pong = 0.0
- def accept_connection(self):
+ # Use a property for this to satisfy the abc.
+ @property
+ def selected_subprotocol(self) -> Optional[str]:
+ return self._selected_subprotocol
+
+ @selected_subprotocol.setter
+ def selected_subprotocol(self, value: Optional[str]) -> None:
+ self._selected_subprotocol = value
+
+ def accept_connection(self, handler: WebSocketHandler) -> None:
try:
- self._handle_websocket_headers()
+ self._handle_websocket_headers(handler)
except ValueError:
- self.handler.set_status(400)
+ handler.set_status(400)
log_msg = "Missing/Invalid WebSocket headers"
- self.handler.finish(log_msg)
+ handler.finish(log_msg)
gen_log.debug(log_msg)
return
try:
- self._accept_connection()
+ self._accept_connection(handler)
except ValueError:
gen_log.debug("Malformed WebSocket request received",
exc_info=True)
self._abort()
return
- def _handle_websocket_headers(self):
+ def _handle_websocket_headers(self, handler: WebSocketHandler) -> None:
"""Verifies all invariant- and required headers
If a header is missing or have an incorrect value ValueError will be
raised
"""
fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
- if not all(map(lambda f: self.request.headers.get(f), fields)):
+ if not all(map(lambda f: handler.request.headers.get(f), fields)):
raise ValueError("Missing/Invalid WebSocket headers")
@staticmethod
- def compute_accept_value(key):
+ def compute_accept_value(key: Union[str, bytes]) -> str:
"""Computes the value for the Sec-WebSocket-Accept header,
given the value for Sec-WebSocket-Key.
"""
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
return native_str(base64.b64encode(sha1.digest()))
- def _challenge_response(self):
+ def _challenge_response(self, handler: WebSocketHandler) -> str:
return WebSocketProtocol13.compute_accept_value(
- self.request.headers.get("Sec-Websocket-Key"))
+ cast(str, handler.request.headers.get("Sec-Websocket-Key")))
@gen.coroutine
- def _accept_connection(self):
- subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol")
+ def _accept_connection(self, handler: WebSocketHandler) -> Generator[Any, Any, None]:
+ subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
subprotocols = [s.strip() for s in subprotocol_header.split(',')]
else:
subprotocols = []
- self.selected_subprotocol = self.handler.select_subprotocol(subprotocols)
+ self.selected_subprotocol = handler.select_subprotocol(subprotocols)
if self.selected_subprotocol:
assert self.selected_subprotocol in subprotocols
- self.handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
+ handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
- extensions = self._parse_extensions_header(self.request.headers)
+ extensions = self._parse_extensions_header(handler.request.headers)
for ext in extensions:
if (ext[0] == 'permessage-deflate' and
self._compression_options is not None):
# Don't echo an offered client_max_window_bits
# parameter with no value.
del ext[1]['client_max_window_bits']
- self.handler.set_header("Sec-WebSocket-Extensions",
- httputil._encode_header(
- 'permessage-deflate', ext[1]))
+ handler.set_header("Sec-WebSocket-Extensions",
+ httputil._encode_header(
+ 'permessage-deflate', ext[1]))
break
- self.handler.clear_header("Content-Type")
- self.handler.set_status(101)
- self.handler.set_header("Upgrade", "websocket")
- self.handler.set_header("Connection", "Upgrade")
- self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response())
- self.handler.finish()
+ handler.clear_header("Content-Type")
+ handler.set_status(101)
+ handler.set_header("Upgrade", "websocket")
+ handler.set_header("Connection", "Upgrade")
+ handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler))
+ handler.finish()
- self.handler._attach_stream()
- self.stream = self.handler.stream
+ handler._attach_stream()
+ assert handler.stream is not None
+ self.stream = handler.stream
self.start_pinging()
- open_result = self._run_callback(self.handler.open, *self.handler.open_args,
- **self.handler.open_kwargs)
+ open_result = self._run_callback(handler.open, *handler.open_args,
+ **handler.open_kwargs)
if open_result is not None:
yield open_result
yield self._receive_frame_loop()
- def _parse_extensions_header(self, headers):
+ def _parse_extensions_header(
+ self, headers: httputil.HTTPHeaders
+ ) -> List[Tuple[str, Dict[str, str]]]:
extensions = headers.get("Sec-WebSocket-Extensions", '')
if extensions:
return [httputil._parse_header(e.strip())
for e in extensions.split(',')]
return []
- def _process_server_headers(self, key, headers):
+ def _process_server_headers(self, key: Union[str, bytes],
+ headers: httputil.HTTPHeaders) -> None:
"""Process the headers sent by the server to this client connection.
'key' is the websocket handshake challenge/response key.
self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None)
- def _get_compressor_options(self, side, agreed_parameters, compression_options=None):
+ def _get_compressor_options(self, side: str, agreed_parameters: Dict[str, Any],
+ compression_options: Dict[str, Any]=None) -> Dict[str, Any]:
"""Converts a websocket agreed_parameters set to keyword arguments
for our compressor objects.
"""
- options = dict(
- persistent=(side + '_no_context_takeover') not in agreed_parameters)
+ options = dict(persistent=(side + '_no_context_takeover') not in agreed_parameters) \
+ # type: Dict[str, Any]
wbits_header = agreed_parameters.get(side + '_max_window_bits', None)
if wbits_header is None:
options['max_wbits'] = zlib.MAX_WBITS
options['compression_options'] = compression_options
return options
- def _create_compressors(self, side, agreed_parameters, compression_options=None):
+ def _create_compressors(self, side: str, agreed_parameters: Dict[str, Any],
+ compression_options: Dict[str, Any]=None) -> None:
# TODO: handle invalid parameters gracefully
allowed_keys = set(['server_no_context_takeover',
'client_no_context_takeover',
max_message_size=self.handler.max_message_size,
**self._get_compressor_options(other_side, agreed_parameters, compression_options))
- def _write_frame(self, fin, opcode, data, flags=0):
+ def _write_frame(self, fin: bool, opcode: int, data: bytes, flags: int=0) -> 'Future[None]':
data_len = len(data)
if opcode & 0x8:
# All control frames MUST have a payload length of 125
self._wire_bytes_out += len(frame)
return self.stream.write(frame)
- def write_message(self, message, binary=False):
+ def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
"""Sends the given message to the client of this Web Socket."""
if binary:
opcode = 0x2
raise WebSocketClosedError()
@gen.coroutine
- def wrapper():
+ def wrapper() -> Generator[Any, Any, None]:
try:
yield fut
except StreamClosedError:
raise WebSocketClosedError()
return wrapper()
- def write_ping(self, data):
+ def write_ping(self, data: bytes) -> None:
"""Send ping frame."""
assert isinstance(data, bytes)
self._write_frame(True, 0x9, data)
@gen.coroutine
- def _receive_frame_loop(self):
+ def _receive_frame_loop(self) -> Generator[Any, Any, None]:
try:
while not self.client_terminated:
yield self._receive_frame()
except StreamClosedError:
self._abort()
- def _read_bytes(self, n):
+ def _read_bytes(self, n: int) -> Awaitable[bytes]:
self._wire_bytes_in += n
return self.stream.read_bytes(n)
@gen.coroutine
- def _receive_frame(self):
+ def _receive_frame(self) -> Generator[Any, Any, None]:
# Read the frame header.
data = yield self._read_bytes(2)
header, mask_payloadlen = struct.unpack("BB", data)
self._frame_mask = yield self._read_bytes(4)
data = yield self._read_bytes(payloadlen)
if is_masked:
+ assert self._frame_mask is not None
data = _websocket_mask(self._frame_mask, data)
# Decide what to do with this frame.
if handled_future is not None:
yield handled_future
- def _handle_message(self, opcode, data):
+ def _handle_message(self, opcode: int, data: bytes) -> Optional['Future[None]']:
"""Execute on_message, returning its Future if it is a coroutine."""
if self.client_terminated:
- return
+ return None
if self._frame_compressed:
+ assert self._decompressor is not None
try:
data = self._decompressor.decompress(data)
except _DecompressTooLargeError:
self.close(1009, "message too big after decompression")
self._abort()
- return
+ return None
if opcode == 0x1:
# UTF-8 data
decoded = data.decode("utf-8")
except UnicodeDecodeError:
self._abort()
- return
+ return None
return self._run_callback(self.handler.on_message, decoded)
elif opcode == 0x2:
# Binary data
return self._run_callback(self.handler.on_pong, data)
else:
self._abort()
+ return None
- def close(self, code=None, reason=None):
+ def close(self, code: int=None, reason: str=None) -> None:
"""Closes the WebSocket connection."""
if not self.server_terminated:
if not self.stream.closed():
self._waiting = self.stream.io_loop.add_timeout(
self.stream.io_loop.time() + 5, self._abort)
- def is_closing(self):
+ def is_closing(self) -> bool:
"""Return true if this connection is closing.
The connection is considered closing if either side has
self.server_terminated)
@property
- def ping_interval(self):
+ def ping_interval(self) -> Optional[float]:
interval = self.handler.ping_interval
if interval is not None:
return interval
return 0
@property
- def ping_timeout(self):
+ def ping_timeout(self) -> Optional[float]:
timeout = self.handler.ping_timeout
if timeout is not None:
return timeout
+ assert self.ping_interval is not None
return max(3 * self.ping_interval, 30)
- def start_pinging(self):
+ def start_pinging(self) -> None:
"""Start sending periodic pings to keep the connection alive"""
+ assert self.ping_interval is not None
if self.ping_interval > 0:
self.last_ping = self.last_pong = IOLoop.current().time()
self.ping_callback = PeriodicCallback(
self.periodic_ping, self.ping_interval * 1000)
self.ping_callback.start()
- def periodic_ping(self):
+ def periodic_ping(self) -> None:
"""Send a ping to keep the websocket alive
Called periodically if the websocket_ping_interval is set and non-zero.
now = IOLoop.current().time()
since_last_pong = now - self.last_pong
since_last_ping = now - self.last_ping
+ assert self.ping_interval is not None
+ assert self.ping_timeout is not None
if (since_last_ping < 2 * self.ping_interval and
since_last_pong > self.ping_timeout):
self.close()
This class should not be instantiated directly; use the
`websocket_connect` function instead.
"""
- def __init__(self, request, on_message_callback=None,
- compression_options=None, ping_interval=None, ping_timeout=None,
- max_message_size=None, subprotocols=[]):
+ protocol = None # type: WebSocketProtocol
+
+ def __init__(self, request: httpclient.HTTPRequest,
+ on_message_callback: Callable[[Union[None, str, bytes]], None]=None,
+ compression_options: Dict[str, Any]=None,
+ ping_interval: float=None, ping_timeout: float=None,
+ max_message_size: int=_default_max_message_size,
+ subprotocols: Optional[List[str]]=[]) -> None:
self.compression_options = compression_options
- self.connect_future = Future()
- self.protocol = None
- self.read_queue = Queue(1)
+ self.connect_future = Future() # type: Future[WebSocketClientConnection]
+ self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]]
self.key = base64.b64encode(os.urandom(16))
self._on_message_callback = on_message_callback
- self.close_code = self.close_reason = None
+ self.close_code = None # type: Optional[int]
+ self.close_reason = None # type: Optional[str]
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.max_message_size = max_message_size
None, request, lambda: None, self._on_http_response,
104857600, self.tcp_client, 65536, 104857600)
- def close(self, code=None, reason=None):
+ def close(self, code: int=None, reason: str=None) -> None:
"""Closes the websocket connection.
``code`` and ``reason`` are documented under
"""
if self.protocol is not None:
self.protocol.close(code, reason)
- self.protocol = None
+ self.protocol = None # type: ignore
- def on_connection_close(self):
+ def on_connection_close(self) -> None:
if not self.connect_future.done():
self.connect_future.set_exception(StreamClosedError())
- self.on_message(None)
+ self._on_message(None)
self.tcp_client.close()
super(WebSocketClientConnection, self).on_connection_close()
- def _on_http_response(self, response):
+ def _on_http_response(self, response: httpclient.HTTPResponse) -> None:
if not self.connect_future.done():
if response.error:
self.connect_future.set_exception(response.error)
self.connect_future.set_exception(WebSocketError(
"Non-websocket response"))
- def headers_received(self, start_line, headers):
+ def headers_received(self, start_line: Union[httputil.RequestStartLine,
+ httputil.ResponseStartLine],
+ headers: httputil.HTTPHeaders) -> None:
+ assert isinstance(start_line, httputil.ResponseStartLine)
if start_line.code != 101:
return super(WebSocketClientConnection, self).headers_received(
start_line, headers)
# we set on the http request. This deactivates the error handling
# in simple_httpclient that would otherwise interfere with our
# ability to see exceptions.
- self.final_callback = None
+ self.final_callback = None # type: ignore
future_set_result_unless_cancelled(self.connect_future, self)
- def write_message(self, message, binary=False):
+ def write_message(self, message: Union[str, bytes], binary: bool=False) -> 'Future[None]':
"""Sends a message to the WebSocket server.
If the stream is closed, raises `WebSocketClosedError`.
"""
return self.protocol.write_message(message, binary=binary)
- def read_message(self, callback=None):
+ def read_message(
+ self, callback: Callable[['Future[Union[None, str, bytes]]'], None]=None
+ ) -> 'Future[Union[None, str, bytes]]':
"""Reads a message from the WebSocket server.
If on_message_callback was specified at WebSocket
self.io_loop.add_future(future, callback)
return future
- def on_message(self, message):
+ def on_message(self, message: Union[str, bytes]) -> Optional['Future[None]']:
+ return self._on_message(message)
+
+ def _on_message(self, message: Union[None, str, bytes]) -> Optional['Future[None]']:
if self._on_message_callback:
self._on_message_callback(message)
+ return None
else:
return self.read_queue.put(message)
- def ping(self, data=b''):
+ def ping(self, data: bytes=b'') -> None:
"""Send ping frame to the remote end.
The data argument allows a small amount of data (up to 125
raise WebSocketClosedError()
self.protocol.write_ping(data)
- def on_pong(self, data):
+ def on_pong(self, data: bytes) -> None:
pass
- def on_ping(self, data):
+ def on_ping(self, data: bytes) -> None:
pass
- def get_websocket_protocol(self):
+ def get_websocket_protocol(self) -> WebSocketProtocol:
return WebSocketProtocol13(self, mask_outgoing=True,
compression_options=self.compression_options)
@property
- def selected_subprotocol(self):
+ def selected_subprotocol(self) -> Optional[str]:
"""The subprotocol selected by the server.
.. versionadded:: 5.1
"""
return self.protocol.selected_subprotocol
-
-def websocket_connect(url, callback=None, connect_timeout=None,
- on_message_callback=None, compression_options=None,
- ping_interval=None, ping_timeout=None,
- max_message_size=_default_max_message_size, subprotocols=None):
+ def log_exception(self, typ: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ tb: Optional[TracebackType]) -> None:
+ assert typ is not None
+ assert value is not None
+ app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb))
+
+
+def websocket_connect(
+ url: Union[str, httpclient.HTTPRequest],
+ callback: Callable[['Future[WebSocketClientConnection]'], None]=None,
+ connect_timeout: float=None,
+ on_message_callback: Callable[[Union[None, str, bytes]], None]=None,
+ compression_options: Dict[str, Any]=None,
+ ping_interval: float=None, ping_timeout: float=None,
+ max_message_size: int=_default_max_message_size, subprotocols: List[str]=None
+) -> 'Future[WebSocketClientConnection]':
"""Client-side websocket support.
Takes a url and returns a Future whose result is a
request.headers = httputil.HTTPHeaders(request.headers)
else:
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
- request = httpclient._RequestProxy(
- request, httpclient.HTTPRequest._DEFAULTS)
+ request = cast(httpclient.HTTPRequest, httpclient._RequestProxy(
+ request, httpclient.HTTPRequest._DEFAULTS))
conn = WebSocketClientConnection(request,
on_message_callback=on_message_callback,
compression_options=compression_options,