]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Switch to dispatch(conn) instead of dispatch(scope), coverage
authorflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 22:24:41 +0000 (00:24 +0200)
committerflorimondmanca <florimond.manca@protonmail.com>
Tue, 14 Jun 2022 22:24:41 +0000 (00:24 +0200)
starlette/middleware/http.py
tests/middleware/test_http.py

index ce4a7373bcce629ac56c66932e0c1a7c90aa9654..a2365eca70960875e6250e66573182f2bd451854 100644 (file)
@@ -2,10 +2,11 @@ from typing import AsyncGenerator, Callable, Optional, Union
 
 from .._compat import aclosing
 from ..datastructures import MutableHeaders
+from ..requests import HTTPConnection
 from ..responses import Response
 from ..types import ASGIApp, Message, Receive, Scope, Send
 
-_HTTPDispatchFlow = Union[
+_DispatchFlow = Union[
     AsyncGenerator[None, Response],
     AsyncGenerator[Response, Response],
     AsyncGenerator[Optional[Response], Response],
@@ -16,12 +17,15 @@ class HTTPMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        dispatch: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
+        dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None,
     ) -> None:
+        if dispatch is None:
+            dispatch = self.dispatch
+
         self.app = app
-        self.dispatch_func = self.dispatch if dispatch is None else dispatch
+        self._dispatch_func = dispatch
 
-    def dispatch(self, scope: Scope) -> _HTTPDispatchFlow:
+    def dispatch(self, conn: HTTPConnection) -> _DispatchFlow:
         raise NotImplementedError  # pragma: no cover
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -29,7 +33,9 @@ class HTTPMiddleware:
             await self.app(scope, receive, send)
             return
 
-        async with aclosing(self.dispatch(scope)) as flow:
+        conn = HTTPConnection(scope)
+
+        async with aclosing(self._dispatch_func(conn)) as flow:
             # Kick the flow until the first `yield`.
             # Might respond early before we call into the app.
             maybe_early_response = await flow.__anext__()
index a36235c5c1b504c350d66a5140087dbd2ca39943..b65d5a6d51db58d7b8bd80972fe2c02090210205 100644 (file)
@@ -6,13 +6,14 @@ import pytest
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.http import HTTPMiddleware
+from starlette.requests import HTTPConnection
 from starlette.responses import PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
 from starlette.types import ASGIApp, Receive, Scope, Send
 
 
 class CustomMiddleware(HTTPMiddleware):
-    async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+    async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
         response = yield
         response.headers["Custom-Header"] = "Example"
 
@@ -89,20 +90,26 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
     expected_value2 = "bar"
 
     class aMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
-            scope["state_foo"] = expected_value1
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[None, Response]:
+            conn.state.foo = expected_value1
             yield
 
     class bMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
-            scope["state_bar"] = expected_value2
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[None, Response]:
+            conn.state.bar = expected_value2
             response = yield
-            response.headers["X-State-Foo"] = scope["state_foo"]
+            response.headers["X-State-Foo"] = conn.state.foo
 
     class cMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[None, Response]:
             response = yield
-            response.headers["X-State-Bar"] = scope["state_bar"]
+            response.headers["X-State-Bar"] = conn.state.bar
 
     def homepage(request):
         return PlainTextResponse("OK")
@@ -123,12 +130,17 @@ def test_state_data_across_multiple_middlewares(test_client_factory):
     assert response.headers["X-State-Bar"] == expected_value2
 
 
-def test_app_middleware_argument(test_client_factory):
+def test_dispatch_argument(test_client_factory):
     def homepage(request):
         return PlainTextResponse("Homepage")
 
+    async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+        response = yield
+        response.headers["Custom-Header"] = "Example"
+
     app = Starlette(
-        routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
+        routes=[Route("/", homepage)],
+        middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
     )
 
     client = test_client_factory(app)
@@ -143,7 +155,9 @@ def test_middleware_repr():
 
 def test_early_response(test_client_factory):
     class CustomMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[Response, Response]:
             yield Response(status_code=401)
 
     app = Starlette(middleware=[Middleware(CustomMiddleware)])
@@ -155,7 +169,9 @@ def test_early_response(test_client_factory):
 
 def test_too_many_yields(test_client_factory) -> None:
     class BadMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[None, Response]:
             _ = yield
             yield
 
@@ -175,7 +191,7 @@ def test_error_response(test_client_factory) -> None:
 
     class ErrorMiddleware(HTTPMiddleware):
         async def dispatch(
-            self, scope: Scope
+            self, conn: HTTPConnection
         ) -> AsyncGenerator[Optional[Response], Response]:
             try:
                 yield None
@@ -201,7 +217,9 @@ def test_no_error_response(test_client_factory) -> None:
         raise Failed()
 
     class BadMiddleware(HTTPMiddleware):
-        async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+        async def dispatch(
+            self, conn: HTTPConnection
+        ) -> AsyncGenerator[None, Response]:
             try:
                 yield
             except Failed:
@@ -228,7 +246,7 @@ class PureASGICustomMiddleware:
 
 
 class HTTPCustomMiddleware(HTTPMiddleware):
-    async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+    async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
         ctxvar.set("set by middleware")
         yield
         assert ctxvar.get() == "set by endpoint"