path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
- if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
+ if is_gen_callable(call) or is_async_gen_callable(call):
check_dependency_contextmanagers()
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items():
def is_coroutine_callable(call: Callable) -> bool:
if inspect.isroutine(call):
- return asyncio.iscoroutinefunction(call)
+ return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
call = getattr(call, "__call__", None)
- return asyncio.iscoroutinefunction(call)
+ return inspect.iscoroutinefunction(call)
+
+
+def is_async_gen_callable(call: Callable) -> bool:
+ if inspect.isasyncgenfunction(call):
+ return True
+ call = getattr(call, "__call__", None)
+ return inspect.isasyncgenfunction(call)
+
+
+def is_gen_callable(call: Callable) -> bool:
+ if inspect.isgeneratorfunction(call):
+ return True
+ call = getattr(call, "__call__", None)
+ return inspect.isgeneratorfunction(call)
async def solve_generator(
*, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
) -> Any:
- if inspect.isgeneratorfunction(call):
+ if is_gen_callable(call):
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
- elif inspect.isasyncgenfunction(call):
+ elif is_async_gen_callable(call):
+ if not inspect.isasyncgenfunction(call):
+ # asynccontextmanager from the async_generator backfill pre python3.7
+ # does not support callables that are not functions or methods.
+ # See https://github.com/python-trio/async_generator/issues/32
+ #
+ # Expand the callable class into its __call__ method before decorating it.
+ # This approach will work on newer python versions as well.
+ call = getattr(call, "__call__", None)
cm = asynccontextmanager(call)(**sub_values)
return await stack.enter_async_context(cm)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
- elif inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
+ elif is_gen_callable(call) or is_async_gen_callable(call):
stack = request.scope.get("fastapi_astack")
if stack is None:
raise RuntimeError(
+from typing import AsyncGenerator, Generator
+
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
return value
+class CallableGenDependency:
+ def __call__(self, value: str) -> Generator[str, None, None]:
+ yield value
+
+
class AsyncCallableDependency:
async def __call__(self, value: str) -> str:
return value
+class AsyncCallableGenDependency:
+ async def __call__(self, value: str) -> AsyncGenerator[str, None]:
+ yield value
+
+
class MethodsDependency:
def synchronous(self, value: str) -> str:
return value
async def asynchronous(self, value: str) -> str:
return value
+ def synchronous_gen(self, value: str) -> Generator[str, None, None]:
+ yield value
+
+ async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
+ yield value
+
callable_dependency = CallableDependency()
+callable_gen_dependency = CallableGenDependency()
async_callable_dependency = AsyncCallableDependency()
+async_callable_gen_dependency = AsyncCallableGenDependency()
methods_dependency = MethodsDependency()
return value
+@app.get("/callable-gen-dependency")
+async def get_callable_gen_dependency(value: str = Depends(callable_gen_dependency)):
+ return value
+
+
@app.get("/async-callable-dependency")
async def get_callable_dependency(value: str = Depends(async_callable_dependency)):
return value
+@app.get("/async-callable-gen-dependency")
+async def get_callable_gen_dependency(
+ value: str = Depends(async_callable_gen_dependency),
+):
+ return value
+
+
@app.get("/synchronous-method-dependency")
async def get_synchronous_method_dependency(
value: str = Depends(methods_dependency.synchronous),
return value
+@app.get("/synchronous-method-gen-dependency")
+async def get_synchronous_method_gen_dependency(
+ value: str = Depends(methods_dependency.synchronous_gen),
+):
+ return value
+
+
@app.get("/asynchronous-method-dependency")
async def get_asynchronous_method_dependency(
value: str = Depends(methods_dependency.asynchronous),
return value
+@app.get("/asynchronous-method-gen-dependency")
+async def get_asynchronous_method_gen_dependency(
+ value: str = Depends(methods_dependency.asynchronous_gen),
+):
+ return value
+
+
client = TestClient(app)
"route,value",
[
("/callable-dependency", "callable-dependency"),
+ ("/callable-gen-dependency", "callable-gen-dependency"),
("/async-callable-dependency", "async-callable-dependency"),
+ ("/async-callable-gen-dependency", "async-callable-gen-dependency"),
("/synchronous-method-dependency", "synchronous-method-dependency"),
+ ("/synchronous-method-gen-dependency", "synchronous-method-gen-dependency"),
("/asynchronous-method-dependency", "asynchronous-method-dependency"),
+ ("/asynchronous-method-gen-dependency", "asynchronous-method-gen-dependency"),
],
)
def test_class_dependency(route, value):