from typing_extensions import ParamSpec
from starlette.datastructures import State, URLPath
-from starlette.middleware import Middleware, _MiddlewareClass
+from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
app = self.router
for cls, args, kwargs in reversed(middleware):
- app = cls(app=app, *args, **kwargs)
+ app = cls(app, *args, **kwargs)
return app
@property
def add_middleware(
self,
- middleware_class: type[_MiddlewareClass[P]],
+ middleware_class: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
else: # pragma: no cover
from typing_extensions import ParamSpec
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp
P = ParamSpec("P")
-class _MiddlewareClass(Protocol[P]):
- def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover
+class _MiddlewareFactory(Protocol[P]):
+ def __call__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
class Middleware:
def __init__(
self,
- cls: type[_MiddlewareClass[P]],
+ cls: _MiddlewareFactory[P],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
class_name = self.__class__.__name__
args_strings = [f"{value!r}" for value in self.args]
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
- args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
+ name = getattr(self.cls, "__name__", "")
+ args_repr = ", ".join([name] + args_strings + option_strings)
return f"{class_name}({args_repr})"
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
- self.app = cls(app=self.app, *args, **kwargs)
+ self.app = cls(self.app, *args, **kwargs)
if methods is None:
self.methods = None
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
- self.app = cls(app=self.app, *args, **kwargs)
+ self.app = cls(self.app, *args, **kwargs)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self.app = self._base_app
if middleware is not None:
for cls, args, kwargs in reversed(middleware):
- self.app = cls(app=self.app, *args, **kwargs)
+ self.app = cls(self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
from starlette.applications import Starlette
from starlette.background import BackgroundTask
-from starlette.middleware import Middleware, _MiddlewareClass
+from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
)
def test_contextvars(
test_client_factory: TestClientFactory,
- middleware_cls: type[_MiddlewareClass[Any]],
+ middleware_cls: _MiddlewareFactory[Any],
) -> None:
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
+from __future__ import annotations
+
import os
from contextlib import asynccontextmanager
from pathlib import Path
assert SimpleInitializableMiddleware.counter == 2
+def test_middleware_args(test_client_factory: TestClientFactory) -> None:
+ calls: list[str] = []
+
+ class MiddlewareWithArgs:
+ def __init__(self, app: ASGIApp, arg: str) -> None:
+ self.app = app
+ self.arg = arg
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ calls.append(self.arg)
+ await self.app(scope, receive, send)
+
+ app = Starlette()
+ app.add_middleware(MiddlewareWithArgs, "foo")
+ app.add_middleware(MiddlewareWithArgs, "bar")
+
+ with test_client_factory(app):
+ pass
+
+ assert calls == ["bar", "foo"]
+
+
+def test_middleware_factory(test_client_factory: TestClientFactory) -> None:
+ calls: list[str] = []
+
+ def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp:
+ async def _app(scope: Scope, receive: Receive, send: Send) -> None:
+ calls.append(arg)
+ await app(scope, receive, send)
+
+ return _app
+
+ app = Starlette()
+ app.add_middleware(_middleware_factory, arg="foo")
+ app.add_middleware(_middleware_factory, arg="bar")
+
+ with test_client_factory(app):
+ pass
+
+ assert calls == ["bar", "foo"]
+
+
def test_lifespan_app_subclass() -> None:
# This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan