]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Remove manual contextvar copy logic (#1421)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Mon, 31 Jan 2022 23:18:13 +0000 (00:18 +0100)
committerGitHub <noreply@github.com>
Mon, 31 Jan 2022 23:18:13 +0000 (00:18 +0100)
* Remove manual contextvar copy logic

* Add test

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
setup.py
starlette/concurrency.py
tests/test_concurrency.py

index 2d1b7da1411c29fa29c7fdd9c13292cdbf68ffb9..697a561a61e4a4c9d71d4da611cb784dec96044a 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -38,7 +38,7 @@ setup(
     package_data={"starlette": ["py.typed"]},
     include_package_data=True,
     install_requires=[
-        "anyio>=3.0.0,<4",
+        "anyio>=3.4.0,<4",
         "typing_extensions; python_version < '3.10'",
         "contextlib2 >= 21.6.0; python_version < '3.7'",
     ],
index 78602077a65bebfd58f947072502514acde28258..ac2ce6eb735994a1017634a81a053480fce47586 100644 (file)
@@ -9,11 +9,6 @@ if sys.version_info >= (3, 10):  # pragma: no cover
 else:  # pragma: no cover
     from typing_extensions import ParamSpec
 
-try:
-    import contextvars  # Python 3.7+ only or via contextvars backport.
-except ImportError:  # pragma: no cover
-    contextvars = None  # type: ignore
-
 
 T = typing.TypeVar("T")
 P = ParamSpec("P")
@@ -33,13 +28,7 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
 async def run_in_threadpool(
     func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
 ) -> T:
-    if contextvars is not None:  # pragma: no cover
-        # Ensure we run in the same context
-        child = functools.partial(func, *args, **kwargs)
-        context = contextvars.copy_context()
-        func = context.run  # type: ignore[assignment]
-        args = (child,)  # type: ignore[assignment]
-    elif kwargs:  # pragma: no cover
+    if kwargs:  # pragma: no cover
         # run_sync doesn't accept 'kwargs', so bind them in here
         func = functools.partial(func, **kwargs)
     return await anyio.to_thread.run_sync(func, *args)
index cc5eba974fb4b0cb3c38e7017a0faf611c3b82fe..22b9da0e8356816e126789090cff4ee47dbf5b3e 100644 (file)
@@ -1,7 +1,13 @@
+from contextvars import ContextVar
+
 import anyio
 import pytest
 
+from starlette.applications import Starlette
 from starlette.concurrency import run_until_first_complete
+from starlette.requests import Request
+from starlette.responses import Response
+from starlette.routing import Route
 
 
 @pytest.mark.anyio
@@ -20,3 +26,17 @@ async def test_run_until_first_complete():
     await run_until_first_complete((task1, {}), (task2, {}))
     assert task1_finished.is_set()
     assert not task2_finished.is_set()
+
+
+def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None:
+    ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
+    ctxvar.set(b"data")
+
+    def endpoint(request: Request) -> Response:
+        return Response(ctxvar.get())
+
+    app = Starlette(routes=[Route("/", endpoint)])
+    client = test_client_factory(app)
+
+    resp = client.get("/")
+    assert resp.content == b"data"