]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Addressed comments from Ben 1286/head
authorvovanec <vovanec@gmail.com>
Mon, 12 Jan 2015 16:06:44 +0000 (08:06 -0800)
committervovanec <vovanec@gmail.com>
Mon, 12 Jan 2015 16:06:44 +0000 (08:06 -0800)
Addressed comments from Ben. Added get_websocket_protocol() to WebSocketClientConnection class as well.

tornado/websocket.py

index 1d08f1067f4c0ca54c51746a0324c098953c461a..9868f8975ddd755969f34140572e27c147764d5a 100644 (file)
@@ -171,12 +171,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
         self.stream = self.request.connection.detach()
         self.stream.set_close_callback(self.on_connection_close)
 
-        protocol_subclass = self.get_websocket_protocol_subclass(
-            self.request.headers.get("Sec-WebSocket-Version"))
-
-        if protocol_subclass:
-            self.ws_connection = protocol_subclass(
-                self, compression_options=self.get_compression_options())
+        self.ws_connection = self.get_websocket_protocol()
+        if self.ws_connection:
             self.ws_connection.accept_connection()
         else:
             if not self.stream.closed():
@@ -185,20 +181,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
                     "Sec-WebSocket-Version: 8\r\n\r\n"))
                 self.stream.close()
 
-    def get_websocket_protocol_subclass(self, web_socket_version):
-        """Returns WebSocketProtocol subclass for specific WebSocket version.
-        ``web_socket_version`` argument is a protocol version string passed in
-        "Sec-WebSocket-Version" header.
-
-        This method can be overridden in subclasses to add support for
-        custom protocol implementations.
-
-        .. versionadded:: 4.1
-        """
-
-        if web_socket_version in ("7", "8", "13"):
-            return WebSocketProtocol13
-
     def write_message(self, message, binary=False):
         """Sends the given message to the client of this Web Socket.
 
@@ -378,6 +360,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
             # we can close the connection more gracefully.
             self.stream.close()
 
+    def get_websocket_protocol(self):
+        websocket_version = self.request.headers.get("Sec-WebSocket-Version")
+        if websocket_version in ("7", "8", "13"):
+            return WebSocketProtocol13(
+                self, compression_options=self.get_compression_options())
+
+
 def _wrap_method(method):
     def _disallow_for_websocket(self, *args, **kwargs):
         if self.stream is None:
@@ -871,6 +860,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     def __init__(self, io_loop, request, compression_options=None):
         self.compression_options = compression_options
         self.connect_future = TracebackFuture()
+        self.protocol = None
         self.read_future = None
         self.read_queue = collections.deque()
         self.key = base64.b64encode(os.urandom(16))
@@ -935,9 +925,7 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
                 start_line, headers)
 
         self.headers = headers
-        self.protocol = WebSocketProtocol13(
-            self, mask_outgoing=True,
-            compression_options=self.compression_options)
+        self.protocol = self.get_websocket_protocol()
         self.protocol._process_server_headers(self.key, self.headers)
         self.protocol._receive_frame()
 
@@ -987,6 +975,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
     def on_pong(self, data):
         pass
 
+    def get_websocket_protocol(self):
+        return WebSocketProtocol13(self, mask_outgoing=True,
+                                   compression_options=self.compression_options)
+
 
 def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None,
                       compression_options=None):