From: James Maier Date: Fri, 23 Dec 2016 16:01:42 +0000 (-0500) Subject: Do WebSocket upgrade using RequestHandler, allowing set_default_headers X-Git-Tag: v4.5.0~12^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=85f3d8bc6f349a38683edc1cb652a70b1c426153;p=thirdparty%2Ftornado.git Do WebSocket upgrade using RequestHandler, allowing set_default_headers --- diff --git a/docs/websocket.rst b/docs/websocket.rst index 836bb3df4..5d36d2349 100644 --- a/docs/websocket.rst +++ b/docs/websocket.rst @@ -22,7 +22,7 @@ .. automethod:: WebSocketHandler.write_message .. automethod:: WebSocketHandler.close - .. automethod:: WebSocketHandler.upgrade_response_headers + .. automethod:: WebSocketHandler.set_default_headers Configuration ------------- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 60e5fd65f..acd61fad1 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -70,12 +70,11 @@ class HeaderHandler(TestWebSocketHandler): class HeaderEchoHandler(TestWebSocketHandler): - def upgrade_response_headers(self): - return ''.join( - "{}: {}\r\n".format(k, v) - for k, v in self.request.headers.get_all() - if k.lower().startswith('x-test') - ) + def set_default_headers(self): + for k, v in self.request.headers.get_all(): + if k.lower().startswith('x-test'): + self.set_header(k, v) + self.set_header("X-Extra-Response-Header", "Extra-Response-Value") class NonWebSocketHandler(RequestHandler): @@ -249,6 +248,7 @@ class WebSocketTest(WebSocketBaseTestCase): self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello') self.assertEqual(ws.headers.get('X-Test-Goodbye'), 'goodbye') self.assertEqual(ws.headers.get('X-Test-Random'), random_str) + self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value') yield self.close(ws) @gen_test diff --git a/tornado/websocket.py b/tornado/websocket.py index 8f5ed7eef..e4a65a928 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -66,7 +66,7 @@ class WebSocketHandler(tornado.web.RequestHandler): connections. Custom upgrade response headers can be sent by overriding - `upgrade_response_headers`. + `set_default_headers`. See http://dev.w3.org/html5/websockets/ for details on the JavaScript interface. The protocol is specified at @@ -127,12 +127,12 @@ class WebSocketHandler(tornado.web.RequestHandler): to accept it before the websocket connection will succeed. """ def __init__(self, application, request, **kwargs): - super(WebSocketHandler, self).__init__(application, request, **kwargs) self.ws_connection = None self.close_code = None self.close_reason = None self.stream = None self._on_close_called = False + super(WebSocketHandler, self).__init__(application, request, **kwargs) @tornado.web.asynchronous def get(self, *args, **kwargs): @@ -179,18 +179,13 @@ class WebSocketHandler(tornado.web.RequestHandler): gen_log.debug(log_msg) return - self.stream = self.request.connection.detach() - self.stream.set_close_callback(self.on_connection_close) - self.ws_connection = self.get_websocket_protocol() if self.ws_connection: self.ws_connection.accept_connection() else: - if not self.stream.closed(): - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 426 Upgrade Required\r\n" - "Sec-WebSocket-Version: 7, 8, 13\r\n\r\n")) - self.stream.close() + self.set_status(426) + self.set_header("Sec-WebSocket-Version", "7, 8, 13") + self.finish() def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -241,11 +236,16 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return None - def upgrade_response_headers(self): - """Override to return additional headers to send in the websocket - upgrade response. + def set_default_headers(self): + """Override this to set HTTP headers at the beginning of the + WebSocket upgrade request. + + For example, this is the place to set a custom ``Server`` header. + Note that setting such headers in the normal flow of request + processing may not do what you want, since headers may be reset + during error handling. """ - return "" + pass def open(self, *args, **kwargs): """Invoked when a new WebSocket is opened. @@ -402,6 +402,10 @@ class WebSocketHandler(tornado.web.RequestHandler): return WebSocketProtocol13( self, compression_options=self.get_compression_options()) + def _attach_stream(self): + self.stream = self.request.connection.detach() + self.stream.set_close_callback(self.on_connection_close) + def _wrap_method(method): def _disallow_for_websocket(self, *args, **kwargs): @@ -578,8 +582,7 @@ class WebSocketProtocol13(WebSocketProtocol): selected = self.handler.select_subprotocol(subprotocols) if selected: assert selected in subprotocols - subprotocol_header = ("Sec-WebSocket-Protocol: %s\r\n" - % selected) + self.handler.set_header("Sec-WebSocket-Protocol", selected) extension_header = '' extensions = self._parse_extensions_header(self.request.headers) @@ -594,23 +597,20 @@ class WebSocketProtocol13(WebSocketProtocol): # Don't echo an offered client_max_window_bits # parameter with no value. del ext[1]['client_max_window_bits'] - extension_header = ('Sec-WebSocket-Extensions: %s\r\n' % - httputil._encode_header( - 'permessage-deflate', ext[1])) + self.handler.set_header("Sec-WebSocket-Extensions", + httputil._encode_header( + 'permessage-deflate', ext[1])) break - if self.stream.closed(): - self._abort() - return - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: websocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: %s\r\n" - "%s%s%s" - "\r\n" % (self._challenge_response(), - subprotocol_header, extension_header, - self.handler.upgrade_response_headers()))) + self.handler.clear_header("Content-Type") + self.handler.set_status(101) + self.handler.set_header("Upgrade", "websocket") + self.handler.set_header("Connection", "Upgrade") + self.handler.set_header("Sec-WebSocket-Accept", self._challenge_response()) + self.handler.finish() + + self.handler._attach_stream() + self.stream = self.handler.stream self._run_callback(self.handler.open, *self.handler.open_args, **self.handler.open_kwargs)