from ..types import ASGIApp, Message, Receive, Scope, Send
_HTTPDispatchFlow = Union[
- AsyncGenerator[None, Response], AsyncGenerator[ASGIApp, Response]
+ AsyncGenerator[None, Response],
+ AsyncGenerator[Response, Response],
+ AsyncGenerator[Optional[Response], Response],
]
def __init__(
self,
app: ASGIApp,
- dispatch_func: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
+ dispatch: Optional[Callable[[Scope], _HTTPDispatchFlow]] = None,
) -> None:
self.app = app
- self.dispatch_func = self.dispatch if dispatch_func is None else dispatch_func
+ self.dispatch_func = self.dispatch if dispatch is None else dispatch
def dispatch(self, scope: Scope) -> _HTTPDispatchFlow:
raise NotImplementedError # pragma: no cover
await maybe_early_response(scope, receive, send)
return
- response_started = set[bool]()
+ response_started: set = set()
async def wrapped_send(message: Message) -> None:
if message["type"] == "http.response.start":
try:
response = await flow.athrow(exc)
+ except StopAsyncIteration:
+ response = None
except Exception:
# Exception was not handled, or they raised another one.
raise
)
await response(scope, receive, send)
+ return
if not response_started:
raise RuntimeError("No response returned.")
import contextvars
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Optional
import pytest
assert repr(middleware) == "Middleware(CustomMiddleware)"
-def test_fully_evaluated_response(test_client_factory):
- # Test for https://github.com/encode/starlette/issues/1022
+def test_early_response(test_client_factory):
class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> AsyncGenerator[Response, Response]:
- yield PlainTextResponse("Custom")
+ yield Response(status_code=401)
app = Starlette(middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
- response = client.get("/does_not_exist")
- assert response.text == "Custom"
+ response = client.get("/")
+ assert response.status_code == 401
+
+
+def test_too_many_yields(test_client_factory) -> None:
+ class BadMiddleware(HTTPMiddleware):
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ _ = yield
+ yield
+
+ app = Starlette(middleware=[Middleware(BadMiddleware)])
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError, match="should yield exactly once"):
+ client.get("/")
+
+
+def test_error_response(test_client_factory) -> None:
+ class Failed(Exception):
+ pass
+
+ async def failure(request):
+ raise Failed()
+
+ class ErrorMiddleware(HTTPMiddleware):
+ async def dispatch(
+ self, scope: Scope
+ ) -> AsyncGenerator[Optional[Response], Response]:
+ try:
+ yield None
+ except Failed:
+ yield Response("Failed", status_code=500)
+
+ app = Starlette(
+ routes=[Route("/fail", failure)],
+ middleware=[Middleware(ErrorMiddleware)],
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/fail")
+ assert response.text == "Failed"
+ assert response.status_code == 500
+
+
+def test_no_error_response(test_client_factory) -> None:
+ class Failed(Exception):
+ pass
+
+ async def index(request):
+ raise Failed()
+
+ class BadMiddleware(HTTPMiddleware):
+ async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
+ try:
+ yield
+ except Failed:
+ pass
+
+ app = Starlette(routes=[Route("/", index)], middleware=[Middleware(BadMiddleware)])
+
+ client = test_client_factory(app)
+ with pytest.raises(RuntimeError, match="no response was returned"):
+ client.get("/")
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
-class CustomMiddlewareWithoutBaseHTTPMiddleware:
+class PureASGICustomMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
assert ctxvar.get() == "set by endpoint"
-class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
+class HTTPCustomMiddleware(HTTPMiddleware):
async def dispatch(self, scope: Scope) -> AsyncGenerator[None, Response]:
ctxvar.set("set by middleware")
yield
@pytest.mark.parametrize(
"middleware_cls",
[
- CustomMiddlewareWithoutBaseHTTPMiddleware,
- CustomMiddlewareUsingHTTPMiddleware,
+ PureASGICustomMiddleware,
+ HTTPCustomMiddleware,
],
)
def test_contextvars(test_client_factory, middleware_cls: type):