]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add body_producer argument to httpclient.HTTPRequest.
authorBen Darnell <ben@bendarnell.com>
Sat, 29 Mar 2014 15:01:14 +0000 (15:01 +0000)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Mar 2014 15:26:23 +0000 (15:26 +0000)
This allows for sending non-contiguous or asynchronously-produced
request bodies, including chunked encoding when the content-length
is not known in advance.

tornado/http1connection.py
tornado/httpclient.py
tornado/httpserver.py
tornado/simple_httpclient.py
tornado/test/httpserver_test.py
tornado/test/simple_httpclient_test.py

index 745e4772acdd24bbecdbb488eec780d75dcca776..e46c04278601e51a7921bd67473695b892bf1705 100644 (file)
@@ -34,7 +34,9 @@ class HTTP1Connection(object):
     We parse HTTP headers and bodies, and execute the request callback
     until the HTTP conection is closed.
     """
-    def __init__(self, stream, address, no_keep_alive=False, protocol=None):
+    def __init__(self, stream, address, is_client,
+                 no_keep_alive=False, protocol=None):
+        self.is_client = is_client
         self.stream = stream
         self.address = address
         # Save the socket's address family now so we know how to
@@ -83,7 +85,7 @@ class HTTP1Connection(object):
             if gzip:
                 request_delegate = _GzipMessageDelegate(request_delegate)
             try:
-                ret = yield self._read_message(request_delegate, False)
+                ret = yield self._read_message(request_delegate)
             except iostream.StreamClosedError:
                 self.close()
                 return
@@ -93,10 +95,10 @@ class HTTP1Connection(object):
     def read_response(self, delegate, method, use_gzip=False):
         if use_gzip:
             delegate = _GzipMessageDelegate(delegate)
-        return self._read_message(delegate, True, method=method)
+        return self._read_message(delegate, method=method)
 
     @gen.coroutine
-    def _read_message(self, delegate, is_client, method=None):
+    def _read_message(self, delegate, method=None):
         assert isinstance(delegate, httputil.HTTPMessageDelegate)
         self.message_delegate = delegate
         try:
@@ -104,7 +106,7 @@ class HTTP1Connection(object):
             self._reading = True
             self._finish_future = Future()
             start_line, headers = self._parse_headers(header_data)
-            if is_client:
+            if self.is_client:
                 start_line = httputil.parse_response_start_line(start_line)
             else:
                 start_line = httputil.parse_request_start_line(start_line)
@@ -120,7 +122,7 @@ class HTTP1Connection(object):
                 # TODO: where else do we need to check for detach?
                 raise gen.Return(False)
             skip_body = False
-            if is_client:
+            if self.is_client:
                 if method == 'HEAD':
                     skip_body = True
                 code = start_line.code
@@ -130,12 +132,12 @@ class HTTP1Connection(object):
                     # TODO: client delegates will get headers_received twice
                     # in the case of a 100-continue.  Document or change?
                     yield self._read_message(self.message_delegate,
-                                             is_client, method=method)
+                                             method=method)
             else:
                 if headers.get("Expect") == "100-continue":
                     self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
             if not skip_body:
-                body_future = self._read_body(is_client, headers)
+                body_future = self._read_body(headers)
                 if body_future is not None:
                     yield body_future
             self._reading = False
@@ -194,19 +196,29 @@ class HTTP1Connection(object):
         self.stream = None
         return stream
 
-    def write_headers(self, start_line, headers, chunk=None, callback=None):
-        self._chunking = (
-            # TODO: should this use self._version or start_line.version?
-            self._version == 'HTTP/1.1' and
-            # 304 responses have no body (not even a zero-length body), and so
-            # should not have either Content-Length or Transfer-Encoding.
-            # headers.
-            start_line.code != 304 and
-            # No need to chunk the output if a Content-Length is specified.
-            'Content-Length' not in headers and
-            # Applications are discouraged from touching Transfer-Encoding,
-            # but if they do, leave it alone.
-            'Transfer-Encoding' not in headers)
+    def write_headers(self, start_line, headers, chunk=None, callback=None,
+                      has_body=True):
+        if self.is_client:
+            # Client requests with a non-empty body must have either a
+            # Content-Length or a Transfer-Encoding.
+            self._chunking = (
+                has_body and
+                'Content-Length' not in headers and
+                'Transfer-Encoding' not in headers)
+        else:
+            self._chunking = (
+                has_body and
+                # TODO: should this use self._version or start_line.version?
+                self._version == 'HTTP/1.1' and
+                # 304 responses have no body (not even a zero-length body), and so
+                # should not have either Content-Length or Transfer-Encoding.
+                # headers.
+                start_line.code != 304 and
+                # No need to chunk the output if a Content-Length is specified.
+                'Content-Length' not in headers and
+                # Applications are discouraged from touching Transfer-Encoding,
+                # but if they do, leave it alone.
+                'Transfer-Encoding' not in headers)
         if self._chunking:
             headers['Transfer-Encoding'] = 'chunked'
         lines = [utf8("%s %s %s" % start_line)]
@@ -293,7 +305,8 @@ class HTTP1Connection(object):
         # Turn Nagle's algorithm back on, leaving the stream in its
         # default state for the next request.
         self.stream.set_nodelay(False)
-        self._finish_future.set_result(None)
+        if self._finish_future is not None:
+            self._finish_future.set_result(None)
 
     def _parse_headers(self, data):
         data = native_str(data.decode('latin1'))
@@ -307,7 +320,7 @@ class HTTP1Connection(object):
                                                 data[eol:100])
         return start_line, headers
 
-    def _read_body(self, is_client, headers):
+    def _read_body(self, headers):
         content_length = headers.get("Content-Length")
         if content_length:
             content_length = int(content_length)
@@ -316,7 +329,7 @@ class HTTP1Connection(object):
             return self._read_fixed_body(content_length)
         if headers.get("Transfer-Encoding") == "chunked":
             return self._read_chunked_body()
-        if is_client:
+        if self.is_client:
             return self._read_body_until_close()
         return None
 
index 9b42d401ad5fb465be55ec0ce535d97d73a375bd..ac3d92b3e91f29fb43cf208e947894271711f27a 100644 (file)
@@ -259,14 +259,23 @@ class HTTPRequest(object):
                  proxy_password=None, allow_nonstandard_methods=None,
                  validate_cert=None, ca_certs=None,
                  allow_ipv6=None,
-                 client_key=None, client_cert=None):
+                 client_key=None, client_cert=None, body_producer=None):
         r"""All parameters except ``url`` are optional.
 
         :arg string url: URL to fetch
         :arg string method: HTTP method, e.g. "GET" or "POST"
         :arg headers: Additional HTTP headers to pass on the request
