]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Do WebSocket upgrade using RequestHandler, allowing set_default_headers
authorJames Maier <James.Maier@viasat.com>
Fri, 23 Dec 2016 16:01:42 +0000 (11:01 -0500)
committerJames Maier <James.Maier@viasat.com>
Fri, 30 Dec 2016 05:30:39 +0000 (00:30 -0500)
docs/websocket.rst
tornado/test/websocket_test.py
tornado/websocket.py

index 836bb3df48d62d71ce518bb3786b20504b13d01d..5d36d2349fb18164993d380fac5cbcb93d127767 100644 (file)
@@ -22,7 +22,7 @@
 
    .. automethod:: WebSocketHandler.write_message
    .. automethod:: WebSocketHandler.close
-   .. automethod:: WebSocketHandler.upgrade_response_headers
+   .. automethod:: WebSocketHandler.set_default_headers
 
    Configuration
    -------------
index 60e5fd65fba026e75ab084293f886d7fa524bf34..acd61fad14fca6cc5f9a4656646eaf43df6d152b 100644 (file)
@@ -70,12 +70,11 @@ class HeaderHandler(TestWebSocketHandler):
 
 
 class HeaderEchoHandler(TestWebSocketHandler):
-    def upgrade_response_headers(self):
-        return ''.join(
-            "{}: {}\r\n".format(k, v)
-            for k, v in self.request.headers.get_all()
-            if k.lower().startswith('x-test')
-        )
+    def set_default_headers(self):
+        for k, v in self.request.headers.get_all():
+            if k.lower().startswith('x-test'):
+                self.set_header(k, v)
+        self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
 
 
 class NonWebSocketHandler(RequestHandler):
@@ -249,6 +248,7 @@ class WebSocketTest(WebSocketBaseTestCase):
         self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
         self.assertEqual(ws.headers.get('X-Test-Goodbye'), 'goodbye')
         self.assertEqual(ws.headers.get('X-Test-Random'), random_str)
+        self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
         yield self.close(ws)
 
     @gen_test
index 8f5ed7eef99062b370f915e764748f38ba43af13..e4a65a9287eec947a3202b92bd00387e76169715 100644 (file)
@@ -66,7 +66,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
     connections.
 
     Custom upgrade response headers can be sent by overriding
-    `upgrade_response_headers`.
+    `set_default_headers`.
 
     See http://dev.w3.org/html5/websockets/ for details on the
     JavaScript interface.  The protocol is specified at
@@ -127,12 +127,12 @@ class WebSocketHandler(tornado.web.RequestHandler):
     to accept it before the websocket connection will succeed.
     """
     def __init__(self, application, request, **kwargs):
-        super(WebSocketHandler, self).__init__(application, request, **kwargs)
         self.ws_connection = None
         self.close_code = None
         self.close_reason = None
         self.stream = None
         self._on_close_called = False
+        super(WebSocketHandler, self).__init__(application, request, **kwargs)
 
     @tornado.web.asynchronous
     def get(self, *args, **kwargs):
@@ -179,18 +179,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
             gen_log.debug(log_msg)
             return
 
-        self.stream = self.request.connection.detach()
-        self.stream.set_close_callback(self.on_connection_close)
-
         self.ws_connection = self.get_websocket_protocol()
         if self.ws_connection:
             self.ws_connection.accept_connection()
         else:
-            if not self.stream.closed():
-                self.stream.write(tornado.escape.utf8(
-                    "HTTP/1.1 426 Upgrade Required\r\n"
-                    "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n"))
-                self.stream.close()
+            self.set_status(426)
+            self.set_header("Sec-WebSocket-Version", "7, 8, 13")
+            self.finish()
 
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket.
@@ -241,11 +236,16 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         return None
 
-    def upgrade_response_headers(self):
-        """Override to return additional headers to send in the websocket
-           upgrade response.
+    def set_default_headers(self):
+        """Override this to set HTTP headers at the beginning of the
+           WebSocket upgrade request.
+
+        For example, this is the place to set a custom ``Server`` header.
+        Note that setting such headers in the normal flow of request
+        processing may not do what you want, since headers may be reset
+        during error handling.
         """
-        return ""
+        pass
 
     def open(self, *args, **kwargs):
         """Invoked when a new WebSocket is opened.
@@ -402,6 +402,10 @@ class WebSocketHandler(tornado.web.RequestHandler):
             return WebSocketProtocol13(
                 self, compression_options=self.get_compression_options())
 
+    def _attach_stream(self):
+        self.stream = self.request.connection.detach()
+        self.stream.set_close_callback(self.on_connection_close)
+
 
 def _wrap_method(method):
     def _disallow_for_websocket(self, *args, **kwargs):
@@ -578,8 +582,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             selected = self.handler.select_subprotocol(subprotocols)
             if selected:
                 assert selected in subprotocols
-                subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n"
-                                      % selected)
+                self.handler.set_header("Sec-WebSocket-Protocol", selected)
 
         extension_header = ''
         extensions = self._parse_extensions_header(self.request.headers)
@@ -594,23 +597,20 @@ class WebSocketProtocol13(WebSocketProtocol):
                     # Don't echo an offered client_max_window_bits
                     # parameter with no value.
                     del ext[1]['client_max_window_bits']
-                extension_header = ('Sec-WebSocket-Extensions: %s\r\n' %
-                                    httputil._encode_header(
-                                        'permessage-deflate', ext[1]))
+                self.handler.set_header("Sec-WebSocket-Extensions",
+                                        httputil._encode_header(
+                                            'permessage-deflate', ext[1]))
                 break
 
-        if self.stream.closed():
-            self._abort()
-            return
-        self.stream.write(tornado.escape.utf8(
-            "HTTP/1.1 101 Switching Protocols\r\n"
-            "Upgrade: websocket\r\n"
-            "Connection: Upgrade\r\n"
-            "Sec-WebSocket-Accept: %s\r\n"
-            "%s%s%s"
-            "\r\n" % (self._challenge_response(),
-                      subprotocol_header, extension_header,
-                      self.handler.upgrade_response_headers())))
+        self.handler.clear_header("Content-Type")
+        self.handler.set_status(101)
+        self.handler.set_header("Upgrade", "websocket")
+        self.handler.set_header("Connection", "Upgrade")
+        self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response())
+        self.handler.finish()
+
+        self.handler._attach_stream()
+        self.stream = self.handler.stream
 
         self._run_callback(self.handler.open, *self.handler.open_args,
                            **self.handler.open_kwargs)