+import contextvars
+
import pytest
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.types import ASGIApp, Receive, Scope, Send
class CustomMiddleware(BaseHTTPMiddleware):
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"
+
+
+ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
+
+
+class CustomMiddlewareWithoutBaseHTTPMiddleware:
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ ctxvar.set("set by middleware")
+ await self.app(scope, receive, send)
+ assert ctxvar.get() == "set by endpoint"
+
+
+class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ ctxvar.set("set by middleware")
+ resp = await call_next(request)
+ assert ctxvar.get() == "set by endpoint"
+ return resp # pragma: no cover
+
+
+@pytest.mark.parametrize(
+ "middleware_cls",
+ [
+ CustomMiddlewareWithoutBaseHTTPMiddleware,
+ pytest.param(
+ CustomMiddlewareUsingBaseHTTPMiddleware,
+ marks=pytest.mark.xfail(
+ reason=(
+ "BaseHTTPMiddleware creates a TaskGroup which copies the context"
+ "and erases any changes to it made within the TaskGroup"
+ ),
+ raises=AssertionError,
+ ),
+ ),
+ ],
+)
+def test_contextvars(test_client_factory, middleware_cls: type):
+ # this has to be an async endpoint because Starlette calls run_in_threadpool
+ # on sync endpoints which has it's own set of peculiarities w.r.t propagating
+ # contextvars (it propagates them forwards but not backwards)
+ async def homepage(request):
+ assert ctxvar.get() == "set by middleware"
+ ctxvar.set("set by endpoint")
+ return PlainTextResponse("Homepage")
+
+ app = Starlette(
+ middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
+ )
+
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.status_code == 200, response.content