]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Document interaction of BaseHTTPMiddleware and contextvars (#1525)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Sun, 24 Apr 2022 04:48:50 +0000 (23:48 -0500)
committerGitHub <noreply@github.com>
Sun, 24 Apr 2022 04:48:50 +0000 (06:48 +0200)
* test: document behavior of ContextVars with BaseHTTPMiddleware

* lint & fix

* add pragma

* Update test_base.py

* Update tests/middleware/test_base.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
* fix typo

* try to make comment clearer

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
tests/middleware/test_base.py

index 04da3a961a9bd57d1fd7aa5fa27a93b694a2bce7..0d023ddd18fe06672f2e36fc6bece5953ec4c348 100644 (file)
@@ -1,3 +1,5 @@
+import contextvars
+
 import pytest
 
 from starlette.applications import Starlette
@@ -5,6 +7,7 @@ from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.responses import PlainTextResponse, StreamingResponse
 from starlette.routing import Mount, Route, WebSocketRoute
+from starlette.types import ASGIApp, Receive, Scope, Send
 
 
 class CustomMiddleware(BaseHTTPMiddleware):
@@ -163,3 +166,58 @@ def test_exception_on_mounted_apps(test_client_factory):
     with pytest.raises(Exception) as ctx:
         client.get("/sub/")
     assert str(ctx.value) == "Exc"
+
+
+ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
+
+
+class CustomMiddlewareWithoutBaseHTTPMiddleware:
+    def __init__(self, app: ASGIApp) -> None:
+        self.app = app
+
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        ctxvar.set("set by middleware")
+        await self.app(scope, receive, send)
+        assert ctxvar.get() == "set by endpoint"
+
+
+class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request, call_next):
+        ctxvar.set("set by middleware")
+        resp = await call_next(request)
+        assert ctxvar.get() == "set by endpoint"
+        return resp  # pragma: no cover
+
+
+@pytest.mark.parametrize(
+    "middleware_cls",
+    [
+        CustomMiddlewareWithoutBaseHTTPMiddleware,
+        pytest.param(
+            CustomMiddlewareUsingBaseHTTPMiddleware,
+            marks=pytest.mark.xfail(
+                reason=(
+                    "BaseHTTPMiddleware creates a TaskGroup which copies the context"
+                    "and erases any changes to it made within the TaskGroup"
+                ),
+                raises=AssertionError,
+            ),
+        ),
+    ],
+)
+def test_contextvars(test_client_factory, middleware_cls: type):
+    # this has to be an async endpoint because Starlette calls run_in_threadpool
+    # on sync endpoints which has it's own set of peculiarities w.r.t propagating
+    # contextvars (it propagates them forwards but not backwards)
+    async def homepage(request):
+        assert ctxvar.get() == "set by middleware"
+        ctxvar.set("set by endpoint")
+        return PlainTextResponse("Homepage")
+
+    app = Starlette(
+        middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
+    )
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.status_code == 200, response.content