From: Aaron Gibson Date: Wed, 17 Sep 2025 17:43:47 +0000 (-0700) Subject: Permit streaming_callback of AsyncHTTPClient to be a coroutine. (#3471) X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;p=thirdparty%2Ftornado.git Permit streaming_callback of AsyncHTTPClient to be a coroutine. (#3471) Co-authored-by: Aaron Gibson --- diff --git a/tornado/curl_httpclient.py b/tornado/curl_httpclient.py index eb3fa783..6d98b44b 100644 --- a/tornado/curl_httpclient.py +++ b/tornado/curl_httpclient.py @@ -22,8 +22,10 @@ import pycurl import re import threading import time +import inspect from io import BytesIO +from tornado import gen from tornado import httputil from tornado import ioloop @@ -368,6 +370,13 @@ class CurlAsyncHTTPClient(AsyncHTTPClient): ) if request.streaming_callback: + if gen.is_coroutine_function( + request.streaming_callback + ) or inspect.iscoroutinefunction(request.streaming_callback): + raise TypeError( + "'CurlAsyncHTTPClient' does not support async streaming_callbacks." + ) + def write_function(b: Union[bytes, bytearray]) -> int: assert request.streaming_callback is not None self.io_loop.add_callback(request.streaming_callback, b) diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 3a45ffd0..488fe6de 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -53,7 +53,7 @@ from tornado import gen, httputil from tornado.ioloop import IOLoop from tornado.util import Configurable -from typing import Type, Any, Union, Dict, Callable, Optional, cast +from typing import Type, Any, Union, Dict, Callable, Optional, Awaitable, cast class HTTPClient: @@ -372,7 +372,9 @@ class HTTPRequest: user_agent: Optional[str] = None, use_gzip: Optional[bool] = None, network_interface: Optional[str] = None, - streaming_callback: Optional[Callable[[bytes], None]] = None, + streaming_callback: Optional[ + Callable[[bytes], Optional[Awaitable[None]]] + ] = None, header_callback: Optional[Callable[[str], None]] = None, prepare_curl_callback: Optional[Callable[[Any], None]] = None, proxy_host: Optional[str] = None, diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index cc163761..5ed273db 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -33,7 +33,7 @@ import time from io import BytesIO import urllib.parse -from typing import Dict, Any, Callable, Optional, Type, Union +from typing import Dict, Any, Callable, Optional, Type, Union, Awaitable from types import TracebackType import typing @@ -687,14 +687,15 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): def _on_end_request(self) -> None: self.stream.close() - def data_received(self, chunk: bytes) -> None: + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: if self._should_follow_redirect(): # We're going to follow a redirect so just discard the body. - return + return None if self.request.streaming_callback is not None: - self.request.streaming_callback(chunk) + return self.request.streaming_callback(chunk) else: self.chunks.append(chunk) + return None if __name__ == "__main__": diff --git a/tornado/test/curl_httpclient_test.py b/tornado/test/curl_httpclient_test.py index ce3f68d7..bf87df68 100644 --- a/tornado/test/curl_httpclient_test.py +++ b/tornado/test/curl_httpclient_test.py @@ -5,6 +5,7 @@ from tornado.escape import utf8 from tornado.testing import AsyncHTTPTestCase from tornado.test import httpclient_test from tornado.web import Application, RequestHandler +from tornado import gen try: @@ -123,3 +124,19 @@ class CurlHTTPClientTestCase(AsyncHTTPTestCase): auth_password="barユ£", ) self.assertEqual(response.body, b"ok") + + def test_streaming_callback_not_permitted(self): + @gen.coroutine + def _recv_chunk(chunk): + yield gen.moment + + with self.assertRaises(TypeError): + self.fetch("/digest", streaming_callback=_recv_chunk) + + import asyncio + + async def _async_recv_chunk(chunk): + await asyncio.sleep(0) + + with self.assertRaises(TypeError): + self.fetch("/digest", streaming_callback=_async_recv_chunk) diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index a40435e8..9b21acad 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -539,6 +539,27 @@ class SimpleHTTPClientTestMixin(AsyncTestCase): num_start_lines = len([h for h in headers if h.startswith("HTTP/")]) self.assertEqual(num_start_lines, 1) + def test_streaming_callback_coroutine(self: typing.Any): + headers = [] # type: typing.List[str] + chunk_bytes = [] # type: typing.List[bytes] + + import asyncio + + async def _put_chunk(chunk): + await asyncio.sleep(0) + chunk_bytes.append(chunk) + + self.fetch( + "/chunk", + header_callback=headers.append, + streaming_callback=_put_chunk, + ) + chunks = list(map(to_unicode, chunk_bytes)) + self.assertEqual("".join(chunks), "asdfqwer") + # Make sure we only got one set of headers. + num_start_lines = len([h for h in headers if h.startswith("HTTP/")]) + self.assertEqual(num_start_lines, 1) + class SimpleHTTPClientTestCase(AsyncHTTPTestCase, SimpleHTTPClientTestMixin): def setUp(self):