from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send
+# This type hint not exposed, as it exists mostly for our own documentation purposes.
+# End users should use one of these type hints explicitly when overriding '.dispatch()'.
_DispatchFlow = Union[
# Default case:
- # response = yield
+ # response = yield
AsyncGenerator[None, Response],
# Early response and/or error handling:
- # if condition:
- # yield Response(...)
- # return
- # try:
- # response = yield None
- # except Exception:
- # yield Response(...)
- # else:
- # ...
+ # if condition:
+ # yield Response(...)
+ # return
+ # try:
+ # response = yield None
+ # except Exception:
+ # yield Response(...)
+ # else:
+ # ...
AsyncGenerator[Optional[Response], Response],
]
app: ASGIApp,
dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None,
) -> None:
- if dispatch is None:
- dispatch = self.dispatch
-
self.app = app
- self._dispatch_func = dispatch
+ self._dispatch_func = self.dispatch if dispatch is None else dispatch
- def dispatch(self, conn: HTTPConnection) -> _DispatchFlow:
- raise NotImplementedError # pragma: no cover
+ def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow:
+ raise NotImplementedError(
+ "No dispatch implementation was given. "
+ "Either pass 'dispatch=...' to HTTPMiddleware, "
+ "or subclass HTTPMiddleware and override the 'dispatch()' method."
+ )
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
maybe_early_response = await flow.__anext__()
if maybe_early_response is not None:
+ try:
+ await flow.__anext__()
+ except StopAsyncIteration:
+ pass
+ else:
+ raise RuntimeError("dispatch() should yield exactly once")
+
await maybe_early_response(scope, receive, send)
return
- response_started: set = set()
+ response_started = False
async def wrapped_send(message: Message) -> None:
+ nonlocal response_started
+
if message["type"] == "http.response.start":
- response_started.add(True)
+ response_started = True
response = Response(status_code=message["status"])
response.raw_headers.clear()
headers = MutableHeaders(raw=message["headers"])
headers.update(response.headers)
+ message["headers"] = headers.raw
await send(message)
await response(scope, receive, send)
return
-
- if not response_started:
- raise RuntimeError("No response returned.")
-import contextvars
-from typing import AsyncGenerator, Optional
+from typing import AsyncGenerator, Callable, Iterator, Optional
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.http import HTTPMiddleware
-from starlette.requests import HTTPConnection
+from starlette.requests import HTTPConnection, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.testclient import TestClient
+from starlette.types import ASGIApp
+from starlette.websockets import WebSocket
-class CustomMiddleware(HTTPMiddleware):
- async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
- response = yield
- response.headers["Custom-Header"] = "Example"
-
-
-def homepage(request):
+async def homepage(request: Request) -> Response:
return PlainTextResponse("Homepage")
-def exc(request):
+async def exc(request: Request) -> Response:
raise Exception("Exc")
-def exc_stream(request):
+async def exc_stream(request: Request) -> Response:
return StreamingResponse(_generate_faulty_stream())
-def _generate_faulty_stream():
+def _generate_faulty_stream() -> Iterator[bytes]:
yield b"Ok"
raise Exception("Faulty Stream")
-class NoResponse:
- def __init__(self, scope, receive, send):
- pass
-
- def __await__(self):
- return self.dispatch().__await__()
-
- async def dispatch(self):
- pass
-
-
-async def websocket_endpoint(session):
+async def websocket_endpoint(session: WebSocket) -> None:
await session.accept()
await session.send_text("Hello, world!")
await session.close()
+class CustomMiddleware(HTTPMiddleware):
+ async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+ response = yield
+ response.headers["Custom-Header"] = "Example"
+
+
app = Starlette(
routes=[
Route("/", endpoint=homepage),
Route("/exc", endpoint=exc),
Route("/exc-stream", endpoint=exc_stream),
- Route("/no-response", endpoint=NoResponse),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
],
middleware=[Middleware(CustomMiddleware)],
)
-def test_custom_middleware(test_client_factory):
+def test_custom_middleware(
+ test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"
- with pytest.raises(RuntimeError):
- response = client.get("/no-response")
-
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
-def test_state_data_across_multiple_middlewares(test_client_factory):
+def test_state_data_across_multiple_middlewares(
+ test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
+ async def homepage(request: Request) -> Response:
+ return PlainTextResponse("OK")
+
expected_value1 = "foo"
expected_value2 = "bar"
- class aMiddleware(HTTPMiddleware):
- async def dispatch(
- self, conn: HTTPConnection
- ) -> AsyncGenerator[None, Response]:
- conn.state.foo = expected_value1
- yield
-
- class bMiddleware(HTTPMiddleware):
- async def dispatch(
- self, conn: HTTPConnection
- ) -> AsyncGenerator[None, Response]:
- conn.state.bar = expected_value2
- response = yield
- response.headers["X-State-Foo"] = conn.state.foo
+ async def middleware_a(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+ conn.state.foo = expected_value1
+ yield
- class cMiddleware(HTTPMiddleware):
- async def dispatch(
- self, conn: HTTPConnection
- ) -> AsyncGenerator[None, Response]:
- response = yield
- response.headers["X-State-Bar"] = conn.state.bar
+ async def middleware_b(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+ conn.state.bar = expected_value2
+ response = yield
+ response.headers["X-State-Foo"] = conn.state.foo
- def homepage(request):
- return PlainTextResponse("OK")
+ async def middleware_c(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
+ response = yield
+ response.headers["X-State-Bar"] = conn.state.bar
app = Starlette(
routes=[Route("/", homepage)],
middleware=[
- Middleware(aMiddleware),
- Middleware(bMiddleware),
- Middleware(cMiddleware),
+ Middleware(HTTPMiddleware, dispatch=middleware_a),
+ Middleware(HTTPMiddleware, dispatch=middleware_b),
+ Middleware(HTTPMiddleware, dispatch=middleware_c),
],
)
assert response.headers["X-State-Bar"] == expected_value2
-def test_dispatch_argument(test_client_factory):
- def homepage(request):
- return PlainTextResponse("Homepage")
-
- async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
- response = yield
- response.headers["Custom-Header"] = "Example"
+def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
+ class CustomMiddleware(HTTPMiddleware):
+ async def dispatch(
+ self, conn: HTTPConnection
+ ) -> AsyncGenerator[None, Response]:
+ _ = yield
+ yield
- app = Starlette(
- routes=[Route("/", homepage)],
- middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)],
- )
+ app = Starlette(middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
- response = client.get("/")
- assert response.headers["Custom-Header"] == "Example"
-
-
-def test_middleware_repr():
- middleware = Middleware(CustomMiddleware)
- assert repr(middleware) == "Middleware(CustomMiddleware)"
+ with pytest.raises(RuntimeError, match="should yield exactly once"):
+ client.get("/")
-def test_early_response(test_client_factory):
- async def index(request):
+def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
+ async def index(request: Request) -> Response:
return PlainTextResponse("Hello, world!")
class CustomMiddleware(HTTPMiddleware):
assert response.status_code == 401
-def test_too_many_yields(test_client_factory) -> None:
+def test_early_response_too_many_yields(
+ test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
- ) -> AsyncGenerator[None, Response]:
- _ = yield
- yield
+ ) -> AsyncGenerator[Optional[Response], Response]:
+ yield Response()
+ yield None
app = Starlette(middleware=[Middleware(CustomMiddleware)])
client.get("/")
-def test_error_response(test_client_factory) -> None:
+def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
class Failed(Exception):
pass
- async def failure(request):
+ async def failure(request: Request) -> Response:
raise Failed()
class CustomMiddleware(HTTPMiddleware):
assert response.status_code == 500
-def test_no_error_response(test_client_factory) -> None:
+def test_error_handling_must_send_response(
+ test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
class Failed(Exception):
pass
- async def index(request):
+ async def index(request: Request) -> Response:
raise Failed()
class CustomMiddleware(HTTPMiddleware):
try:
yield
except Failed:
- pass
+ pass # `yield <response>` expected
app = Starlette(
routes=[Route("/", index)],
client.get("/")
-ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
-
-
-class PureASGICustomMiddleware:
- def __init__(self, app: ASGIApp) -> None:
- self.app = app
-
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- ctxvar.set("set by middleware")
- await self.app(scope, receive, send)
- assert ctxvar.get() == "set by endpoint"
-
-
-class HTTPCustomMiddleware(HTTPMiddleware):
- async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
- ctxvar.set("set by middleware")
- yield
- assert ctxvar.get() == "set by endpoint"
-
-
-@pytest.mark.parametrize(
- "middleware_cls",
- [
- PureASGICustomMiddleware,
- HTTPCustomMiddleware,
- ],
-)
-def test_contextvars(test_client_factory, middleware_cls: type):
- async def homepage(request):
- assert ctxvar.get() == "set by middleware"
- ctxvar.set("set by endpoint")
- return PlainTextResponse("Homepage")
-
- app = Starlette(
- middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
- )
-
+def test_no_dispatch_given(
+ test_client_factory: Callable[[ASGIApp], TestClient]
+) -> None:
+ app = Starlette(middleware=[Middleware(HTTPMiddleware)])
client = test_client_factory(app)
- response = client.get("/")
- assert response.status_code == 200, response.content
+ with pytest.raises(NotImplementedError, match="No dispatch implementation"):
+ client.get("/")