]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add `*args` to `Middleware` and improve its type hints (#2381)
authorPaweł Rubin <pawelrubin19@gmail.com>
Wed, 20 Dec 2023 21:09:17 +0000 (22:09 +0100)
committerGitHub <noreply@github.com>
Wed, 20 Dec 2023 21:09:17 +0000 (22:09 +0100)
Co-authored-by: Paweł Rubin <pawel.rubin@ocado.com>
starlette/applications.py
starlette/middleware/__init__.py
starlette/routing.py
tests/middleware/test_base.py
tests/middleware/test_middleware.py
tests/test_applications.py
tests/test_authentication.py

index 554a25e651e99649a7a12a84ed7c1ce4c2bee1ae..3e1086d98b2f301226039ca02ea15b6af8f60312 100644 (file)
@@ -3,8 +3,10 @@ from __future__ import annotations
 import typing
 import warnings
 
+from typing_extensions import ParamSpec
+
 from starlette.datastructures import State, URLPath
-from starlette.middleware import Middleware
+from starlette.middleware import Middleware, _MiddlewareClass
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.errors import ServerErrorMiddleware
 from starlette.middleware.exceptions import ExceptionMiddleware
@@ -15,6 +17,7 @@ from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope,
 from starlette.websockets import WebSocket
 
 AppType = typing.TypeVar("AppType", bound="Starlette")
+P = ParamSpec("P")
 
 
 class Starlette:
@@ -98,8 +101,8 @@ class Starlette:
         )
 
         app = self.router
-        for cls, options in reversed(middleware):
-            app = cls(app=app, **options)
+        for cls, args, kwargs in reversed(middleware):
+            app = cls(app=app, *args, **kwargs)
         return app
 
     @property
@@ -124,10 +127,15 @@ class Starlette:
     def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
         self.router.host(host, app=app, name=name)  # pragma: no cover
 
-    def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
+    def add_middleware(
+        self,
+        middleware_class: typing.Type[_MiddlewareClass[P]],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> None:
         if self.middleware_stack is not None:  # pragma: no cover
             raise RuntimeError("Cannot add middleware after an application has started")
-        self.user_middleware.insert(0, Middleware(middleware_class, **options))
+        self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
 
     def add_exception_handler(
         self,
index 05bd57f04068a01838097d551b789efe8726374f..880e301ebd5dcd4a07efe2dc5bca4d5ee9907768 100644 (file)
@@ -1,17 +1,38 @@
-import typing
+from typing import Any, Iterator, Protocol, Type
+
+from typing_extensions import ParamSpec
+
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+P = ParamSpec("P")
+
+
+class _MiddlewareClass(Protocol[P]):
+    def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
+        ...  # pragma: no cover
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        ...  # pragma: no cover
 
 
 class Middleware:
-    def __init__(self, cls: type, **options: typing.Any) -> None:
+    def __init__(
+        self,
+        cls: Type[_MiddlewareClass[P]],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> None:
         self.cls = cls
-        self.options = options
+        self.args = args
+        self.kwargs = kwargs
 
-    def __iter__(self) -> typing.Iterator[typing.Any]:
-        as_tuple = (self.cls, self.options)
+    def __iter__(self) -> Iterator[Any]:
+        as_tuple = (self.cls, self.args, self.kwargs)
         return iter(as_tuple)
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
-        option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
-        args_repr = ", ".join([self.cls.__name__] + option_strings)
+        args_strings = [f"{value!r}" for value in self.args]
+        option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
+        args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
         return f"{class_name}({args_repr})"
index 9a21349579dfc76544d987ab3b0e97c7f68c4c41..c8c854d2c86350696cafbdc5f5afc33ea7487236 100644 (file)
@@ -238,8 +238,8 @@ class Route(BaseRoute):
             self.app = endpoint
 
         if middleware is not None:
-            for cls, options in reversed(middleware):
-                self.app = cls(app=self.app, **options)
+            for cls, args, kwargs in reversed(middleware):
+                self.app = cls(app=self.app, *args, **kwargs)
 
         if methods is None:
             self.methods = None
@@ -335,8 +335,8 @@ class WebSocketRoute(BaseRoute):
             self.app = endpoint
 
         if middleware is not None:
-            for cls, options in reversed(middleware):
-                self.app = cls(app=self.app, **options)
+            for cls, args, kwargs in reversed(middleware):
+                self.app = cls(app=self.app, *args, **kwargs)
 
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
 
@@ -404,8 +404,8 @@ class Mount(BaseRoute):
             self._base_app = Router(routes=routes)
         self.app = self._base_app
         if middleware is not None:
-            for cls, options in reversed(middleware):
-                self.app = cls(app=self.app, **options)
+            for cls, args, kwargs in reversed(middleware):
+                self.app = cls(app=self.app, *args, **kwargs)
         self.name = name
         self.path_regex, self.path_format, self.param_convertors = compile_path(
             self.path + "/{path:path}"
@@ -672,8 +672,8 @@ class Router:
 
         self.middleware_stack = self.app
         if middleware:
-            for cls, options in reversed(middleware):
-                self.middleware_stack = cls(self.middleware_stack, **options)
+            for cls, args, kwargs in reversed(middleware):
+                self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
 
     async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] == "websocket":
index 650f4aee18349f2af73d5c1f6c9ab91f62837ab7..4d51f34bf9dfb74e5609a1f03f6c9ebe27aa8807 100644 (file)
@@ -1,13 +1,13 @@
 import contextvars
 from contextlib import AsyncExitStack
-from typing import AsyncGenerator, Awaitable, Callable, List, Union
+from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union
 
 import anyio
 import pytest
 
 from starlette.applications import Starlette
 from starlette.background import BackgroundTask
-from starlette.middleware import Middleware
+from starlette.middleware import Middleware, _MiddlewareClass
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 from starlette.requests import Request
 from starlette.responses import PlainTextResponse, Response, StreamingResponse
@@ -196,7 +196,7 @@ class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
         ),
     ],
 )
