]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Support other yieldables for httpclient body_producer
authorBen Darnell <ben@bendarnell.com>
Sun, 4 Oct 2015 00:52:04 +0000 (20:52 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 4 Oct 2015 02:28:15 +0000 (22:28 -0400)
tornado/simple_httpclient.py
tornado/test/simple_httpclient_test.py

index 074d18b849a0684fdf99f036b5d4d1b559c7c19a..37b0bc27fdd453da15d2a8d11c8a5a1eb5f7e726 100644 (file)
@@ -1,8 +1,8 @@
 #!/usr/bin/env python
 from __future__ import absolute_import, division, print_function, with_statement
 
-from tornado.concurrent import is_future
 from tornado.escape import utf8, _unicode
+from tornado import gen
 from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
 from tornado import httputil
 from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
@@ -391,7 +391,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             self.connection.write(self.request.body)
         elif self.request.body_producer is not None:
             fut = self.request.body_producer(self.connection.write)
-            if is_future(fut):
+            if fut is not None:
+                fut = gen.convert_yielded(fut)
+
                 def on_body_written(fut):
                     fut.result()
                     self.connection.finish()
index d478071f318f9202e9bf5e421a3a9944229b0a24..b6687a2982d8fab158021b6ba7334655a28b30ab 100644 (file)
@@ -22,7 +22,7 @@ from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler, RedirectHandler
 from tornado.test import httpclient_test
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
-from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest
+from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port, unittest, skipBefore35, exec_test
 from tornado.web import RequestHandler, Application, asynchronous, url, stream_request_body
 
 
@@ -404,6 +404,33 @@ class SimpleHTTPClientTestMixin(object):
         response.rethrow()
         self.assertEqual(response.body, b"12345678")
 
+    @skipBefore35
+    def test_native_body_producer_chunked(self):
+        namespace = exec_test(globals(), locals(), """
+        async def body_producer(write):
+            await write(b'1234')
+            await gen.Task(IOLoop.current().add_callback)
+            await write(b'5678')
+        """)
+        response = self.fetch("/echo_post", method="POST",
+                              body_producer=namespace["body_producer"])
+        response.rethrow()
+        self.assertEqual(response.body, b"12345678")
+
+    @skipBefore35
+    def test_native_body_producer_content_length(self):
+        namespace = exec_test(globals(), locals(), """
+        async def body_producer(write):
+            await write(b'1234')
+            await gen.Task(IOLoop.current().add_callback)
+            await write(b'5678')
+        """)
+        response = self.fetch("/echo_post", method="POST",
+                              body_producer=namespace["body_producer"],
+                              headers={'Content-Length': '8'})
+        response.rethrow()
+        self.assertEqual(response.body, b"12345678")
+
     def test_100_continue(self):
         response = self.fetch("/echo_post", method="POST",
                               body=b"1234",