From: James Maier Date: Mon, 9 Jan 2017 03:45:15 +0000 (-0500) Subject: WebSocket: disable RequestHandler methods by patching the instance X-Git-Tag: v4.5.0~12^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=70727e7708e56adb44c28747d4f179eaa47c0edd;p=thirdparty%2Ftornado.git WebSocket: disable RequestHandler methods by patching the instance --- diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index bcf5e1327..659b2f000 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function, with_statement +import functools import random import string import traceback @@ -59,13 +60,23 @@ class ErrorInOnMessageHandler(TestWebSocketHandler): class HeaderHandler(TestWebSocketHandler): def open(self): - try: - # In a websocket context, many RequestHandler methods - # raise RuntimeErrors. - self.set_status(503) - raise Exception("did not get expected exception") - except RuntimeError: - pass + methods_to_test = [ + functools.partial(self.write, 'This should not work'), + functools.partial(self.redirect, 'http://localhost/elsewhere'), + functools.partial(self.set_header, 'X-Test', ''), + functools.partial(self.set_cookie, 'Chocolate', 'Chip'), + functools.partial(self.set_status, 503), + self.flush, + self.finish, + ] + for method in methods_to_test: + try: + # In a websocket context, many RequestHandler methods + # raise RuntimeErrors. + method() + raise Exception("did not get expected exception") + except RuntimeError: + pass self.write_message(self.request.headers.get('X-Test', '')) diff --git a/tornado/websocket.py b/tornado/websocket.py index e55e7246b..625e43039 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -127,12 +127,12 @@ class WebSocketHandler(tornado.web.RequestHandler): to accept it before the websocket connection will succeed. """ def __init__(self, application, request, **kwargs): + super(WebSocketHandler, self).__init__(application, request, **kwargs) self.ws_connection = None self.close_code = None self.close_reason = None self.stream = None self._on_close_called = False - super(WebSocketHandler, self).__init__(application, request, **kwargs) @tornado.web.asynchronous def get(self, *args, **kwargs): @@ -405,19 +405,14 @@ class WebSocketHandler(tornado.web.RequestHandler): def _attach_stream(self): self.stream = self.request.connection.detach() self.stream.set_close_callback(self.on_connection_close) + # disable non-WS methods + for method in ["write", "redirect", "set_header", "set_cookie", + "set_status", "flush", "finish"]: + setattr(self, method, _raise_not_supported_for_websockets) -def _wrap_method(method): - def _disallow_for_websocket(self, *args, **kwargs): - if self.stream is None: - method(self, *args, **kwargs) - else: - raise RuntimeError("Method not supported for Web Sockets") - return _disallow_for_websocket -for method in ["write", "redirect", "set_header", "set_cookie", - "set_status", "flush", "finish"]: - setattr(WebSocketHandler, method, - _wrap_method(getattr(WebSocketHandler, method))) +def _raise_not_supported_for_websockets(*args, **kwargs): + raise RuntimeError("Method not supported for Web Sockets") class WebSocketProtocol(object):