]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Narrow the websocket handler interface
authorBen Darnell <ben@bendarnell.com>
Sun, 21 Oct 2018 19:13:56 +0000 (15:13 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Dec 2018 03:17:57 +0000 (22:17 -0500)
Pass close arguments via a new method instead of setting attributes.
Extract a "params" struct.

tornado/websocket.py

index efede9b51ddcd6270d46761d4878232986341f26..e72b2b456f2beddb63db9dc884c23c087dc6784a 100644 (file)
@@ -72,39 +72,13 @@ if TYPE_CHECKING:
         def decompress(self, data: bytes, max_length: int) -> bytes:
             pass
 
-    class _WebSocketConnection(Protocol):
+    class _WebSocketDelegate(Protocol):
         # The common base interface implemented by WebSocketHandler on
         # the server side and WebSocketClientConnection on the client
         # side.
-        @property
-        def ping_interval(self) -> Optional[float]:
-            pass
-
-        @property
-        def ping_timeout(self) -> Optional[float]:
-            pass
-
-        @property
-        def max_message_size(self) -> int:
-            pass
-
-        @property
-        def close_code(self) -> Optional[int]:
-            pass
-
-        @close_code.setter
-        def close_code(self, value: Optional[int]) -> None:
-            pass
-
-        @property
-        def close_reason(self) -> Optional[str]:
-            pass
-
-        @close_reason.setter
-        def close_reason(self, value: Optional[str]) -> None:
-            pass
-
-        def on_connection_close(self) -> None:
+        def on_ws_connection_close(
+            self, close_code: int = None, close_reason: str = None
+        ) -> None:
             pass
 
         def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]:
@@ -145,6 +119,20 @@ class _DecompressTooLargeError(Exception):
     pass
 
 
