from .._compat import aclosing
from ..datastructures import MutableHeaders
+from ..requests import HTTPConnection
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send
-_HTTPDispatchFlow = Union[
+_DispatchFlow = Union[
AsyncGenerator[None, Response],
AsyncGenerator[Response, Response],
AsyncGenerator[Optional[Response], Response],
def __init__(
self,
app: ASGIApp,
- dispatch: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
+ dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None,
) -> None:
+ if dispatch is None:
+ dispatch = self.dispatch
+
self.app = app
- self.dispatch_func = self.dispatch if dispatch is None else dispatch
+ self._dispatch_func = dispatch
- def dispatch(self, scope: Scope) -> _HTTPDispatchFlow:
+ def dispatch(self, conn: HTTPConnection) -> _DispatchFlow:
raise NotImplementedError # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return
- async with aclosing(self.dispatch(scope)) as flow:
+ conn = HTTPConnection(scope)
+
+ async with aclosing(self._dispatch_func(conn)) as flow:
# Kick the flow until the first `yield`.
# Might respond early before we call into the app.
maybe_early_response = await flow.__anext__()
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.http import HTTPMiddleware
+from starlette.requests import HTTPConnection
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
class CustomMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
response = yield
response.headers["Custom-Header"] = "Example"
expected_value2 = "bar"
class aMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
- scope["state_foo"] = expected_value1
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[None, Response]:
+ conn.state.foo = expected_value1
yield
class bMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
- scope["state_bar"] = expected_value2
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[None, Response]:
+ conn.state.bar = expected_value2
response = yield
- response.headers["X-State-Foo"] = scope["state_foo"]
+ response.headers["X-State-Foo"] = conn.state.foo
class cMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[None, Response]:
response = yield
- response.headers["X-State-Bar"] = scope["state_bar"]
+ response.headers["X-State-Bar"] = conn.state.bar
def homepage(request):
return PlainTextResponse("OK")
assert response.headers["X-State-Bar"] == expected_value2
-def test_app_middleware_argument(test_client_factory):
+def test_dispatch_argument(test_client_factory):
def homepage(request):
return PlainTextResponse("Homepage")
+ async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+ response = yield
+ response.headers["Custom-Header"] = "Example"
+
app = Starlette(
- routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
+ routes=[Route("/", homepage)],
+ middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
)
client = test_client_factory(app)
def test_early_response(test_client_factory):
class CustomMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[Response, Response]:
yield Response(status_code=401)
app = Starlette(middleware=[Middleware(CustomMiddleware)])
def test_too_many_yields(test_client_factory) -> None:
class BadMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[None, Response]:
_ = yield
yield
class ErrorMiddleware(HTTPMiddleware):
async def dispatch(
- self, scope: Scope
+ self, conn: HTTPConnection
) -> AsyncGenerator[Optional[Response], Response]:
try:
yield None
raise Failed()
class BadMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[None, Response]:
try:
yield
except Failed:
class HTTPCustomMiddleware(HTTPMiddleware):
- async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
ctxvar.set("set by middleware")
yield
assert ctxvar.get() == "set by endpoint"