]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop 'Response(on_close=...)' from API (#1572)
authorTom Christie <tom@tomchristie.com>
Fri, 16 Apr 2021 09:03:37 +0000 (10:03 +0100)
committerGitHub <noreply@github.com>
Fri, 16 Apr 2021 09:03:37 +0000 (10:03 +0100)
httpx/_client.py
httpx/_models.py
tests/client/test_async_client.py
tests/client/test_client.py
tests/models/test_responses.py

index dde2fec03fe1a78d54058988d44293455ede828f..429382fa808097472f9bced75a5a96cafa0eb26f 100644 (file)
@@ -86,6 +86,52 @@ class ClientState(enum.Enum):
     CLOSED = 3
 
 
+class BoundSyncStream(SyncByteStream):
+    """
+    A byte stream that is bound to a given response instance, and that
+    ensures the `response.elapsed` is set once the response is closed.
+    """
+
+    def __init__(
+        self, stream: SyncByteStream, response: Response, timer: Timer
+    ) -> None:
+        self._stream = stream
+        self._response = response
+        self._timer = timer
+
+    def __iter__(self) -> typing.Iterator[bytes]:
+        for chunk in self._stream:
+            yield chunk
+
+    def close(self) -> None:
+        seconds = self._timer.sync_elapsed()
+        self._response.elapsed = datetime.timedelta(seconds=seconds)
+        self._stream.close()
+
+
+class BoundAsyncStream(AsyncByteStream):
+    """
+    An async byte stream that is bound to a given response instance, and that
+    ensures the `response.elapsed` is set once the response is closed.
+    """
+
+    def __init__(
+        self, stream: AsyncByteStream, response: Response, timer: Timer
+    ) -> None:
+        self._stream = stream
+        self._response = response
+        self._timer = timer
+
+    async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+        async for chunk in self._stream:
+            yield chunk
+
+    async def aclose(self) -> None:
+        seconds = await self._timer.async_elapsed()
+        self._response.elapsed = datetime.timedelta(seconds=seconds)
+        await self._stream.aclose()
+
+
 class BaseClient:
     def __init__(
         self,
@@ -874,28 +920,29 @@ class Client(BaseClient):
         timer = Timer()
         timer.sync_start()
 
+        if not isinstance(request.stream, SyncByteStream):
+            raise RuntimeError(
+                "Attempted to send an async request with a sync Client instance."
+            )
+
         with request_context(request=request):
             (status_code, headers, stream, extensions) = transport.handle_request(
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
-                stream=request.stream,  # type: ignore
+                stream=request.stream,
                 extensions={"timeout": timeout.as_dict()},
             )
 
-        def on_close(response: Response) -> None:
-            response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
-            stream.close()
-
         response = Response(
             status_code,
             headers=headers,
             stream=stream,
             extensions=extensions,
             request=request,
-            on_close=on_close,
         )
 
+        response.stream = BoundSyncStream(stream, response=response, timer=timer)
         self.cookies.extract_cookies(response)
 
         status = f"{response.status_code} {response.reason_phrase}"
@@ -1512,6 +1559,11 @@ class AsyncClient(BaseClient):
         timer = Timer()
         await timer.async_start()
 
+        if not isinstance(request.stream, AsyncByteStream):
+            raise RuntimeError(
+                "Attempted to send an sync request with an AsyncClient instance."
+            )
+
         with request_context(request=request):
             (
                 status_code,
@@ -1522,23 +1574,19 @@ class AsyncClient(BaseClient):
                 request.method.encode(),
                 request.url.raw,
                 headers=request.headers.raw,
-                stream=request.stream,  # type: ignore
+                stream=request.stream,
                 extensions={"timeout": timeout.as_dict()},
             )
 
-        async def on_close(response: Response) -> None:
-            response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
-            await stream.aclose()
-
         response = Response(
             status_code,
             headers=headers,
             stream=stream,
             extensions=extensions,
             request=request,
-            on_close=on_close,
         )
 
+        response.stream = BoundAsyncStream(stream, response=response, timer=timer)
         self.cookies.extract_cookies(response)
 
         status = f"{response.status_code} {response.reason_phrase}"
index a3b6ff1f01119c4d39190d2cc8226d3406f865dd..bd1ef63891f07926a2de468e30e95456c32b6cdc 100644 (file)
@@ -908,7 +908,6 @@ class Response:
         request: Request = None,
         extensions: dict = None,
         history: typing.List["Response"] = None,
-        on_close: typing.Callable = None,
     ):
         self.status_code = status_code
         self.headers = Headers(headers)
@@ -923,7 +922,6 @@ class Response:
 
         self.extensions = {} if extensions is None else extensions
         self.history = [] if history is None else list(history)
-        self._on_close = on_close
 
         self.is_closed = False
         self.is_stream_consumed = False
@@ -1245,11 +1243,13 @@ class Response:
         Close the response and release the connection.
         Automatically called if the response body is read to completion.
         """
+        if not isinstance(self.stream, SyncByteStream):
+            raise RuntimeError("Attempted to call an sync close on an async stream.")
+
         if not self.is_closed:
             self.is_closed = True
-            if self._on_close is not None:
-                with request_context(request=self._request):
-                    self._on_close(self)
+            with request_context(request=self._request):
+                self.stream.close()
 
     async def aread(self) -> bytes:
         """
@@ -1341,11 +1341,13 @@ class Response:
         Close the response and release the connection.
         Automatically called if the response body is read to completion.
         """
+        if not isinstance(self.stream, AsyncByteStream):
+            raise RuntimeError("Attempted to call an async close on an sync stream.")
+
         if not self.is_closed:
             self.is_closed = True
-            if self._on_close is not None:
-                with request_context(request=self._request):
-                    await self._on_close(self)
+            with request_context(request=self._request):
+                await self.stream.aclose()
 
 
 class Cookies(MutableMapping):
index 6c121b5e882f7e528e58417c9c6d5f663e75ac75..0f83eddd7ffd78daaccada9008d3de687ac8bb2c 100644 (file)
@@ -94,10 +94,21 @@ async def test_stream_request(server):
         yield b"world!"
 
     async with httpx.AsyncClient() as client:
-        response = await client.request("POST", server.url, content=hello_world())
+        response = await client.post(server.url, content=hello_world())
     assert response.status_code == 200
 
 
+@pytest.mark.usefixtures("async_environment")
+async def test_cannot_stream_sync_request(server):
+    def hello_world():  # pragma: nocover
+        yield b"Hello, "
+        yield b"world!"
+
+    async with httpx.AsyncClient() as client:
+        with pytest.raises(RuntimeError):
+            await client.post(server.url, content=hello_world())
+
+
 @pytest.mark.usefixtures("async_environment")
 async def test_raise_for_status(server):
     async with httpx.AsyncClient() as client:
index 0538e960b2a569591fa26f363c9744280082001e..c31a1ae6df0ad871dedaac48ac10b172ff1c9a48 100644 (file)
@@ -114,6 +114,16 @@ def test_raw_iterator(server):
     assert body == b"Hello, world!"
 
 
+def test_cannot_stream_async_request(server):
+    async def hello_world():  # pragma: nocover
+        yield b"Hello, "
+        yield b"world!"
+
+    with httpx.Client() as client:
+        with pytest.raises(RuntimeError):
+            client.post(server.url, content=hello_world())
+
+
 def test_raise_for_status(server):
     with httpx.Client() as client:
         for status_code in (200, 400, 404, 500, 505):
index 793fad3b76e7c2a1790144bd421054e1408fdea4..10a3f1aac3873a23e17b63453d002365ba15d928 100644 (file)
@@ -382,6 +382,16 @@ def test_iter_raw_on_async():
         [part for part in response.iter_raw()]
 
 
+def test_close_on_async():
+    response = httpx.Response(
+        200,
+        content=async_streaming_body(),
+    )
+
+    with pytest.raises(RuntimeError):
+        response.close()
+
+
 def test_iter_raw_increments_updates_counter():
     response = httpx.Response(200, content=streaming_body())
 
@@ -430,6 +440,17 @@ async def test_aiter_raw_on_sync():
         [part async for part in response.aiter_raw()]
 
 
+@pytest.mark.asyncio
+async def test_aclose_on_sync():
+    response = httpx.Response(
+        200,
+        content=streaming_body(),
+    )
+
+    with pytest.raises(RuntimeError):
+        await response.aclose()
+
+
 @pytest.mark.asyncio
 async def test_aiter_raw_increments_updates_counter():
     response = httpx.Response(200, content=async_streaming_body())