]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
WebSocket: disable RequestHandler methods by patching the instance
authorJames Maier <James.Maier@viasat.com>
Mon, 9 Jan 2017 03:45:15 +0000 (22:45 -0500)
committerJames Maier <James.Maier@viasat.com>
Mon, 9 Jan 2017 03:45:29 +0000 (22:45 -0500)
tornado/test/websocket_test.py
tornado/websocket.py

index bcf5e1327deb13488aabf7795420a8b18d7516a2..659b2f00000324c701410c8ba61e82f524140452 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import absolute_import, division, print_function, with_statement
 
+import functools
 import random
 import string
 import traceback
@@ -59,13 +60,23 @@ class ErrorInOnMessageHandler(TestWebSocketHandler):
 
 class HeaderHandler(TestWebSocketHandler):
     def open(self):
-        try:
-            # In a websocket context, many RequestHandler methods
-            # raise RuntimeErrors.
-            self.set_status(503)
-            raise Exception("did not get expected exception")
-        except RuntimeError:
-            pass
+        methods_to_test = [
+            functools.partial(self.write, 'This should not work'),
+            functools.partial(self.redirect, 'http://localhost/elsewhere'),
+            functools.partial(self.set_header, 'X-Test', ''),
+            functools.partial(self.set_cookie, 'Chocolate', 'Chip'),
+            functools.partial(self.set_status, 503),
+            self.flush,
+            self.finish,
+        ]
+        for method in methods_to_test:
+            try:
+                # In a websocket context, many RequestHandler methods
+                # raise RuntimeErrors.
+                method()
+                raise Exception("did not get expected exception")
+            except RuntimeError:
+                pass
         self.write_message(self.request.headers.get('X-Test', ''))
 
 
index e55e7246bc31a14bf4d7f02d8c612f5e2bebb597..625e430393d20d6c66512a13e9b44fbcdcf94a42 100644 (file)
@@ -127,12 +127,12 @@ class WebSocketHandler(tornado.web.RequestHandler):
     to accept it before the websocket connection will succeed.
     """
     def __init__(self, application, request, **kwargs):
+        super(WebSocketHandler, self).__init__(application, request, **kwargs)
         self.ws_connection = None
         self.close_code = None
         self.close_reason = None
         self.stream = None
         self._on_close_called = False
-        super(WebSocketHandler, self).__init__(application, request, **kwargs)
 
     @tornado.web.asynchronous
     def get(self, *args, **kwargs):
@@ -405,19 +405,14 @@ class WebSocketHandler(tornado.web.RequestHandler):
     def _attach_stream(self):
         self.stream = self.request.connection.detach()
         self.stream.set_close_callback(self.on_connection_close)
+        # disable non-WS methods
+        for method in ["write", "redirect", "set_header", "set_cookie",
+                       "set_status", "flush", "finish"]:
+            setattr(self, method, _raise_not_supported_for_websockets)
 
 
-def _wrap_method(method):
-    def _disallow_for_websocket(self, *args, **kwargs):
-        if self.stream is None:
-            method(self, *args, **kwargs)
-        else:
-            raise RuntimeError("Method not supported for Web Sockets")
-    return _disallow_for_websocket
-for method in ["write", "redirect", "set_header", "set_cookie",
-               "set_status", "flush", "finish"]:
-    setattr(WebSocketHandler, method,
-            _wrap_method(getattr(WebSocketHandler, method)))
+def _raise_not_supported_for_websockets(*args, **kwargs):
+    raise RuntimeError("Method not supported for Web Sockets")
 
 
 class WebSocketProtocol(object):