import typing
import warnings
+from typing_extensions import ParamSpec
+
from starlette.datastructures import State, URLPath
-from starlette.middleware import Middleware
+from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.websockets import WebSocket
AppType = typing.TypeVar("AppType", bound="Starlette")
+P = ParamSpec("P")
class Starlette:
)
app = self.router
- for cls, options in reversed(middleware):
- app = cls(app=app, **options)
+ for cls, args, kwargs in reversed(middleware):
+ app = cls(app=app, *args, **kwargs)
return app
@property
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover
- def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
+ def add_middleware(
+ self,
+ middleware_class: typing.Type[_MiddlewareClass[P]],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
- self.user_middleware.insert(0, Middleware(middleware_class, **options))
+ self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
def add_exception_handler(
self,
-import typing
+from typing import Any, Iterator, Protocol, Type
+
+from typing_extensions import ParamSpec
+
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+P = ParamSpec("P")
+
+
+class _MiddlewareClass(Protocol[P]):
+ def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
+ ... # pragma: no cover
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ ... # pragma: no cover
class Middleware:
- def __init__(self, cls: type, **options: typing.Any) -> None:
+ def __init__(
+ self,
+ cls: Type[_MiddlewareClass[P]],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> None:
self.cls = cls
- self.options = options
+ self.args = args
+ self.kwargs = kwargs
- def __iter__(self) -> typing.Iterator[typing.Any]:
- as_tuple = (self.cls, self.options)
+ def __iter__(self) -> Iterator[Any]:
+ as_tuple = (self.cls, self.args, self.kwargs)
return iter(as_tuple)
def __repr__(self) -> str:
class_name = self.__class__.__name__
- option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
- args_repr = ", ".join([self.cls.__name__] + option_strings)
+ args_strings = [f"{value!r}" for value in self.args]
+ option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
+ args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
return f"{class_name}({args_repr})"
self.app = endpoint
if middleware is not None:
- for cls, options in reversed(middleware):
- self.app = cls(app=self.app, **options)
+ for cls, args, kwargs in reversed(middleware):
+ self.app = cls(app=self.app, *args, **kwargs)
if methods is None:
self.methods = None
self.app = endpoint
if middleware is not None:
- for cls, options in reversed(middleware):
- self.app = cls(app=self.app, **options)
+ for cls, args, kwargs in reversed(middleware):
+ self.app = cls(app=self.app, *args, **kwargs)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
self._base_app = Router(routes=routes)
self.app = self._base_app
if middleware is not None:
- for cls, options in reversed(middleware):
- self.app = cls(app=self.app, **options)
+ for cls, args, kwargs in reversed(middleware):
+ self.app = cls(app=self.app, *args, **kwargs)
self.name = name
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
self.middleware_stack = self.app
if middleware:
- for cls, options in reversed(middleware):
- self.middleware_stack = cls(self.middleware_stack, **options)
+ for cls, args, kwargs in reversed(middleware):
+ self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
import contextvars
from contextlib import AsyncExitStack
-from typing import AsyncGenerator, Awaitable, Callable, List, Union
+from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union
import anyio
import pytest
from starlette.applications import Starlette
from starlette.background import BackgroundTask
-from starlette.middleware import Middleware
+from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
),
],
)
-def test_contextvars(test_client_factory, middleware_cls: type):
+def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
from starlette.middleware import Middleware
+from starlette.types import ASGIApp, Receive, Scope, Send
-class CustomMiddleware:
- pass
+class CustomMiddleware: # pragma: no cover
+ def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
+ self.app = app
+ self.foo = foo
+ self.bar = bar
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ await self.app(scope, receive, send)
-def test_middleware_repr():
- middleware = Middleware(CustomMiddleware)
- assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+def test_middleware_repr() -> None:
+ middleware = Middleware(CustomMiddleware, "foo", bar=123)
+ assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"
+
+
+def test_middleware_iter() -> None:
+ cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
+ assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})
import os
from contextlib import asynccontextmanager
-from typing import Any, AsyncIterator, Callable
+from typing import AsyncIterator, Callable
import anyio
import httpx
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
-from starlette.types import ASGIApp
+from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket
def __init__(self, app: ASGIApp):
self.app = app
- async def __call__(self, *args: Any):
- await self.app(*args)
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ await self.app(scope, receive, send)
class SimpleInitializableMiddleware:
counter = 0
self.app = app
SimpleInitializableMiddleware.counter += 1
- async def __call__(self, *args: Any):
- await self.app(*args)
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ await self.app(scope, receive, send)
def get_app() -> ASGIApp:
app = Starlette()
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
from starlette.responses import JSONResponse
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocketDisconnect
assert response.json() == {"authenticated": True, "user": "tomchristie"}
-def on_auth_error(request: Request, exc: Exception):
+def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
return JSONResponse({"error": str(exc)}, status_code=401)