From: florimondmanca Date: Tue, 14 Jun 2022 22:24:41 +0000 (+0200) Subject: Switch to dispatch(conn) instead of dispatch(scope), coverage X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=4c48d3dcde6357c2315ce18ba27b8bf1a19a8fa7;p=thirdparty%2Fstarlette.git Switch to dispatch(conn) instead of dispatch(scope), coverage --- diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index ce4a7373..a2365eca 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -2,10 +2,11 @@ from typing import AsyncGenerator, Callable, Optional, Union from .._compat import aclosing from ..datastructures import MutableHeaders +from ..requests import HTTPConnection from ..responses import Response from ..types import ASGIApp, Message, Receive, Scope, Send -_HTTPDispatchFlow = Union[ +_DispatchFlow = Union[ AsyncGenerator[None, Response], AsyncGenerator[Response, Response], AsyncGenerator[Optional[Response], Response], @@ -16,12 +17,15 @@ class HTTPMiddleware: def __init__( self, app: ASGIApp, - dispatch: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None, + dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None, ) -> None: + if dispatch is None: + dispatch = self.dispatch + self.app = app - self.dispatch_func = self.dispatch if dispatch is None else dispatch + self._dispatch_func = dispatch - def dispatch(self, scope: Scope) -> _HTTPDispatchFlow: + def dispatch(self, conn: HTTPConnection) -> _DispatchFlow: raise NotImplementedError # pragma: no cover async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -29,7 +33,9 @@ class HTTPMiddleware: await self.app(scope, receive, send) return - async with aclosing(self.dispatch(scope)) as flow: + conn = HTTPConnection(scope) + + async with aclosing(self._dispatch_func(conn)) as flow: # Kick the flow until the first `yield`. # Might respond early before we call into the app. maybe_early_response = await flow.__anext__() diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index a36235c5..b65d5a6d 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -6,13 +6,14 @@ import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.http import HTTPMiddleware +from starlette.requests import HTTPConnection from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.types import ASGIApp, Receive, Scope, Send class CustomMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: response = yield response.headers["Custom-Header"] = "Example" @@ -89,20 +90,26 @@ def test_state_data_across_multiple_middlewares(test_client_factory): expected_value2 = "bar" class aMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: - scope["state_foo"] = expected_value1 + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: + conn.state.foo = expected_value1 yield class bMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: - scope["state_bar"] = expected_value2 + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: + conn.state.bar = expected_value2 response = yield - response.headers["X-State-Foo"] = scope["state_foo"] + response.headers["X-State-Foo"] = conn.state.foo class cMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: response = yield - response.headers["X-State-Bar"] = scope["state_bar"] + response.headers["X-State-Bar"] = conn.state.bar def homepage(request): return PlainTextResponse("OK") @@ -123,12 +130,17 @@ def test_state_data_across_multiple_middlewares(test_client_factory): assert response.headers["X-State-Bar"] == expected_value2 -def test_app_middleware_argument(test_client_factory): +def test_dispatch_argument(test_client_factory): def homepage(request): return PlainTextResponse("Homepage") + async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["Custom-Header"] = "Example" + app = Starlette( - routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] + routes=[Route("/", homepage)], + middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)], ) client = test_client_factory(app) @@ -143,7 +155,9 @@ def test_middleware_repr(): def test_early_response(test_client_factory): class CustomMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]: + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Response, Response]: yield Response(status_code=401) app = Starlette(middleware=[Middleware(CustomMiddleware)]) @@ -155,7 +169,9 @@ def test_early_response(test_client_factory): def test_too_many_yields(test_client_factory) -> None: class BadMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: _ = yield yield @@ -175,7 +191,7 @@ def test_error_response(test_client_factory) -> None: class ErrorMiddleware(HTTPMiddleware): async def dispatch( - self, scope: Scope + self, conn: HTTPConnection ) -> AsyncGenerator[Optional[Response], Response]: try: yield None @@ -201,7 +217,9 @@ def test_no_error_response(test_client_factory) -> None: raise Failed() class BadMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: try: yield except Failed: @@ -228,7 +246,7 @@ class PureASGICustomMiddleware: class HTTPCustomMiddleware(HTTPMiddleware): - async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]: + async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: ctxvar.set("set by middleware") yield assert ctxvar.get() == "set by endpoint"