]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Permit streaming_callback of AsyncHTTPClient to be a coroutine. (#3471) master
authorAaron Gibson <aaron@headspin.io>
Wed, 17 Sep 2025 17:43:47 +0000 (10:43 -0700)
committerGitHub <noreply@github.com>
Wed, 17 Sep 2025 17:43:47 +0000 (13:43 -0400)
Co-authored-by: Aaron Gibson <eulersidcrisis@yahoo.com>
tornado/curl_httpclient.py
tornado/httpclient.py
tornado/simple_httpclient.py
tornado/test/curl_httpclient_test.py
tornado/test/simple_httpclient_test.py

index eb3fa7836fac39a9256a863108c74023fd90a7c7..6d98b44b74f072783fd57a4f1f5410470d4cd96d 100644 (file)
@@ -22,8 +22,10 @@ import pycurl
 import re
 import threading
 import time
 import re
 import threading
 import time
+import inspect
 from io import BytesIO
 
 from io import BytesIO
 
+from tornado import gen
 from tornado import httputil
 from tornado import ioloop
 
 from tornado import httputil
 from tornado import ioloop
 
@@ -368,6 +370,13 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
         )
         if request.streaming_callback:
 
         )
         if request.streaming_callback:
 
+            if gen.is_coroutine_function(
+                request.streaming_callback
+            ) or inspect.iscoroutinefunction(request.streaming_callback):
+                raise TypeError(
+                    "'CurlAsyncHTTPClient' does not support async streaming_callbacks."
+                )
+
             def write_function(b: Union[bytes, bytearray]) -> int:
                 assert request.streaming_callback is not None
                 self.io_loop.add_callback(request.streaming_callback, b)
             def write_function(b: Union[bytes, bytearray]) -> int:
                 assert request.streaming_callback is not None
                 self.io_loop.add_callback(request.streaming_callback, b)
index 3a45ffd0415b087e267385067678f87bb117728f..488fe6de0b0832c49483e0e89caec18ff5832b14 100644 (file)
@@ -53,7 +53,7 @@ from tornado import gen, httputil
 from tornado.ioloop import IOLoop
 from tornado.util import Configurable
 
 from tornado.ioloop import IOLoop
 from tornado.util import Configurable
 
-from typing import Type, Any, Union, Dict, Callable, Optional, cast
+from typing import Type, Any, Union, Dict, Callable, Optional, Awaitable, cast
 
 
 class HTTPClient:
 
 
 class HTTPClient:
@@ -372,7 +372,9 @@ class HTTPRequest:
         user_agent: Optional[str] = None,
         use_gzip: Optional[bool] = None,
         network_interface: Optional[str] = None,
         user_agent: Optional[str] = None,
         use_gzip: Optional[bool] = None,
         network_interface: Optional[str] = None,
-        streaming_callback: Optional[Callable[[bytes], None]] = None,
+        streaming_callback: Optional[
+            Callable[[bytes], Optional[Awaitable[None]]]
+        ] = None,
         header_callback: Optional[Callable[[str], None]] = None,
         prepare_curl_callback: Optional[Callable[[Any], None]] = None,
         proxy_host: Optional[str] = None,
         header_callback: Optional[Callable[[str], None]] = None,
         prepare_curl_callback: Optional[Callable[[Any], None]] = None,
         proxy_host: Optional[str] = None,
index cc1637613350dd9e9a812b1da50aa2fef02ad286..5ed273db3e2065ed2c1f63a575dc4d5641ce7ec6 100644 (file)
@@ -33,7 +33,7 @@ import time
 from io import BytesIO
 import urllib.parse
 
 from io import BytesIO
 import urllib.parse
 
-from typing import Dict, Any, Callable, Optional, Type, Union
+from typing import Dict, Any, Callable, Optional, Type, Union, Awaitable
 from types import TracebackType
 import typing
 
 from types import TracebackType
 import typing
 
@@ -687,14 +687,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
     def _on_end_request(self) -> None:
         self.stream.close()
 
     def _on_end_request(self) -> None:
         self.stream.close()
 
-    def data_received(self, chunk: bytes) -> None:
+    def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
         if self._should_follow_redirect():
             # We're going to follow a redirect so just discard the body.
         if self._should_follow_redirect():
             # We're going to follow a redirect so just discard the body.
-            return
+            return None
         if self.request.streaming_callback is not None:
         if self.request.streaming_callback is not None:
-            self.request.streaming_callback(chunk)
+            return self.request.streaming_callback(chunk)
         else:
             self.chunks.append(chunk)
         else:
             self.chunks.append(chunk)
+            return None
 
 
 if __name__ == "__main__":
 
 
 if __name__ == "__main__":
index ce3f68d7f779ea227aa4f590d37874d6f9dc1a16..bf87df68221ce9b149e2eb8ef970eb5d806d37fa 100644 (file)
@@ -5,6 +5,7 @@ from tornado.escape import utf8
 from tornado.testing import AsyncHTTPTestCase
 from tornado.test import httpclient_test
 from tornado.web import Application, RequestHandler
 from tornado.testing import AsyncHTTPTestCase
 from tornado.test import httpclient_test
 from tornado.web import Application, RequestHandler
+from tornado import gen
 
 
 try:
 
 
 try:
@@ -123,3 +124,19 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase):
             auth_password="barユ£",
         )
         self.assertEqual(response.body, b"ok")
             auth_password="barユ£",
         )
         self.assertEqual(response.body, b"ok")
+
+    def test_streaming_callback_not_permitted(self):
+        @gen.coroutine
+        def _recv_chunk(chunk):
+            yield gen.moment
+
+        with self.assertRaises(TypeError):
+            self.fetch("/digest", streaming_callback=_recv_chunk)
+
+        import asyncio
+
+        async def _async_recv_chunk(chunk):
+            await asyncio.sleep(0)
+
+        with self.assertRaises(TypeError):
+            self.fetch("/digest", streaming_callback=_async_recv_chunk)
index a40435e812d280d7cbcc1c97909af6c8b38d868c..9b21acada6c752ab994b51145c097f81decf7d44 100644 (file)
@@ -539,6 +539,27 @@ class SimpleHTTPClientTestMixin(AsyncTestCase):
         num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
         self.assertEqual(num_start_lines, 1)
 
         num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
         self.assertEqual(num_start_lines, 1)
 
+    def test_streaming_callback_coroutine(self: typing.Any):
+        headers = []  # type: typing.List[str]
+        chunk_bytes = []  # type: typing.List[bytes]
+
+        import asyncio
+
+        async def _put_chunk(chunk):
+            await asyncio.sleep(0)
+            chunk_bytes.append(chunk)
+
+        self.fetch(
+            "/chunk",
+            header_callback=headers.append,
+            streaming_callback=_put_chunk,
+        )
+        chunks = list(map(to_unicode, chunk_bytes))
+        self.assertEqual("".join(chunks), "asdfqwer")
+        # Make sure we only got one set of headers.
+        num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
+        self.assertEqual(num_start_lines, 1)
+
 
 class SimpleHTTPClientTestCase(AsyncHTTPTestCase, SimpleHTTPClientTestMixin):
     def setUp(self):
 
 class SimpleHTTPClientTestCase(AsyncHTTPTestCase, SimpleHTTPClientTestMixin):
     def setUp(self):