]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add compression_level and mem_level to WebSocketHandler.get_compression_options().
authorAndreas Røsdal <andreas.rosdal@gmail.com>
Wed, 8 Mar 2017 19:31:00 +0000 (20:31 +0100)
committerAndreas Røsdal <andreas.rosdal@gmail.com>
Wed, 8 Mar 2017 19:31:00 +0000 (20:31 +0100)
tornado/websocket.py

index 754fca5cdc1540141e6578b51fb0d7afe056c7e5..48486d5bd94fdfc1388770ae01a86dff9931103e 100644 (file)
@@ -249,11 +249,21 @@ class WebSocketHandler(tornado.web.RequestHandler):
         If this method returns None (the default), compression will
         be disabled.  If it returns a dict (even an empty one), it
         will be enabled.  The contents of the dict may be used to
-        control the memory and CPU usage of the compression,
-        but no such options are currently implemented.
+        control the following compression options:
+
+        ``compression_level`` specifies the compression level. 
+
+        ``mem_level`` specifies the amount of memory used for the internal compression state.
+         
+         These parameters are documented in details here: https://docs.python.org/3.6/library/zlib.html#zlib.compressobj
 
         .. versionadded:: 4.1
+
+        .. versionchanged:: 4.5
+
+           Added ``compression_level`` and ``mem_level``.
         """
+        #TODO: Add wbits option.
         return None
 
     def open(self, *args, **kwargs):
@@ -474,7 +484,7 @@ class WebSocketProtocol(object):
 
 
 class _PerMessageDeflateCompressor(object):
-    def __init__(self, persistent, max_wbits):
+    def __init__(self, persistent, max_wbits, compression_options=None):
         if max_wbits is None:
             max_wbits = zlib.MAX_WBITS
         # There is no symbolic constant for the minimum wbits value.
@@ -482,17 +492,27 @@ class _PerMessageDeflateCompressor(object):
             raise ValueError("Invalid max_wbits value %r; allowed range 8-%d",
                              max_wbits, zlib.MAX_WBITS)
         self._max_wbits = max_wbits
+
+        if compression_options is None or not 'compression_level' in compression_options:
+            self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL
+        else:
+            self._compression_level = compression_options['compression_level']
+
+        if compression_options is None or not 'mem_level' in compression_options:
+            self._mem_level = 8
+        else:
+            self._mem_level = compression_options['mem_level']
+
         if persistent:
-            self._compressor = self._create_compressor()
+            self._compressor = self._create_compressor(self._compression_level, self._mem_level)
         else:
             self._compressor = None
 
-    def _create_compressor(self):
-        return zlib.compressobj(tornado.web.GZipContentEncoding.GZIP_LEVEL,
-                                zlib.DEFLATED, -self._max_wbits)
+    def _create_compressor(self, compression_level, mem_level):
+        return zlib.compressobj(compression_level, zlib.DEFLATED, -self._max_wbits, mem_level)
 
     def compress(self, data):
-        compressor = self._compressor or self._create_compressor()
+        compressor = self._compressor or self._create_compressor(self._compression_level, self._mem_level)
         data = (compressor.compress(data) +
                 compressor.flush(zlib.Z_SYNC_FLUSH))
         assert data.endswith(b'\x00\x00\xff\xff')
@@ -616,7 +636,7 @@ class WebSocketProtocol13(WebSocketProtocol):
                     self._compression_options is not None):
                 # TODO: negotiate parameters if compression_options
                 # specifies limits.
-                self._create_compressors('server', ext[1])
+                self._create_compressors('server', ext[1], self._compression_options)
                 if ('client_max_window_bits' in ext[1] and
                         ext[1]['client_max_window_bits'] is None):
                     # Don't echo an offered client_max_window_bits
@@ -682,7 +702,7 @@ class WebSocketProtocol13(WebSocketProtocol):
             options['max_wbits'] = int(wbits_header)
         return options
 
-    def _create_compressors(self, side, agreed_parameters):
+    def _create_compressors(self, side, agreed_parameters, compression_options=None):
         # TODO: handle invalid parameters gracefully
         allowed_keys = set(['server_no_context_takeover',
                             'client_no_context_takeover',
@@ -693,7 +713,7 @@ class WebSocketProtocol13(WebSocketProtocol):
                 raise ValueError("unsupported compression parameter %r" % key)
         other_side = 'client' if (side == 'server') else 'server'
         self._compressor = _PerMessageDeflateCompressor(
-            **self._get_compressor_options(side, agreed_parameters))
+            **self._get_compressor_options(side, agreed_parameters), compression_options=compression_options)
         self._decompressor = _PerMessageDeflateDecompressor(
             **self._get_compressor_options(other_side, agreed_parameters))