From 681e51b38f0e20d528df1f28517d88aedb34984b Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sat, 29 Mar 2014 15:01:14 +0000 Subject: [PATCH] Add body_producer argument to httpclient.HTTPRequest. This allows for sending non-contiguous or asynchronously-produced request bodies, including chunked encoding when the content-length is not known in advance. --- tornado/http1connection.py | 61 ++++++++++++++++---------- tornado/httpclient.py | 22 +++++++++- tornado/httpserver.py | 2 +- tornado/simple_httpclient.py | 30 +++++++++++-- tornado/test/httpserver_test.py | 2 +- tornado/test/simple_httpclient_test.py | 44 +++++++++++++++++++ 6 files changed, 129 insertions(+), 32 deletions(-) diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 745e4772a..e46c04278 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -34,7 +34,9 @@ class HTTP1Connection(object): We parse HTTP headers and bodies, and execute the request callback until the HTTP conection is closed. """ - def __init__(self, stream, address, no_keep_alive=False, protocol=None): + def __init__(self, stream, address, is_client, + no_keep_alive=False, protocol=None): + self.is_client = is_client self.stream = stream self.address = address # Save the socket's address family now so we know how to @@ -83,7 +85,7 @@ class HTTP1Connection(object): if gzip: request_delegate = _GzipMessageDelegate(request_delegate) try: - ret = yield self._read_message(request_delegate, False) + ret = yield self._read_message(request_delegate) except iostream.StreamClosedError: self.close() return @@ -93,10 +95,10 @@ class HTTP1Connection(object): def read_response(self, delegate, method, use_gzip=False): if use_gzip: delegate = _GzipMessageDelegate(delegate) - return self._read_message(delegate, True, method=method) + return self._read_message(delegate, method=method) @gen.coroutine - def _read_message(self, delegate, is_client, method=None): + def _read_message(self, delegate, method=None): assert isinstance(delegate, httputil.HTTPMessageDelegate) self.message_delegate = delegate try: @@ -104,7 +106,7 @@ class HTTP1Connection(object): self._reading = True self._finish_future = Future() start_line, headers = self._parse_headers(header_data) - if is_client: + if self.is_client: start_line = httputil.parse_response_start_line(start_line) else: start_line = httputil.parse_request_start_line(start_line) @@ -120,7 +122,7 @@ class HTTP1Connection(object): # TODO: where else do we need to check for detach? raise gen.Return(False) skip_body = False - if is_client: + if self.is_client: if method == 'HEAD': skip_body = True code = start_line.code @@ -130,12 +132,12 @@ class HTTP1Connection(object): # TODO: client delegates will get headers_received twice # in the case of a 100-continue. Document or change? yield self._read_message(self.message_delegate, - is_client, method=method) + method=method) else: if headers.get("Expect") == "100-continue": self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") if not skip_body: - body_future = self._read_body(is_client, headers) + body_future = self._read_body(headers) if body_future is not None: yield body_future self._reading = False @@ -194,19 +196,29 @@ class HTTP1Connection(object): self.stream = None return stream - def write_headers(self, start_line, headers, chunk=None, callback=None): - self._chunking = ( - # TODO: should this use self._version or start_line.version? - self._version == 'HTTP/1.1' and - # 304 responses have no body (not even a zero-length body), and so - # should not have either Content-Length or Transfer-Encoding. - # headers. - start_line.code != 304 and - # No need to chunk the output if a Content-Length is specified. - 'Content-Length' not in headers and - # Applications are discouraged from touching Transfer-Encoding, - # but if they do, leave it alone. - 'Transfer-Encoding' not in headers) + def write_headers(self, start_line, headers, chunk=None, callback=None, + has_body=True): + if self.is_client: + # Client requests with a non-empty body must have either a + # Content-Length or a Transfer-Encoding. + self._chunking = ( + has_body and + 'Content-Length' not in headers and + 'Transfer-Encoding' not in headers) + else: + self._chunking = ( + has_body and + # TODO: should this use self._version or start_line.version? + self._version == 'HTTP/1.1' and + # 304 responses have no body (not even a zero-length body), and so + # should not have either Content-Length or Transfer-Encoding. + # headers. + start_line.code != 304 and + # No need to chunk the output if a Content-Length is specified. + 'Content-Length' not in headers and + # Applications are discouraged from touching Transfer-Encoding, + # but if they do, leave it alone. + 'Transfer-Encoding' not in headers) if self._chunking: headers['Transfer-Encoding'] = 'chunked' lines = [utf8("%s %s %s" % start_line)] @@ -293,7 +305,8 @@ class HTTP1Connection(object): # Turn Nagle's algorithm back on, leaving the stream in its # default state for the next request. self.stream.set_nodelay(False) - self._finish_future.set_result(None) + if self._finish_future is not None: + self._finish_future.set_result(None) def _parse_headers(self, data): data = native_str(data.decode('latin1')) @@ -307,7 +320,7 @@ class HTTP1Connection(object): data[eol:100]) return start_line, headers - def _read_body(self, is_client, headers): + def _read_body(self, headers): content_length = headers.get("Content-Length") if content_length: content_length = int(content_length) @@ -316,7 +329,7 @@ class HTTP1Connection(object): return self._read_fixed_body(content_length) if headers.get("Transfer-Encoding") == "chunked": return self._read_chunked_body() - if is_client: + if self.is_client: return self._read_body_until_close() return None diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 9b42d401a..ac3d92b3e 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -259,14 +259,23 @@ class HTTPRequest(object): proxy_password=None, allow_nonstandard_methods=None, validate_cert=None, ca_certs=None, allow_ipv6=None, - client_key=None, client_cert=None): + client_key=None, client_cert=None, body_producer=None): r"""All parameters except ``url`` are optional. :arg string url: URL to fetch :arg string method: HTTP method, e.g. "GET" or "POST" :arg headers: Additional HTTP headers to pass on the request - :arg body: HTTP body to pass on the request :type headers: `~tornado.httputil.HTTPHeaders` or `dict` + :arg body: HTTP request body as a string (byte or unicode; if unicode + the utf-8 encoding will be used) + :arg body_producer: Callable used for lazy/asynchronous request bodies. + TODO: document the interface. + Only one of ``body`` and ``body_producer`` may + be specified. ``body_producer`` is not supported on + ``curl_httpclient``. When using ``body_producer`` it is recommended + to pass a ``Content-Length`` in the headers as otherwise chunked + encoding will be used, and many servers do not support chunked + encoding on requests. :arg string auth_username: Username for HTTP authentication :arg string auth_password: Password for HTTP authentication :arg string auth_mode: Authentication mode; default is "basic". @@ -348,6 +357,7 @@ class HTTPRequest(object): self.url = url self.method = method self.body = body + self.body_producer = body_producer self.auth_username = auth_username self.auth_password = auth_password self.auth_mode = auth_mode @@ -388,6 +398,14 @@ class HTTPRequest(object): def body(self, value): self._body = utf8(value) + @property + def body_producer(self): + return self._body_producer + + @body_producer.setter + def body_producer(self, value): + self._body_producer = stack_context.wrap(value) + @property def streaming_callback(self): return self._streaming_callback diff --git a/tornado/httpserver.py b/tornado/httpserver.py index ff512ba1a..d83f54c46 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -146,7 +146,7 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate): **kwargs) def handle_stream(self, stream, address): - conn = HTTP1Connection(stream, address=address, + conn = HTTP1Connection(stream, address=address, is_client=False, no_keep_alive=self.no_keep_alive, protocol=self.protocol) conn.start_serving(self, gzip=self.gzip) diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index d6925a38e..4f562226b 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import absolute_import, division, print_function, with_statement +from tornado.concurrent import is_future from tornado.escape import utf8, _unicode from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy from tornado import httputil @@ -303,16 +304,20 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): - if self.request.body is None: + if (self.request.body is None and + self.request.body_producer is None): raise AssertionError( 'Body must not be empty for "%s" request' % self.request.method) else: - if self.request.body is not None: + if (self.request.body is not None or + self.request.body_producer is not None): raise AssertionError( 'Body must be empty for "%s" request' % self.request.method) if self.request.body is not None: + # When body_producer is used the caller is responsible for + # setting Content-Length (or else chunked encoding will be used). self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and @@ -324,13 +329,30 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): (('?' + self.parsed.query) if self.parsed.query else '')) self.stream.set_nodelay(True) self.connection = HTTP1Connection( - self.stream, self._sockaddr, + self.stream, self._sockaddr, is_client=True, no_keep_alive=True, protocol=self.parsed.scheme) start_line = httputil.RequestStartLine(self.request.method, req_path, 'HTTP/1.1') - self.connection.write_headers(start_line, self.request.headers) + self.connection.write_headers( + start_line, self.request.headers, + has_body=(self.request.body is not None or + self.request.body_producer is not None)) if self.request.body is not None: self.connection.write(self.request.body) + self.connection.finish() + elif self.request.body_producer is not None: + fut = self.request.body_producer(self.connection.write) + if is_future(fut): + def on_body_written(fut): + fut.result() + self.connection.finish() + self._read_response() + self.io_loop.add_future(fut, on_body_written) + return + self.connection.finish() + self._read_response() + + def _read_response(self): # Ensure that any exception raised in read_response ends up in our # stack context. self.io_loop.add_future( diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 82148298a..343a8e4c0 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -39,7 +39,7 @@ def read_stream_body(stream, callback): def finish(self): callback(b''.join(chunks)) - conn = HTTP1Connection(stream, None) + conn = HTTP1Connection(stream, None, is_client=True) conn.read_response(Delegate(), method='GET') diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index c10e7777c..0194f3145 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -107,6 +107,11 @@ class NoContentLengthHandler(RequestHandler): stream.close() +class EchoPostHandler(RequestHandler): + def post(self): + self.write(self.request.body) + + class SimpleHTTPClientTestMixin(object): def get_app(self): # callable objects to finish pending /trigger requests @@ -126,6 +131,7 @@ class SimpleHTTPClientTestMixin(object): url("/see_other_get", SeeOtherGetHandler), url("/host_echo", HostEchoHandler), url("/no_content_length", NoContentLengthHandler), + url("/echo_post", EchoPostHandler), ], gzip=True) def test_singleton(self): @@ -331,6 +337,44 @@ class SimpleHTTPClientTestMixin(object): response = self.fetch("/no_content_length") self.assertEquals(b"hello", response.body) + def sync_body_producer(self, write): + write(b'1234') + write(b'5678') + + @gen.coroutine + def async_body_producer(self, write): + # TODO: write should return a Future. + # wrap it in simple_httpclient or change http1connection? + yield gen.Task(write, b'1234') + yield gen.Task(IOLoop.current().add_callback) + yield gen.Task(write, b'5678') + + def test_sync_body_producer_chunked(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.sync_body_producer) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_sync_body_producer_content_length(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.sync_body_producer, + headers={'Content-Length': '8'}) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_async_body_producer_chunked(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.async_body_producer) + response.rethrow() + self.assertEqual(response.body, b"12345678") + + def test_async_body_producer_content_length(self): + response = self.fetch("/echo_post", method="POST", + body_producer=self.async_body_producer, + headers={'Content-Length': '8'}) + response.rethrow() + self.assertEqual(response.body, b"12345678") + class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase): def setUp(self): -- 2.47.2