CLOSED = 3
+class BoundSyncStream(SyncByteStream):
+ """
+ A byte stream that is bound to a given response instance, and that
+ ensures the `response.elapsed` is set once the response is closed.
+ """
+
+ def __init__(
+ self, stream: SyncByteStream, response: Response, timer: Timer
+ ) -> None:
+ self._stream = stream
+ self._response = response
+ self._timer = timer
+
+ def __iter__(self) -> typing.Iterator[bytes]:
+ for chunk in self._stream:
+ yield chunk
+
+ def close(self) -> None:
+ seconds = self._timer.sync_elapsed()
+ self._response.elapsed = datetime.timedelta(seconds=seconds)
+ self._stream.close()
+
+
+class BoundAsyncStream(AsyncByteStream):
+ """
+ An async byte stream that is bound to a given response instance, and that
+ ensures the `response.elapsed` is set once the response is closed.
+ """
+
+ def __init__(
+ self, stream: AsyncByteStream, response: Response, timer: Timer
+ ) -> None:
+ self._stream = stream
+ self._response = response
+ self._timer = timer
+
+ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
+ async for chunk in self._stream:
+ yield chunk
+
+ async def aclose(self) -> None:
+ seconds = await self._timer.async_elapsed()
+ self._response.elapsed = datetime.timedelta(seconds=seconds)
+ await self._stream.aclose()
+
+
class BaseClient:
def __init__(
self,
timer = Timer()
timer.sync_start()
+ if not isinstance(request.stream, SyncByteStream):
+ raise RuntimeError(
+ "Attempted to send an async request with a sync Client instance."
+ )
+
with request_context(request=request):
(status_code, headers, stream, extensions) = transport.handle_request(
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
- stream=request.stream, # type: ignore
+ stream=request.stream,
extensions={"timeout": timeout.as_dict()},
)
- def on_close(response: Response) -> None:
- response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
- stream.close()
-
response = Response(
status_code,
headers=headers,
stream=stream,
extensions=extensions,
request=request,
- on_close=on_close,
)
+ response.stream = BoundSyncStream(stream, response=response, timer=timer)
self.cookies.extract_cookies(response)
status = f"{response.status_code} {response.reason_phrase}"
timer = Timer()
await timer.async_start()
+ if not isinstance(request.stream, AsyncByteStream):
+ raise RuntimeError(
+ "Attempted to send an sync request with an AsyncClient instance."
+ )
+
with request_context(request=request):
(
status_code,
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
- stream=request.stream, # type: ignore
+ stream=request.stream,
extensions={"timeout": timeout.as_dict()},
)
- async def on_close(response: Response) -> None:
- response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
- await stream.aclose()
-
response = Response(
status_code,
headers=headers,
stream=stream,
extensions=extensions,
request=request,
- on_close=on_close,
)
+ response.stream = BoundAsyncStream(stream, response=response, timer=timer)
self.cookies.extract_cookies(response)
status = f"{response.status_code} {response.reason_phrase}"
request: Request = None,
extensions: dict = None,
history: typing.List["Response"] = None,
- on_close: typing.Callable = None,
):
self.status_code = status_code
self.headers = Headers(headers)
self.extensions = {} if extensions is None else extensions
self.history = [] if history is None else list(history)
- self._on_close = on_close
self.is_closed = False
self.is_stream_consumed = False
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
+ if not isinstance(self.stream, SyncByteStream):
+ raise RuntimeError("Attempted to call an sync close on an async stream.")
+
if not self.is_closed:
self.is_closed = True
- if self._on_close is not None:
- with request_context(request=self._request):
- self._on_close(self)
+ with request_context(request=self._request):
+ self.stream.close()
async def aread(self) -> bytes:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
+ if not isinstance(self.stream, AsyncByteStream):
+ raise RuntimeError("Attempted to call an async close on an sync stream.")
+
if not self.is_closed:
self.is_closed = True
- if self._on_close is not None:
- with request_context(request=self._request):
- await self._on_close(self)
+ with request_context(request=self._request):
+ await self.stream.aclose()
class Cookies(MutableMapping):
yield b"world!"
async with httpx.AsyncClient() as client:
- response = await client.request("POST", server.url, content=hello_world())
+ response = await client.post(server.url, content=hello_world())
assert response.status_code == 200
+@pytest.mark.usefixtures("async_environment")
+async def test_cannot_stream_sync_request(server):
+ def hello_world(): # pragma: nocover
+ yield b"Hello, "
+ yield b"world!"
+
+ async with httpx.AsyncClient() as client:
+ with pytest.raises(RuntimeError):
+ await client.post(server.url, content=hello_world())
+
+
@pytest.mark.usefixtures("async_environment")
async def test_raise_for_status(server):
async with httpx.AsyncClient() as client:
assert body == b"Hello, world!"
+def test_cannot_stream_async_request(server):
+ async def hello_world(): # pragma: nocover
+ yield b"Hello, "
+ yield b"world!"
+
+ with httpx.Client() as client:
+ with pytest.raises(RuntimeError):
+ client.post(server.url, content=hello_world())
+
+
def test_raise_for_status(server):
with httpx.Client() as client:
for status_code in (200, 400, 404, 500, 505):
[part for part in response.iter_raw()]
+def test_close_on_async():
+ response = httpx.Response(
+ 200,
+ content=async_streaming_body(),
+ )
+
+ with pytest.raises(RuntimeError):
+ response.close()
+
+
def test_iter_raw_increments_updates_counter():
response = httpx.Response(200, content=streaming_body())
[part async for part in response.aiter_raw()]
+@pytest.mark.asyncio
+async def test_aclose_on_sync():
+ response = httpx.Response(
+ 200,
+ content=streaming_body(),
+ )
+
+ with pytest.raises(RuntimeError):
+ await response.aclose()
+
+
@pytest.mark.asyncio
async def test_aiter_raw_increments_updates_counter():
response = httpx.Response(200, content=async_streaming_body())