package_data={"starlette": ["py.typed"]},
include_package_data=True,
install_requires=[
- "anyio>=3.0.0,<4",
+ "anyio>=3.4.0,<4",
"typing_extensions; python_version < '3.10'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
],
else: # pragma: no cover
from typing_extensions import ParamSpec
-try:
- import contextvars # Python 3.7+ only or via contextvars backport.
-except ImportError: # pragma: no cover
- contextvars = None # type: ignore
-
T = typing.TypeVar("T")
P = ParamSpec("P")
async def run_in_threadpool(
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
- if contextvars is not None: # pragma: no cover
- # Ensure we run in the same context
- child = functools.partial(func, *args, **kwargs)
- context = contextvars.copy_context()
- func = context.run # type: ignore[assignment]
- args = (child,) # type: ignore[assignment]
- elif kwargs: # pragma: no cover
+ if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)
+from contextvars import ContextVar
+
import anyio
import pytest
+from starlette.applications import Starlette
from starlette.concurrency import run_until_first_complete
+from starlette.requests import Request
+from starlette.responses import Response
+from starlette.routing import Route
@pytest.mark.anyio
await run_until_first_complete((task1, {}), (task2, {}))
assert task1_finished.is_set()
assert not task2_finished.is_set()
+
+
+def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None:
+ ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
+ ctxvar.set(b"data")
+
+ def endpoint(request: Request) -> Response:
+ return Response(ctxvar.get())
+
+ app = Starlette(routes=[Route("/", endpoint)])
+ client = test_client_factory(app)
+
+ resp = client.get("/")
+ assert resp.content == b"data"