]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Additional checks for WebSocket protocol handshake.
authorSerge S. Koval <serge.koval@gmail.com>
Thu, 8 Dec 2011 12:42:37 +0000 (14:42 +0200)
committerSerge S. Koval <serge.koval@gmail.com>
Thu, 8 Dec 2011 12:42:37 +0000 (14:42 +0200)
tornado/websocket.py

index ecc80ac436d863faa3db7336c1d70abae0547a2e..0d51d8203abe5fb14ee560d9ab822c576e2c26d7 100644 (file)
@@ -35,7 +35,7 @@ class WebSocketHandler(tornado.web.RequestHandler):
     http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
     The older protocol versions specified at
     http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
-    and 
+    and
     http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76.
     are also supported.
 
@@ -82,19 +82,46 @@ class WebSocketHandler(tornado.web.RequestHandler):
         self.open_args = args
         self.open_kwargs = kwargs
 
+        # Websocket requires GET method
+        if self.request.method != 'GET':
+            self.stream.write(tornado.escape.utf8(
+                "HTTP/1.1 405 Method Not Allowed\r\n\r\n"
+            ))
+            self.stream.close()
+            return
+
+        # Upgrade header should be present and should be equal to WebSocket
+        if self.request.headers.get("Upgrade", "").lower() != 'websocket':
+            self.stream.write(tornado.escape.utf8(
+                "HTTP/1.1 400 Bad Request\r\n\r\n"
+                "Can \"Upgrade\" only to \"WebSocket\"."
+            ))
+            self.stream.close()
+            return
+
+        # Connection header should be upgrade. Some proxy servers/load balancers
+        # might mess with it.
+        if self.request.headers.get("Connection", "").lower() != 'upgrade':
+            self.stream.write(tornado.escape.utf8(
+                "HTTP/1.1 400 Bad Request\r\n\r\n"
+                "\"Connection\" must be \"Upgrade\"."
+            ))
+            self.stream.close()
+            return
+
         # The difference between version 8 and 13 is that in 8 the
         # client sends a "Sec-Websocket-Origin" header and in 13 it's
         # simply "Origin".
         if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
             self.ws_connection = WebSocketProtocol8(self)
             self.ws_connection.accept_connection()
-            
+
         elif self.request.headers.get("Sec-WebSocket-Version"):
             self.stream.write(tornado.escape.utf8(
                 "HTTP/1.1 426 Upgrade Required\r\n"
                 "Sec-WebSocket-Version: 8\r\n\r\n"))
             self.stream.close()
-            
+
         else:
             self.ws_connection = WebSocketProtocol76(self)
             self.ws_connection.accept_connection()
@@ -355,7 +382,7 @@ class WebSocketProtocol8(WebSocketProtocol):
             logging.debug("Malformed WebSocket request received")
             self._abort()
             return
-    
+
     def _handle_websocket_headers(self):
         """Verifies all invariant- and required headers
 
@@ -439,7 +466,7 @@ class WebSocketProtocol8(WebSocketProtocol):
     def _on_frame_length_16(self, data):
         self._frame_length = struct.unpack("!H", data)[0];
         self.stream.read_bytes(4, self._on_masking_key);
-        
+
     def _on_frame_length_64(self, data):
         self._frame_length = struct.unpack("!Q", data)[0];
         self.stream.read_bytes(4, self._on_masking_key);
@@ -471,11 +498,11 @@ class WebSocketProtocol8(WebSocketProtocol):
 
         if not self.client_terminated:
             self._receive_frame()
-        
+
 
     def _handle_message(self, opcode, data):
         if self.client_terminated: return
-        
+
         if opcode == 0x1:
             # UTF-8 data
             self.async_callback(self.handler.on_message)(data.decode("utf-8", "replace"))
@@ -496,7 +523,7 @@ class WebSocketProtocol8(WebSocketProtocol):
             pass
         else:
             self._abort()
-        
+
     def close(self):
         """Closes the WebSocket connection."""
         self._write_frame(True, 0x8, b(""))