]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Fix issue with middleware args passing (#2752)
authorYurii Karabas <1998uriyyo@gmail.com>
Thu, 14 Nov 2024 23:34:22 +0000 (00:34 +0100)
committerGitHub <noreply@github.com>
Thu, 14 Nov 2024 23:34:22 +0000 (17:34 -0600)
starlette/applications.py
starlette/middleware/__init__.py
starlette/routing.py
tests/middleware/test_base.py
tests/test_applications.py

index 0feae72e411e5a0cc97c7aead431cdb918245745..aae38f588daee19ce5bc60db044ef388be2fb570 100644 (file)
@@ -10,7 +10,7 @@ else:  # pragma: no cover
     from typing_extensions import ParamSpec
 
 from starlette.datastructures import State, URLPath
-from starlette.middleware import Middleware, _MiddlewareClass
+from starlette.middleware import Middleware, _MiddlewareFactory
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.errors import ServerErrorMiddleware
 from starlette.middleware.exceptions import ExceptionMiddleware
@@ -96,7 +96,7 @@ class Starlette:
 
         app = self.router
         for cls, args, kwargs in reversed(middleware):
-            app = cls(app=app, *args, **kwargs)
+            app = cls(app, *args, **kwargs)
         return app
 
     @property
@@ -123,7 +123,7 @@ class Starlette:
 
     def add_middleware(
         self,
-        middleware_class: type[_MiddlewareClass[P]],
+        middleware_class: _MiddlewareFactory[P],
         *args: P.args,
         **kwargs: P.kwargs,
     ) -> None:
index 8566aac08a6d1bce4d8fd3fdc7457f56f1852f11..8e0a54edbdf14563fe79873f1bc9152c3b9202f5 100644 (file)
@@ -8,21 +8,19 @@ if sys.version_info >= (3, 10):  # pragma: no cover
 else:  # pragma: no cover
     from typing_extensions import ParamSpec
 
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp
 
 P = ParamSpec("P")
 
 
-class _MiddlewareClass(Protocol[P]):
-    def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ...  # pragma: no cover
-
-    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ...  # pragma: no cover
+class _MiddlewareFactory(Protocol[P]):
+    def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ...  # pragma: no cover
 
 
 class Middleware:
     def __init__(
         self,
-        cls: type[_MiddlewareClass[P]],
+        cls: _MiddlewareFactory[P],
         *args: P.args,
         **kwargs: P.kwargs,
     ) -> None:
@@ -38,5 +36,6 @@ class Middleware:
         class_name = self.__class__.__name__
         args_strings = [f"{value!r}" for value in self.args]
         option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
-        args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
+        name = getattr(self.cls, "__name__", "")
+        args_repr = ", ".join([name] + args_strings + option_strings)
         return f"{class_name}({args_repr})"
index 1504ef50a016b973492de7a51aafa5f35d149a66..3b3c52968c88b3f24b0bcb4f31b2102fbeae5feb 100644 (file)
@@ -236,7 +236,7 @@ class Route(BaseRoute):
 
         if middleware is not None:
             for cls, args, kwargs in reversed(middleware):
-                self.app = cls(app=self.app, *args, **kwargs)
+                self.app = cls(self.app, *args, **kwargs)
 
         if methods is None:
             self.methods = None
@@ -328,7 +328,7 @@ class WebSocketRoute(BaseRoute):
 
         if middleware is not None:
             for cls, args, kwargs in reversed(middleware):
-                self.app = cls(app=self.app, *args, **kwargs)
+                self.app = cls(self.app, *args, **kwargs)
 
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
 
@@ -388,7 +388,7 @@ class Mount(BaseRoute):
         self.app = self._base_app
         if middleware is not None:
             for cls, args, kwargs in reversed(middleware):
-                self.app = cls(app=self.app, *args, **kwargs)
+                self.app = cls(self.app, *args, **kwargs)
         self.name = name
         self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
 
index 041cc7ce2c835cc41c6b74f997775423adc98907..fa0cba47929f8a4a74ca98baae5189adb650c752 100644 (file)
@@ -10,7 +10,7 @@ from anyio.abc import TaskStatus
 
 from starlette.applications import Starlette
 from starlette.background import BackgroundTask
-from starlette.middleware import Middleware, _MiddlewareClass
+from starlette.middleware import Middleware, _MiddlewareFactory
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.requests import ClientDisconnect, Request
 from starlette.responses import PlainTextResponse, Response, StreamingResponse
@@ -232,7 +232,7 @@ class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
 )
 def test_contextvars(
     test_client_factory: TestClientFactory,
-    middleware_cls: type[_MiddlewareClass[Any]],
+    middleware_cls: _MiddlewareFactory[Any],
 ) -> None:
     # this has to be an async endpoint because Starlette calls run_in_threadpool
     # on sync endpoints which has it's own set of peculiarities w.r.t propagating
index 0560444383ab81fb33d20554b00cfc965ae0800a..29c011a298c80198b5d86b4fd92dc48e85709386 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 from contextlib import asynccontextmanager
 from pathlib import Path
@@ -533,6 +535,48 @@ def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
     assert SimpleInitializableMiddleware.counter == 2
 
 
+def test_middleware_args(test_client_factory: TestClientFactory) -> None:
+    calls: list[str] = []
+
+    class MiddlewareWithArgs:
+        def __init__(self, app: ASGIApp, arg: str) -> None:
+            self.app = app
+            self.arg = arg
+
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+            calls.append(self.arg)
+            await self.app(scope, receive, send)
+
+    app = Starlette()
+    app.add_middleware(MiddlewareWithArgs, "foo")
+    app.add_middleware(MiddlewareWithArgs, "bar")
+
+    with test_client_factory(app):
+        pass
+
+    assert calls == ["bar", "foo"]
+
+
+def test_middleware_factory(test_client_factory: TestClientFactory) -> None:
+    calls: list[str] = []
+
+    def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp:
+        async def _app(scope: Scope, receive: Receive, send: Send) -> None:
+            calls.append(arg)
+            await app(scope, receive, send)
+
+        return _app
+
+    app = Starlette()
+    app.add_middleware(_middleware_factory, arg="foo")
+    app.add_middleware(_middleware_factory, arg="bar")
+
+    with test_client_factory(app):
+        pass
+
+    assert calls == ["bar", "foo"]
+
+
 def test_lifespan_app_subclass() -> None:
     # This test exists to make sure that subclasses of Starlette
     # (like FastAPI) are compatible with the types hints for Lifespan