From: vovanec Date: Mon, 12 Jan 2015 16:06:44 +0000 (-0800) Subject: Addressed comments from Ben X-Git-Tag: v4.1.0b1~6^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=433badc08ec4b5f529ed8dfc20f21a19edf21e85;p=thirdparty%2Ftornado.git Addressed comments from Ben Addressed comments from Ben. Added get_websocket_protocol() to WebSocketClientConnection class as well. --- diff --git a/tornado/websocket.py b/tornado/websocket.py index 1d08f1067..9868f8975 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -171,12 +171,8 @@ class WebSocketHandler(tornado.web.RequestHandler): self.stream = self.request.connection.detach() self.stream.set_close_callback(self.on_connection_close) - protocol_subclass = self.get_websocket_protocol_subclass( - self.request.headers.get("Sec-WebSocket-Version")) - - if protocol_subclass: - self.ws_connection = protocol_subclass( - self, compression_options=self.get_compression_options()) + self.ws_connection = self.get_websocket_protocol() + if self.ws_connection: self.ws_connection.accept_connection() else: if not self.stream.closed(): @@ -185,20 +181,6 @@ class WebSocketHandler(tornado.web.RequestHandler): "Sec-WebSocket-Version: 8\r\n\r\n")) self.stream.close() - def get_websocket_protocol_subclass(self, web_socket_version): - """Returns WebSocketProtocol subclass for specific WebSocket version. - ``web_socket_version`` argument is a protocol version string passed in - "Sec-WebSocket-Version" header. - - This method can be overridden in subclasses to add support for - custom protocol implementations. - - .. versionadded:: 4.1 - """ - - if web_socket_version in ("7", "8", "13"): - return WebSocketProtocol13 - def write_message(self, message, binary=False): """Sends the given message to the client of this Web Socket. @@ -378,6 +360,13 @@ class WebSocketHandler(tornado.web.RequestHandler): # we can close the connection more gracefully. self.stream.close() + def get_websocket_protocol(self): + websocket_version = self.request.headers.get("Sec-WebSocket-Version") + if websocket_version in ("7", "8", "13"): + return WebSocketProtocol13( + self, compression_options=self.get_compression_options()) + + def _wrap_method(method): def _disallow_for_websocket(self, *args, **kwargs): if self.stream is None: @@ -871,6 +860,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def __init__(self, io_loop, request, compression_options=None): self.compression_options = compression_options self.connect_future = TracebackFuture() + self.protocol = None self.read_future = None self.read_queue = collections.deque() self.key = base64.b64encode(os.urandom(16)) @@ -935,9 +925,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): start_line, headers) self.headers = headers - self.protocol = WebSocketProtocol13( - self, mask_outgoing=True, - compression_options=self.compression_options) + self.protocol = self.get_websocket_protocol() self.protocol._process_server_headers(self.key, self.headers) self.protocol._receive_frame() @@ -987,6 +975,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): def on_pong(self, data): pass + def get_websocket_protocol(self): + return WebSocketProtocol13(self, mask_outgoing=True, + compression_options=self.compression_options) + def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None, compression_options=None):