+class _WebSocketParams(object):
+    def __init__(
+        self,
+        ping_interval: float = None,
+        ping_timeout: float = None,
+        max_message_size: int = _default_max_message_size,
+        compression_options: Dict[str, Any] = None,
+    ) -> None:
+        self.ping_interval = ping_interval
+        self.ping_timeout = ping_timeout
+        self.max_message_size = max_message_size
+        self.compression_options = compression_options
+
+
 class WebSocketHandler(tornado.web.RequestHandler):
     """Subclass this class to create a basic WebSocket handler.
 
@@ -583,6 +571,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.on_close()
             self._break_cycles()
 
+    def on_ws_connection_close(
+        self, close_code: int = None, close_reason: str = None
+    ) -> None:
+        self.close_code = close_code
+        self.close_reason = close_reason
+        self.on_connection_close()
+
     def _break_cycles(self) -> None:
         # WebSocketHandlers call finish() early, but we don't want to
         # break up reference cycles (which makes it impossible to call
@@ -605,9 +600,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
     def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]:
         websocket_version = self.request.headers.get("Sec-WebSocket-Version")
         if websocket_version in ("7", "8", "13"):
-            return WebSocketProtocol13(
-                self, compression_options=self.get_compression_options()
+            params = _WebSocketParams(
+                ping_interval=self.ping_interval,
+                ping_timeout=self.ping_timeout,
+                max_message_size=self.max_message_size,
+                compression_options=self.get_compression_options(),
             )
+            return WebSocketProtocol13(self, False, params)
         return None
 
     def _detach_stream(self) -> IOStream:
@@ -633,7 +632,7 @@ class WebSocketProtocol(abc.ABC):
     """Base class for WebSocket protocol versions.
     """
 
-    def __init__(self, handler: "_WebSocketConnection") -> None:
+    def __init__(self, handler: "_WebSocketDelegate") -> None:
         self.handler = handler
         self.stream = None  # type: Optional[IOStream]
         self.client_terminated = False
@@ -822,12 +821,13 @@ class WebSocketProtocol13(WebSocketProtocol):
 
     def __init__(
         self,
-        handler: "_WebSocketConnection",
-        mask_outgoing: bool = False,
-        compression_options: Dict[str, Any] = None,
+        handler: "_WebSocketDelegate",
+        mask_outgoing: bool,
+        params: _WebSocketParams,
     ) -> None:
         WebSocketProtocol.__init__(self, handler)
         self.mask_outgoing = mask_outgoing
+        self.params = params
         self._final_frame = False
         self._frame_opcode = None
         self._masked_frame = None
@@ -836,7 +836,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         self._fragmented_message_buffer = None
         self._fragmented_message_opcode = None
         self._waiting = None  # type: object
-        self._compression_options = compression_options
+        self._compression_options = params.compression_options
         self._decompressor = None  # type: Optional[_PerMessageDeflateDecompressor]
         self._compressor = None  # type: Optional[_PerMessageDeflateCompressor]
         self._frame_compressed = None  # type: Optional[bool]
@@ -852,6 +852,8 @@ class WebSocketProtocol13(WebSocketProtocol):
         self.ping_callback = None  # type: Optional[PeriodicCallback]
         self.last_ping = 0.0
         self.last_pong = 0.0
+        self.close_code = None  # type: Optional[int]
+        self.close_reason = None  # type: Optional[str]
 
     # Use a property for this to satisfy the abc.
     @property
@@ -1026,7 +1028,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             **self._get_compressor_options(side, agreed_parameters, compression_options)
         )
         self._decompressor = _PerMessageDeflateDecompressor(
-            max_message_size=self.handler.max_message_size,
+            max_message_size=self.params.max_message_size,
             **self._get_compressor_options(
                 other_side, agreed_parameters, compression_options
             )
@@ -1111,7 +1113,7 @@ class WebSocketProtocol13(WebSocketProtocol):
                 yield self._receive_frame()
         except StreamClosedError:
             self._abort()
-        self.handler.on_connection_close()
+        self.handler.on_ws_connection_close(self.close_code, self.close_reason)
 
     def _read_bytes(self, n: int) -> Awaitable[bytes]:
         self._wire_bytes_in += n
@@ -1155,7 +1157,7 @@ class WebSocketProtocol13(WebSocketProtocol):
         new_len = payloadlen
         if self._fragmented_message_buffer is not None:
             new_len += len(self._fragmented_message_buffer)
-        if new_len > self.handler.max_message_size:
+        if new_len > self.params.max_message_size:
             self.close(1009, "message too big")
             self._abort()
             return
@@ -1232,11 +1234,11 @@ class WebSocketProtocol13(WebSocketProtocol):
             # Close
             self.client_terminated = True
             if len(data) >= 2:
-                self.handler.close_code = struct.unpack(">H", data[:2])[0]
+                self.close_code = struct.unpack(">H", data[:2])[0]
             if len(data) > 2:
-                self.handler.close_reason = to_unicode(data[2:])
+                self.close_reason = to_unicode(data[2:])
             # Echo the received close code, if any (RFC 6455 section 5.5.1).
-            self.close(self.handler.close_code)
+            self.close(self.close_code)
         elif opcode == 0x9:
             # Ping
             try:
@@ -1292,14 +1294,14 @@ class WebSocketProtocol13(WebSocketProtocol):
 
     @property
     def ping_interval(self) -> Optional[float]:
-        interval = self.handler.ping_interval
+        interval = self.params.ping_interval
         if interval is not None:
             return interval
         return 0
 
     @property
     def ping_timeout(self) -> Optional[float]:
-        timeout = self.handler.ping_timeout
+        timeout = self.params.ping_timeout
         if timeout is not None:
             return timeout
         assert self.ping_interval is not None
@@ -1362,16 +1364,18 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         max_message_size: int = _default_max_message_size,
         subprotocols: Optional[List[str]] = [],
     ) -> None:
-        self.compression_options = compression_options
         self.connect_future = Future()  # type: Future[WebSocketClientConnection]
         self.read_queue = Queue(1)  # type: Queue[Union[None, str, bytes]]
         self.key = base64.b64encode(os.urandom(16))
         self._on_message_callback = on_message_callback
         self.close_code = None  # type: Optional[int]
         self.close_reason = None  # type: Optional[str]
-        self.ping_interval = ping_interval
-        self.ping_timeout = ping_timeout
-        self.max_message_size = max_message_size
+        self.params = _WebSocketParams(
+            ping_interval=ping_interval,
+            ping_timeout=ping_timeout,
+            max_message_size=max_message_size,
+            compression_options=compression_options,
+        )
 
         scheme, sep, rest = request.url.partition(":")
         scheme = {"ws": "http", "wss": "https"}[scheme]
@@ -1386,7 +1390,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         )
         if subprotocols is not None:
             request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols)
-        if self.compression_options is not None:
+        if compression_options is not None:
             # Always offer to let the server set our max_wbits (and even though
             # we don't offer it, we will accept a client_no_context_takeover
             # from the server).
@@ -1431,6 +1435,13 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         self.tcp_client.close()
         super(WebSocketClientConnection, self).on_connection_close()
 
+    def on_ws_connection_close(
+        self, close_code: int = None, close_reason: str = None
+    ) -> None:
+        self.close_code = close_code
+        self.close_reason = close_reason
+        self.on_connection_close()
+
     def _on_http_response(self, response: httpclient.HTTPResponse) -> None:
         if not self.connect_future.done():
             if response.error:
@@ -1541,9 +1552,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         pass
 
     def get_websocket_protocol(self) -> WebSocketProtocol:
-        return WebSocketProtocol13(
-            self, mask_outgoing=True, compression_options=self.compression_options
-        )
+        return WebSocketProtocol13(self, mask_outgoing=True, params=self.params)
 
     @property
     def selected_subprotocol(self) -> Optional[str]: