def decompress(self, data: bytes, max_length: int) -> bytes:
pass
- class _WebSocketConnection(Protocol):
+ class _WebSocketDelegate(Protocol):
# The common base interface implemented by WebSocketHandler on
# the server side and WebSocketClientConnection on the client
# side.
- @property
- def ping_interval(self) -> Optional[float]:
- pass
-
- @property
- def ping_timeout(self) -> Optional[float]:
- pass
-
- @property
- def max_message_size(self) -> int:
- pass
-
- @property
- def close_code(self) -> Optional[int]:
- pass
-
- @close_code.setter
- def close_code(self, value: Optional[int]) -> None:
- pass
-
- @property
- def close_reason(self) -> Optional[str]:
- pass
-
- @close_reason.setter
- def close_reason(self, value: Optional[str]) -> None:
- pass
-
- def on_connection_close(self) -> None:
+ def on_ws_connection_close(
+ self, close_code: int = None, close_reason: str = None
+ ) -> None:
pass
def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]:
pass
+class _WebSocketParams(object):
+ def __init__(
+ self,
+ ping_interval: float = None,
+ ping_timeout: float = None,
+ max_message_size: int = _default_max_message_size,
+ compression_options: Dict[str, Any] = None,
+ ) -> None:
+ self.ping_interval = ping_interval
+ self.ping_timeout = ping_timeout
+ self.max_message_size = max_message_size
+ self.compression_options = compression_options
+
+
class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
self.on_close()
self._break_cycles()
+ def on_ws_connection_close(
+ self, close_code: int = None, close_reason: str = None
+ ) -> None:
+ self.close_code = close_code
+ self.close_reason = close_reason
+ self.on_connection_close()
+
def _break_cycles(self) -> None:
# WebSocketHandlers call finish() early, but we don't want to
# break up reference cycles (which makes it impossible to call
def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]:
websocket_version = self.request.headers.get("Sec-WebSocket-Version")
if websocket_version in ("7", "8", "13"):
- return WebSocketProtocol13(
- self, compression_options=self.get_compression_options()
+ params = _WebSocketParams(
+ ping_interval=self.ping_interval,
+ ping_timeout=self.ping_timeout,
+ max_message_size=self.max_message_size,
+ compression_options=self.get_compression_options(),
)
+ return WebSocketProtocol13(self, False, params)
return None
def _detach_stream(self) -> IOStream:
"""Base class for WebSocket protocol versions.
"""
- def __init__(self, handler: "_WebSocketConnection") -> None:
+ def __init__(self, handler: "_WebSocketDelegate") -> None:
self.handler = handler
self.stream = None # type: Optional[IOStream]
self.client_terminated = False
def __init__(
self,
- handler: "_WebSocketConnection",
- mask_outgoing: bool = False,
- compression_options: Dict[str, Any] = None,
+ handler: "_WebSocketDelegate",
+ mask_outgoing: bool,
+ params: _WebSocketParams,
) -> None:
WebSocketProtocol.__init__(self, handler)
self.mask_outgoing = mask_outgoing
+ self.params = params
self._final_frame = False
self._frame_opcode = None
self._masked_frame = None
self._fragmented_message_buffer = None
self._fragmented_message_opcode = None
self._waiting = None # type: object
- self._compression_options = compression_options
+ self._compression_options = params.compression_options
self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor]
self._compressor = None # type: Optional[_PerMessageDeflateCompressor]
self._frame_compressed = None # type: Optional[bool]
self.ping_callback = None # type: Optional[PeriodicCallback]
self.last_ping = 0.0
self.last_pong = 0.0
+ self.close_code = None # type: Optional[int]
+ self.close_reason = None # type: Optional[str]
# Use a property for this to satisfy the abc.
@property
**self._get_compressor_options(side, agreed_parameters, compression_options)
)
self._decompressor = _PerMessageDeflateDecompressor(
- max_message_size=self.handler.max_message_size,
+ max_message_size=self.params.max_message_size,
**self._get_compressor_options(
other_side, agreed_parameters, compression_options
)
yield self._receive_frame()
except StreamClosedError:
self._abort()
- self.handler.on_connection_close()
+ self.handler.on_ws_connection_close(self.close_code, self.close_reason)
def _read_bytes(self, n: int) -> Awaitable[bytes]:
self._wire_bytes_in += n
new_len = payloadlen
if self._fragmented_message_buffer is not None:
new_len += len(self._fragmented_message_buffer)
- if new_len > self.handler.max_message_size:
+ if new_len > self.params.max_message_size:
self.close(1009, "message too big")
self._abort()
return
# Close
self.client_terminated = True
if len(data) >= 2:
- self.handler.close_code = struct.unpack(">H", data[:2])[0]
+ self.close_code = struct.unpack(">H", data[:2])[0]
if len(data) > 2:
- self.handler.close_reason = to_unicode(data[2:])
+ self.close_reason = to_unicode(data[2:])
# Echo the received close code, if any (RFC 6455 section 5.5.1).
- self.close(self.handler.close_code)
+ self.close(self.close_code)
elif opcode == 0x9:
# Ping
try:
@property
def ping_interval(self) -> Optional[float]:
- interval = self.handler.ping_interval
+ interval = self.params.ping_interval
if interval is not None:
return interval
return 0
@property
def ping_timeout(self) -> Optional[float]:
- timeout = self.handler.ping_timeout
+ timeout = self.params.ping_timeout
if timeout is not None:
return timeout
assert self.ping_interval is not None
max_message_size: int = _default_max_message_size,
subprotocols: Optional[List[str]] = [],
) -> None:
- self.compression_options = compression_options
self.connect_future = Future() # type: Future[WebSocketClientConnection]
self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]]
self.key = base64.b64encode(os.urandom(16))
self._on_message_callback = on_message_callback
self.close_code = None # type: Optional[int]
self.close_reason = None # type: Optional[str]
- self.ping_interval = ping_interval
- self.ping_timeout = ping_timeout
- self.max_message_size = max_message_size
+ self.params = _WebSocketParams(
+ ping_interval=ping_interval,
+ ping_timeout=ping_timeout,
+ max_message_size=max_message_size,
+ compression_options=compression_options,
+ )
scheme, sep, rest = request.url.partition(":")
scheme = {"ws": "http", "wss": "https"}[scheme]
)
if subprotocols is not None:
request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols)
- if self.compression_options is not None:
+ if compression_options is not None:
# Always offer to let the server set our max_wbits (and even though
# we don't offer it, we will accept a client_no_context_takeover
# from the server).
self.tcp_client.close()
super(WebSocketClientConnection, self).on_connection_close()
+ def on_ws_connection_close(
+ self, close_code: int = None, close_reason: str = None
+ ) -> None:
+ self.close_code = close_code
+ self.close_reason = close_reason
+ self.on_connection_close()
+
def _on_http_response(self, response: httpclient.HTTPResponse) -> None:
if not self.connect_future.done():
if response.error:
pass
def get_websocket_protocol(self) -> WebSocketProtocol:
- return WebSocketProtocol13(
- self, mask_outgoing=True, compression_options=self.compression_options
- )
+ return WebSocketProtocol13(self, mask_outgoing=True, params=self.params)
@property
def selected_subprotocol(self) -> Optional[str]: