]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
websocket: Improve subprotocol support
authorBen Darnell <ben@bendarnell.com>
Sat, 12 May 2018 18:43:58 +0000 (14:43 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 12 May 2018 18:43:58 +0000 (14:43 -0400)
- Add client-side subprotocol option
- Add selected_subprotocol attribute to client and server objects
- Call select_subprotocol exactly once instead of only on non-empty
- Fix bug in previous select_subprotocol change when multiple
  subprotocols are offered
- Add tests

Updates #2281

docs/releases/v5.1.0.rst
docs/websocket.rst
tornado/test/websocket_test.py
tornado/websocket.py

index feeaec3fdd6a99b7b6a44040137283ab90f5e4c5..be2c244dfe54a0a289986e69ab07be0bff958fe6 100644 (file)
@@ -135,9 +135,12 @@ Deprecation notice
 `tornado.websocket`
 ~~~~~~~~~~~~~~~~~~~
 
-- The `.WebSocketHandler.select_subprotocol` method is now called only
-  when a subprotocol header is provided (previously it would be called
-  with a list containing an empty string).
+- `.websocket_connect` now supports subprotocols.
+- `.WebSocketHandler` and `.WebSocketClientConnection` now have
+  ``selected_subprotocol`` attributes to see the subprotocol in use.
+- The `.WebSocketHandler.select_subprotocol` method is now called with
+  an empty list instead of a list containing an empty string if no
+  subprotocols were requested by the client.
 - The ``data`` argument to `.WebSocketHandler.ping` is now optional.
 - Client-side websocket connections no longer buffer more than one
   message in memory at a time.
index 96255589bed3d92aacb2ea72ddb71032c8488814..76bc05227e115482f88f856af76450bcd7da8fb2 100644 (file)
@@ -16,6 +16,7 @@
    .. automethod:: WebSocketHandler.on_message
    .. automethod:: WebSocketHandler.on_close
    .. automethod:: WebSocketHandler.select_subprotocol
+   .. autoattribute:: WebSocketHandler.selected_subprotocol
    .. automethod:: WebSocketHandler.on_ping
 
    Output
index 4fb918ec946f0bd64edf6f6d9e3b787fb5bf26f9..ecb7123f977ac5e481b0b32f12a733f22c4e5f31 100644 (file)
@@ -143,6 +143,25 @@ class RenderMessageHandler(TestWebSocketHandler):
         self.write_message(self.render_string('message.html', message=message))
 
 
+class SubprotocolHandler(TestWebSocketHandler):
+    def initialize(self, **kwargs):
+        super(SubprotocolHandler, self).initialize(**kwargs)
+        self.select_subprotocol_called = False
+
+    def select_subprotocol(self, subprotocols):
+        if self.select_subprotocol_called:
+            raise Exception("select_subprotocol called twice")
+        self.select_subprotocol_called = True
+        if 'goodproto' in subprotocols:
+            return 'goodproto'
+        return None
+
+    def open(self):
+        if not self.select_subprotocol_called:
+            raise Exception("select_subprotocol not called")
+        self.write_message("subprotocol=%s" % self.selected_subprotocol)
+
+
 class WebSocketBaseTestCase(AsyncHTTPTestCase):
     @gen.coroutine
     def ws_connect(self, path, **kwargs):
@@ -183,6 +202,8 @@ class WebSocketTest(WebSocketBaseTestCase):
              dict(close_future=self.close_future)),
             ('/render', RenderMessageHandler,
              dict(close_future=self.close_future)),
+            ('/subprotocol', SubprotocolHandler,
+             dict(close_future=self.close_future)),
         ], template_loader=DictLoader({
             'message.html': '<b>{{ message }}</b>',
         }))
@@ -443,6 +464,22 @@ class WebSocketTest(WebSocketBaseTestCase):
 
         self.assertEqual(cm.exception.code, 403)
 
+    @gen_test
+    def test_subprotocols(self):
+        ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
+        self.assertEqual(ws.selected_subprotocol, 'goodproto')
+        res = yield ws.read_message()
+        self.assertEqual(res, 'subprotocol=goodproto')
+        yield self.close(ws)
+
+    @gen_test
+    def test_subprotocols_not_offered(self):
+        ws = yield self.ws_connect('/subprotocol')
+        self.assertIs(ws.selected_subprotocol, None)
+        res = yield ws.read_message()
+        self.assertEqual(res, 'subprotocol=None')
+        yield self.close(ws)
+
 
 if sys.version_info >= (3, 5):
     NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
index f01572d9a1e79c86c23c5c20898be5f06cb5555b..e77e6623dc36f4f2c937a51e2559ba00234e3bc9 100644 (file)
@@ -256,18 +256,38 @@ class WebSocketHandler(tornado.web.RequestHandler):
         return self.ws_connection.write_message(message, binary=binary)
 
     def select_subprotocol(self, subprotocols):
-        """Invoked when a new WebSocket requests specific subprotocols.
+        """Override to implement subprotocol negotiation.
 
         ``subprotocols`` is a list of strings identifying the
         subprotocols proposed by the client.  This method may be
         overridden to return one of those strings to select it, or
