]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix callable class generator dependencies (#1365)
authorMicah Rosales <2433663+mrosales@users.noreply.github.com>
Fri, 12 Jun 2020 20:57:59 +0000 (15:57 -0500)
committerGitHub <noreply@github.com>
Fri, 12 Jun 2020 20:57:59 +0000 (22:57 +0200)
* Fix callable class generator dependencies

* workaround to support asynccontextmanager backfill for pre python3.7

Co-authored-by: Micah Rosales <mrosales@users.noreply.github.com>
fastapi/dependencies/utils.py
tests/test_dependency_class.py

index 1a660f5d355faa200be273cd9be4286344a0be11..3ff7d3356c0ab078a2b948cbe93824bbf5f651e4 100644 (file)
@@ -274,7 +274,7 @@ def get_dependant(
     path_param_names = get_path_param_names(path)
     endpoint_signature = get_typed_signature(call)
     signature_params = endpoint_signature.parameters
-    if inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
+    if is_gen_callable(call) or is_async_gen_callable(call):
         check_dependency_contextmanagers()
     dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
     for param_name, param in signature_params.items():
@@ -412,19 +412,41 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
 
 def is_coroutine_callable(call: Callable) -> bool:
     if inspect.isroutine(call):
-        return asyncio.iscoroutinefunction(call)
+        return inspect.iscoroutinefunction(call)
     if inspect.isclass(call):
         return False
     call = getattr(call, "__call__", None)
-    return asyncio.iscoroutinefunction(call)
+    return inspect.iscoroutinefunction(call)
+
+
+def is_async_gen_callable(call: Callable) -> bool:
+    if inspect.isasyncgenfunction(call):
+        return True
+    call = getattr(call, "__call__", None)
+    return inspect.isasyncgenfunction(call)
+
+
+def is_gen_callable(call: Callable) -> bool:
+    if inspect.isgeneratorfunction(call):
+        return True
+    call = getattr(call, "__call__", None)
+    return inspect.isgeneratorfunction(call)
 
 
 async def solve_generator(
     *, call: Callable, stack: AsyncExitStack, sub_values: Dict[str, Any]
 ) -> Any:
-    if inspect.isgeneratorfunction(call):
+    if is_gen_callable(call):
         cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
-    elif inspect.isasyncgenfunction(call):
+    elif is_async_gen_callable(call):
+        if not inspect.isasyncgenfunction(call):
+            # asynccontextmanager from the async_generator backfill pre python3.7
+            # does not support callables that are not functions or methods.
+            # See https://github.com/python-trio/async_generator/issues/32
+            #
+            # Expand the callable class into its __call__ method before decorating it.
+            # This approach will work on newer python versions as well.
+            call = getattr(call, "__call__", None)
         cm = asynccontextmanager(call)(**sub_values)
     return await stack.enter_async_context(cm)
 
@@ -505,7 +527,7 @@ async def solve_dependencies(
             continue
         if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
             solved = dependency_cache[sub_dependant.cache_key]
-        elif inspect.isgeneratorfunction(call) or inspect.isasyncgenfunction(call):
+        elif is_gen_callable(call) or is_async_gen_callable(call):
             stack = request.scope.get("fastapi_astack")
             if stack is None:
                 raise RuntimeError(
index ba2e3cfcfe38169c64af57d195ca52be7ee374bc..bfe777f52705aadbc623d725b36ec61092309fa9 100644 (file)
@@ -1,3 +1,5 @@
+from typing import AsyncGenerator, Generator
+
 import pytest
 from fastapi import Depends, FastAPI
 from fastapi.testclient import TestClient
@@ -10,11 +12,21 @@ class CallableDependency:
         return value
 
 
+class CallableGenDependency:
+    def __call__(self, value: str) -> Generator[str, None, None]:
+        yield value
+
+
 class AsyncCallableDependency:
     async def __call__(self, value: str) -> str:
         return value
 
 
+class AsyncCallableGenDependency:
+    async def __call__(self, value: str) -> AsyncGenerator[str, None]:
+        yield value
+
+
 class MethodsDependency:
     def synchronous(self, value: str) -> str:
         return value
@@ -22,9 +34,17 @@ class MethodsDependency:
     async def asynchronous(self, value: str) -> str:
         return value
 
+    def synchronous_gen(self, value: str) -> Generator[str, None, None]:
+        yield value
+
+    async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
+        yield value
+
 
 callable_dependency = CallableDependency()
+callable_gen_dependency = CallableGenDependency()
 async_callable_dependency = AsyncCallableDependency()
+async_callable_gen_dependency = AsyncCallableGenDependency()
 methods_dependency = MethodsDependency()
 
 
@@ -33,11 +53,23 @@ async def get_callable_dependency(value: str = Depends(callable_dependency)):
     return value
 
 
+@app.get("/callable-gen-dependency")
+async def get_callable_gen_dependency(value: str = Depends(callable_gen_dependency)):
+    return value
+
+
 @app.get("/async-callable-dependency")
 async def get_callable_dependency(value: str = Depends(async_callable_dependency)):
     return value
 
 
+@app.get("/async-callable-gen-dependency")
+async def get_callable_gen_dependency(
+    value: str = Depends(async_callable_gen_dependency),
+):
+    return value
+
+
 @app.get("/synchronous-method-dependency")
 async def get_synchronous_method_dependency(
     value: str = Depends(methods_dependency.synchronous),
@@ -45,6 +77,13 @@ async def get_synchronous_method_dependency(
     return value
 
 
+@app.get("/synchronous-method-gen-dependency")
+async def get_synchronous_method_gen_dependency(
+    value: str = Depends(methods_dependency.synchronous_gen),
+):
+    return value
+
+
 @app.get("/asynchronous-method-dependency")
 async def get_asynchronous_method_dependency(
     value: str = Depends(methods_dependency.asynchronous),
@@ -52,6 +91,13 @@ async def get_asynchronous_method_dependency(
     return value
 
 
+@app.get("/asynchronous-method-gen-dependency")
+async def get_asynchronous_method_gen_dependency(
+    value: str = Depends(methods_dependency.asynchronous_gen),
+):
+    return value
+
+
 client = TestClient(app)
 
 
@@ -59,9 +105,13 @@ client = TestClient(app)
     "route,value",
     [
         ("/callable-dependency", "callable-dependency"),
+        ("/callable-gen-dependency", "callable-gen-dependency"),
         ("/async-callable-dependency", "async-callable-dependency"),
+        ("/async-callable-gen-dependency", "async-callable-gen-dependency"),
         ("/synchronous-method-dependency", "synchronous-method-dependency"),
+        ("/synchronous-method-gen-dependency", "synchronous-method-gen-dependency"),
         ("/asynchronous-method-dependency", "asynchronous-method-dependency"),
+        ("/asynchronous-method-gen-dependency", "asynchronous-method-gen-dependency"),
     ],
 )
 def test_class_dependency(route, value):