]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Relax restrictions on HTTP methods in WebSocketHandler.
authorBen Darnell <ben@bendarnell.com>
Mon, 16 Jun 2014 03:35:02 +0000 (23:35 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 16 Jun 2014 03:35:02 +0000 (23:35 -0400)
Methods like set_status are now disallowed once the websocket handshake
has begun, but may be used before then.  This applies to application
overrides of prepare() and to WebSocketHandler.get's internal error
handling.

Closes #1065.

tornado/test/websocket_test.py
tornado/websocket.py

index fd0b08ca969477ba2cf0d4bee60a37971a7ec603..7b3c34ceef5196768df2d4c318022aab24f2e984 100644 (file)
@@ -47,6 +47,13 @@ class EchoHandler(TestWebSocketHandler):
 
 class HeaderHandler(TestWebSocketHandler):
     def open(self):
+        try:
+            # In a websocket context, many RequestHandler methods
+            # raise RuntimeErrors.
+            self.set_status(503)
+            raise Exception("did not get expected exception")
+        except RuntimeError:
+            pass
         self.write_message(self.request.headers.get('X-Test', ''))
 
 
@@ -71,6 +78,11 @@ class WebSocketTest(AsyncHTTPTestCase):
              dict(close_future=self.close_future)),
         ])
 
+    def test_http_request(self):
+        # WS server, HTTP client.
+        response = self.fetch('/echo')
+        self.assertEqual(response.code, 400)
+
     @gen_test
     def test_websocket_gen(self):
         ws = yield websocket_connect(
index 2704c26c1a9d924c3e84d1337853ec9143496944..19196b88b1109ee7ed631e901e22b6b6b530b37e 100644 (file)
@@ -115,22 +115,17 @@ class WebSocketHandler(tornado.web.RequestHandler):
         self.ws_connection = None
         self.close_code = None
         self.close_reason = None
+        self.stream = None
 
     @tornado.web.asynchronous
     def get(self, *args, **kwargs):
         self.open_args = args
         self.open_kwargs = kwargs
 
-        self.stream = self.request.connection.detach()
-        self.stream.set_close_callback(self.on_connection_close)
-
         # Upgrade header should be present and should be equal to WebSocket
         if self.request.headers.get("Upgrade", "").lower() != 'websocket':
-            self.stream.write(tornado.escape.utf8(
-                "HTTP/1.1 400 Bad Request\r\n\r\n"
-                "Can \"Upgrade\" only to \"WebSocket\"."
-            ))
-            self.stream.close()
+            self.set_status(400)
+            self.finish("Can \"Upgrade\" only to \"WebSocket\".")
             return
 
         # Connection header should be upgrade. Some proxy servers/load balancers
@@ -138,11 +133,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
         headers = self.request.headers
         connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
         if 'upgrade' not in connection:
-            self.stream.write(tornado.escape.utf8(
-                "HTTP/1.1 400 Bad Request\r\n\r\n"
-                "\"Connection\" must be \"Upgrade\"."
-            ))
-            self.stream.close()
+            self.set_status(400)
+            self.finish("\"Connection\" must be \"Upgrade\".")
             return
 
         # Handle WebSocket Origin naming convention differences
@@ -159,12 +151,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
         # according to check_origin. When the origin is None, we assume it
         # did not come from a browser and that it can be passed on.
         if origin is not None and not self.check_origin(origin):
-            self.stream.write(tornado.escape.utf8(
-                "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
-            ))
-            self.stream.close()
+            self.set_status(403)
+            self.finish("Cross origin websockets not allowed")
             return
 
+        self.stream = self.request.connection.detach()
+        self.stream.set_close_callback(self.on_connection_close)
+
         if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
             self.ws_connection = WebSocketProtocol13(self)
             self.ws_connection.accept_connection()
@@ -346,9 +339,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
         """
         return "wss" if self.request.protocol == "https" else "ws"
 
-    def _not_supported(self, *args, **kwargs):
-        raise Exception("Method not supported for Web Sockets")
-
     def on_connection_close(self):
         if self.ws_connection:
             self.ws_connection.on_connection_close()
@@ -356,9 +346,17 @@ class WebSocketHandler(tornado.web.RequestHandler):
             self.on_close()
 
 
+def _wrap_method(method):
+    def _disallow_for_websocket(self, *args, **kwargs):
+        if self.stream is None:
+            method(self, *args, **kwargs)
+        else:
+            raise RuntimeError("Method not supported for Web Sockets")
+    return _disallow_for_websocket
 for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
                "set_status", "flush", "finish"]:
-    setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
+    setattr(WebSocketHandler, method,
+            _wrap_method(getattr(WebSocketHandler, method)))
 
 
 class WebSocketProtocol(object):