From: Marcelo Trylesinski Date: Mon, 31 Jan 2022 23:18:13 +0000 (+0100) Subject: Remove manual contextvar copy logic (#1421) X-Git-Tag: 0.19.0~48 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=565898b3fdce24f85ec279fe998832da1a0ee664;p=thirdparty%2Fstarlette.git Remove manual contextvar copy logic (#1421) * Remove manual contextvar copy logic * Add test Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- diff --git a/setup.py b/setup.py index 2d1b7da1..697a561a 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ setup( package_data={"starlette": ["py.typed"]}, include_package_data=True, install_requires=[ - "anyio>=3.0.0,<4", + "anyio>=3.4.0,<4", "typing_extensions; python_version < '3.10'", "contextlib2 >= 21.6.0; python_version < '3.7'", ], diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 78602077..ac2ce6eb 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -9,11 +9,6 @@ if sys.version_info >= (3, 10): # pragma: no cover else: # pragma: no cover from typing_extensions import ParamSpec -try: - import contextvars # Python 3.7+ only or via contextvars backport. -except ImportError: # pragma: no cover - contextvars = None # type: ignore - T = typing.TypeVar("T") P = ParamSpec("P") @@ -33,13 +28,7 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) - async def run_in_threadpool( func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> T: - if contextvars is not None: # pragma: no cover - # Ensure we run in the same context - child = functools.partial(func, *args, **kwargs) - context = contextvars.copy_context() - func = context.run # type: ignore[assignment] - args = (child,) # type: ignore[assignment] - elif kwargs: # pragma: no cover + if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await anyio.to_thread.run_sync(func, *args) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index cc5eba97..22b9da0e 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,7 +1,13 @@ +from contextvars import ContextVar + import anyio import pytest +from starlette.applications import Starlette from starlette.concurrency import run_until_first_complete +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route @pytest.mark.anyio @@ -20,3 +26,17 @@ async def test_run_until_first_complete(): await run_until_first_complete((task1, {}), (task2, {})) assert task1_finished.is_set() assert not task2_finished.is_set() + + +def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None: + ctxvar: ContextVar[bytes] = ContextVar("ctxvar") + ctxvar.set(b"data") + + def endpoint(request: Request) -> Response: + return Response(ctxvar.get()) + + app = Starlette(routes=[Route("/", endpoint)]) + client = test_client_factory(app) + + resp = client.get("/") + assert resp.content == b"data"