]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Allow custom websocket upgrade response headers via upgrade_response_headers
authorJames Maier <James.Maier@viasat.com>
Tue, 20 Dec 2016 20:04:12 +0000 (15:04 -0500)
committerJames Maier <James.Maier@viasat.com>
Tue, 20 Dec 2016 20:04:12 +0000 (15:04 -0500)
tornado/test/websocket_test.py
tornado/websocket.py

index ed5c7070fc43fcf8fd9fb2b0fc8b575d35c5fd7b..60e5fd65fba026e75ab084293f886d7fa524bf34 100644 (file)
@@ -1,5 +1,7 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
+import random
+import string
 import traceback
 
 from tornado.concurrent import Future
@@ -67,6 +69,15 @@ class HeaderHandler(TestWebSocketHandler):
         self.write_message(self.request.headers.get('X-Test', ''))
 
 
+class HeaderEchoHandler(TestWebSocketHandler):
+    def upgrade_response_headers(self):
+        return ''.join(
+            "{}: {}\r\n".format(k, v)
+            for k, v in self.request.headers.get_all()
+            if k.lower().startswith('x-test')
+        )
+
+
 class NonWebSocketHandler(RequestHandler):
     def get(self):
         self.write('ok')
@@ -118,6 +129,8 @@ class WebSocketTest(WebSocketBaseTestCase):
             ('/echo', EchoHandler, dict(close_future=self.close_future)),
             ('/non_ws', NonWebSocketHandler),
             ('/header', HeaderHandler, dict(close_future=self.close_future)),
+            ('/header_echo', HeaderEchoHandler,
+             dict(close_future=self.close_future)),
             ('/close_reason', CloseReasonHandler,
              dict(close_future=self.close_future)),
             ('/error_in_on_message', ErrorInOnMessageHandler,
@@ -221,6 +234,23 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(response, 'hello')
         yield self.close(ws)
 
+    @gen_test
+    def test_websocket_header_echo(self):
+        # Ensure that headers can be returned in the response.
+        # Specifically, that arbitrary headers passed through websocket_connect
+        # can be returned.
+        random_str = ''.join(random.choice(string.ascii_lowercase)
+                             for i in range(10))
+        ws = yield websocket_connect(
+            HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
+                        headers={'X-Test-Hello': 'hello',
+                                 'X-Test-Goodbye': 'goodbye',
+                                 'X-Test-Random': random_str}))
+        self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
+        self.assertEqual(ws.headers.get('X-Test-Goodbye'), 'goodbye')
+        self.assertEqual(ws.headers.get('X-Test-Random'), random_str)
+        yield self.close(ws)
+
     @gen_test
     def test_server_close_reason(self):
         ws = yield self.ws_connect('/close_reason')
index 6e1220b3ecbcf6c13a71b8e07c020aff6ea15e49..8f5ed7eef99062b370f915e764748f38ba43af13 100644 (file)
@@ -65,6 +65,9 @@ class WebSocketHandler(tornado.web.RequestHandler):
     override `open` and `on_close` to handle opened and closed
     connections.
 
+    Custom upgrade response headers can be sent by overriding
+    `upgrade_response_headers`.
+
     See http://dev.w3.org/html5/websockets/ for details on the
     JavaScript interface.  The protocol is specified at
     http://tools.ietf.org/html/rfc6455.
@@ -238,6 +241,12 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         return None
 
+    def upgrade_response_headers(self):
+        """Override to return additional headers to send in the websocket
+           upgrade response.
+        """
+        return ""
+
     def open(self, *args, **kwargs):
         """Invoked when a new WebSocket is opened.
 
@@ -598,9 +607,10 @@ class WebSocketProtocol13(WebSocketProtocol):
             "Upgrade: websocket\r\n"
             "Connection: Upgrade\r\n"
             "Sec-WebSocket-Accept: %s\r\n"
-            "%s%s"
+            "%s%s%s"
             "\r\n" % (self._challenge_response(),
-                      subprotocol_header, extension_header)))
+                      subprotocol_header, extension_header,
+                      self.handler.upgrade_response_headers())))
 
         self._run_callback(self.handler.open, *self.handler.open_args,
                            **self.handler.open_kwargs)