-        ``None`` to not select a subprotocol.  Failure to select a
-        subprotocol does not automatically abort the connection,
-        although clients may close the connection if none of their
-        proposed subprotocols was selected.
+        ``None`` to not select a subprotocol.
+
+        Failure to select a subprotocol does not automatically abort
+        the connection, although clients may close the connection if
+        none of their proposed subprotocols was selected.
+
+        The list may be empty, in which case this method must return
+        None. This method is always called exactly once even if no
+        subprotocols were proposed so that the handler can be advised
+        of this fact.
+
+        .. versionchanged:: 5.1
+
+           Previously, this method was called with a list containing
+           an empty string instead of an empty list if no subprotocols
+           were proposed by the client.
         """
         return None
 
+    @property
+    def selected_subprotocol(self):
+        """The subprotocol returned by `select_subprotocol`.
+
+        .. versionadded:: 5.1
+        """
+        return self.ws_connection.selected_subprotocol
+
     def get_compression_options(self):
         """Override to return compression options for the connection.
 
@@ -675,12 +695,15 @@ class WebSocketProtocol13(WebSocketProtocol):
             self.request.headers.get("Sec-Websocket-Key"))
 
     def _accept_connection(self):
-        subprotocols = [s.strip() for s in self.request.headers.get_list("Sec-WebSocket-Protocol")]
-        if subprotocols:
-            selected = self.handler.select_subprotocol(subprotocols)
-            if selected:
-                assert selected in subprotocols
-                self.handler.set_header("Sec-WebSocket-Protocol", selected)
+        subprotocol_header = self.request.headers.get("Sec-WebSocket-Protocol")
+        if subprotocol_header:
+            subprotocols = [s.strip() for s in subprotocol_header.split(',')]
+        else:
+            subprotocols = []
+        self.selected_subprotocol = self.handler.select_subprotocol(subprotocols)
+        if self.selected_subprotocol:
+            assert self.selected_subprotocol in subprotocols
+            self.handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol)
 
         extensions = self._parse_extensions_header(self.request.headers)
         for ext in extensions:
@@ -739,6 +762,8 @@ class WebSocketProtocol13(WebSocketProtocol):
             else:
                 raise ValueError("unsupported extension %r", ext)
 
+        self.selected_subprotocol = headers.get('Sec-WebSocket-Protocol', None)
+
     def _get_compressor_options(self, side, agreed_parameters, compression_options=None):
         """Converts a websocket agreed_parameters set to keyword arguments
         for our compressor objects.
@@ -1056,7 +1081,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     """
     def __init__(self, request, on_message_callback=None,
                  compression_options=None, ping_interval=None, ping_timeout=None,
-                 max_message_size=None):
+                 max_message_size=None, subprotocols=[]):
         self.compression_options = compression_options
         self.connect_future = Future()
         self.protocol = None
@@ -1077,6 +1102,8 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             'Sec-WebSocket-Key': self.key,
             'Sec-WebSocket-Version': '13',
         })
+        if subprotocols is not None:
+            request.headers['Sec-WebSocket-Protocol'] = ','.join(subprotocols)
         if self.compression_options is not None:
             # Always offer to let the server set our max_wbits (and even though
             # we don't offer it, we will accept a client_no_context_takeover
@@ -1211,11 +1238,19 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
         return WebSocketProtocol13(self, mask_outgoing=True,
                                    compression_options=self.compression_options)
 
+    @property
+    def selected_subprotocol(self):
+        """The subprotocol selected by the server.
+
+        .. versionadded:: 5.1
+        """
+        return self.protocol.selected_subprotocol
+
 
 def websocket_connect(url, callback=None, connect_timeout=None,
                       on_message_callback=None, compression_options=None,
                       ping_interval=None, ping_timeout=None,
-                      max_message_size=None):
+                      max_message_size=None, subprotocols=None):
     """Client-side websocket support.
 
     Takes a url and returns a Future whose result is a
@@ -1238,6 +1273,11 @@ def websocket_connect(url, callback=None, connect_timeout=None,
     ``websocket_connect``. In both styles, a message of ``None``
     indicates that the connection has been closed.
 
+    ``subprotocols`` may be a list of strings specifying proposed
+    subprotocols. The selected protocol may be found on the
+    ``selected_subprotocol`` attribute of the connection object
+    when the connection is complete.
+
     .. versionchanged:: 3.2
        Also accepts ``HTTPRequest`` objects in place of urls.
 
@@ -1250,6 +1290,9 @@ def websocket_connect(url, callback=None, connect_timeout=None,
 
     .. versionchanged:: 5.0
        The ``io_loop`` argument (deprecated since version 4.1) has been removed.
+
+    .. versionchanged:: 5.1
+       Added the ``subprotocols`` argument.
     """
     if isinstance(url, httpclient.HTTPRequest):
         assert connect_timeout is None
@@ -1266,7 +1309,8 @@ def websocket_connect(url, callback=None, connect_timeout=None,
                                      compression_options=compression_options,
                                      ping_interval=ping_interval,
                                      ping_timeout=ping_timeout,
-                                     max_message_size=max_message_size)
+                                     max_message_size=max_message_size,
+                                     subprotocols=subprotocols)
     if callback is not None:
         IOLoop.current().add_future(conn.connect_future, callback)
     return conn.connect_future