-        :arg body: HTTP body to pass on the request
         :type headers: `~tornado.httputil.HTTPHeaders` or `dict`
+        :arg body: HTTP request body as a string (byte or unicode; if unicode
+           the utf-8 encoding will be used)
+        :arg body_producer: Callable used for lazy/asynchronous request bodies.
+           TODO: document the interface.
+           Only one of ``body`` and ``body_producer`` may
+           be specified.  ``body_producer`` is not supported on
+           ``curl_httpclient``.  When using ``body_producer`` it is recommended
+           to pass a ``Content-Length`` in the headers as otherwise chunked
+           encoding will be used, and many servers do not support chunked
+           encoding on requests.
         :arg string auth_username: Username for HTTP authentication
         :arg string auth_password: Password for HTTP authentication
         :arg string auth_mode: Authentication mode; default is "basic".
@@ -348,6 +357,7 @@ class HTTPRequest(object):
         self.url = url
         self.method = method
         self.body = body
+        self.body_producer = body_producer
         self.auth_username = auth_username
         self.auth_password = auth_password
         self.auth_mode = auth_mode
@@ -388,6 +398,14 @@ class HTTPRequest(object):
     def body(self, value):
         self._body = utf8(value)
 
+    @property
+    def body_producer(self):
+        return self._body_producer
+
+    @body_producer.setter
+    def body_producer(self, value):
+        self._body_producer = stack_context.wrap(value)
+
     @property
     def streaming_callback(self):
         return self._streaming_callback
index ff512ba1a634a00358c99bd974c2edde7a67c436..d83f54c468134a3b0a146f54435f42bc5ffe4661 100644 (file)
@@ -146,7 +146,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
                            **kwargs)
 
     def handle_stream(self, stream, address):
