From: Marcelo Trylesinski Date: Fri, 1 Dec 2023 12:58:30 +0000 (+0100) Subject: Add middleware per `Router` (#2351) X-Git-Tag: 0.33.0~2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=811dacc06291c869298ff9f6e26b23a6309d08e2;p=thirdparty%2Fstarlette.git Add middleware per `Router` (#2351) --- diff --git a/docs/middleware.md b/docs/middleware.md index 92ac5886..6fcd8ba2 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -29,8 +29,7 @@ middleware = [ app = Starlette(routes=routes, middleware=middleware) ``` -Every Starlette application automatically includes two pieces of middleware -by default: +Every Starlette application automatically includes two pieces of middleware by default: * `ServerErrorMiddleware` - Ensures that application exceptions may return a custom 500 page, or display an application traceback in DEBUG mode. This is *always* the outermost middleware layer. * `ExceptionMiddleware` - Adds exception handlers, so that particular types of expected exception cases can be associated with handler functions. For example raising `HTTPException(status_code=404)` within an endpoint will end up rendering a custom 404 page. @@ -106,7 +105,7 @@ The following arguments are supported: * `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session. * `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`. * `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`. -* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute). +* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute). ## HTTPSRedirectMiddleware @@ -573,7 +572,7 @@ import time class MonitoringMiddleware: def __init__(self, app): self.app = app - + async def __call__(self, scope, receive, send): start = time.time() try: @@ -690,9 +689,9 @@ to use the `middleware=` style, as it will: * Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`. * Preserves the top-level `app` instance. -## Applying middleware to `Mount`s +## Applying middleware to groups of routes -Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application: +Middleware can also be added to `Mount` instances, which allows you to apply middleware to a group of routes or a sub-application: ```python from starlette.applications import Starlette @@ -725,6 +724,43 @@ If you do want to apply the middleware logic to error responses only on some rou * Add a `try/except` block to your middleware and return an error response from there * Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting. +The `Route`/`WebSocket` class also accepts a `middleware` argument, which allows you to apply middleware to a single route: + +```python +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.gzip import GZipMiddleware +from starlette.routing import Route + + +routes = [ + Route( + "/example", + endpoint=..., + middleware=[Middleware(GZipMiddleware)] + ) +] + +app = Starlette(routes=routes) +``` + +You can also apply middleware to the `Router` class, which allows you to apply middleware to a group of routes: + +```python +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.gzip import GZipMiddleware +from starlette.routing import Route, Router + + +routes = [ + Route("/example", endpoint=...), + Route("/another", endpoint=...), +] + +router = Router(routes=routes, middleware=[Middleware(GZipMiddleware)]) +``` + ## Third party middleware #### [asgi-auth-github](https://github.com/simonw/asgi-auth-github) diff --git a/starlette/routing.py b/starlette/routing.py index b167f351..9a213495 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -623,6 +623,8 @@ class Router: # the generic to Lifespan[AppType] is the type of the top level application # which the router cannot know statically, so we use typing.Any lifespan: typing.Optional[Lifespan[typing.Any]] = None, + *, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -668,6 +670,11 @@ class Router: else: self.lifespan_context = lifespan + self.middleware_stack = self.app + if middleware: + for cls, options in reversed(middleware): + self.middleware_stack = cls(self.middleware_stack, **options) + async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": websocket_close = WebSocketClose() @@ -744,6 +751,9 @@ class Router: """ The main entry point to the Router class. """ + await self.middleware_stack(scope, receive, send) + + async def app(self, scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] in ("http", "websocket", "lifespan") if "router" not in scope: diff --git a/tests/test_routing.py b/tests/test_routing.py index 1e99e335..5c12ebab 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -328,6 +328,26 @@ def test_router_add_websocket_route(client): assert text == "Hello, test!" +def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]): + class CustomMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + response = PlainTextResponse("OK") + await response(scope, receive, send) + + app = Router( + routes=[Route("/", homepage)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 200 + assert response.text == "OK" + + def http_endpoint(request): url = request.url_for("http_endpoint") return Response(f"URL: {url}", media_type="text/plain")