]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a simple websocket client and unittest
authorBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 05:03:45 +0000 (00:03 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 17 Feb 2013 05:03:45 +0000 (00:03 -0500)
tornado/simple_httpclient.py
tornado/test/runtests.py
tornado/test/websocket_test.py [new file with mode: 0644]
tornado/websocket.py

index 7827a7bfa06b07ae2d2f529193e026117c9be63c..bbd6b70b874079c38a247648a6b005ce75896aca 100644 (file)
@@ -321,19 +321,22 @@ class _HTTPConnection(object):
                 message = str(self.stream.error)
             raise HTTPError(599, message)
 
+    def _handle_1xx(self, code):
+        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
+
     def _on_headers(self, data):
         data = native_str(data.decode("latin1"))
         first_line, _, header_data = data.partition("\n")
         match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
         assert match
         code = int(match.group(1))
+        self.headers = HTTPHeaders.parse(header_data)
         if 100 <= code < 200:
-            self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
+            self._handle_1xx(code)
             return
         else:
             self.code = code
             self.reason = match.group(2)
-        self.headers = HTTPHeaders.parse(header_data)
 
         if "Content-Length" in self.headers:
             if "," in self.headers["Content-Length"]:
index e3e3bb818629ffeabf628bdbbb53758f560adf25..24717114eec7331500b066b5476627829b43052d 100644 (file)
@@ -36,6 +36,7 @@ TEST_MODULES = [
     'tornado.test.twisted_test',
     'tornado.test.util_test',
     'tornado.test.web_test',
+    'tornado.test.websocket_test',
     'tornado.test.wsgi_test',
 ]
 
diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py
new file mode 100644 (file)
index 0000000..ca10b4a
--- /dev/null
@@ -0,0 +1,32 @@
+from tornado.testing import AsyncHTTPTestCase, gen_test
+from tornado.web import Application
+from tornado.websocket import WebSocketHandler, WebSocketConnect
+
+class EchoHandler(WebSocketHandler):
+    def on_message(self, message):
+        self.write_message(message, isinstance(message, bytes))
+
+class WebSocketTest(AsyncHTTPTestCase):
+    def get_app(self):
+        return Application([
+                ('/echo', EchoHandler),
+                ])
+
+    @gen_test
+    def test_websocket_gen(self):
+        ws = yield WebSocketConnect(
+            'ws://localhost:%d/echo' % self.get_http_port(),
+            io_loop=self.io_loop)
+        ws.write_message('hello')
+        response = yield ws.read_message()
+        self.assertEqual(response, 'hello')
+
+    def test_websocket_callbacks(self):
+        WebSocketConnect(
+            'ws://localhost:%d/echo' % self.get_http_port(),
+            io_loop=self.io_loop, callback=self.stop)
+        ws = self.wait().result()
+        ws.write_message('hello')
+        ws.read_message(self.stop)
+        response = self.wait().result()
+        self.assertEqual(response, 'hello')
index e08d17ed50807a0be904726d8098ea7f08fb5155..beffc9705969e47181da7ca4f3511a85a24652ef 100644 (file)
@@ -21,15 +21,26 @@ from __future__ import absolute_import, division, print_function, with_statement
 # Author: Jacob Kristhammar, 2010
 
 import array
+import base64
+import collections
 import functools
 import hashlib
+import logging
+import os
+import re
+import socket
 import struct
 import time
-import base64
 import tornado.escape
 import tornado.web
 
+from tornado.concurrent import Future, return_future
+from tornado.escape import utf8, to_unicode, native_str
+from tornado.httputil import HTTPHeaders
+from tornado.ioloop import IOLoop
+from tornado.iostream import IOStream, SSLIOStream
 from tornado.log import gen_log, app_log
+from tornado import simple_httpclient
 from tornado.util import bytes_type
 
 try:
@@ -37,6 +48,11 @@ try:
 except NameError:
     xrange = range  # py3
 
+try:
+    import urlparse  # py2
+except ImportError:
+    import urllib.parse as urlparse  # py3
+
 class WebSocketHandler(tornado.web.RequestHandler):
     """Subclass this class to create a basic WebSocket handler.
 
@@ -458,10 +474,12 @@ class WebSocketProtocol13(WebSocketProtocol):
     This class supports versions 7 and 8 of the protocol in addition to the
     final version 13.
     """
-    def __init__(self, handler):
+    def __init__(self, handler, mask_outgoing=False):
         WebSocketProtocol.__init__(self, handler)
+        self.mask_outgoing = mask_outgoing
         self._final_frame = False
         self._frame_opcode = None
+        self._masked_frame = None
         self._frame_mask = None
         self._frame_length = None
         self._fragmented_message_buffer = None
@@ -473,7 +491,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             self._handle_websocket_headers()
             self._accept_connection()
         except ValueError:
-            gen_log.debug("Malformed WebSocket request received")
+            gen_log.debug("Malformed WebSocket request received", exc_info=True)
             self._abort()
             return
 
@@ -487,12 +505,20 @@ class WebSocketProtocol13(WebSocketProtocol):
         if not all(map(lambda f: self.request.headers.get(f), fields)):
             raise ValueError("Missing/Invalid WebSocket headers")
 
-    def _challenge_response(self):
+    @staticmethod
+    def compute_accept_value(key):
+        """Computes the value for the Sec-WebSocket-Accept header,
+        given the value for Sec-WebSocket-Key.
+        """
         sha1 = hashlib.sha1()
-        sha1.update(tornado.escape.utf8(
-            self.request.headers.get("Sec-Websocket-Key")))
+        sha1.update(utf8(key))
         sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11")  # Magic value
-        return tornado.escape.native_str(base64.b64encode(sha1.digest()))
+        return native_str(base64.b64encode(sha1.digest()))
+
+    def _challenge_response(self):
+        return WebSocketProtocol13.compute_accept_value(
+            self.request.headers.get("Sec-Websocket-Key"))
+
 
     def _accept_connection(self):
         subprotocol_header = ''
@@ -522,12 +548,19 @@ class WebSocketProtocol13(WebSocketProtocol):
             finbit = 0
         frame = struct.pack("B", finbit | opcode)
         l = len(data)
+        if self.mask_outgoing:
+            mask_bit = 0x80
+        else:
+            mask_bit = 0
         if l < 126:
-            frame += struct.pack("B", l)
+            frame += struct.pack("B", l | mask_bit)
         elif l <= 0xFFFF:
-            frame += struct.pack("!BH", 126, l)
+            frame += struct.pack("!BH", 126 | mask_bit, l)
         else:
-            frame += struct.pack("!BQ", 127, l)
+            frame += struct.pack("!BQ", 127 | mask_bit, l)
+        if self.mask_outgoing:
+            mask = os.urandom(4)
+            data = mask + self._apply_mask(mask, data)
         frame += data
         self.stream.write(frame)
 
@@ -559,10 +592,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             # client is using as-yet-undefined extensions; abort
             self._abort()
             return
-        if not (payloadlen & 0x80):
-            # Unmasked frame -> abort connection
-            self._abort()
-            return
+        self._masked_frame = bool(payloadlen & 0x80)
         payloadlen = payloadlen & 0x7f
         if self._frame_opcode_is_control and payloadlen >= 126:
             # control frames must have payload < 126
@@ -570,7 +600,10 @@ class WebSocketProtocol13(WebSocketProtocol):
             return
         if payloadlen < 126:
             self._frame_length = payloadlen
-            self.stream.read_bytes(4, self._on_masking_key)
+            if self._masked_frame:
+                self.stream.read_bytes(4, self._on_masking_key)
+            else:
+                self.stream.read_bytes(self._frame_length, self._on_frame_data)
         elif payloadlen == 126:
             self.stream.read_bytes(2, self._on_frame_length_16)
         elif payloadlen == 127:
@@ -578,21 +611,39 @@ class WebSocketProtocol13(WebSocketProtocol):
 
     def _on_frame_length_16(self, data):
         self._frame_length = struct.unpack("!H", data)[0]
-        self.stream.read_bytes(4, self._on_masking_key)
+        if self._masked_frame:
+            self.stream.read_bytes(4, self._on_masking_key)
+        else:
+            self.stream.read_bytes(self._frame_length, self._on_frame_data)
 
     def _on_frame_length_64(self, data):
         self._frame_length = struct.unpack("!Q", data)[0]
-        self.stream.read_bytes(4, self._on_masking_key)
+        if self._masked_frame:
+            self.stream.read_bytes(4, self._on_masking_key)
+        else:
+            self.stream.read_bytes(self._frame_length, self._on_frame_data)
 
     def _on_masking_key(self, data):
-        self._frame_mask = array.array("B", data)
-        self.stream.read_bytes(self._frame_length, self._on_frame_data)
+        self._frame_mask = data
+        self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
 
-    def _on_frame_data(self, data):
+    def _apply_mask(self, mask, data):
+        mask = array.array("B", mask)
         unmasked = array.array("B", data)
         for i in xrange(len(data)):
-            unmasked[i] = unmasked[i] ^ self._frame_mask[i % 4]
+            unmasked[i] = unmasked[i] ^ mask[i % 4]
+        if hasattr(unmasked, 'tobytes'):
+            # tostring was deprecated in py32.  It hasn't been removed,
+            # but since we turn on deprecation warnings in our tests
+            # we need to use the right one.
+            return unmasked.tobytes()
+        else:
+            return unmasked.tostring()
+
+    def _on_masked_frame_data(self, data):
+        self._on_frame_data(self._apply_mask(self._frame_mask, data))
 
+    def _on_frame_data(self, data):
         if self._frame_opcode_is_control:
             # control frames may be interleaved with a series of fragmented
             # data frames, so control frames must not interact with
@@ -607,10 +658,10 @@ class WebSocketProtocol13(WebSocketProtocol):
                 # nothing to continue
                 self._abort()
                 return
-            self._fragmented_message_buffer += unmasked
+            self._fragmented_message_buffer += data
             if self._final_frame:
                 opcode = self._fragmented_message_opcode
-                unmasked = self._fragmented_message_buffer
+                data = self._fragmented_message_buffer
                 self._fragmented_message_buffer = None
         else:  # start of new data message
             if self._fragmented_message_buffer is not None:
@@ -621,10 +672,10 @@ class WebSocketProtocol13(WebSocketProtocol):
                 opcode = self._frame_opcode
             else:
                 self._fragmented_message_opcode = self._frame_opcode
-                self._fragmented_message_buffer = unmasked
+                self._fragmented_message_buffer = data
 
         if self._final_frame:
-            self._handle_message(opcode, unmasked.tostring())
+            self._handle_message(opcode, data)
 
         if not self.client_terminated:
             self._receive_frame()
@@ -673,3 +724,85 @@ class WebSocketProtocol13(WebSocketProtocol):
             # otherwise just close the connection.
             self._waiting = self.stream.io_loop.add_timeout(
                 self.stream.io_loop.time() + 5, self._abort)
+
+
+class _WebSocketClientConnection(simple_httpclient._HTTPConnection):
+    def __init__(self, io_loop, client, request):
+        self.connect_future = Future()
+        self.read_future = None
+        self.read_queue = collections.deque()
+        self.key = base64.b64encode(os.urandom(16))
+
+        scheme, sep, rest = request.url.partition(':')
+        scheme = {'ws': 'http', 'wss': 'https'}[scheme]
+        request.url = scheme + sep + rest
+        request.headers.update({
+                'Upgrade': 'websocket',
+                'Connection': 'Upgrade',
+                'Sec-WebSocket-Key': self.key,
+                'Sec-WebSocket-Version': '13',
+                })
+
+        super(_WebSocketClientConnection, self).__init__(
+            io_loop, client, request, lambda: None, lambda response: None,
+            104857600)
+
+    def _on_close(self):
+        self.on_message(None)
+
+
+    def _handle_1xx(self, code):
+        assert code == 101
+        assert self.headers['Upgrade'].lower() == 'websocket'
+        assert self.headers['Connection'].lower() == 'upgrade'
+        accept = WebSocketProtocol13.compute_accept_value(self.key)
+        assert self.headers['Sec-Websocket-Accept'] == accept
+
+        self.protocol = WebSocketProtocol13(self, mask_outgoing=True)
+        self.protocol._receive_frame()
+
+        if self._timeout is not None:
+            self.io_loop.remove_timeout(self._timeout)
+            self._timeout = None
+
+        self.connect_future.set_result(self)
+
+    def write_message(self, message, binary=False):
+        self.protocol.write_message(message, binary)
+
+    def read_message(self, callback=None):
+        assert self.read_future is None
+        future = Future()
+        if self.read_queue:
+            future.set_result(self.read_queue.popleft())
+        else:
+            self.read_future = future
+        if callback is not None:
+            self.io_loop.add_future(future, callback)
+        return future
+
+    def on_message(self, message):
+        if self.read_future is not None:
+            self.read_future.set_result(message)
+            self.read_future = None
+        else:
+            self.read_queue.append(message)
+
+    def on_pong(self, data):
+        pass
+
+
+def WebSocketConnect(url, io_loop=None, callback=None):
+    if io_loop is None:
+        io_loop = IOLoop.instance()
+    request = simple_httpclient.HTTPRequest(url)
+    request = simple_httpclient._RequestProxy(
+        request, simple_httpclient.HTTPRequest._DEFAULTS)
+    from tornado.util import ObjectDict
+    from tornado.netutil import Resolver
+    # TODO: refactor _HTTPConnection's client parameter
+    client = ObjectDict(resolver=Resolver(io_loop), hostname_mapping=None)
+    conn = _WebSocketClientConnection(io_loop, client, request)
+    if callback is not None:
+        io_loop.add_future(conn.connect_future, callback)
+    return conn.connect_future