+from contextlib import aclosing
from functools import partial
-from typing import Generator, Optional
-from ..types import ASGIApp, Scope, Receive, Send, Message
-from ..responses import Response
+from typing import AsyncGenerator, Callable, Optional
+
from ..datastructures import MutableHeaders
+from ..responses import Response
+from ..types import ASGIApp, Message, Receive, Scope, Send
+HTTPDispatchFlow = AsyncGenerator[Optional[Response], Response]
-class HTTPMiddleware:
- DispatchFlow = Generator[Optional[Response], Response, None]
- def __init__(self, app: ASGIApp) -> None:
+class HTTPMiddleware:
+ def __init__(
+ self,
+ app: ASGIApp,
+ dispatch_func: Optional[Callable[[Scope], HTTPDispatchFlow]] = None,
+ ) -> None:
self.app = app
+ self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func
- def dispatch(self, scope: Scope) -> DispatchFlow:
- raise NotImplementedError
+ def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
+ raise NotImplementedError # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
- flow = self.dispatch(scope)
-
- try:
- # Run until first `yield` to allow modifying the connection scope.
- # Middleware may return a response before we call into the underlying app.
- early_response = flow.__next__()
- except StopIteration:
- raise RuntimeError("dispatch() did not run")
+ async with aclosing(self.dispatch(scope)) as flow:
+ # Kick the flow until the first `yield`.
+ # Might respond early before we call into the app.
+ early_response = await flow.__anext__()
- if early_response is not None:
- await early_response(scope, receive, send)
- return
+ if early_response is not None:
+ await early_response(scope, receive, send)
+ return
- response_started = set[bool]()
+ response_started = set[bool]()
- wrapped_send = partial(
- self._send,
- flow=flow,
- response_started=response_started,
- send=send,
- )
+ wrapped_send = partial(
+ self._send,
+ flow=flow,
+ response_started=response_started,
+ send=send,
+ )
- try:
- await self.app(scope, receive, wrapped_send)
- except Exception as exc:
try:
- response = flow.throw(exc)
- except Exception:
- # Exception was not handled, or they raised another one.
- raise
+ await self.app(scope, receive, wrapped_send)
+ except Exception as exc:
+ if response_started:
+ raise
+
+ try:
+ response = await flow.athrow(exc)
+ except Exception:
+ # Exception was not handled, or they raised another one.
+ raise
- if response is None:
- raise RuntimeError(
- f"dispatch() handled exception {exc!r}, "
- "but no response application was returned"
- )
+ if response is None:
+ raise RuntimeError(
+ f"dispatch() handled exception {exc!r}, "
+ "but no response was returned"
+ )
- await response(scope, receive, send)
+ await response(scope, receive, send)
- if not response_started:
- raise RuntimeError("No response returned.")
+ if not response_started:
+ raise RuntimeError("No response returned.")
async def _send(
- self, message: Message, *, flow: DispatchFlow, response_started: set, send: Send
+ self,
+ message: Message,
+ *,
+ flow: HTTPDispatchFlow,
+ response_started: set,
+ send: Send,
) -> None:
if message["type"] == "http.response.start":
response_started.add(True)
response.raw_headers.clear()
try:
- flow.send(response)
- except StopIteration as exc:
- if exc.value is not None:
- raise RuntimeError("swapping responses it not supported")
+ await flow.asend(response)
+ except StopAsyncIteration:
+ pass
else:
raise RuntimeError("dispatch() should yield exactly once")
from starlette.applications import Starlette
from starlette.middleware import Middleware
-from starlette.middleware.http import HTTPMiddleware
+from starlette.middleware.http import HTTPDispatchFlow, HTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
class CustomMiddleware(HTTPMiddleware):
- def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
response = yield None
response.headers["Custom-Header"] = "Example"
expected_value2 = "bar"
class aMiddleware(HTTPMiddleware):
- def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
scope["state_foo"] = expected_value1
yield None
class bMiddleware(HTTPMiddleware):
- def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
scope["state_bar"] = expected_value2
response = yield None
response.headers["X-State-Foo"] = scope["state_foo"]
class cMiddleware(HTTPMiddleware):
- def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
response = yield None
response.headers["X-State-Bar"] = scope["state_bar"]
def test_fully_evaluated_response(test_client_factory):
# Test for https://github.com/encode/starlette/issues/1022
class CustomMiddleware(HTTPMiddleware):
- def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
yield PlainTextResponse("Custom")
app = Starlette(middleware=[Middleware(CustomMiddleware)])
class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
- def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ async def dispatch(self, scope: Scope) -> HTTPDispatchFlow:
ctxvar.set("set by middleware")
yield None
assert ctxvar.get() == "set by endpoint"