from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
+T = typing.TypeVar("T")
class BaseHTTPMiddleware:
await self.app(scope, receive, send)
return
+ response_sent = anyio.Event()
+
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
+ async def receive_or_disconnect() -> Message:
+ if response_sent.is_set():
+ return {"type": "http.disconnect"}
+
+ async with anyio.create_task_group() as task_group:
+
+ async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
+ result = await func()
+ task_group.cancel_scope.cancel()
+ return result
+
+ task_group.start_soon(wrap, response_sent.wait)
+ message = await wrap(request.receive)
+
+ if response_sent.is_set():
+ return {"type": "http.disconnect"}
+
+ return message
+
+ async def close_recv_stream_on_response_sent() -> None:
+ await response_sent.wait()
+ recv_stream.close()
+
+ async def send_no_error(message: Message) -> None:
+ try:
+ await send_stream.send(message)
+ except anyio.BrokenResourceError:
+ # recv_stream has been closed, i.e. response_sent has been set.
+ return
+
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
- await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
+ task_group.start_soon(close_recv_stream_on_response_sent)
task_group.start_soon(coro)
try:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
- task_group.cancel_scope.cancel()
+ response_sent.set()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
import contextvars
+from contextlib import AsyncExitStack
+import anyio
import pytest
from starlette.applications import Starlette
+from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content
+
+
+@pytest.mark.anyio
+async def test_run_background_tasks_even_if_client_disconnects():
+ # test for https://github.com/encode/starlette/issues/1438
+ request_body_sent = False
+ response_complete = anyio.Event()
+ background_task_run = anyio.Event()
+
+ async def sleep_and_set():
+ # small delay to give BaseHTTPMiddleware a chance to cancel us
+ # this is required to make the test fail prior to fixing the issue
+ # so do not be surprised if you remove it and the test still passes
+ await anyio.sleep(0.1)
+ background_task_run.set()
+
+ async def endpoint_with_background_task(_):
+ return PlainTextResponse(background=BackgroundTask(sleep_and_set))
+
+ async def passthrough(request, call_next):
+ return await call_next(request)
+
+ app = Starlette(
+ middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
+ routes=[Route("/", endpoint_with_background_task)],
+ )
+
+ scope = {
+ "type": "http",
+ "version": "3",
+ "method": "GET",
+ "path": "/",
+ }
+
+ async def receive():
+ nonlocal request_body_sent
+ if not request_body_sent:
+ request_body_sent = True
+ return {"type": "http.request", "body": b"", "more_body": False}
+ # We simulate a client that disconnects immediately after receiving the response
+ await response_complete.wait()
+ return {"type": "http.disconnect"}
+
+ async def send(message):
+ if message["type"] == "http.response.body":
+ if not message.get("more_body", False):
+ response_complete.set()
+
+ await app(scope, receive, send)
+
+ assert background_task_run.is_set()
+
+
+@pytest.mark.anyio
+async def test_run_context_manager_exit_even_if_client_disconnects():
+ # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
+ request_body_sent = False
+ response_complete = anyio.Event()
+ context_manager_exited = anyio.Event()
+
+ async def sleep_and_set():
+ # small delay to give BaseHTTPMiddleware a chance to cancel us
+ # this is required to make the test fail prior to fixing the issue
+ # so do not be surprised if you remove it and the test still passes
+ await anyio.sleep(0.1)
+ context_manager_exited.set()
+
+ class ContextManagerMiddleware:
+ def __init__(self, app):
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
+ async with AsyncExitStack() as stack:
+ stack.push_async_callback(sleep_and_set)
+ await self.app(scope, receive, send)
+
+ async def simple_endpoint(_):
+ return PlainTextResponse(background=BackgroundTask(sleep_and_set))
+
+ async def passthrough(request, call_next):
+ return await call_next(request)
+
+ app = Starlette(
+ middleware=[
+ Middleware(BaseHTTPMiddleware, dispatch=passthrough),
+ Middleware(ContextManagerMiddleware),
+ ],
+ routes=[Route("/", simple_endpoint)],
+ )
+
+ scope = {
+ "type": "http",
+ "version": "3",
+ "method": "GET",
+ "path": "/",
+ }
+
+ async def receive():
+ nonlocal request_body_sent
+ if not request_body_sent:
+ request_body_sent = True
+ return {"type": "http.request", "body": b"", "more_body": False}
+ # We simulate a client that disconnects immediately after receiving the response
+ await response_complete.wait()
+ return {"type": "http.disconnect"}
+
+ async def send(message):
+ if message["type"] == "http.response.body":
+ if not message.get("more_body", False):
+ response_complete.set()
+
+ await app(scope, receive, send)
+
+ assert context_manager_exited.is_set()
+
+
+def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
+ class DiscardingMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ await call_next(request)
+ return PlainTextResponse("Custom")
+
+ async def downstream_app(scope, receive, send):
+ await send(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [
+ (b"content-type", b"text/plain"),
+ ],
+ }
+ )
+ async with anyio.create_task_group() as task_group:
+
+ async def cancel_on_disconnect():
+ while True:
+ message = await receive()
+ if message["type"] == "http.disconnect":
+ task_group.cancel_scope.cancel()
+ break
+
+ task_group.start_soon(cancel_on_disconnect)
+
+ # A timeout is set for 0.1 second in order to ensure that
+ # cancel_on_disconnect is scheduled by the event loop
+ with anyio.move_on_after(0.1):
+ while True:
+ await send(
+ {
+ "type": "http.response.body",
+ "body": b"chunk ",
+ "more_body": True,
+ }
+ )
+
+ pytest.fail(
+ "http.disconnect should have been received and canceled the scope"
+ ) # pragma: no cover
+
+ app = DiscardingMiddleware(downstream_app)
+
+ client = test_client_factory(app)
+ response = client.get("/does_not_exist")
+ assert response.text == "Custom"
+
+
+def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
+ class DiscardingMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ await call_next(request)
+ return PlainTextResponse("Custom")
+
+ async def downstream_app(scope, receive, send):
+ await send(
+ {
+ "type": "http.response.start",
+ "status": 200,
+ "headers": [
+ (b"content-type", b"text/plain"),
+ ],
+ }
+ )
+ await send(
+ {
+ "type": "http.response.body",
+ "body": b"first chunk, ",
+ "more_body": True,
+ }
+ )
+ await send(
+ {
+ "type": "http.response.body",
+ "body": b"second chunk",
+ "more_body": True,
+ }
+ )
+ message = await receive()
+ assert message["type"] == "http.disconnect"
+
+ app = DiscardingMiddleware(downstream_app)
+
+ client = test_client_factory(app)
+ response = client.get("/does_not_exist")
+ assert response.text == "Custom"