]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add cython-based speedup for websocket mask function.
authorBen Darnell <ben@bendarnell.com>
Fri, 25 Oct 2013 19:46:13 +0000 (15:46 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 27 Oct 2013 02:13:14 +0000 (22:13 -0400)
This optimization is currently activated only if Cython is present
when Tornado is installed.

MANIFEST.in
setup.py
tornado/speedups.pyx [new file with mode: 0644]
tornado/test/websocket_test.py
tornado/websocket.py
tox.ini

index b710aac98090ca7e5bdd14941ba73d666bec0ba3..ea526699f9a9e4ee782c1973ca08905b1efcf67b 100644 (file)
@@ -1,4 +1,5 @@
 recursive-include demos *.py *.yaml *.html *.css *.js *.xml *.sql README
+include tornado/speedups.pyx
 include tornado/ca-certificates.crt
 include tornado/test/README
 include tornado/test/csv_translations/fr_FR.csv
index 7e21325d31e5a365a553e63938d0b318426af782..1d019fc7cc57c51d897b039ee46c263ef75ef132 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -23,6 +23,11 @@ try:
 except ImportError:
     pass
 
+try:
+    from Cython.Build import cythonize
+except ImportError:
+    cythonize = None
+
 kwargs = {}
 
 version = "3.2.dev2"
@@ -30,10 +35,16 @@ version = "3.2.dev2"
 with open('README.rst') as f:
     long_description = f.read()
 
+if cythonize is not None:
+    extensions = cythonize('tornado/speedups.pyx')
+else:
+    extensions = []
+
 distutils.core.setup(
     name="tornado",
     version=version,
     packages = ["tornado", "tornado.test", "tornado.platform"],
+    ext_modules = extensions,
     package_data = {
         "tornado": ["ca-certificates.crt"],
         # data files need to be listed both here (which determines what gets
diff --git a/tornado/speedups.pyx b/tornado/speedups.pyx
new file mode 100644 (file)
index 0000000..4634445
--- /dev/null
@@ -0,0 +1,16 @@
+# -*- python -*-
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+
+def websocket_mask(bytes mask_bytes, bytes data_bytes):
+    cdef size_t data_len = len(data_bytes)
+    cdef char* data = data_bytes
+    cdef char* mask = mask_bytes
+    cdef size_t i
+    cdef char* buf = <char*> PyMem_Malloc(data_len)
+    try:
+        for i in xrange(data_len):
+            buf[i] = data[i] ^ mask[i % 4]
+        # Is there a zero-copy equivalent of this?
+        return <bytes>(buf[:data_len])
+    finally:
+        PyMem_Free(buf)
index da9d780b1140f1a686cb93cd4d51167c5cb3687a..66c48e5763000537c5747d5a2e8bd32d45980c8c 100644 (file)
@@ -2,8 +2,14 @@ from tornado.concurrent import Future
 from tornado.httpclient import HTTPError, HTTPRequest
 from tornado.log import gen_log
 from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
+from tornado.test.util import unittest
 from tornado.web import Application, RequestHandler
-from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError
+from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError, _websocket_mask_python
+
+try:
+    from tornado import speedups
+except ImportError:
+    speedups = None
 
 class TestWebSocketHandler(WebSocketHandler):
     """Base class for testing handlers that exposes the on_close event.
@@ -110,3 +116,22 @@ class WebSocketTest(AsyncHTTPTestCase):
         self.assertEqual(response, 'hello')
         ws.close()
         yield self.close_future
+
+
+class MaskFunctionMixin(object):
+    # Subclasses should define self.mask(mask, data)
+    def test_mask(self):
+        self.assertEqual(self.mask(b'abcd', b''), b'')
+        self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
+        self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
+        self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
+
+
+class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
+    def mask(self, mask, data):
+        return _websocket_mask_python(mask, data)
+
+@unittest.skipIf(speedups is None, "tornado.speedups module not present")
+class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
+    def mask(self, mask, data):
+        return speedups.websocket_mask(mask, data)
index 75b1d8f23072b4a7ed0a704150452958c493935b..bdfa779c787b32fde8b79394395f08060a40ba4a 100644 (file)
@@ -586,7 +586,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             frame += struct.pack("!BQ", 127 | mask_bit, l)
         if self.mask_outgoing:
             mask = os.urandom(4)
-            data = mask + self._apply_mask(mask, data)
+            data = mask + _websocket_mask(mask, data)
         frame += data
         self.stream.write(frame)
 
@@ -671,21 +671,8 @@ class WebSocketProtocol13(WebSocketProtocol):
         except StreamClosedError:
             self._abort()
 
-    def _apply_mask(self, mask, data):
-        mask = array.array("B", mask)
-        unmasked = array.array("B", data)
-        for i in xrange(len(data)):
-            unmasked[i] = unmasked[i] ^ mask[i % 4]
-        if hasattr(unmasked, 'tobytes'):
-            # tostring was deprecated in py32.  It hasn't been removed,
-            # but since we turn on deprecation warnings in our tests
-            # we need to use the right one.
-            return unmasked.tobytes()
-        else:
-            return unmasked.tostring()
-
     def _on_masked_frame_data(self, data):
-        self._on_frame_data(self._apply_mask(self._frame_mask, data))
+        self._on_frame_data(_websocket_mask(self._frame_mask, data))
 
     def _on_frame_data(self, data):
         if self._frame_opcode_is_control:
@@ -882,3 +869,29 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
     if callback is not None:
         io_loop.add_future(conn.connect_future, callback)
     return conn.connect_future
+
+def _websocket_mask_python(mask, data):
+    """Websocket masking function.
+
+    `mask` is a `bytes` object of length 4; `data` is a `bytes` object of any length.
+    Returns a `bytes` object of the same length as `data` with the mask applied
+    as specified in section 5.3 of RFC 6455.
+
+    This pure-python implementation may be replaced by an optimized version when available.
+    """
+    mask = array.array("B", mask)
+    unmasked = array.array("B", data)
+    for i in xrange(len(data)):
+        unmasked[i] = unmasked[i] ^ mask[i % 4]
+    if hasattr(unmasked, 'tobytes'):
+        # tostring was deprecated in py32.  It hasn't been removed,
+        # but since we turn on deprecation warnings in our tests
+        # we need to use the right one.
+        return unmasked.tobytes()
+    else:
+        return unmasked.tostring()
+
+try:
+    from tornado.speedups import websocket_mask as _websocket_mask
+except ImportError:
+    _websocket_mask = _websocket_mask_python
diff --git a/tox.ini b/tox.ini
index 77918ef08ddc4ebffd3c2fea9a147d448f9cfef0..2d9afaffd0ab417806814144fb34e3027287405d 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -30,6 +30,7 @@ deps = unittest2
 [testenv:py26-full]
 basepython = python2.6
 deps =
+     Cython
      futures
      mock
      pycurl
@@ -39,6 +40,7 @@ deps =
 [testenv:py27-full]
 basepython = python2.7
 deps =
+     Cython
      futures
      mock
      pycurl
@@ -148,6 +150,7 @@ commands = python -m tornado.test.runtests --locale=zh_TW {posargs:}
 # there.
 basepython = pypy
 deps =
+     Cython
      futures
      mock
 
@@ -168,6 +171,7 @@ setenv = LANG=en_US.utf-8
 [testenv:py32-full]
 basepython = python3.2
 deps =
+     Cython
      mock
 
 [testenv:py33]