]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Introduce streaming request body support for RequestHandler.
authorBen Darnell <ben@bendarnell.com>
Mon, 24 Mar 2014 01:31:33 +0000 (21:31 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 24 Mar 2014 01:31:33 +0000 (21:31 -0400)
tornado/test/web_test.py
tornado/web.py

index 81ee890f8e35a9b4114722c52d79b8e5b433a7cd..4dc55f34e785a32382e18b40881a745cf16946d5 100644 (file)
@@ -1,4 +1,5 @@
 from __future__ import absolute_import, division, print_function, with_statement
+from tornado.concurrent import Future
 from tornado import gen
 from tornado.escape import json_decode, utf8, to_unicode, recursive_unicode, native_str, to_basestring
 from tornado.httputil import format_timestamp
@@ -6,10 +7,10 @@ from tornado.iostream import IOStream
 from tornado.log import app_log, gen_log
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.template import DictLoader
-from tornado.testing import AsyncHTTPTestCase, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
 from tornado.test.util import unittest
 from tornado.util import u, bytes_type, ObjectDict, unicode_type
-from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature, create_signed_value, ErrorHandler, UIModule, MissingArgumentError
+from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature, create_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body
 
 import binascii
 import datetime
@@ -1800,3 +1801,52 @@ class HandlerByNameTest(WebTestCase):
         self.assertEqual(resp.body, b'hello')
         resp = self.fetch('/hello3')
         self.assertEqual(resp.body, b'hello')
+
+
+class StreamingRequestBodyTest(WebTestCase):
+    def get_handlers(self):
+        @stream_request_body
+        class StreamingBodyHandler(RequestHandler):
+            def initialize(self, test):
+                self.test = test
+
+            def prepare(self):
+                self.test.prepared.set_result(None)
+
+            def data_received(self, data):
+                self.test.data.set_result(data)
+
+            def get(self):
+                self.test.finished.set_result(None)
+                self.write({})
+
+        return [('/', StreamingBodyHandler, dict(test=self))]
+
+    @gen_test
+    def test_streaming_body(self):
+        self.prepared = Future()
+        self.data = Future()
+        self.finished = Future()
+
+        # Use a raw connection so we can control the sending of data.
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+        s.connect(("localhost", self.get_http_port()))
+        stream = IOStream(s, io_loop=self.io_loop)
+        stream.write(b"GET / HTTP/1.1\r\n")
+        stream.write(b"Connection: close\r\n")
+        stream.write(b"Transfer-Encoding: chunked\r\n\r\n")
+        yield self.prepared
+        stream.write(b"4\r\nasdf\r\n")
+        # Ensure the first chunk is received before we send the second.
+        data = yield self.data
+        self.assertEqual(data, b"asdf")
+        self.data = Future()
+        stream.write(b"4\r\nqwer\r\n")
+        data = yield self.data
+        self.assertEquals(data, b"qwer")
+        stream.write(b"0\r\n")
+        yield self.finished
+        data = yield gen.Task(stream.read_until_close)
+        # This would ideally use an HTTP1Connection to read the response.
+        self.assertTrue(data.endswith(b"{}"))
+        stream.close()
index 471707e358d436db8ef90a65d6e09181d0269813..60b854f5fe9ac9473bf79e0a54ca24cbc672d1b1 100644 (file)
@@ -1219,6 +1219,13 @@ class RequestHandler(object):
             if self._finished:
                 return
 
+            if _has_stream_request_body(self.__class__):
+                # In streaming mode request.body is a Future that signals
+                # the body has been completely received.  The Future has no
+                # result; the data has been passed to self.data_received
+                # instead.
+                yield self.request.body
+
             method = getattr(self, self.request.method.lower())
             result = method(*self.path_args, **self.path_kwargs)
             if is_future(result):
@@ -1368,6 +1375,36 @@ def asynchronous(method):
     return wrapper
 
 
+def stream_request_body(cls):
+    """Apply to `RequestHandler` subclasses to enable streaming body support.
+
+    This decorator implies the following changes:
+    * `.HTTPServerRequest.body` is undefined, and body arguments will not
+      be included in `RequestHandler.get_argument`.
+    * `RequestHandler.prepare` is called when the request headers have been
+      read instead of after the entire body has been read.
+    * The subclass must define a method ``data_received(self, data):``, which
+      will be called zero or more times as data is available.  Note that
+      if the request has an empty body, ``data_received`` may not be called.
+    * The regular HTTP method (``post``, ``put``, etc) will be called after
+      the entire body has been read.
+
+    There is a subtle interaction between ``data_received`` and asynchronous
+    ``prepare``: The first call to ``data_recieved`` may occur at any point
+    after the call to ``prepare`` has returned *or yielded*.
+    """
+    if not issubclass(cls, RequestHandler):
+        raise TypeError("expected subclass of RequestHandler, got %r", cls)
+    cls._stream_request_body = True
+    return cls
+
+
+def _has_stream_request_body(cls):
+    if not issubclass(cls, RequestHandler):
+        raise TypeError("expected subclass of RequestHandler, got %r", cls)
+    return getattr(cls, '_stream_request_body', False)
+
+
 def removeslash(method):
     """Use this decorator to remove trailing slashes from the request path.
 
@@ -1664,10 +1701,14 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
     def headers_received(self, start_line, headers):
         self.set_request(httputil.HTTPServerRequest(
             connection=self.connection, start_line=start_line, headers=headers))
+        if self.stream_request_body:
+            self.request.body = Future()
+            self.execute()
 
     def set_request(self, request):
         self.request = request
         self._find_handler()
+        self.stream_request_body = _has_stream_request_body(self.handler_class)
 
     def _find_handler(self):
         # Identify the handler to use as soon as we have the request.
@@ -1705,12 +1746,18 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
             self.handler_kwargs = dict(status_code=404)
 
     def data_received(self, data):
-        self.chunks.append(data)
+        if self.stream_request_body:
+            self.handler.data_received(data)
+        else:
+            self.chunks.append(data)
 
     def finish(self):
-        self.request.body = b''.join(self.chunks)
-        self.request._parse_body()
-        self.execute()
+        if self.stream_request_body:
+            self.request.body.set_result(None)
+        else:
+            self.request.body = b''.join(self.chunks)
+            self.request._parse_body()
+            self.execute()
 
     def execute(self):
         # If template cache is disabled (usually in the debug mode),
@@ -1723,7 +1770,7 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
         if not self.application.settings.get('static_hash_cache', True):
             StaticFileHandler.reset()
 
-        handler = self.handler_class(self.application, self.request,
+        self.handler = self.handler_class(self.application, self.request,
                                      **self.handler_kwargs)
         transforms = [t(self.request) for t in self.application.transforms]
         # Note that if an exception escapes handler._execute it will be
@@ -1731,9 +1778,7 @@ class _RequestDispatcher(httputil.HTTPMessageDelegate):
         # However, that shouldn't happen because _execute has a blanket
         # except handler, and we cannot easily access the IOLoop here to
         # call add_future.
-        handler._execute(transforms, *self.path_args, **self.path_kwargs)
-        return handler
-
+        self.handler._execute(transforms, *self.path_args, **self.path_kwargs)
 
 
 class HTTPError(Exception):