From: Ben Darnell Date: Sun, 21 Oct 2018 19:13:56 +0000 (-0400) Subject: websocket: Narrow the websocket handler interface X-Git-Tag: v6.0.0b1~12^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e719d82050e493080eca5f8e21c4bda2bcc9c25e;p=thirdparty%2Ftornado.git websocket: Narrow the websocket handler interface Pass close arguments via a new method instead of setting attributes. Extract a "params" struct. --- diff --git a/tornado/websocket.py b/tornado/websocket.py index efede9b51..e72b2b456 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -72,39 +72,13 @@ if TYPE_CHECKING: def decompress(self, data: bytes, max_length: int) -> bytes: pass - class _WebSocketConnection(Protocol): + class _WebSocketDelegate(Protocol): # The common base interface implemented by WebSocketHandler on # the server side and WebSocketClientConnection on the client # side. - @property - def ping_interval(self) -> Optional[float]: - pass - - @property - def ping_timeout(self) -> Optional[float]: - pass - - @property - def max_message_size(self) -> int: - pass - - @property - def close_code(self) -> Optional[int]: - pass - - @close_code.setter - def close_code(self, value: Optional[int]) -> None: - pass - - @property - def close_reason(self) -> Optional[str]: - pass - - @close_reason.setter - def close_reason(self, value: Optional[str]) -> None: - pass - - def on_connection_close(self) -> None: + def on_ws_connection_close( + self, close_code: int = None, close_reason: str = None + ) -> None: pass def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: @@ -145,6 +119,20 @@ class _DecompressTooLargeError(Exception): pass +class _WebSocketParams(object): + def __init__( + self, + ping_interval: float = None, + ping_timeout: float = None, + max_message_size: int = _default_max_message_size, + compression_options: Dict[str, Any] = None, + ) -> None: + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.max_message_size = max_message_size + self.compression_options = compression_options + + class WebSocketHandler(tornado.web.RequestHandler): """Subclass this class to create a basic WebSocket handler. @@ -583,6 +571,13 @@ class WebSocketHandler(tornado.web.RequestHandler): self.on_close() self._break_cycles() + def on_ws_connection_close( + self, close_code: int = None, close_reason: str = None + ) -> None: + self.close_code = close_code + self.close_reason = close_reason + self.on_connection_close() + def _break_cycles(self) -> None: # WebSocketHandlers call finish() early, but we don't want to # break up reference cycles (which makes it impossible to call @@ -605,9 +600,13 @@ class WebSocketHandler(tornado.web.RequestHandler): def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: websocket_version = self.request.headers.get("Sec-WebSocket-Version") if websocket_version in ("7", "8", "13"): - return WebSocketProtocol13( - self, compression_options=self.get_compression_options() + params = _WebSocketParams( + ping_interval=self.ping_interval, + ping_timeout=self.ping_timeout, + max_message_size=self.max_message_size, + compression_options=self.get_compression_options(), ) + return WebSocketProtocol13(self, False, params) return None def _detach_stream(self) -> IOStream: @@ -633,7 +632,7 @@ class WebSocketProtocol(abc.ABC): """Base class for WebSocket protocol versions. """ - def __init__(self, handler: "_WebSocketConnection") -> None: + def __init__(self, handler: "_WebSocketDelegate") -> None: self.handler = handler self.stream = None # type: Optional[IOStream] self.client_terminated = False @@ -822,12 +821,13 @@ class WebSocketProtocol13(WebSocketProtocol): def __init__( self, - handler: "_WebSocketConnection", - mask_outgoing: bool = False, - compression_options: Dict[str, Any] = None, + handler: "_WebSocketDelegate", + mask_outgoing: bool, + params: _WebSocketParams, ) -> None: WebSocketProtocol.__init__(self, handler) self.mask_outgoing = mask_outgoing + self.params = params self._final_frame = False self._frame_opcode = None self._masked_frame = None @@ -836,7 +836,7 @@ class WebSocketProtocol13(WebSocketProtocol): self._fragmented_message_buffer = None self._fragmented_message_opcode = None self._waiting = None # type: object - self._compression_options = compression_options + self._compression_options = params.compression_options self._decompressor = None # type: Optional[_PerMessageDeflateDecompressor] self._compressor = None # type: Optional[_PerMessageDeflateCompressor] self._frame_compressed = None # type: Optional[bool] @@ -852,6 +852,8 @@ class WebSocketProtocol13(WebSocketProtocol): self.ping_callback = None # type: Optional[PeriodicCallback] self.last_ping = 0.0 self.last_pong = 0.0 + self.close_code = None # type: Optional[int] + self.close_reason = None # type: Optional[str] # Use a property for this to satisfy the abc. @property @@ -1026,7 +1028,7 @@ class WebSocketProtocol13(WebSocketProtocol): **self._get_compressor_options(side, agreed_parameters, compression_options) ) self._decompressor = _PerMessageDeflateDecompressor( - max_message_size=self.handler.max_message_size, + max_message_size=self.params.max_message_size, **self._get_compressor_options( other_side, agreed_parameters, compression_options ) @@ -1111,7 +1113,7 @@ class WebSocketProtocol13(WebSocketProtocol): yield self._receive_frame() except StreamClosedError: self._abort() - self.handler.on_connection_close() + self.handler.on_ws_connection_close(self.close_code, self.close_reason) def _read_bytes(self, n: int) -> Awaitable[bytes]: self._wire_bytes_in += n @@ -1155,7 +1157,7 @@ class WebSocketProtocol13(WebSocketProtocol): new_len = payloadlen if self._fragmented_message_buffer is not None: new_len += len(self._fragmented_message_buffer) - if new_len > self.handler.max_message_size: + if new_len > self.params.max_message_size: self.close(1009, "message too big") self._abort() return @@ -1232,11 +1234,11 @@ class WebSocketProtocol13(WebSocketProtocol): # Close self.client_terminated = True if len(data) >= 2: - self.handler.close_code = struct.unpack(">H", data[:2])[0] + self.close_code = struct.unpack(">H", data[:2])[0] if len(data) > 2: - self.handler.close_reason = to_unicode(data[2:]) + self.close_reason = to_unicode(data[2:]) # Echo the received close code, if any (RFC 6455 section 5.5.1). - self.close(self.handler.close_code) + self.close(self.close_code) elif opcode == 0x9: # Ping try: @@ -1292,14 +1294,14 @@ class WebSocketProtocol13(WebSocketProtocol): @property def ping_interval(self) -> Optional[float]: - interval = self.handler.ping_interval + interval = self.params.ping_interval if interval is not None: return interval return 0 @property def ping_timeout(self) -> Optional[float]: - timeout = self.handler.ping_timeout + timeout = self.params.ping_timeout if timeout is not None: return timeout assert self.ping_interval is not None @@ -1362,16 +1364,18 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): max_message_size: int = _default_max_message_size, subprotocols: Optional[List[str]] = [], ) -> None: - self.compression_options = compression_options self.connect_future = Future() # type: Future[WebSocketClientConnection] self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback self.close_code = None # type: Optional[int] self.close_reason = None # type: Optional[str] - self.ping_interval = ping_interval - self.ping_timeout = ping_timeout - self.max_message_size = max_message_size + self.params = _WebSocketParams( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + max_message_size=max_message_size, + compression_options=compression_options, + ) scheme, sep, rest = request.url.partition(":") scheme = {"ws": "http", "wss": "https"}[scheme] @@ -1386,7 +1390,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): ) if subprotocols is not None: request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) - if self.compression_options is not None: + if compression_options is not None: # Always offer to let the server set our max_wbits (and even though # we don't offer it, we will accept a client_no_context_takeover # from the server). @@ -1431,6 +1435,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): self.tcp_client.close() super(WebSocketClientConnection, self).on_connection_close() + def on_ws_connection_close( + self, close_code: int = None, close_reason: str = None + ) -> None: + self.close_code = close_code + self.close_reason = close_reason + self.on_connection_close() + def _on_http_response(self, response: httpclient.HTTPResponse) -> None: if not self.connect_future.done(): if response.error: @@ -1541,9 +1552,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): pass def get_websocket_protocol(self) -> WebSocketProtocol: - return WebSocketProtocol13( - self, mask_outgoing=True, compression_options=self.compression_options - ) + return WebSocketProtocol13(self, mask_outgoing=True, params=self.params) @property def selected_subprotocol(self) -> Optional[str]: