--- /dev/null
+from functools import partial
+from typing import Generator, Optional
+from ..types import ASGIApp, Scope, Receive, Send, Message
+from ..responses import Response
+from ..datastructures import MutableHeaders
+
+
+class HTTPMiddleware:
+ DispatchFlow = Generator[Optional[Response], Response, None]
+
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ def dispatch(self, scope: Scope) -> DispatchFlow:
+ raise NotImplementedError
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ flow = self.dispatch(scope)
+
+ try:
+ # Run until first `yield` to allow modifying the connection scope.
+ # Middleware may return a response before we call into the underlying app.
+ early_response = flow.__next__()
+ except StopIteration:
+ raise RuntimeError("dispatch() did not run")
+
+ if early_response is not None:
+ await early_response(scope, receive, send)
+ return
+
+ response_started = set[bool]()
+
+ wrapped_send = partial(
+ self._send,
+ flow=flow,
+ response_started=response_started,
+ send=send,
+ )
+
+ try:
+ await self.app(scope, receive, wrapped_send)
+ except Exception as exc:
+ try:
+ response = flow.throw(exc)
+ except Exception:
+ # Exception was not handled, or they raised another one.
+ raise
+
+ if response is None:
+ raise RuntimeError(
+ f"dispatch() handled exception {exc!r}, "
+ "but no response application was returned"
+ )
+
+ await response(scope, receive, send)
+
+ if not response_started:
+ raise RuntimeError("No response returned.")
+
+ async def _send(
+ self, message: Message, *, flow: DispatchFlow, response_started: set, send: Send
+ ) -> None:
+ if message["type"] == "http.response.start":
+ response_started.add(True)
+
+ response = Response(status_code=message["status"])
+ response.raw_headers.clear()
+
+ try:
+ flow.send(response)
+ except StopIteration as exc:
+ if exc.value is not None:
+ raise RuntimeError("swapping responses it not supported")
+ else:
+ raise RuntimeError("dispatch() should yield exactly once")
+
+ headers = MutableHeaders(raw=message["headers"])
+ headers.update(response.headers)
+
+ await send(message)
--- /dev/null
+import contextvars
+
+import pytest
+
+from starlette.applications import Starlette
+from starlette.middleware import Middleware
+from starlette.middleware.http import HTTPMiddleware
+from starlette.responses import PlainTextResponse, StreamingResponse
+from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.types import ASGIApp, Receive, Scope, Send
+
+
+class CustomMiddleware(HTTPMiddleware):
+ def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ response = yield None
+ response.headers["Custom-Header"] = "Example"
+
+
+def homepage(request):
+ return PlainTextResponse("Homepage")
+
+
+def exc(request):
+ raise Exception("Exc")
+
+
+def exc_stream(request):
+ return StreamingResponse(_generate_faulty_stream())
+
+
+def _generate_faulty_stream():
+ yield b"Ok"
+ raise Exception("Faulty Stream")
+
+
+class NoResponse:
+ def __init__(self, scope, receive, send):
+ pass
+
+ def __await__(self):
+ return self.dispatch().__await__()
+
+ async def dispatch(self):
+ pass
+
+
+async def websocket_endpoint(session):
+ await session.accept()
+ await session.send_text("Hello, world!")
+ await session.close()
+
+
+app = Starlette(
+ routes=[
+ Route("/", endpoint=homepage),
+ Route("/exc", endpoint=exc),
+ Route("/exc-stream", endpoint=exc_stream),
+ Route("/no-response", endpoint=NoResponse),
+ WebSocketRoute("/ws", endpoint=websocket_endpoint),
+ ],
+ middleware=[Middleware(CustomMiddleware)],
+)
+
+
+def test_custom_middleware(test_client_factory):
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.headers["Custom-Header"] == "Example"
+
+ with pytest.raises(Exception) as ctx:
+ response = client.get("/exc")
+ assert str(ctx.value) == "Exc"
+
+ with pytest.raises(Exception) as ctx:
+ response = client.get("/exc-stream")
+ assert str(ctx.value) == "Faulty Stream"
+
+ with pytest.raises(RuntimeError):
+ response = client.get("/no-response")
+
+ with client.websocket_connect("/ws") as session:
+ text = session.receive_text()
+ assert text == "Hello, world!"
+
+
+def test_state_data_across_multiple_middlewares(test_client_factory):
+ expected_value1 = "foo"
+ expected_value2 = "bar"
+
+ class aMiddleware(HTTPMiddleware):
+ def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ scope["state_foo"] = expected_value1
+ yield None
+
+ class bMiddleware(HTTPMiddleware):
+ def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ scope["state_bar"] = expected_value2
+ response = yield None
+ response.headers["X-State-Foo"] = scope["state_foo"]
+
+ class cMiddleware(HTTPMiddleware):
+ def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ response = yield None
+ response.headers["X-State-Bar"] = scope["state_bar"]
+
+ def homepage(request):
+ return PlainTextResponse("OK")
+
+ app = Starlette(
+ routes=[Route("/", homepage)],
+ middleware=[
+ Middleware(aMiddleware),
+ Middleware(bMiddleware),
+ Middleware(cMiddleware),
+ ],
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.text == "OK"
+ assert response.headers["X-State-Foo"] == expected_value1
+ assert response.headers["X-State-Bar"] == expected_value2
+
+
+def test_app_middleware_argument(test_client_factory):
+ def homepage(request):
+ return PlainTextResponse("Homepage")
+
+ app = Starlette(
+ routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.headers["Custom-Header"] == "Example"
+
+
+def test_middleware_repr():
+ middleware = Middleware(CustomMiddleware)
+ assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+
+def test_fully_evaluated_response(test_client_factory):
+ # Test for https://github.com/encode/starlette/issues/1022
+ class CustomMiddleware(HTTPMiddleware):
+ def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ yield PlainTextResponse("Custom")
+
+ app = Starlette(middleware=[Middleware(CustomMiddleware)])
+
+ client = test_client_factory(app)
+ response = client.get("/does_not_exist")
+ assert response.text == "Custom"
+
+
+def test_exception_on_mounted_apps(test_client_factory):
+ sub_app = Starlette(routes=[Route("/", exc)])
+ app = Starlette(routes=[Mount("/sub", app=sub_app)])
+
+ client = test_client_factory(app)
+ with pytest.raises(Exception) as ctx:
+ client.get("/sub/")
+ assert str(ctx.value) == "Exc"
+
+
+ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
+
+
+class CustomMiddlewareWithoutBaseHTTPMiddleware:
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ ctxvar.set("set by middleware")
+ await self.app(scope, receive, send)
+ assert ctxvar.get() == "set by endpoint"
+
+
+class CustomMiddlewareUsingHTTPMiddleware(HTTPMiddleware):
+ def dispatch(self, scope) -> HTTPMiddleware.DispatchFlow:
+ ctxvar.set("set by middleware")
+ yield None
+ assert ctxvar.get() == "set by endpoint"
+
+
+@pytest.mark.parametrize(
+ "middleware_cls",
+ [
+ CustomMiddlewareWithoutBaseHTTPMiddleware,
+ CustomMiddlewareUsingHTTPMiddleware,
+ ],
+)
+def test_contextvars(test_client_factory, middleware_cls: type):
+ # this has to be an async endpoint because Starlette calls run_in_threadpool
+ # on sync endpoints which has it's own set of peculiarities w.r.t propagating
+ # contextvars (it propagates them forwards but not backwards)
+ async def homepage(request):
+ assert ctxvar.get() == "set by middleware"
+ ctxvar.set("set by endpoint")
+ return PlainTextResponse("Homepage")
+
+ app = Starlette(
+ middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.status_code == 200, response.content