From: Ben Darnell Date: Sat, 12 May 2018 18:43:58 +0000 (-0400) Subject: websocket: Improve subprotocol support X-Git-Tag: v5.1.0b1~18^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fac04e0acfb4a79a819c4905380b95f545e3719b;p=thirdparty%2Ftornado.git websocket: Improve subprotocol support - Add client-side subprotocol option - Add selected_subprotocol attribute to client and server objects - Call select_subprotocol exactly once instead of only on non-empty - Fix bug in previous select_subprotocol change when multiple subprotocols are offered - Add tests Updates #2281 --- diff --git a/docs/releases/v5.1.0.rst b/docs/releases/v5.1.0.rst index feeaec3fd..be2c244df 100644 --- a/docs/releases/v5.1.0.rst +++ b/docs/releases/v5.1.0.rst @@ -135,9 +135,12 @@ Deprecation notice `tornado.websocket` ~~~~~~~~~~~~~~~~~~~ -- The `.WebSocketHandler.select_subprotocol` method is now called only - when a subprotocol header is provided (previously it would be called - with a list containing an empty string). +- `.websocket_connect` now supports subprotocols. +- `.WebSocketHandler` and `.WebSocketClientConnection` now have + ``selected_subprotocol`` attributes to see the subprotocol in use. +- The `.WebSocketHandler.select_subprotocol` method is now called with + an empty list instead of a list containing an empty string if no + subprotocols were requested by the client. - The ``data`` argument to `.WebSocketHandler.ping` is now optional. - Client-side websocket connections no longer buffer more than one message in memory at a time. diff --git a/docs/websocket.rst b/docs/websocket.rst index 96255589b..76bc05227 100644 --- a/docs/websocket.rst +++ b/docs/websocket.rst @@ -16,6 +16,7 @@ .. automethod:: WebSocketHandler.on_message .. automethod:: WebSocketHandler.on_close .. automethod:: WebSocketHandler.select_subprotocol + .. autoattribute:: WebSocketHandler.selected_subprotocol .. automethod:: WebSocketHandler.on_ping Output diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 4fb918ec9..ecb7123f9 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -143,6 +143,25 @@ class RenderMessageHandler(TestWebSocketHandler): self.write_message(self.render_string('message.html', message=message)) +class SubprotocolHandler(TestWebSocketHandler): + def initialize(self, **kwargs): + super(SubprotocolHandler, self).initialize(**kwargs) + self.select_subprotocol_called = False + + def select_subprotocol(self, subprotocols): + if self.select_subprotocol_called: + raise Exception("select_subprotocol called twice") + self.select_subprotocol_called = True + if 'goodproto' in subprotocols: + return 'goodproto' + return None + + def open(self): + if not self.select_subprotocol_called: + raise Exception("select_subprotocol not called") + self.write_message("subprotocol=%s" % self.selected_subprotocol) + + class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): @@ -183,6 +202,8 @@ class WebSocketTest(WebSocketBaseTestCase): dict(close_future=self.close_future)), ('/render', RenderMessageHandler, dict(close_future=self.close_future)), + ('/subprotocol', SubprotocolHandler, + dict(close_future=self.close_future)), ], template_loader=DictLoader({ 'message.html': '{{ message }}', })) @@ -443,6 +464,22 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(cm.exception.code, 403) + @gen_test + def test_subprotocols(self): + ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto']) + self.assertEqual(ws.selected_subprotocol, 'goodproto') + res = yield ws.read_message() + self.assertEqual(res, 'subprotocol=goodproto') + yield self.close(ws) + + @gen_test + def test_subprotocols_not_offered(self): + ws = yield self.ws_connect('/subprotocol') + self.assertIs(ws.selected_subprotocol, None) + res = yield ws.read_message() + self.assertEqual(res, 'subprotocol=None') + yield self.close(ws) + if sys.version_info >= (3, 5): NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """ diff --git a/tornado/websocket.py b/tornado/websocket.py index f01572d9a..e77e6623d 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -256,18 +256,38 @@ class WebSocketHandler(tornado.web.RequestHandler): return self.ws_connection.write_message(message, binary=binary) def select_subprotocol(self, subprotocols): - """Invoked when a new WebSocket requests specific subprotocols. + """Override to implement subprotocol negotiation. ``subprotocols`` is a list of strings identifying the subprotocols proposed by the client. This method may be overridden to return one of those strings to select it, or - ``None`` to not select a subprotocol. Failure to select a - subprotocol does not automatically abort the connection, - although clients may close the connection if none of their - proposed subprotocols was selected. + ``None`` to not select a subprotocol. + + Failure to select a subprotocol does not automatically abort + the connection, although clients may close the connection if + none of their proposed subprotocols was selected. + + The list may be empty, in which case this method must return + None. This method is always called exactly once even if no + subprotocols were proposed so that the handler can be advised + of this fact. + + .. versionchanged:: 5.1 + + Previously, this method was called with a list containing + an empty string instead of an empty list if no subprotocols + were proposed by the client. """ return None + @property + def selected_subprotocol(self): + """The subprotocol returned by `select_subprotocol`. + + .. versionadded:: 5.1 + """ + return self.ws_connection.selected_subprotocol + def get_compression_options(self): """Override to return compression options for the connection. @@ -675,12 +695,15 @@ class WebSocketProtocol13(WebSocketProtocol): self.request.headers.get("Sec-Websocket-Key")) def _accept_connection(self): - subprotocols = [s.strip() for s in self.request.headers.get_list("Sec-WebSocket-Protocol")] - if subprotocols: - selected = self.handler.select_subprotocol(subprotocols) - if selected: - assert selected in subprotocols - self.handler.set_header("Sec-WebSocket-Protocol", selected) + subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol") + if subprotocol_header: + subprotocols = [s.strip() for s in subprotocol_header.split(',')] + else: + subprotocols = [] + self.selected_subprotocol = self.handler.select_subprotocol(subprotocols) + if self.selected_subprotocol: + assert self.selected_subprotocol in subprotocols + self.handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) extensions = self._parse_extensions_header(self.request.headers) for ext in extensions: @@ -739,6 +762,8 @@ class WebSocketProtocol13(WebSocketProtocol): else: raise ValueError("unsupported extension %r", ext) + self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None) + def _get_compressor_options(self, side, agreed_parameters, compression_options=None): """Converts a websocket agreed_parameters set to keyword arguments for our compressor objects. @@ -1056,7 +1081,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): """ def __init__(self, request, on_message_callback=None, compression_options=None, ping_interval=None, ping_timeout=None, - max_message_size=None): + max_message_size=None, subprotocols=[]): self.compression_options = compression_options self.connect_future = Future() self.protocol = None @@ -1077,6 +1102,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): 'Sec-WebSocket-Key': self.key, 'Sec-WebSocket-Version': '13', }) + if subprotocols is not None: + request.headers['Sec-WebSocket-Protocol'] = ','.join(subprotocols) if self.compression_options is not None: # Always offer to let the server set our max_wbits (and even though # we don't offer it, we will accept a client_no_context_takeover @@ -1211,11 +1238,19 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): return WebSocketProtocol13(self, mask_outgoing=True, compression_options=self.compression_options) + @property + def selected_subprotocol(self): + """The subprotocol selected by the server. + + .. versionadded:: 5.1 + """ + return self.protocol.selected_subprotocol + def websocket_connect(url, callback=None, connect_timeout=None, on_message_callback=None, compression_options=None, ping_interval=None, ping_timeout=None, - max_message_size=None): + max_message_size=None, subprotocols=None): """Client-side websocket support. Takes a url and returns a Future whose result is a @@ -1238,6 +1273,11 @@ def websocket_connect(url, callback=None, connect_timeout=None, ``websocket_connect``. In both styles, a message of ``None`` indicates that the connection has been closed. + ``subprotocols`` may be a list of strings specifying proposed + subprotocols. The selected protocol may be found on the + ``selected_subprotocol`` attribute of the connection object + when the connection is complete. + .. versionchanged:: 3.2 Also accepts ``HTTPRequest`` objects in place of urls. @@ -1250,6 +1290,9 @@ def websocket_connect(url, callback=None, connect_timeout=None, .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. + + .. versionchanged:: 5.1 + Added the ``subprotocols`` argument. """ if isinstance(url, httpclient.HTTPRequest): assert connect_timeout is None @@ -1266,7 +1309,8 @@ def websocket_connect(url, callback=None, connect_timeout=None, compression_options=compression_options, ping_interval=ping_interval, ping_timeout=ping_timeout, - max_message_size=max_message_size) + max_message_size=max_message_size, + subprotocols=subprotocols) if callback is not None: IOLoop.current().add_future(conn.connect_future, callback) return conn.connect_future