-def test_contextvars(test_client_factory, middleware_cls: type):
+def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
     # this has to be an async endpoint because Starlette calls run_in_threadpool
     # on sync endpoints which has it's own set of peculiarities w.r.t propagating
     # contextvars (it propagates them forwards but not backwards)
index f4d7a32f0b9e0fff510f662d0d110d7633943166..c6cf1fa1c8440882c3f4fbdb2a18bcd185d57239 100644 (file)
@@ -1,10 +1,22 @@
 from starlette.middleware import Middleware
+from starlette.types import ASGIApp, Receive, Scope, Send
 
 
-class CustomMiddleware:
-    pass
+class CustomMiddleware:  # pragma: no cover
+    def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
+        self.app = app
+        self.foo = foo
+        self.bar = bar
 
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        await self.app(scope, receive, send)
 
-def test_middleware_repr():
-    middleware = Middleware(CustomMiddleware)
-    assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+def test_middleware_repr() -> None:
+    middleware = Middleware(CustomMiddleware, "foo", bar=123)
+    assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"
+
+
+def test_middleware_iter() -> None:
+    cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
+    assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})
index e30ec929568dcd251487d560d4c0f86edec02ba1..6d0118b535a16de36c9dda8a6ad6e5ae5b2dc991 100644 (file)
@@ -1,6 +1,6 @@
 import os
 from contextlib import asynccontextmanager
-from typing import Any, AsyncIterator, Callable
+from typing import AsyncIterator, Callable
 
 import anyio
 import httpx
@@ -15,7 +15,7 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
 from starlette.responses import JSONResponse, PlainTextResponse
 from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
 from starlette.staticfiles import StaticFiles
-from starlette.types import ASGIApp
+from starlette.types import ASGIApp, Receive, Scope, Send
 from starlette.websockets import WebSocket
 
 
@@ -499,8 +499,8 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
         def __init__(self, app: ASGIApp):
             self.app = app
 
-        async def __call__(self, *args: Any):
-            await self.app(*args)
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+            await self.app(scope, receive, send)
 
     class SimpleInitializableMiddleware:
         counter = 0
@@ -509,8 +509,8 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
             self.app = app
             SimpleInitializableMiddleware.counter += 1
 
-        async def __call__(self, *args: Any):
-            await self.app(*args)
+        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+            await self.app(scope, receive, send)
 
     def get_app() -> ASGIApp:
         app = Starlette()
index af0beafd024c32799bb2fa16e8778dd5c07a9340..150482a1b6a7d187b5892ec56252415483024397 100644 (file)
@@ -15,7 +15,7 @@ from starlette.authentication import (
 from starlette.endpoints import HTTPEndpoint
 from starlette.middleware import Middleware
 from starlette.middleware.authentication import AuthenticationMiddleware
-from starlette.requests import Request
+from starlette.requests import HTTPConnection
 from starlette.responses import JSONResponse
 from starlette.routing import Route, WebSocketRoute
 from starlette.websockets import WebSocketDisconnect
@@ -327,7 +327,7 @@ def test_authentication_redirect(test_client_factory):
         assert response.json() == {"authenticated": True, "user": "tomchristie"}
 
 
-def on_auth_error(request: Request, exc: Exception):
+def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
     return JSONResponse({"error": str(exc)}, status_code=401)