]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add middleware per `Router` (#2351)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Fri, 1 Dec 2023 12:58:30 +0000 (13:58 +0100)
committerGitHub <noreply@github.com>
Fri, 1 Dec 2023 12:58:30 +0000 (12:58 +0000)
docs/middleware.md
starlette/routing.py
tests/test_routing.py

index 92ac5886ac33917f623347232e790407f9446ce7..6fcd8ba2e329b3950b78403ae2d79cf691eea267 100644 (file)
@@ -29,8 +29,7 @@ middleware = [
 app = Starlette(routes=routes, middleware=middleware)
 ```
 
-Every Starlette application automatically includes two pieces of middleware
-by default:
+Every Starlette application automatically includes two pieces of middleware by default:
 
 * `ServerErrorMiddleware` - Ensures that application exceptions may return a custom 500 page, or display an application traceback in DEBUG mode. This is *always* the outermost middleware layer.
 * `ExceptionMiddleware` - Adds exception handlers, so that particular types of expected exception cases can be associated with handler functions. For example raising `HTTPException(status_code=404)` within an endpoint will end up rendering a custom 404 page.
@@ -106,7 +105,7 @@ The following arguments are supported:
 * `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session.
 * `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`.
 * `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`.
-* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute). 
+* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute).
 
 
 ## HTTPSRedirectMiddleware
@@ -573,7 +572,7 @@ import time
 class MonitoringMiddleware:
     def __init__(self, app):
         self.app = app
-    
+
     async def __call__(self, scope, receive, send):
         start = time.time()
         try:
@@ -690,9 +689,9 @@ to use the `middleware=<List of Middleware instances>` style, as it will:
 * Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`.
 * Preserves the top-level `app` instance.
 
-## Applying middleware to `Mount`s
+## Applying middleware to groups of routes
 
-Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application:
+Middleware can also be added to `Mount` instances, which allows you to apply middleware to a group of routes or a sub-application:
 
 ```python
 from starlette.applications import Starlette
@@ -725,6 +724,43 @@ If you do want to apply the middleware logic to error responses only on some rou
 * Add a `try/except` block to your middleware and return an error response from there
 * Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting.
 
+The `Route`/`WebSocket` class also accepts a `middleware` argument, which allows you to apply middleware to a single route:
+
+```python
+from starlette.applications import Starlette
+from starlette.middleware import Middleware
+from starlette.middleware.gzip import GZipMiddleware
+from starlette.routing import Route
+
+
+routes = [
+    Route(
+        "/example",
+        endpoint=...,
+        middleware=[Middleware(GZipMiddleware)]
+    )
+]
+
+app = Starlette(routes=routes)
+```
+
+You can also apply middleware to the `Router` class, which allows you to apply middleware to a group of routes:
+
+```python
+from starlette.applications import Starlette
+from starlette.middleware import Middleware
+from starlette.middleware.gzip import GZipMiddleware
+from starlette.routing import Route, Router
+
+
+routes = [
+    Route("/example", endpoint=...),
+    Route("/another", endpoint=...),
+]
+
+router = Router(routes=routes, middleware=[Middleware(GZipMiddleware)])
+```
+
 ## Third party middleware
 
 #### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)
index b167f3513b1d9b24f1c2d6431b6dfb3f5df5a4c3..9a21349579dfc76544d987ab3b0e97c7f68c4c41 100644 (file)
@@ -623,6 +623,8 @@ class Router:
         # the generic to Lifespan[AppType] is the type of the top level application
         # which the router cannot know statically, so we use typing.Any
         lifespan: typing.Optional[Lifespan[typing.Any]] = None,
+        *,
+        middleware: typing.Optional[typing.Sequence[Middleware]] = None,
     ) -> None:
         self.routes = [] if routes is None else list(routes)
         self.redirect_slashes = redirect_slashes
@@ -668,6 +670,11 @@ class Router:
         else:
             self.lifespan_context = lifespan
 
+        self.middleware_stack = self.app
+        if middleware:
+            for cls, options in reversed(middleware):
+                self.middleware_stack = cls(self.middleware_stack, **options)
+
     async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] == "websocket":
             websocket_close = WebSocketClose()
@@ -744,6 +751,9 @@ class Router:
         """
         The main entry point to the Router class.
         """
+        await self.middleware_stack(scope, receive, send)
+
+    async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
         assert scope["type"] in ("http", "websocket", "lifespan")
 
         if "router" not in scope:
index 1e99e335bb6a1dec1f7e946e974ce76787cd8f1e..5c12ebabe7a8971b9ea46b67ca5b98a64e37c1e7 100644 (file)
@@ -328,6 +328,26 @@ def test_router_add_websocket_route(client):
         assert text == "Hello, test!"
 
 
+def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]):
+    class CustomMiddleware:
+        def __init__(self, app: ASGIApp) -> None:
+            self.app = app
+
+        async def __call__(self, scope: Scope, receive: Receive, send: Send):
+            response = PlainTextResponse("OK")
+            await response(scope, receive, send)
+
+    app = Router(
+        routes=[Route("/", homepage)],
+        middleware=[Middleware(CustomMiddleware)],
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.status_code == 200
+    assert response.text == "OK"
+
+
 def http_endpoint(request):
     url = request.url_for("http_endpoint")
     return Response(f"URL: {url}", media_type="text/plain")