]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Refactor WebSocket support to prepare for multiple protocol versions.
authorFlorian Diebold <flodiebold@gmail.com>
Sat, 9 Jul 2011 20:33:21 +0000 (22:33 +0200)
committerFlorian Diebold <flodiebold@gmail.com>
Wed, 13 Jul 2011 19:26:54 +0000 (21:26 +0200)
All protocol-specific functions are moved to the former
WebSocketRequest class, which is renamed to WebSocketProtocol76. The
WebSocketHandler chooses the right WebSocketProtocol implementation by
looking at the request headers.

tornado/websocket.py

index 3fcde801f1e9ffa1d2bbd3d62fb2893b9aadecf6..35bf3a6402bff6ef4c94573257929caa774137dc 100644 (file)
@@ -69,59 +69,19 @@ class WebSocketHandler(tornado.web.RequestHandler):
         tornado.web.RequestHandler.__init__(self, application, request,
                                             **kwargs)
         self.stream = request.connection.stream
-        self.client_terminated = False
-        self._waiting = None
 
     def _execute(self, transforms, *args, **kwargs):
         self.open_args = args
         self.open_kwargs = kwargs
-        try:
-            self.ws_request = WebSocketRequest(self.request)
-        except ValueError:
-            logging.debug("Malformed WebSocket request received")
-            self._abort()
-            return
-        scheme = "wss" if self.request.protocol == "https" else "ws"
-        # Write the initial headers before attempting to read the challenge.
-        # This is necessary when using proxies (such as HAProxy), which
-        # need to see the Upgrade headers before passing through the
-        # non-HTTP traffic that follows.
-        self.stream.write(tornado.escape.utf8(
-            "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
-            "Upgrade: WebSocket\r\n"
-            "Connection: Upgrade\r\n"
-            "Server: TornadoServer/%(version)s\r\n"
-            "Sec-WebSocket-Origin: %(origin)s\r\n"
-            "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n\r\n" % (dict(
-                    version=tornado.version,
-                    origin=self.request.headers["Origin"],
-                    scheme=scheme,
-                    host=self.request.host,
-                    uri=self.request.uri))))
-        self.stream.read_bytes(8, self._handle_challenge)
-
-    def _handle_challenge(self, challenge):
-        try:
-            challenge_response = self.ws_request.challenge_response(challenge)
-        except ValueError:
-            logging.debug("Malformed key data in WebSocket request")
-            self._abort()
-            return
-        self._write_response(challenge_response)
-
-    def _write_response(self, challenge):
-        self.stream.write(challenge)
-        self.async_callback(self.open)(*self.open_args, **self.open_kwargs)
-        self._receive_message()
+        if ("Sec-WebSocket-Version" in self.request.headers and
+            self.request.headers["Sec-WebSocket-Version"] == "8"):
+            logging.error("WebSocket protocol 8 request!")
+        else:
+            self.ws_connection = WebSocketProtocol76(self)
 
     def write_message(self, message):
         """Sends the given message to the client of this Web Socket."""
-        if isinstance(message, dict):
-            message = tornado.escape.json_encode(message)
-        if isinstance(message, unicode):
-            message = message.encode("utf-8")
-        assert isinstance(message, bytes_type)
-        self.stream.write(b("\x00") + message + b("\xff"))
+        self.ws_connection.write_message(message)
 
     def open(self, *args, **kwargs):
         """Invoked when a new WebSocket is opened."""
@@ -138,24 +98,53 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """Invoked when the WebSocket is closed."""
         pass
 
-
     def close(self):
         """Closes this Web Socket.
 
         Once the close handshake is successful the socket will be closed.
         """
-        if self.client_terminated and self._waiting:
-            tornado.ioloop.IOLoop.instance().remove_timeout(self._waiting)
-            self.stream.close()
-        else:
-            self.stream.write("\xff\x00")
-            self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(
-                                time.time() + 5, self._abort)
+        self.ws_connection.close()
+
+    def async_callback(self, callback, *args, **kwargs):
+        """Wrap callbacks with this if they are used on asynchronous requests.
+
+        Catches exceptions properly and closes this WebSocket if an exception
+        is uncaught.
+        """
+        return self.ws_connection.async_callback(callback, *args, **kwargs)
+
+    def _not_supported(self, *args, **kwargs):
+        raise Exception("Method not supported for Web Sockets")
+
+    def on_connection_close(self):
+        self.ws_connection.client_terminated = True
+        self.on_close()
+
+    def _set_client_terminated(self, value):
+        self.ws_connection.client_terminated = value
+
+    client_terminated = property(lambda self: self.ws_connection.client_terminated,
+                                 _set_client_terminated)
+
+
+for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
+               "set_status", "flush", "finish"]:
+    setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
+
+
+class WebSocketProtocol(object):
+    """Base class for WebSocket protocol versions.
+    """
+    def __init__(self, handler):
+        self.handler = handler
+        self.request = handler.request
+        self.stream = handler.stream
+        self.client_terminated = False
 
     def async_callback(self, callback, *args, **kwargs):
         """Wrap callbacks with this if they are used on asynchronous requests.
 
-        Catches exceptions properly and closes this Web Socket if an exception
+        Catches exceptions properly and closes this WebSocket if an exception
         is uncaught.
         """
         if args or kwargs:
