]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add Mount(..., middleware=[...]) (#1649)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Wed, 21 Sep 2022 19:02:41 +0000 (14:02 -0500)
committerGitHub <noreply@github.com>
Wed, 21 Sep 2022 19:02:41 +0000 (14:02 -0500)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Florimond Manca <florimond.manca@protonmail.com>
Co-authored-by: Aber <me@abersheeran.com>
docs/middleware.md
starlette/routing.py
tests/test_routing.py

index 5d6a32f3a8c3c1a4ec57e77363131ca90a570e23..e42ada493364989c92cf29a21e96ac40e32859cd 100644 (file)
@@ -686,6 +686,41 @@ to use the `middleware=<List of Middleware instances>` style, as it will:
 * Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`.
 * Preserves the top-level `app` instance.
 
+## Applying middleware to `Mount`s
+
+Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application:
+
+```python
+from starlette.applications import Starlette
+from starlette.middleware import Middleware
+from starlette.middleware.gzip import GZipMiddleware
+from starlette.routing import Mount, Route
+
+
+routes = [
+    Mount(
+        "/",
+        routes=[
+            Route(
+                "/example",
+                endpoint=...,
+            )
+        ],
+        middleware=[Middleware(GZipMiddleware)]
+    )
+]
+
+app = Starlette(routes=routes)
+```
+
+Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is.
+This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses.
+If you do want to apply the middleware logic to error responses only on some routes you have a couple of options:
+
+* Add an `ExceptionMiddleware` onto the `Mount`
+* Add a `try/except` block to your middleware and return an error response from there
+* Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting.
+
 ## Third party middleware
 
 #### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)
index 1aa2cdb6de626656820cb2952f686c427f1aacbd..2c6965be0f30876c72050409e50395b3e9fd321c 100644 (file)
@@ -14,6 +14,7 @@ from starlette.concurrency import run_in_threadpool
 from starlette.convertors import CONVERTOR_TYPES, Convertor
 from starlette.datastructures import URL, Headers, URLPath
 from starlette.exceptions import HTTPException
+from starlette.middleware import Middleware
 from starlette.requests import Request
 from starlette.responses import PlainTextResponse, RedirectResponse
 from starlette.types import ASGIApp, Receive, Scope, Send
@@ -348,6 +349,8 @@ class Mount(BaseRoute):
         app: typing.Optional[ASGIApp] = None,
         routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
         name: typing.Optional[str] = None,
+        *,
+        middleware: typing.Optional[typing.Sequence[Middleware]] = None,
     ) -> None:
         assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
         assert (
@@ -355,9 +358,13 @@ class Mount(BaseRoute):
         ), "Either 'app=...', or 'routes=' must be specified"
         self.path = path.rstrip("/")
         if app is not None:
-            self.app: ASGIApp = app
+            self._base_app: ASGIApp = app
         else:
-            self.app = Router(routes=routes)
+            self._base_app = Router(routes=routes)
+        self.app = self._base_app
+        if middleware is not None:
+            for cls, options in reversed(middleware):
+                self.app = cls(app=self.app, **options)
         self.name = name
         self.path_regex, self.path_format, self.param_convertors = compile_path(
             self.path + "/{path:path}"
@@ -365,7 +372,7 @@ class Mount(BaseRoute):
 
     @property
     def routes(self) -> typing.List[BaseRoute]:
-        return getattr(self.app, "routes", [])
+        return getattr(self._base_app, "routes", [])
 
     def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
         if scope["type"] in ("http", "websocket"):
index e3b1e412a5304a7199a7a9fdda930874ddd91f4d..750f324960cef7a812f4879a7227084a8f92570c 100644 (file)
@@ -5,8 +5,13 @@ import uuid
 import pytest
 
 from starlette.applications import Starlette
+from starlette.exceptions import HTTPException
+from starlette.middleware import Middleware
+from starlette.requests import Request
 from starlette.responses import JSONResponse, PlainTextResponse, Response
 from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
+from starlette.testclient import TestClient
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 from starlette.websockets import WebSocket, WebSocketDisconnect
 
 
@@ -768,6 +773,115 @@ def test_route_name(endpoint: typing.Callable, expected_name: str):
     assert Route(path="/", endpoint=endpoint).name == expected_name
 
 
+class AddHeadersMiddleware:
+    def __init__(self, app: ASGIApp) -> None:
+        self.app = app
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        scope["add_headers_middleware"] = True
+
+        async def modified_send(msg: Message) -> None:
+            if msg["type"] == "http.response.start":
+                msg["headers"].append((b"X-Test", b"Set by middleware"))
+            await send(msg)
+
+        await self.app(scope, receive, modified_send)
+
+
+def assert_middleware_header_route(request: Request) -> Response:
+    assert request.scope["add_headers_middleware"] is True
+    return Response()
+
+
+mounted_routes_with_middleware = Starlette(
+    routes=[
+        Mount(
+            "/http",
+            routes=[
+                Route(
+                    "/",
+                    endpoint=assert_middleware_header_route,
+                    methods=["GET"],
+                    name="route",
+                ),
+            ],
+            middleware=[Middleware(AddHeadersMiddleware)],
+        ),
+        Route("/home", homepage),
+    ]
+)
+
+
+mounted_app_with_middleware = Starlette(
+    routes=[
+        Mount(
+            "/http",
+            app=Route(
+                "/",
+                endpoint=assert_middleware_header_route,
+                methods=["GET"],
+                name="route",
+            ),
+            middleware=[Middleware(AddHeadersMiddleware)],
+        ),
+        Route("/home", homepage),
+    ]
+)
+
+
+@pytest.mark.parametrize(
+    "app",
+    [
+        mounted_routes_with_middleware,
+        mounted_app_with_middleware,
+    ],
+)
+def test_mount_middleware(
+    test_client_factory: typing.Callable[..., TestClient],
+    app: Starlette,
+) -> None:
+    test_client = test_client_factory(app)
+
+    response = test_client.get("/home")
+    assert response.status_code == 200
+    assert "X-Test" not in response.headers
+
+    response = test_client.get("/http")
+    assert response.status_code == 200
+    assert response.headers["X-Test"] == "Set by middleware"
+
+
+def test_mount_routes_with_middleware_url_path_for() -> None:
+    """Checks that url_path_for still works with mounted routes with Middleware"""
+    assert mounted_routes_with_middleware.url_path_for("route") == "/http/"
+
+
+def test_mount_asgi_app_with_middleware_url_path_for() -> None:
+    """Mounted ASGI apps do not work with url path for,
+    middleware does not change this
+    """
+    with pytest.raises(NoMatchFound):
+        mounted_app_with_middleware.url_path_for("route")
+
+
+def test_add_route_to_app_after_mount(
+    test_client_factory: typing.Callable[..., TestClient],
+) -> None:
+    """Checks that Mount will pick up routes
+    added to the underlying app after it is mounted
+    """
+    inner_app = Router()
+    app = Mount("/http", app=inner_app)
+    inner_app.add_route(
+        "/inner",
+        endpoint=homepage,
+        methods=["GET"],
+    )
+    client = test_client_factory(app)
+    response = client.get("/http/inner")
+    assert response.status_code == 200
+
+
 def test_exception_on_mounted_apps(test_client_factory):
     def exc(request):
         raise Exception("Exc")
@@ -779,3 +893,62 @@ def test_exception_on_mounted_apps(test_client_factory):
     with pytest.raises(Exception) as ctx:
         client.get("/sub/")
     assert str(ctx.value) == "Exc"
+
+
+def test_mounted_middleware_does_not_catch_exception(
+    test_client_factory: typing.Callable[..., TestClient],
+) -> None:
+    # https://github.com/encode/starlette/pull/1649#discussion_r960236107
+    def exc(request: Request) -> Response:
+        raise HTTPException(status_code=403, detail="auth")
+
+    class NamedMiddleware:
+        def __init__(self, app: ASGIApp, name: str) -> None:
+            self.app = app
+            self.name = name
+
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+            async def modified_send(msg: Message) -> None:
+                if msg["type"] == "http.response.start":
+                    msg["headers"].append((f"X-{self.name}".encode(), b"true"))
+                await send(msg)
+
+            await self.app(scope, receive, modified_send)
+
+    app = Starlette(
+        routes=[
+            Mount(
+                "/mount",
+                routes=[
+                    Route("/err", exc),
+                    Route("/home", homepage),
+                ],
+                middleware=[Middleware(NamedMiddleware, name="Mounted")],
+            ),
+            Route("/err", exc),
+            Route("/home", homepage),
+        ],
+        middleware=[Middleware(NamedMiddleware, name="Outer")],
+    )
+
+    client = test_client_factory(app)
+
+    resp = client.get("/home")
+    assert resp.status_code == 200, resp.content
+    assert "X-Outer" in resp.headers
+
+    resp = client.get("/err")
+    assert resp.status_code == 403, resp.content
+    assert "X-Outer" in resp.headers
+
+    resp = client.get("/mount/home")
+    assert resp.status_code == 200, resp.content
+    assert "X-Mounted" in resp.headers
+
+    # this is the "surprising" behavior bit
+    # the middleware on the mount never runs because there
+    # is nothing to catch the HTTPException
+    # since Mount middlweare is not wrapped by ExceptionMiddleware
+    resp = client.get("/mount/err")
+    assert resp.status_code == 403, resp.content
+    assert "X-Mounted" not in resp.headers