]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Replace task cancellation in `BaseHTTPMiddleware` with `http.disconnect`+`recv_stream...
authorJean Hominal <jhominal@gmail.com>
Sat, 24 Sep 2022 05:29:08 +0000 (07:29 +0200)
committerGitHub <noreply@github.com>
Sat, 24 Sep 2022 05:29:08 +0000 (07:29 +0200)
* replace BaseMiddleware cancellation after request send with closing recv_stream + http.disconnect in receive

fixes #1438

* Add no cover pragma on pytest.fail in tests/middleware/test_base.py

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
* make http_disconnect_while_sending test more robust in the face of scheduling issues

* Fix issue with running middleware context manager

Reported in https://github.com/encode/starlette/issues/1678#issuecomment-1172916042

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/middleware/base.py
tests/middleware/test_base.py

index 49a5e3e2d7a0a96cd0f7e83f6b9f527d074d4280..586c9870dc1be3bd026091fa731a0e2f8460e83c 100644 (file)
@@ -4,12 +4,13 @@ import anyio
 
 from starlette.requests import Request
 from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
 DispatchFunction = typing.Callable[
     [Request, RequestResponseEndpoint], typing.Awaitable[Response]
 ]
+T = typing.TypeVar("T")
 
 
 class BaseHTTPMiddleware:
@@ -24,19 +25,52 @@ class BaseHTTPMiddleware:
             await self.app(scope, receive, send)
             return
 
+        response_sent = anyio.Event()
+
         async def call_next(request: Request) -> Response:
             app_exc: typing.Optional[Exception] = None
             send_stream, recv_stream = anyio.create_memory_object_stream()
 
+            async def receive_or_disconnect() -> Message:
+                if response_sent.is_set():
+                    return {"type": "http.disconnect"}
+
+                async with anyio.create_task_group() as task_group:
+
+                    async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
+                        result = await func()
+                        task_group.cancel_scope.cancel()
+                        return result
+
+                    task_group.start_soon(wrap, response_sent.wait)
+                    message = await wrap(request.receive)
+
+                if response_sent.is_set():
+                    return {"type": "http.disconnect"}
+
+                return message
+
+            async def close_recv_stream_on_response_sent() -> None:
+                await response_sent.wait()
+                recv_stream.close()
+
+            async def send_no_error(message: Message) -> None:
+                try:
+                    await send_stream.send(message)
+                except anyio.BrokenResourceError:
+                    # recv_stream has been closed, i.e. response_sent has been set.
+                    return
+
             async def coro() -> None:
                 nonlocal app_exc
 
                 async with send_stream:
                     try:
-                        await self.app(scope, request.receive, send_stream.send)
+                        await self.app(scope, receive_or_disconnect, send_no_error)
                     except Exception as exc:
                         app_exc = exc
 
+            task_group.start_soon(close_recv_stream_on_response_sent)
             task_group.start_soon(coro)
 
             try:
@@ -71,7 +105,7 @@ class BaseHTTPMiddleware:
             request = Request(scope, receive=receive)
             response = await self.dispatch_func(request, call_next)
             await response(scope, receive, send)
-            task_group.cancel_scope.cancel()
+            response_sent.set()
 
     async def dispatch(
         self, request: Request, call_next: RequestResponseEndpoint
index 976d77b86074b7106c3d9f9c18eded9e12ecdc25..ed0734bd389063a1d3d6117d677c505ebc0a76ad 100644 (file)
@@ -1,8 +1,11 @@
 import contextvars
+from contextlib import AsyncExitStack
 
+import anyio
 import pytest
 
 from starlette.applications import Starlette
+from starlette.background import BackgroundTask
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.responses import PlainTextResponse, StreamingResponse
@@ -206,3 +209,207 @@ def test_contextvars(test_client_factory, middleware_cls: type):
     client = test_client_factory(app)
     response = client.get("/")
     assert response.status_code == 200, response.content
+
+
+@pytest.mark.anyio
+async def test_run_background_tasks_even_if_client_disconnects():
+    # test for https://github.com/encode/starlette/issues/1438
+    request_body_sent = False
+    response_complete = anyio.Event()
+    background_task_run = anyio.Event()
+
+    async def sleep_and_set():
+        # small delay to give BaseHTTPMiddleware a chance to cancel us
+        # this is required to make the test fail prior to fixing the issue
+        # so do not be surprised if you remove it and the test still passes
+        await anyio.sleep(0.1)
+        background_task_run.set()
+
+    async def endpoint_with_background_task(_):
+        return PlainTextResponse(background=BackgroundTask(sleep_and_set))
+
+    async def passthrough(request, call_next):
+        return await call_next(request)
+
+    app = Starlette(
+        middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
+        routes=[Route("/", endpoint_with_background_task)],
+    )
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+    }
+
+    async def receive():
+        nonlocal request_body_sent
+        if not request_body_sent:
+            request_body_sent = True
+            return {"type": "http.request", "body": b"", "more_body": False}
+        # We simulate a client that disconnects immediately after receiving the response
+        await response_complete.wait()
+        return {"type": "http.disconnect"}
+
+    async def send(message):
+        if message["type"] == "http.response.body":
+            if not message.get("more_body", False):
+                response_complete.set()
+
+    await app(scope, receive, send)
+
+    assert background_task_run.is_set()
+
+
+@pytest.mark.anyio
+async def test_run_context_manager_exit_even_if_client_disconnects():
+    # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
+    request_body_sent = False
+    response_complete = anyio.Event()
+    context_manager_exited = anyio.Event()
+
+    async def sleep_and_set():
+        # small delay to give BaseHTTPMiddleware a chance to cancel us
+        # this is required to make the test fail prior to fixing the issue
+        # so do not be surprised if you remove it and the test still passes
+        await anyio.sleep(0.1)
+        context_manager_exited.set()
+
+    class ContextManagerMiddleware:
+        def __init__(self, app):
+            self.app = app
+
+        async def __call__(self, scope: Scope, receive: Receive, send: Send):
+            async with AsyncExitStack() as stack:
+                stack.push_async_callback(sleep_and_set)
+                await self.app(scope, receive, send)
+
+    async def simple_endpoint(_):
+        return PlainTextResponse(background=BackgroundTask(sleep_and_set))
+
+    async def passthrough(request, call_next):
+        return await call_next(request)
+
+    app = Starlette(
+        middleware=[
+            Middleware(BaseHTTPMiddleware, dispatch=passthrough),
+            Middleware(ContextManagerMiddleware),
+        ],
+        routes=[Route("/", simple_endpoint)],
+    )
+
+    scope = {
+        "type": "http",
+        "version": "3",
+        "method": "GET",
+        "path": "/",
+    }
+
+    async def receive():
+        nonlocal request_body_sent
+        if not request_body_sent:
+            request_body_sent = True
+            return {"type": "http.request", "body": b"", "more_body": False}
+        # We simulate a client that disconnects immediately after receiving the response
+        await response_complete.wait()
+        return {"type": "http.disconnect"}
+
+    async def send(message):
+        if message["type"] == "http.response.body":
+            if not message.get("more_body", False):
+                response_complete.set()
+
+    await app(scope, receive, send)
+
+    assert context_manager_exited.is_set()
+
+
+def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory):
+    class DiscardingMiddleware(BaseHTTPMiddleware):
+        async def dispatch(self, request, call_next):
+            await call_next(request)
+            return PlainTextResponse("Custom")
+
+    async def downstream_app(scope, receive, send):
+        await send(
+            {
+                "type": "http.response.start",
+                "status": 200,
+                "headers": [
+                    (b"content-type", b"text/plain"),
+                ],
+            }
+        )
+        async with anyio.create_task_group() as task_group:
+
+            async def cancel_on_disconnect():
+                while True:
+                    message = await receive()
+                    if message["type"] == "http.disconnect":
+                        task_group.cancel_scope.cancel()
+                        break
+
+            task_group.start_soon(cancel_on_disconnect)
+
+            # A timeout is set for 0.1 second in order to ensure that
+            # cancel_on_disconnect is scheduled by the event loop
+            with anyio.move_on_after(0.1):
+                while True:
+                    await send(
+                        {
+                            "type": "http.response.body",
+                            "body": b"chunk ",
+                            "more_body": True,
+                        }
+                    )
+
+            pytest.fail(
+                "http.disconnect should have been received and canceled the scope"
+            )  # pragma: no cover
+
+    app = DiscardingMiddleware(downstream_app)
+
+    client = test_client_factory(app)
+    response = client.get("/does_not_exist")
+    assert response.text == "Custom"
+
+
+def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory):
+    class DiscardingMiddleware(BaseHTTPMiddleware):
+        async def dispatch(self, request, call_next):
+            await call_next(request)
+            return PlainTextResponse("Custom")
+
+    async def downstream_app(scope, receive, send):
+        await send(
+            {
+                "type": "http.response.start",
+                "status": 200,
+                "headers": [
+                    (b"content-type", b"text/plain"),
+                ],
+            }
+        )
+        await send(
+            {
+                "type": "http.response.body",
+                "body": b"first chunk, ",
+                "more_body": True,
+            }
+        )
+        await send(
+            {
+                "type": "http.response.body",
+                "body": b"second chunk",
+                "more_body": True,
+            }
+        )
+        message = await receive()
+        assert message["type"] == "http.disconnect"
+
+    app = DiscardingMiddleware(downstream_app)
+
+    client = test_client_factory(app)
+    response = client.get("/does_not_exist")
+    assert response.text == "Custom"