from anyio.abc import ObjectReceiveStream, ObjectSendStream
from starlette._utils import collapse_excgroups
-from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
-from starlette.responses import ContentStream, Response, StreamingResponse
+from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
+ self._wrapped_rcv_disconnected = True
return msg
# wrapped_rcv state 3: not yet consumed
raise NotImplementedError() # pragma: no cover
-class _StreamingResponse(StreamingResponse):
+class _StreamingResponse(Response):
def __init__(
self,
- content: ContentStream,
+ content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
- background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
- self._info = info
- super().__init__(content, status_code, headers, media_type, background)
+ self.info = info
+ self.body_iterator = content
+ self.status_code = status_code
+ self.media_type = media_type
+ self.init_headers(headers)
- async def stream_response(self, send: Send) -> None:
- if self._info:
- await send({"type": "http.response.debug", "info": self._info})
- return await super().stream_response(send)
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if self.info is not None:
+ await send({"type": "http.response.debug", "info": self.info})
+ await send(
+ {
+ "type": "http.response.start",
+ "status": self.status_code,
+ "headers": self.raw_headers,
+ }
+ )
+
+ async for chunk in self.body_iterator:
+ await send({"type": "http.response.body", "body": chunk, "more_body": True})
+
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
from typing import (
Any,
AsyncGenerator,
+ AsyncIterator,
Generator,
)
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
-from starlette.requests import Request
+from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1438
- request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()
}
async def receive() -> Message:
- nonlocal request_body_sent
- if not request_body_sent:
- request_body_sent = True
- return {"type": "http.request", "body": b"", "more_body": False}
- # We simulate a client that disconnects immediately after receiving the response
- await response_complete.wait()
- return {"type": "http.disconnect"}
+ raise NotImplementedError("Should not be called!") # pragma: no cover
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
- request_body_sent = False
response_complete = anyio.Event()
events: list[str | Message] = []
}
async def receive() -> Message:
- nonlocal request_body_sent
- if not request_body_sent:
- request_body_sent = True
- return {"type": "http.request", "body": b"", "more_body": False}
- await response_complete.wait()
- return {"type": "http.disconnect"}
+ raise NotImplementedError("Should not be called!") # pragma: no cover
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
- request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()
}
async def receive() -> Message:
- nonlocal request_body_sent
- if not request_body_sent:
- request_body_sent = True
- return {"type": "http.request", "body": b"", "more_body": False}
- # We simulate a client that disconnects immediately after receiving the response
- await response_complete.wait()
- return {"type": "http.disconnect"}
+ raise NotImplementedError("Should not be called!") # pragma: no cover
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
yield {"type": "http.request", "body": b"1", "more_body": True}
yield {"type": "http.request", "body": b"2", "more_body": True}
yield {"type": "http.request", "body": b"3"}
- await anyio.sleep(float("inf"))
+ raise AssertionError( # pragma: no cover
+ "Should not be called, no need to poll for disconnect"
+ )
sent: list[Message] = []
resp.raise_for_status()
assert bodies == [b"Hello, World!-foo"]
+
+
+@pytest.mark.anyio
+async def test_multiple_middlewares_stacked_client_disconnected() -> None:
+ class MyMiddleware(BaseHTTPMiddleware):
+ def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
+ self.version = version
+ self.events = events
+ super().__init__(app)
+
+ async def dispatch(
+ self, request: Request, call_next: RequestResponseEndpoint
+ ) -> Response:
+ self.events.append(f"{self.version}:STARTED")
+ res = await call_next(request)
+ self.events.append(f"{self.version}:COMPLETED")
+ return res
+
+ async def sleepy(request: Request) -> Response:
+ try:
+ await request.body()
+ except ClientDisconnect:
+ pass
+ else: # pragma: no cover
+ raise AssertionError("Should have raised ClientDisconnect")
+ return Response(b"")
+
+ events: list[str] = []
+
+ app = Starlette(
+ routes=[Route("/", sleepy)],
+ middleware=[
+ Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
+ ],
+ )
+
+ scope = {
+ "type": "http",
+ "version": "3",
+ "method": "GET",
+ "path": "/",
+ }
+
+ async def receive() -> AsyncIterator[Message]:
+ yield {"type": "http.disconnect"}
+
+ sent: list[Message] = []
+
+ async def send(message: Message) -> None:
+ sent.append(message)
+
+ await app(scope, receive().__anext__, send)
+
+ assert events == [
+ "1:STARTED",
+ "2:STARTED",
+ "3:STARTED",
+ "4:STARTED",
+ "5:STARTED",
+ "6:STARTED",
+ "7:STARTED",
+ "8:STARTED",
+ "9:STARTED",
+ "10:STARTED",
+ "10:COMPLETED",
+ "9:COMPLETED",
+ "8:COMPLETED",
+ "7:COMPLETED",
+ "6:COMPLETED",
+ "5:COMPLETED",
+ "4:COMPLETED",
+ "3:COMPLETED",
+ "2:COMPLETED",
+ "1:COMPLETED",
+ ]
+
+ assert sent == [
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [(b"content-length", b"0")],
+ },
+ {"type": "http.response.body", "body": b"", "more_body": False},
+ ]
+
+
+@pytest.mark.anyio
+@pytest.mark.parametrize("send_body", [True, False])
+async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
+ async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
+ for _ in range(2):
+ msg = await receive()
+ while msg["type"] == "http.request":
+ msg = await receive()
+ assert msg["type"] == "http.disconnect"
+ await Response(b"good!")(scope, receive, send)
+
+ class MyMiddleware(BaseHTTPMiddleware):
+ async def dispatch(
+ self, request: Request, call_next: RequestResponseEndpoint
+ ) -> Response:
+ return await call_next(request)
+
+ app = MyMiddleware(app_poll_disconnect)
+
+ scope = {
+ "type": "http",
+ "version": "3",
+ "method": "GET",
+ "path": "/",
+ }
+
+ async def receive() -> AsyncIterator[Message]:
+ # the key here is that we only ever send 1 htt.disconnect message
+ if send_body:
+ yield {"type": "http.request", "body": b"hello", "more_body": True}
+ yield {"type": "http.request", "body": b"", "more_body": False}
+ yield {"type": "http.disconnect"}
+ raise AssertionError("Should not be called, would hang") # pragma: no cover
+
+ sent: list[Message] = []
+
+ async def send(message: Message) -> None:
+ sent.append(message)
+
+ await app(scope, receive().__anext__, send)
+
+ assert sent == [
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [(b"content-length", b"5")],
+ },
+ {"type": "http.response.body", "body": b"good!", "more_body": True},
+ {"type": "http.response.body", "body": b"", "more_body": False},
+ ]