]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Experiment a high-level HTTPMiddleware
authorflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 20:29:22 +0000 (22:29 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 20:29:22 +0000 (22:29 +0200)
starlette/middleware/http.py [new file with mode: 0644]
tests/middleware/test_http.py [new file with mode: 0644]

diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py
new file mode 100644 (file)
index 0000000..61571d5
--- /dev/null
@@ -0,0 +1,84 @@
+from functools import partial
+from typing import Generator, Optional
+from ..types import ASGIApp, Scope, Receive, Send, Message
+from ..responses import Response
+from ..datastructures import MutableHeaders
+
+
+class HTTPMiddleware:
+    DispatchFlow = Generator[Optional[Response], Response, None]
+
+    def __init__(self, app: ASGIApp) -> None:
+        self.app = app
+
+    def dispatch(self, scope: Scope) -> DispatchFlow:
+        raise NotImplementedError
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        if scope["type"] != "http":
+            await self.app(scope, receive, send)
+            return
+
+        flow = self.dispatch(scope)
+
+        try:
+            # Run until first `yield` to allow modifying the connection scope.
+            # Middleware may return a response before we call into the underlying app.
+            early_response = flow.__next__()
+        except StopIteration:
+            raise RuntimeError("dispatch() did not run")
+
+        if early_response is not None:
+            await early_response(scope, receive, send)
+            return
+
+        response_started = set[bool]()
+
+        wrapped_send = partial(
+            self._send,
+            flow=flow,
+            response_started=response_started,
+            send=send,
+        )
+
+        try:
+            await self.app(scope, receive, wrapped_send)
+        except Exception as exc:
+            try:
+                response = flow.throw(exc)
+            except Exception:
+                # Exception was not handled, or they raised another one.
+                raise
+
+            if response is None:
+                raise RuntimeError(
+                    f"dispatch() handled exception {exc!r}, "
+                    "but no response application was returned"
+                )
+
+            await response(scope, receive, send)
+
+        if not response_started:
+            raise RuntimeError("No response returned.")
+
+    async def _send(
+        self, message: Message, *, flow: DispatchFlow, response_started: set, send: Send
+    ) -> None:
+        if message["type"] == "http.response.start":
+            response_started.add(True)
+
+            response = Response(status_code=message["status"])
+            response.raw_headers.clear()
+
+            try:
+                flow.send(response)
+            except StopIteration as exc:
+                if exc.value is not None:
+                    raise RuntimeError("swapping responses it not supported")
+            else:
+                raise RuntimeError("dispatch() should yield exactly once")
+
+            headers = MutableHeaders(raw=message["headers"])
+            headers.update(response.headers)
+
+        await send(message)
diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py
new file mode 100644 (file)
index 0000000..af95189
--- /dev/null
@@ -0,0 +1,208 @@
+import contextvars
+
+import pytest
+
+from starlette.applications import Starlette
+from starlette.middleware import Middleware
+from starlette.middleware.http import HTTPMiddleware
+from starlette.responses import PlainTextResponse, StreamingResponse
+from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class CustomMiddleware(HTTPMiddleware):
+    def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+        response = yield None
+        response.headers["Custom-Header"] = "Example"
+
+
+def homepage(request):
+    return PlainTextResponse("Homepage")
+
+
+def exc(request):
+    raise Exception("Exc")
+
+
+def exc_stream(request):
+    return StreamingResponse(_generate_faulty_stream())
+
+
+def _generate_faulty_stream():
+    yield b"Ok"
+    raise Exception("Faulty Stream")
+
+
+class NoResponse:
+    def __init__(self, scope, receive, send):
+        pass
+
+    def __await__(self):
+        return self.dispatch().__await__()
+
+    async def dispatch(self):
+        pass
+
+
+async def websocket_endpoint(session):
+    await session.accept()
+    await session.send_text("Hello, world!")
+    await session.close()
+
+
+app = Starlette(
+    routes=[
+        Route("/", endpoint=homepage),
+        Route("/exc", endpoint=exc),
+        Route("/exc-stream", endpoint=exc_stream),
+        Route("/no-response", endpoint=NoResponse),
+        WebSocketRoute("/ws", endpoint=websocket_endpoint),
+    ],
+    middleware=[Middleware(CustomMiddleware)],
+)
+
+
+def test_custom_middleware(test_client_factory):
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.headers["Custom-Header"] == "Example"
+
+    with pytest.raises(Exception) as ctx:
+        response = client.get("/exc")
+    assert str(ctx.value) == "Exc"
+
+    with pytest.raises(Exception) as ctx:
+        response = client.get("/exc-stream")
+    assert str(ctx.value) == "Faulty Stream"
+
+    with pytest.raises(RuntimeError):
+        response = client.get("/no-response")
+
+    with client.websocket_connect("/ws") as session:
+        text = session.receive_text()
+        assert text == "Hello, world!"
+
+
+def test_state_data_across_multiple_middlewares(test_client_factory):
+    expected_value1 = "foo"
+    expected_value2 = "bar"
+
+    class aMiddleware(HTTPMiddleware):
+        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+            scope["state_foo"] = expected_value1
+            yield None
+
+    class bMiddleware(HTTPMiddleware):
+        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+            scope["state_bar"] = expected_value2
+            response = yield None
+            response.headers["X-State-Foo"] = scope["state_foo"]
+
+    class cMiddleware(HTTPMiddleware):
+        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+            response = yield None
+            response.headers["X-State-Bar"] = scope["state_bar"]
+
+    def homepage(request):
+        return PlainTextResponse("OK")
+
+    app = Starlette(
+        routes=[Route("/", homepage)],
+        middleware=[
+            Middleware(aMiddleware),
+            Middleware(bMiddleware),
+            Middleware(cMiddleware),
+        ],
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.text == "OK"
+    assert response.headers["X-State-Foo"] == expected_value1
+    assert response.headers["X-State-Bar"] == expected_value2
+
+
+def test_app_middleware_argument(test_client_factory):
+    def homepage(request):
+        return PlainTextResponse("Homepage")
+
+    app = Starlette(
+        routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.headers["Custom-Header"] == "Example"
+
+
+def test_middleware_repr():
+    middleware = Middleware(CustomMiddleware)
+    assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+
+def test_fully_evaluated_response(test_client_factory):
+    # Test for https://github.com/encode/starlette/issues/1022
+    class CustomMiddleware(HTTPMiddleware):
+        def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+            yield PlainTextResponse("Custom")
+
+    app = Starlette(middleware=[Middleware(CustomMiddleware)])
+
+    client = test_client_factory(app)
+    response = client.get("/does_not_exist")
+    assert response.text == "Custom"
+
+
+def test_exception_on_mounted_apps(test_client_factory):
+    sub_app = Starlette(routes=[Route("/", exc)])
+    app = Starlette(routes=[Mount("/sub", app=sub_app)])
+
+    client = test_client_factory(app)
+    with pytest.raises(Exception) as ctx:
+        client.get("/sub/")
+    assert str(ctx.value) == "Exc"
+
+
+ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
+
+
+class CustomMiddlewareWithoutBaseHTTPMiddleware:
+    def __init__(self, app: ASGIApp) -> None:
+        self.app = app
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        ctxvar.set("set by middleware")
+        await self.app(scope, receive, send)
+        assert ctxvar.get() == "set by endpoint"
+
+
+class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
+    def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+        ctxvar.set("set by middleware")
+        yield None
+        assert ctxvar.get() == "set by endpoint"
+
+
+@pytest.mark.parametrize(
+    "middleware_cls",
+    [
+        CustomMiddlewareWithoutBaseHTTPMiddleware,
+        CustomMiddlewareUsingHTTPMiddleware,
+    ],
+)
+def test_contextvars(test_client_factory, middleware_cls: type):
+    # this has to be an async endpoint because Starlette calls run_in_threadpool
+    # on sync endpoints which has it's own set of peculiarities w.r.t propagating
+    # contextvars (it propagates them forwards but not backwards)
+    async def homepage(request):
+        assert ctxvar.get() == "set by middleware"
+        ctxvar.set("set by endpoint")
+        return PlainTextResponse("Homepage")
+
+    app = Starlette(
+        middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.status_code == 200, response.content