From: Yurii Karabas <1998uriyyo@gmail.com> Date: Thu, 14 Nov 2024 23:34:22 +0000 (+0100) Subject: Fix issue with middleware args passing (#2752) X-Git-Tag: 0.41.3~3 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=427a8dcf357597df27b2509b1ac436caf7708300;p=thirdparty%2Fstarlette.git Fix issue with middleware args passing (#2752) --- diff --git a/starlette/applications.py b/starlette/applications.py index 0feae72e..aae38f58 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -10,7 +10,7 @@ else: # pragma: no cover from typing_extensions import ParamSpec from starlette.datastructures import State, URLPath -from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware @@ -96,7 +96,7 @@ class Starlette: app = self.router for cls, args, kwargs in reversed(middleware): - app = cls(app=app, *args, **kwargs) + app = cls(app, *args, **kwargs) return app @property @@ -123,7 +123,7 @@ class Starlette: def add_middleware( self, - middleware_class: type[_MiddlewareClass[P]], + middleware_class: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 8566aac0..8e0a54ed 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -8,21 +8,19 @@ if sys.version_info >= (3, 10): # pragma: no cover else: # pragma: no cover from typing_extensions import ParamSpec -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp P = ParamSpec("P") -class _MiddlewareClass(Protocol[P]): - def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover +class _MiddlewareFactory(Protocol[P]): + def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover class Middleware: def __init__( self, - cls: type[_MiddlewareClass[P]], + cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -38,5 +36,6 @@ class Middleware: class_name = self.__class__.__name__ args_strings = [f"{value!r}" for value in self.args] option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] - args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings) + name = getattr(self.cls, "__name__", "") + args_repr = ", ".join([name] + args_strings + option_strings) return f"{class_name}({args_repr})" diff --git a/starlette/routing.py b/starlette/routing.py index 1504ef50..3b3c5296 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -236,7 +236,7 @@ class Route(BaseRoute): if middleware is not None: for cls, args, kwargs in reversed(middleware): - self.app = cls(app=self.app, *args, **kwargs) + self.app = cls(self.app, *args, **kwargs) if methods is None: self.methods = None @@ -328,7 +328,7 @@ class WebSocketRoute(BaseRoute): if middleware is not None: for cls, args, kwargs in reversed(middleware): - self.app = cls(app=self.app, *args, **kwargs) + self.app = cls(self.app, *args, **kwargs) self.path_regex, self.path_format, self.param_convertors = compile_path(path) @@ -388,7 +388,7 @@ class Mount(BaseRoute): self.app = self._base_app if middleware is not None: for cls, args, kwargs in reversed(middleware): - self.app = cls(app=self.app, *args, **kwargs) + self.app = cls(self.app, *args, **kwargs) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}") diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 041cc7ce..fa0cba47 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -10,7 +10,7 @@ from anyio.abc import TaskStatus from starlette.applications import Starlette from starlette.background import BackgroundTask -from starlette.middleware import Middleware, _MiddlewareClass +from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request from starlette.responses import PlainTextResponse, Response, StreamingResponse @@ -232,7 +232,7 @@ class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): ) def test_contextvars( test_client_factory: TestClientFactory, - middleware_cls: type[_MiddlewareClass[Any]], + middleware_cls: _MiddlewareFactory[Any], ) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating diff --git a/tests/test_applications.py b/tests/test_applications.py index 05604443..29c011a2 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from contextlib import asynccontextmanager from pathlib import Path @@ -533,6 +535,48 @@ def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None: assert SimpleInitializableMiddleware.counter == 2 +def test_middleware_args(test_client_factory: TestClientFactory) -> None: + calls: list[str] = [] + + class MiddlewareWithArgs: + def __init__(self, app: ASGIApp, arg: str) -> None: + self.app = app + self.arg = arg + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + calls.append(self.arg) + await self.app(scope, receive, send) + + app = Starlette() + app.add_middleware(MiddlewareWithArgs, "foo") + app.add_middleware(MiddlewareWithArgs, "bar") + + with test_client_factory(app): + pass + + assert calls == ["bar", "foo"] + + +def test_middleware_factory(test_client_factory: TestClientFactory) -> None: + calls: list[str] = [] + + def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp: + async def _app(scope: Scope, receive: Receive, send: Send) -> None: + calls.append(arg) + await app(scope, receive, send) + + return _app + + app = Starlette() + app.add_middleware(_middleware_factory, arg="foo") + app.add_middleware(_middleware_factory, arg="bar") + + with test_client_factory(app): + pass + + assert calls == ["bar", "foo"] + + def test_lifespan_app_subclass() -> None: # This test exists to make sure that subclasses of Starlette # (like FastAPI) are compatible with the types hints for Lifespan