]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement the hybi-10 version of the WebSocket protocol.
authorFlorian Diebold <flodiebold@gmail.com>
Wed, 13 Jul 2011 19:12:48 +0000 (21:12 +0200)
committerFlorian Diebold <flodiebold@gmail.com>
Wed, 13 Jul 2011 19:26:54 +0000 (21:26 +0200)
tornado/websocket.py

index 35bf3a6402bff6ef4c94573257929caa774137dc..00da861c4534484baf9f12294c56563a2f585e4c 100644 (file)
@@ -17,6 +17,7 @@ import hashlib
 import logging
 import struct
 import time
+import base64
 import tornado.escape
 import tornado.web
 
@@ -73,11 +74,20 @@ class WebSocketHandler(tornado.web.RequestHandler):
     def _execute(self, transforms, *args, **kwargs):
         self.open_args = args
         self.open_kwargs = kwargs
-        if ("Sec-WebSocket-Version" in self.request.headers and
-            self.request.headers["Sec-WebSocket-Version"] == "8"):
-            logging.error("WebSocket protocol 8 request!")
+
+        if self.request.headers.get("Sec-WebSocket-Version") == "8":
+            self.ws_connection = WebSocketProtocol8(self)
+            self.ws_connection.accept_connection()
+            
+        elif self.request.headers.get("Sec-WebSocket-Version"):
+            self.stream.write(tornado.escape.utf8(
+                "HTTP/1.1 426 Upgrade Required\r\n"
+                "Sec-WebSocket-Version: 8\r\n\r\n"))
+            self.stream.close()
+            
         else:
             self.ws_connection = WebSocketProtocol76(self)
+            self.ws_connection.accept_connection()
 
     def write_message(self, message):
         """Sends the given message to the client of this Web Socket."""
@@ -175,6 +185,8 @@ class WebSocketProtocol76(WebSocketProtocol):
         WebSocketProtocol.__init__(self, handler)
         self.challenge = None
         self._waiting = None
+
+    def accept_connection(self):
         try:
             self._handle_websocket_headers()
         except ValueError:
@@ -306,3 +318,176 @@ class WebSocketProtocol76(WebSocketProtocol):
             self.stream.write("\xff\x00")
             self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(
                                 time.time() + 5, self._abort)
