]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add support for gzipped request bodies to HTTPServer.
authorBen Darnell <ben@bendarnell.com>
Mon, 3 Mar 2014 03:42:11 +0000 (22:42 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 3 Mar 2014 03:42:11 +0000 (22:42 -0500)
tornado/http1connection.py
tornado/httpserver.py
tornado/httputil.py
tornado/test/httpserver_test.py
tornado/test/web_test.py
tornado/wsgi.py

index cc14379fd2f7f2911db64f4ecd1d5636a3e46374..e9d6c9a954308ba87b54f280632b14a0a4766829 100644 (file)
@@ -51,16 +51,19 @@ class HTTP1Connection(object):
         self.stream.set_close_callback(self._on_connection_close)
         self._finish_future = None
 
-    def start_serving(self, delegate):
+    def start_serving(self, delegate, gzip=False):
         assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
         # Register the future on the IOLoop so its errors get logged.
-        self.stream.io_loop.add_future(self._server_request_loop(delegate),
-                                       lambda f: f.result())
+        self.stream.io_loop.add_future(
+            self._server_request_loop(delegate, gzip=gzip),
+            lambda f: f.result())
 
     @gen.coroutine
-    def _server_request_loop(self, delegate):
+    def _server_request_loop(self, delegate, gzip=False):
         while True:
             request_delegate = delegate.start_request(self)
+            if gzip:
+                request_delegate = _GzipMessageDelegate(request_delegate)
             try:
                 ret = yield self._read_message(request_delegate, False)
             except iostream.StreamClosedError:
@@ -262,6 +265,12 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
     def headers_received(self, start_line, headers):
         if headers.get("Content-Encoding") == "gzip":
             self._decompressor = GzipDecompressor()
+            # Downstream delegates will only see uncompressed data,
+            # so rename the content-encoding header.
+            # (but note that curl_httpclient doesn't do this).
+            headers.add("X-Consumed-Content-Encoding",
+                        headers["Content-Encoding"])
+            del headers["Content-Encoding"]
         return self._delegate.headers_received(start_line, headers)
 
     def data_received(self, chunk):
index 399f7617bb2c619298995b38aae690db7882b804..84b67cd56e91e2b293cf210a900aef6b62efe5db 100644 (file)
@@ -135,11 +135,13 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
 
     """
     def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
-                 xheaders=False, ssl_options=None, protocol=None, **kwargs):
+                 xheaders=False, ssl_options=None, protocol=None, gzip=False,
+                 **kwargs):
         self.request_callback = request_callback
         self.no_keep_alive = no_keep_alive
         self.xheaders = xheaders
         self.protocol = protocol
+        self.gzip = gzip
         TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
                            **kwargs)
 
@@ -147,7 +149,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
         conn = HTTP1Connection(stream, address=address,
                                no_keep_alive=self.no_keep_alive,
                                protocol=self.protocol)
-        conn.start_serving(self)
+        conn.start_serving(self, gzip=self.gzip)
 
     def start_request(self, connection):
         return _ServerRequestAdapter(self, connection)
@@ -203,7 +205,8 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
         if self.request.method in ("POST", "PATCH", "PUT"):
             httputil.parse_body_arguments(
                 self.request.headers.get("Content-Type", ""), self.request.body,
-                self.request.body_arguments, self.request.files)
+                self.request.body_arguments, self.request.files,
+                self.request.headers)
 
             for k, v in self.request.body_arguments.items():
                 self.request.arguments.setdefault(k, []).extend(v)
index 5e5361a9238bb28c2ead3d55bf7653c72cb53590..80cd08ebe794919de93a01b281a5cc240e20b6f4 100644 (file)
@@ -531,7 +531,7 @@ def _int_or_none(val):
     return int(val)
 
 
-def parse_body_arguments(content_type, body, arguments, files):
+def parse_body_arguments(content_type, body, arguments, files, headers=None):
     """Parses a form request body.
 
     Supports ``application/x-www-form-urlencoded`` and
@@ -540,6 +540,10 @@ def parse_body_arguments(content_type, body, arguments, files):
     and ``files`` parameters are dictionaries that will be updated
     with the parsed contents.
     """
+    if headers and 'Content-Encoding' in headers:
+        gen_log.warning("Unsupported Content-Encoding: %s",
+                        headers['Content-Encoding'])
+        return
     if content_type.startswith("application/x-www-form-urlencoded"):
         try:
             uri_arguments = parse_qs_bytes(native_str(body), keep_blank_values=True)
index 53de055d268b3656b7b504fb853c00da36eca28f..e6f174efa4e6b98d7f7e6b8d87723acc5454c3f6 100644 (file)
@@ -8,7 +8,7 @@ from tornado.http1connection import HTTP1Connection
 from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders, HTTPMessageDelegate
 from tornado.iostream import IOStream
-from tornado.log import gen_log
+from tornado.log import app_log, gen_log
 from tornado.netutil import ssl_options_to_context
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
 from tornado.test.util import unittest
@@ -16,6 +16,7 @@ from tornado.util import u, bytes_type
 from tornado.web import Application, RequestHandler, asynchronous
 from contextlib import closing
 import datetime
+import gzip
 import os
 import shutil
 import socket
@@ -23,6 +24,11 @@ import ssl
 import sys
 import tempfile
 
+try:
+    from io import BytesIO  # python 3
+except ImportError:
+    from cStringIO import StringIO as BytesIO  # python 2
+
 
 class HandlerBaseTestCase(AsyncHTTPTestCase):
     def get_app(self):
@@ -674,3 +680,38 @@ class KeepAliveTest(AsyncHTTPTestCase):
         self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
         self.read_headers()
         self.close()
+
+
+class GzipBaseTest(object):
+    def get_app(self):
+        return Application([('/', EchoHandler)])
+
+    def post_gzip(self, body):
+        bytesio = BytesIO()
+        gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio)
+        gzip_file.write(utf8(body))
+        gzip_file.close()
+        compressed_body = bytesio.getvalue()
+        return self.fetch('/', method='POST', body=compressed_body,
+                          headers={'Content-Encoding': 'gzip'})
+
+    def test_uncompressed(self):
+        response = self.fetch('/', method='POST', body='foo=bar')
+        self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
+
+class GzipTest(GzipBaseTest, AsyncHTTPTestCase):
+    def get_httpserver_options(self):
+        return dict(gzip=True)
+
+    def test_gzip(self):
+        response = self.post_gzip('foo=bar')
+        self.assertEquals(json_decode(response.body), {u('foo'): [u('bar')]})
+
+class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase):
+    def test_gzip_unsupported(self):
+        # Gzip support is opt-in; without it the server fails to parse
+        # the body (but parsing form bodies is currently just a log message,
+        # not a fatal error).
+        with ExpectLog(gen_log, "Unsupported Content-Encoding"):
+            response = self.post_gzip('foo=bar')
+        self.assertEquals(json_decode(response.body), {})
index 5029cd526039341542cf0a0013978a44ae18f70b..0e741dc6a6695efa3c792cfbca8fe0bf747f7547 100644 (file)
@@ -1343,7 +1343,13 @@ class GzipTestCase(SimpleHandlerTestCase):
 
     def test_gzip(self):
         response = self.fetch('/')
-        self.assertEqual(response.headers['Content-Encoding'], 'gzip')
+        # simple_httpclient renames the content-encoding header;
+        # curl_httpclient doesn't.
+        self.assertEqual(
+            response.headers.get(
+                'Content-Encoding',
+                response.headers.get('X-Consumed-Content-Encoding')),
+            'gzip')
         self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
 
     def test_gzip_not_requested(self):
index 62423259c53e8473616615677491d7cab8d6f70a..a803e714d1eb153edf69f43834f4969563c4eb0b 100644 (file)
@@ -175,7 +175,8 @@ class HTTPRequest(object):
         # Parse request body
         self.files = {}
         httputil.parse_body_arguments(self.headers.get("Content-Type", ""),
-                                      self.body, self.body_arguments, self.files)
+                                      self.body, self.body_arguments,
+                                      self.files, self.headers)
 
         for k, v in self.body_arguments.items():
             self.arguments.setdefault(k, []).extend(v)