@@ -174,59 +163,45 @@ class WebSocketHandler(tornado.web.RequestHandler):
         self.client_terminated = True
         self.stream.close()
 
-    def _receive_message(self):
-        self.stream.read_bytes(1, self._on_frame_type)
-
-    def _on_frame_type(self, byte):
-        frame_type = ord(byte)
-        if frame_type == 0x00:
-            self.stream.read_until(b("\xff"), self._on_end_delimiter)
-        elif frame_type == 0xff:
-            self.stream.read_bytes(1, self._on_length_indicator)
-        else:
-            self._abort()
 
-    def _on_end_delimiter(self, frame):
-        if not self.client_terminated:
-            self.async_callback(self.on_message)(
-                    frame[:-1].decode("utf-8", "replace"))
-        if not self.client_terminated:
-            self._receive_message()
-
-    def _on_length_indicator(self, byte):
-        if ord(byte) != 0x00:
-            self._abort()
-            return
-        self.client_terminated = True
-        self.close()
-
-    def on_connection_close(self):
-        self.client_terminated = True
-        self.on_close()
-
-    def _not_supported(self, *args, **kwargs):
-        raise Exception("Method not supported for Web Sockets")
-
-
-for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
-               "set_status", "flush", "finish"]:
-    setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
-
-
-class WebSocketRequest(object):
-    """A single WebSocket request.
+class WebSocketProtocol76(WebSocketProtocol):
+    """Implementation of the WebSockets protocol, version hixie-76.
 
     This class provides basic functionality to process WebSockets requests as
     specified in
     http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
     """
-    def __init__(self, request):
-        self.request = request
+    def __init__(self, handler):
+        WebSocketProtocol.__init__(self, handler)
         self.challenge = None
-        self._handle_websocket_headers()
+        self._waiting = None
+        try:
+            self._handle_websocket_headers()
+        except ValueError:
+            logging.debug("Malformed WebSocket request received")
+            self._abort()
+            return
+        scheme = "wss" if self.request.protocol == "https" else "ws"
+        # Write the initial headers before attempting to read the challenge.
+        # This is necessary when using proxies (such as HAProxy), which
+        # need to see the Upgrade headers before passing through the
+        # non-HTTP traffic that follows.
+        self.stream.write(tornado.escape.utf8(
+            "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+            "Upgrade: WebSocket\r\n"
+            "Connection: Upgrade\r\n"
+            "Server: TornadoServer/%(version)s\r\n"
+            "Sec-WebSocket-Origin: %(origin)s\r\n"
+            "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n\r\n" % (dict(
+                    version=tornado.version,
+                    origin=self.request.headers["Origin"],
+                    scheme=scheme,
+                    host=self.request.host,
+                    uri=self.request.uri))))
+        self.stream.read_bytes(8, self._handle_challenge)
 
     def challenge_response(self, challenge):
-        """Generates the challange response that's needed in the handshake
+        """Generates the challenge response that's needed in the handshake
 
         The challenge parameter should be the raw bytes as sent from the
         client.
@@ -240,6 +215,20 @@ class WebSocketRequest(object):
             raise ValueError("Invalid Keys/Challenge")
         return self._generate_challenge_response(part_1, part_2, challenge)
 
+    def _handle_challenge(self, challenge):
+        try:
+            challenge_response = self.challenge_response(challenge)
+        except ValueError:
+            logging.debug("Malformed key data in WebSocket request")
+            self._abort()
+            return
+        self._write_response(challenge_response)
+
+    def _write_response(self, challenge):
+        self.stream.write(challenge)
+        self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
+        self._receive_message()
+
     def _handle_websocket_headers(self):
         """Verifies all invariant- and required headers
 
@@ -272,3 +261,48 @@ class WebSocketRequest(object):
         m.update(part_2)
         m.update(part_3)
         return m.digest()
+
+    def _receive_message(self):
+        self.stream.read_bytes(1, self._on_frame_type)
+
+    def _on_frame_type(self, byte):
+        frame_type = ord(byte)
+        if frame_type == 0x00:
+            self.stream.read_until(b("\xff"), self._on_end_delimiter)
+        elif frame_type == 0xff:
+            self.stream.read_bytes(1, self._on_length_indicator)
+        else:
+            self._abort()
+
+    def _on_end_delimiter(self, frame):
+        if not self.client_terminated:
+            self.async_callback(self.handler.on_message)(
+                    frame[:-1].decode("utf-8", "replace"))
+        if not self.client_terminated:
+            self._receive_message()
+
+    def _on_length_indicator(self, byte):
+        if ord(byte) != 0x00:
+            self._abort()
+            return
+        self.client_terminated = True
+        self.close()
+
+    def write_message(self, message):
+        """Sends the given message to the client of this Web Socket."""
+        if isinstance(message, dict):
+            message = tornado.escape.json_encode(message)
+        if isinstance(message, unicode):
+            message = message.encode("utf-8")
+        assert isinstance(message, bytes_type)
+        self.stream.write(b("\x00") + message + b("\xff"))
+
+    def close(self):
+        """Closes the WebSocket connection."""
+        if self.client_terminated and self._waiting:
+            tornado.ioloop.IOLoop.instance().remove_timeout(self._waiting)
+            self.stream.close()
+        else:
+            self.stream.write("\xff\x00")
+            self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(
+                                time.time() + 5, self._abort)