From: Florian Diebold Date: Sat, 9 Jul 2011 20:33:21 +0000 (+0200) Subject: Refactor WebSocket support to prepare for multiple protocol versions. X-Git-Tag: v2.1.0~78^2~3^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=48e60770c5e27c06de7e8c176bff1f34418e566e;p=thirdparty%2Ftornado.git Refactor WebSocket support to prepare for multiple protocol versions. All protocol-specific functions are moved to the former WebSocketRequest class, which is renamed to WebSocketProtocol76. The WebSocketHandler chooses the right WebSocketProtocol implementation by looking at the request headers. --- diff --git a/tornado/websocket.py b/tornado/websocket.py index 3fcde801f..35bf3a640 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -69,59 +69,19 @@ class WebSocketHandler(tornado.web.RequestHandler): tornado.web.RequestHandler.__init__(self, application, request, **kwargs) self.stream = request.connection.stream - self.client_terminated = False - self._waiting = None def _execute(self, transforms, *args, **kwargs): self.open_args = args self.open_kwargs = kwargs - try: - self.ws_request = WebSocketRequest(self.request) - except ValueError: - logging.debug("Malformed WebSocket request received") - self._abort() - return - scheme = "wss" if self.request.protocol == "https" else "ws" - # Write the initial headers before attempting to read the challenge. - # This is necessary when using proxies (such as HAProxy), which - # need to see the Upgrade headers before passing through the - # non-HTTP traffic that follows. - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Server: TornadoServer/%(version)s\r\n" - "Sec-WebSocket-Origin: %(origin)s\r\n" - "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n\r\n" % (dict( - version=tornado.version, - origin=self.request.headers["Origin"], - scheme=scheme, - host=self.request.host, - uri=self.request.uri)))) - self.stream.read_bytes(8, self._handle_challenge) - - def _handle_challenge(self, challenge): - try: - challenge_response = self.ws_request.challenge_response(challenge) - except ValueError: - logging.debug("Malformed key data in WebSocket request") - self._abort() - return - self._write_response(challenge_response) - - def _write_response(self, challenge): - self.stream.write(challenge) - self.async_callback(self.open)(*self.open_args, **self.open_kwargs) - self._receive_message() + if ("Sec-WebSocket-Version" in self.request.headers and + self.request.headers["Sec-WebSocket-Version"] == "8"): + logging.error("WebSocket protocol 8 request!") + else: + self.ws_connection = WebSocketProtocol76(self) def write_message(self, message): """Sends the given message to the client of this Web Socket.""" - if isinstance(message, dict): - message = tornado.escape.json_encode(message) - if isinstance(message, unicode): - message = message.encode("utf-8") - assert isinstance(message, bytes_type) - self.stream.write(b("\x00") + message + b("\xff")) + self.ws_connection.write_message(message) def open(self, *args, **kwargs): """Invoked when a new WebSocket is opened.""" @@ -138,24 +98,53 @@ class WebSocketHandler(tornado.web.RequestHandler): """Invoked when the WebSocket is closed.""" pass - def close(self): """Closes this Web Socket. Once the close handshake is successful the socket will be closed. """ - if self.client_terminated and self._waiting: - tornado.ioloop.IOLoop.instance().remove_timeout(self._waiting) - self.stream.close() - else: - self.stream.write("\xff\x00") - self._waiting = tornado.ioloop.IOLoop.instance().add_timeout( - time.time() + 5, self._abort) + self.ws_connection.close() + + def async_callback(self, callback, *args, **kwargs): + """Wrap callbacks with this if they are used on asynchronous requests. + + Catches exceptions properly and closes this WebSocket if an exception + is uncaught. + """ + return self.ws_connection.async_callback(callback, *args, **kwargs) + + def _not_supported(self, *args, **kwargs): + raise Exception("Method not supported for Web Sockets") + + def on_connection_close(self): + self.ws_connection.client_terminated = True + self.on_close() + + def _set_client_terminated(self, value): + self.ws_connection.client_terminated = value + + client_terminated = property(lambda self: self.ws_connection.client_terminated, + _set_client_terminated) + + +for method in ["write", "redirect", "set_header", "send_error", "set_cookie", + "set_status", "flush", "finish"]: + setattr(WebSocketHandler, method, WebSocketHandler._not_supported) + + +class WebSocketProtocol(object): + """Base class for WebSocket protocol versions. + """ + def __init__(self, handler): + self.handler = handler + self.request = handler.request + self.stream = handler.stream + self.client_terminated = False def async_callback(self, callback, *args, **kwargs): """Wrap callbacks with this if they are used on asynchronous requests. - Catches exceptions properly and closes this Web Socket if an exception + Catches exceptions properly and closes this WebSocket if an exception is uncaught. """ if args or kwargs: @@ -174,59 +163,45 @@ class WebSocketHandler(tornado.web.RequestHandler): self.client_terminated = True self.stream.close() - def _receive_message(self): - self.stream.read_bytes(1, self._on_frame_type) - - def _on_frame_type(self, byte): - frame_type = ord(byte) - if frame_type == 0x00: - self.stream.read_until(b("\xff"), self._on_end_delimiter) - elif frame_type == 0xff: - self.stream.read_bytes(1, self._on_length_indicator) - else: - self._abort() - def _on_end_delimiter(self, frame): - if not self.client_terminated: - self.async_callback(self.on_message)( - frame[:-1].decode("utf-8", "replace")) - if not self.client_terminated: - self._receive_message() - - def _on_length_indicator(self, byte): - if ord(byte) != 0x00: - self._abort() - return - self.client_terminated = True - self.close() - - def on_connection_close(self): - self.client_terminated = True - self.on_close() - - def _not_supported(self, *args, **kwargs): - raise Exception("Method not supported for Web Sockets") - - -for method in ["write", "redirect", "set_header", "send_error", "set_cookie", - "set_status", "flush", "finish"]: - setattr(WebSocketHandler, method, WebSocketHandler._not_supported) - - -class WebSocketRequest(object): - """A single WebSocket request. +class WebSocketProtocol76(WebSocketProtocol): + """Implementation of the WebSockets protocol, version hixie-76. This class provides basic functionality to process WebSockets requests as specified in http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76 """ - def __init__(self, request): - self.request = request + def __init__(self, handler): + WebSocketProtocol.__init__(self, handler) self.challenge = None - self._handle_websocket_headers() + self._waiting = None + try: + self._handle_websocket_headers() + except ValueError: + logging.debug("Malformed WebSocket request received") + self._abort() + return + scheme = "wss" if self.request.protocol == "https" else "ws" + # Write the initial headers before attempting to read the challenge. + # This is necessary when using proxies (such as HAProxy), which + # need to see the Upgrade headers before passing through the + # non-HTTP traffic that follows. + self.stream.write(tornado.escape.utf8( + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Server: TornadoServer/%(version)s\r\n" + "Sec-WebSocket-Origin: %(origin)s\r\n" + "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n\r\n" % (dict( + version=tornado.version, + origin=self.request.headers["Origin"], + scheme=scheme, + host=self.request.host, + uri=self.request.uri)))) + self.stream.read_bytes(8, self._handle_challenge) def challenge_response(self, challenge): - """Generates the challange response that's needed in the handshake + """Generates the challenge response that's needed in the handshake The challenge parameter should be the raw bytes as sent from the client. @@ -240,6 +215,20 @@ class WebSocketRequest(object): raise ValueError("Invalid Keys/Challenge") return self._generate_challenge_response(part_1, part_2, challenge) + def _handle_challenge(self, challenge): + try: + challenge_response = self.challenge_response(challenge) + except ValueError: + logging.debug("Malformed key data in WebSocket request") + self._abort() + return + self._write_response(challenge_response) + + def _write_response(self, challenge): + self.stream.write(challenge) + self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs) + self._receive_message() + def _handle_websocket_headers(self): """Verifies all invariant- and required headers @@ -272,3 +261,48 @@ class WebSocketRequest(object): m.update(part_2) m.update(part_3) return m.digest() + + def _receive_message(self): + self.stream.read_bytes(1, self._on_frame_type) + + def _on_frame_type(self, byte): + frame_type = ord(byte) + if frame_type == 0x00: + self.stream.read_until(b("\xff"), self._on_end_delimiter) + elif frame_type == 0xff: + self.stream.read_bytes(1, self._on_length_indicator) + else: + self._abort() + + def _on_end_delimiter(self, frame): + if not self.client_terminated: + self.async_callback(self.handler.on_message)( + frame[:-1].decode("utf-8", "replace")) + if not self.client_terminated: + self._receive_message() + + def _on_length_indicator(self, byte): + if ord(byte) != 0x00: + self._abort() + return + self.client_terminated = True + self.close() + + def write_message(self, message): + """Sends the given message to the client of this Web Socket.""" + if isinstance(message, dict): + message = tornado.escape.json_encode(message) + if isinstance(message, unicode): + message = message.encode("utf-8") + assert isinstance(message, bytes_type) + self.stream.write(b("\x00") + message + b("\xff")) + + def close(self): + """Closes the WebSocket connection.""" + if self.client_terminated and self._waiting: + tornado.ioloop.IOLoop.instance().remove_timeout(self._waiting) + self.stream.close() + else: + self.stream.write("\xff\x00") + self._waiting = tornado.ioloop.IOLoop.instance().add_timeout( + time.time() + 5, self._abort)