import sys
from typing import AsyncGenerator, ContextManager, TypeVar
+import anyio
+from anyio import CapacityLimiter
from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa
from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa
from starlette.concurrency import ( # noqa
async def contextmanager_in_threadpool(
cm: ContextManager[_T],
) -> AsyncGenerator[_T, None]:
+ # blocking __exit__ from running waiting on a free thread
+ # can create race conditions/deadlocks if the context manager itself
+ # has it's own internal pool (e.g. a database connection pool)
+ # to avoid this we let __exit__ run without a capacity limit
+ # since we're creating a new limiter for each call, any non-zero limit
+ # works (1 is arbitrary)
+ exit_limiter = CapacityLimiter(1)
try:
yield await run_in_threadpool(cm.__enter__)
except Exception as e:
- ok: bool = await run_in_threadpool(cm.__exit__, type(e), e, None)
+ ok = bool(
+ await anyio.to_thread.run_sync(
+ cm.__exit__, type(e), e, None, limiter=exit_limiter
+ )
+ )
if not ok:
raise e
else:
- await run_in_threadpool(cm.__exit__, None, None, None)
+ await anyio.to_thread.run_sync(
+ cm.__exit__, None, None, None, limiter=exit_limiter
+ )