]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add middleware per `Route`/`WebSocketRoute` (#2349)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Fri, 1 Dec 2023 10:35:14 +0000 (11:35 +0100)
committerGitHub <noreply@github.com>
Fri, 1 Dec 2023 10:35:14 +0000 (03:35 -0700)
starlette/routing.py
starlette/testclient.py
tests/test_routing.py

index 38bb3517e9b7e6f68ed49f0c8f71c9cb466c793e..cb1b1d9199085fb8e8a0c4c60b3eac2305721a3e 100644 (file)
@@ -217,6 +217,7 @@ class Route(BaseRoute):
         methods: typing.Optional[typing.List[str]] = None,
         name: typing.Optional[str] = None,
         include_in_schema: bool = True,
+        middleware: typing.Optional[typing.Sequence[Middleware]] = None,
     ) -> None:
         assert path.startswith("/"), "Routed paths must start with '/'"
         self.path = path
@@ -236,6 +237,10 @@ class Route(BaseRoute):
             # Endpoint is a class. Treat it as ASGI.
             self.app = endpoint
 
+        if middleware is not None:
+            for cls, options in reversed(middleware):
+                self.app = cls(app=self.app, **options)
+
         if methods is None:
             self.methods = None
         else:
@@ -309,6 +314,7 @@ class WebSocketRoute(BaseRoute):
         endpoint: typing.Callable[..., typing.Any],
         *,
         name: typing.Optional[str] = None,
+        middleware: typing.Optional[typing.Sequence[Middleware]] = None,
     ) -> None:
         assert path.startswith("/"), "Routed paths must start with '/'"
         self.path = path
@@ -325,6 +331,10 @@ class WebSocketRoute(BaseRoute):
             # Endpoint is a class. Treat it as ASGI.
             self.app = endpoint
 
+        if middleware is not None:
+            for cls, options in reversed(middleware):
+                self.app = cls(app=self.app, **options)
+
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
 
     def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
index a0046f2ff6feaa12b9119a4870f7133230087e65..4a750cd8cd8de139e425ee9c13355e44a0c9fb54 100644 (file)
@@ -710,7 +710,7 @@ class TestClient(httpx.Client):
 
     def websocket_connect(
         self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
-    ) -> typing.Any:
+    ) -> "WebSocketTestSession":
         url = urljoin("ws://testserver", url)
         headers = kwargs.get("headers", {})
         headers.setdefault("connection", "upgrade")
index 7159a4bfcc74165faa8b30eac9d7d0da6476a657..7644f5a24766737213e1dcafae4413f73687aae5 100644 (file)
@@ -919,6 +919,18 @@ def assert_middleware_header_route(request: Request) -> Response:
     return Response()
 
 
+route_with_middleware = Starlette(
+    routes=[
+        Route(
+            "/http",
+            endpoint=assert_middleware_header_route,
+            methods=["GET"],
+            middleware=[Middleware(AddHeadersMiddleware)],
+        ),
+        Route("/home", homepage),
+    ]
+)
+
 mounted_routes_with_middleware = Starlette(
     routes=[
         Mount(
@@ -960,9 +972,10 @@ mounted_app_with_middleware = Starlette(
     [
         mounted_routes_with_middleware,
         mounted_app_with_middleware,
+        route_with_middleware,
     ],
 )
-def test_mount_middleware(
+def test_base_route_middleware(
     test_client_factory: typing.Callable[..., TestClient],
     app: Starlette,
 ) -> None:
@@ -1076,6 +1089,44 @@ def test_mounted_middleware_does_not_catch_exception(
     assert "X-Mounted" in resp.headers
 
 
+def test_websocket_route_middleware(
+    test_client_factory: typing.Callable[..., TestClient]
+):
+    async def websocket_endpoint(session: WebSocket):
+        await session.accept()
+        await session.send_text("Hello, world!")
+        await session.close()
+
+    class WebsocketMiddleware:
+        def __init__(self, app: ASGIApp) -> None:
+            self.app = app
+
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+            async def modified_send(msg: Message) -> None:
+                if msg["type"] == "websocket.accept":
+                    msg["headers"].append((b"X-Test", b"Set by middleware"))
+                await send(msg)
+
+            await self.app(scope, receive, modified_send)
+
+    app = Starlette(
+        routes=[
+            WebSocketRoute(
+                "/ws",
+                endpoint=websocket_endpoint,
+                middleware=[Middleware(WebsocketMiddleware)],
+            )
+        ]
+    )
+
+    client = test_client_factory(app)
+
+    with client.websocket_connect("/ws") as websocket:
+        text = websocket.receive_text()
+        assert text == "Hello, world!"
+        assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]
+
+
 def test_route_repr() -> None:
     route = Route("/welcome", endpoint=homepage)
     assert (