From a8fa47d9b7ba9c70410a436ef387b9afd29fd312 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Tue, 14 Jun 2022 22:29:22 +0200 Subject: [PATCH] Experiment a high-level HTTPMiddleware --- starlette/middleware/http.py | 84 ++++++++++++++ tests/middleware/test_http.py | 208 ++++++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 starlette/middleware/http.py create mode 100644 tests/middleware/test_http.py diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py new file mode 100644 index 00000000..61571d54 --- /dev/null +++ b/starlette/middleware/http.py @@ -0,0 +1,84 @@ +from functools import partial +from typing import Generator, Optional +from ..types import ASGIApp, Scope, Receive, Send, Message +from ..responses import Response +from ..datastructures import MutableHeaders + + +class HTTPMiddleware: + DispatchFlow = Generator[Optional[Response], Response, None] + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def dispatch(self, scope: Scope) -> DispatchFlow: + raise NotImplementedError + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + flow = self.dispatch(scope) + + try: + # Run until first `yield` to allow modifying the connection scope. + # Middleware may return a response before we call into the underlying app. + early_response = flow.__next__() + except StopIteration: + raise RuntimeError("dispatch() did not run") + + if early_response is not None: + await early_response(scope, receive, send) + return + + response_started = set[bool]() + + wrapped_send = partial( + self._send, + flow=flow, + response_started=response_started, + send=send, + ) + + try: + await self.app(scope, receive, wrapped_send) + except Exception as exc: + try: + response = flow.throw(exc) + except Exception: + # Exception was not handled, or they raised another one. + raise + + if response is None: + raise RuntimeError( + f"dispatch() handled exception {exc!r}, " + "but no response application was returned" + ) + + await response(scope, receive, send) + + if not response_started: + raise RuntimeError("No response returned.") + + async def _send( + self, message: Message, *, flow: DispatchFlow, response_started: set, send: Send + ) -> None: + if message["type"] == "http.response.start": + response_started.add(True) + + response = Response(status_code=message["status"]) + response.raw_headers.clear() + + try: + flow.send(response) + except StopIteration as exc: + if exc.value is not None: + raise RuntimeError("swapping responses it not supported") + else: + raise RuntimeError("dispatch() should yield exactly once") + + headers = MutableHeaders(raw=message["headers"]) + headers.update(response.headers) + + await send(message) diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py new file mode 100644 index 00000000..af951891 --- /dev/null +++ b/tests/middleware/test_http.py @@ -0,0 +1,208 @@ +import contextvars + +import pytest + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.http import HTTPMiddleware +from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.routing import Mount, Route, WebSocketRoute +from starlette.types import ASGIApp, Receive, Scope, Send + + +class CustomMiddleware(HTTPMiddleware): + def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + response = yield None + response.headers["Custom-Header"] = "Example" + + +def homepage(request): + return PlainTextResponse("Homepage") + + +def exc(request): + raise Exception("Exc") + + +def exc_stream(request): + return StreamingResponse(_generate_faulty_stream()) + + +def _generate_faulty_stream(): + yield b"Ok" + raise Exception("Faulty Stream") + + +class NoResponse: + def __init__(self, scope, receive, send): + pass + + def __await__(self): + return self.dispatch().__await__() + + async def dispatch(self): + pass + + +async def websocket_endpoint(session): + await session.accept() + await session.send_text("Hello, world!") + await session.close() + + +app = Starlette( + routes=[ + Route("/", endpoint=homepage), + Route("/exc", endpoint=exc), + Route("/exc-stream", endpoint=exc_stream), + Route("/no-response", endpoint=NoResponse), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ], + middleware=[Middleware(CustomMiddleware)], +) + + +def test_custom_middleware(test_client_factory): + client = test_client_factory(app) + response = client.get("/") + assert response.headers["Custom-Header"] == "Example" + + with pytest.raises(Exception) as ctx: + response = client.get("/exc") + assert str(ctx.value) == "Exc" + + with pytest.raises(Exception) as ctx: + response = client.get("/exc-stream") + assert str(ctx.value) == "Faulty Stream" + + with pytest.raises(RuntimeError): + response = client.get("/no-response") + + with client.websocket_connect("/ws") as session: + text = session.receive_text() + assert text == "Hello, world!" + + +def test_state_data_across_multiple_middlewares(test_client_factory): + expected_value1 = "foo" + expected_value2 = "bar" + + class aMiddleware(HTTPMiddleware): + def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + scope["state_foo"] = expected_value1 + yield None + + class bMiddleware(HTTPMiddleware): + def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + scope["state_bar"] = expected_value2 + response = yield None + response.headers["X-State-Foo"] = scope["state_foo"] + + class cMiddleware(HTTPMiddleware): + def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + response = yield None + response.headers["X-State-Bar"] = scope["state_bar"] + + def homepage(request): + return PlainTextResponse("OK") + + app = Starlette( + routes=[Route("/", homepage)], + middleware=[ + Middleware(aMiddleware), + Middleware(bMiddleware), + Middleware(cMiddleware), + ], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "OK" + assert response.headers["X-State-Foo"] == expected_value1 + assert response.headers["X-State-Bar"] == expected_value2 + + +def test_app_middleware_argument(test_client_factory): + def homepage(request): + return PlainTextResponse("Homepage") + + app = Starlette( + routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.headers["Custom-Header"] == "Example" + + +def test_middleware_repr(): + middleware = Middleware(CustomMiddleware) + assert repr(middleware) == "Middleware(CustomMiddleware)" + + +def test_fully_evaluated_response(test_client_factory): + # Test for https://github.com/encode/starlette/issues/1022 + class CustomMiddleware(HTTPMiddleware): + def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + yield PlainTextResponse("Custom") + + app = Starlette(middleware=[Middleware(CustomMiddleware)]) + + client = test_client_factory(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" + + +def test_exception_on_mounted_apps(test_client_factory): + sub_app = Starlette(routes=[Route("/", exc)]) + app = Starlette(routes=[Mount("/sub", app=sub_app)]) + + client = test_client_factory(app) + with pytest.raises(Exception) as ctx: + client.get("/sub/") + assert str(ctx.value) == "Exc" + + +ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + + +class CustomMiddlewareWithoutBaseHTTPMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ctxvar.set("set by middleware") + await self.app(scope, receive, send) + assert ctxvar.get() == "set by endpoint" + + +class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware): + def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow: + ctxvar.set("set by middleware") + yield None + assert ctxvar.get() == "set by endpoint" + + +@pytest.mark.parametrize( + "middleware_cls", + [ + CustomMiddlewareWithoutBaseHTTPMiddleware, + CustomMiddlewareUsingHTTPMiddleware, + ], +) +def test_contextvars(test_client_factory, middleware_cls: type): + # this has to be an async endpoint because Starlette calls run_in_threadpool + # on sync endpoints which has it's own set of peculiarities w.r.t propagating + # contextvars (it propagates them forwards but not backwards) + async def homepage(request): + assert ctxvar.get() == "set by middleware" + ctxvar.set("set by endpoint") + return PlainTextResponse("Homepage") + + app = Starlette( + middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 200, response.content -- 2.47.3