-        conn = HTTP1Connection(stream, address=address,
+        conn = HTTP1Connection(stream, address=address, is_client=False,
                                no_keep_alive=self.no_keep_alive,
                                protocol=self.protocol)
         conn.start_serving(self, gzip=self.gzip)
index d6925a38edd109edd2bfd48e444c4e4900c34956..4f562226b45251eebbbb7680104f0d0565b34d06 100644 (file)
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 from __future__ import absolute_import, division, print_function, with_statement
 
+from tornado.concurrent import is_future
 from tornado.escape import utf8, _unicode
 from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
 from tornado import httputil
@@ -303,16 +304,20 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             self.request.headers["User-Agent"] = self.request.user_agent
         if not self.request.allow_nonstandard_methods:
             if self.request.method in ("POST", "PATCH", "PUT"):
-                if self.request.body is None:
+                if (self.request.body is None and
+                    self.request.body_producer is None):
                     raise AssertionError(
                         'Body must not be empty for "%s" request'
                         % self.request.method)
             else:
-                if self.request.body is not None:
+                if (self.request.body is not None or
+                    self.request.body_producer is not None):
                     raise AssertionError(
                         'Body must be empty for "%s" request'
                         % self.request.method)
         if self.request.body is not None:
+            # When body_producer is used the caller is responsible for
+            # setting Content-Length (or else chunked encoding will be used).
             self.request.headers["Content-Length"] = str(len(
                 self.request.body))
         if (self.request.method == "POST" and
@@ -324,13 +329,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                    (('?' + self.parsed.query) if self.parsed.query else ''))
         self.stream.set_nodelay(True)
         self.connection = HTTP1Connection(
-            self.stream, self._sockaddr,
+            self.stream, self._sockaddr, is_client=True,
             no_keep_alive=True, protocol=self.parsed.scheme)
         start_line = httputil.RequestStartLine(self.request.method,
                                                req_path, 'HTTP/1.1')
-        self.connection.write_headers(start_line, self.request.headers)
+        self.connection.write_headers(
+            start_line, self.request.headers,
+            has_body=(self.request.body is not None or
+                      self.request.body_producer is not None))
         if self.request.body is not None:
             self.connection.write(self.request.body)
+            self.connection.finish()
+        elif self.request.body_producer is not None:
+            fut = self.request.body_producer(self.connection.write)
+            if is_future(fut):
+                def on_body_written(fut):
+                    fut.result()
+                    self.connection.finish()
+                    self._read_response()
+                self.io_loop.add_future(fut, on_body_written)
+                return
+            self.connection.finish()
+        self._read_response()
+
+    def _read_response(self):
         # Ensure that any exception raised in read_response ends up in our
         # stack context.
         self.io_loop.add_future(
index 82148298a87d949577f2a1f36918674670fd09c6..343a8e4c072e55cadb27cfd17cda5fad88b4ed0a 100644 (file)
@@ -39,7 +39,7 @@ def read_stream_body(stream, callback):
 
         def finish(self):
             callback(b''.join(chunks))
-    conn = HTTP1Connection(stream, None)
+    conn = HTTP1Connection(stream, None, is_client=True)
     conn.read_response(Delegate(), method='GET')
 
 
index c10e7777ccbce42da4c908ccf53a95939166a6ec..0194f31451d52213e19a1a931281b06c28587c58 100644 (file)
@@ -107,6 +107,11 @@ class NoContentLengthHandler(RequestHandler):
         stream.close()
 
 
+class EchoPostHandler(RequestHandler):
+    def post(self):
+        self.write(self.request.body)
+
+
 class SimpleHTTPClientTestMixin(object):
     def get_app(self):
         # callable objects to finish pending /trigger requests
@@ -126,6 +131,7 @@ class SimpleHTTPClientTestMixin(object):
             url("/see_other_get", SeeOtherGetHandler),
             url("/host_echo", HostEchoHandler),
             url("/no_content_length", NoContentLengthHandler),
+            url("/echo_post", EchoPostHandler),
         ], gzip=True)
 
     def test_singleton(self):
@@ -331,6 +337,44 @@ class SimpleHTTPClientTestMixin(object):
         response = self.fetch("/no_content_length")
         self.assertEquals(b"hello", response.body)
 
+    def sync_body_producer(self, write):
+        write(b'1234')
+        write(b'5678')
+
+    @gen.coroutine
+    def async_body_producer(self, write):
+        # TODO: write should return a Future.
+        # wrap it in simple_httpclient or change http1connection?
+        yield gen.Task(write, b'1234')
+        yield gen.Task(IOLoop.current().add_callback)
+        yield gen.Task(write, b'5678')
+
+    def test_sync_body_producer_chunked(self):
+        response = self.fetch("/echo_post", method="POST",
+                              body_producer=self.sync_body_producer)
+        response.rethrow()
+        self.assertEqual(response.body, b"12345678")
+
+    def test_sync_body_producer_content_length(self):
+        response = self.fetch("/echo_post", method="POST",
+                              body_producer=self.sync_body_producer,
+                              headers={'Content-Length': '8'})
+        response.rethrow()
+        self.assertEqual(response.body, b"12345678")
+
+    def test_async_body_producer_chunked(self):
+        response = self.fetch("/echo_post", method="POST",
+                              body_producer=self.async_body_producer)
+        response.rethrow()
+        self.assertEqual(response.body, b"12345678")
+
+    def test_async_body_producer_content_length(self):
+        response = self.fetch("/echo_post", method="POST",
+                              body_producer=self.async_body_producer,
+                              headers={'Content-Length': '8'})
+        response.rethrow()
+        self.assertEqual(response.body, b"12345678")
+
 
 class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
     def setUp(self):