+
+
+class WebSocketProtocol8(WebSocketProtocol):
+    """Implementation of the WebSocket protocol, version 8 (draft version 10).
+
+    Compare
+    http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
+    """
+    def __init__(self, handler):
+        WebSocketProtocol.__init__(self, handler)
+        self._final_frame = False
+        self._frame_opcode = None
+        self._frame_mask = None
+        self._frame_length = None
+        self._fragmented_message_buffer = None
+        self._fragmented_message_opcode = None
+        self._started_closing_handshake = False
+
+    def accept_connection(self):
+        try:
+            self._handle_websocket_headers()
+            self._accept_connection()
+        except ValueError:
+            logging.debug("Malformed WebSocket request received")
+            self._abort()
+            return
+    
+    def _handle_websocket_headers(self):
+        """Verifies all invariant- and required headers
+
+        If a header is missing or have an incorrect value ValueError will be
+        raised
+        """
+        headers = self.request.headers
+        fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version")
+        connection = map(lambda s: s.strip().lower(), headers.get("Connection", '').split(","))
+        if (self.request.method != "GET" or
+            headers.get("Upgrade", '').lower() != "websocket" or
+            "upgrade" not in connection or
+            not all(map(lambda f: self.request.headers.get(f), fields))):
+            raise ValueError("Missing/Invalid WebSocket headers")
+
+    def _challenge_response(self):
+        sha1 = hashlib.sha1()
+        sha1.update(self.request.headers.get("Sec-Websocket-Key"))
+        sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") # Magic value
+        return base64.b64encode(sha1.digest())
+
+    def _accept_connection(self):
+        self.stream.write(tornado.escape.utf8(
+            "HTTP/1.1 101 Switching Protocols\r\n"
+            "Upgrade: websocket\r\n"
+            "Connection: Upgrade\r\n"
+            "Sec-WebSocket-Accept: %s\r\n\r\n" % self._challenge_response()))
+
+        self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
+        self._receive_frame()
+
+    def _write_frame(self, fin, opcode, data):
+        if fin:
+            finbit = 0b10000000
+        else:
+            finbit = 0
+        frame = struct.pack("B", finbit | opcode)
+        l = len(data)
+        if l <= 126:
+            frame += struct.pack("B", l)
+        elif l <= 0xFFFF:
+            frame += struct.pack("!BH", 126, l)
+        else:
+            frame += struct.pack("!BQ", 127, l)
+        frame += data
+        self.stream.write(frame)
+
+    def write_message(self, message, binary=False):
+        """Sends the given message to the client of this Web Socket."""
+        if isinstance(message, dict):
+            message = tornado.escape.json_encode(message)
+        if isinstance(message, unicode):
+            message = message.encode("utf-8")
+        assert isinstance(message, bytes_type)
+        if not binary:
+            opcode = 0x1
+        else:
+            opcode = 0x2
+        self._write_frame(True, opcode, message)
+
+    def _receive_frame(self):
+        self.stream.read_bytes(2, self._on_frame_start)
+
+    def _on_frame_start(self, data):
+        header, payloadlen = struct.unpack("BB", data)
+        self._final_frame = header & 0b10000000
+        self._frame_opcode = header & 0b1111
+        if not (payloadlen & 0b10000000):
+            # Unmasked frame -> abort connection
+            self._abort()
+        payloadlen = payloadlen & 0b1111111
+        if payloadlen < 126:
+            self._frame_length = payloadlen
+            self.stream.read_bytes(4, self._on_masking_key)
+        elif payloadlen == 126:
+            self.stream.read_bytes(2, self._on_frame_length_16)
+        elif payloadlen == 127:
+            self.stream.read_bytes(8, self._on_frame_length_64)
+
+    def _on_frame_length_16(self, data):
+        self._frame_length = struct.unpack("!H", data)[0];
+        self.stream.read_bytes(4, self._on_masking_key);
+        
+    def _on_frame_length_64(self, data):
+        self._frame_length = struct.unpack("!Q", data)[0];
+        self.stream.read_bytes(4, self._on_masking_key);
+
+    def _on_masking_key(self, data):
+        self._frame_mask = bytearray(data)
+        self.stream.read_bytes(self._frame_length, self._on_frame_data)
+
+    def _on_frame_data(self, data):
+        unmasked = bytearray(data)
+        for i in xrange(len(data)):
+            unmasked[i] = unmasked[i] ^ self._frame_mask[i % 4]
+
+        if not self._final_frame:
+            if self._fragmented_message_buffer:
+                self._fragmented_message_buffer += unmasked
+            else:
+                self._fragmented_message_opcode = self._frame_opcode
+                self._fragmented_message_buffer = unmasked
+        else:
+            if self._frame_opcode == 0:
+                unmasked = self._fragmented_message_buffer + unmasked
+                opcode = self._fragmented_message_opcode
+                self._fragmented_message_buffer = None
+            else:
+                opcode = self._frame_opcode
+
+            self._handle_message(opcode, bytes_type(unmasked))
+
+        if not self.client_terminated:
+            self._receive_frame()
+        
+
+    def _handle_message(self, opcode, data):
+        if self.client_terminated: return
+        
+        if opcode == 0x1:
+            # UTF-8 data
+            self.async_callback(self.handler.on_message)(data.decode("utf-8", "replace"))
+        elif opcode == 0x2:
+            # Binary data
+            self.async_callback(self.handler.on_message)(data)
+        elif opcode == 0x8:
+            # Close
+            self.client_terminated = True
+            if not self._started_closing_handshake:
+                self._write_frame(True, 0x8, b(""))
+            self.stream.close()
+        elif opcode == 0x9:
+            # Ping
+            self._write_frame(True, 0xA, b(""))
+        elif opcode == 0xA:
+            # Pong
+            pass
+        else:
+            self._abort()
+        
+    def close(self):
+        """Closes the WebSocket connection."""
+        self._write_frame(True, 0x8, b(""))
+        self._started_closing_handshake = True
+        self._waiting = tornado.ioloop.IOLoop.instance().add_timeout(time.time() + 5, self._abort)
+