]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix support for functools wraps and partial combined, for async and regular functio...
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 4 Dec 2025 07:29:28 +0000 (23:29 -0800)
committerGitHub <noreply@github.com>
Thu, 4 Dec 2025 07:29:28 +0000 (08:29 +0100)
Co-authored-by: Yurii Motov <yurii.motov.monte@gmail.com>
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
tests/test_dependency_wrapped.py

index 2a4d9a01027be8aa3844030ef94cffd3605ba7a8..9b545e4e5cb43f6d4645de57bd4930a0335b119b 100644 (file)
@@ -15,6 +15,19 @@ else:  # pragma: no cover
     from asyncio import iscoroutinefunction
 
 
+def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any:
+    if call is None:
+        return call  # pragma: no cover
+    unwrapped = inspect.unwrap(_impartial(call))
+    return unwrapped
+
+
+def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
+    while isinstance(func, partial):
+        func = func.func
+    return func
+
+
 @dataclass
 class SecurityRequirement:
     security_scheme: SecurityBase
@@ -75,37 +88,82 @@ class Dependant:
                 return True
         return False
 
-    @cached_property
-    def _unwrapped_call(self) -> Any:
-        if self.call is None:
-            return self.call  # pragma: no cover
-        unwrapped = inspect.unwrap(self.call)
-        if isinstance(unwrapped, partial):
-            unwrapped = unwrapped.func
-        return unwrapped
-
     @cached_property
     def is_gen_callable(self) -> bool:
-        if inspect.isgeneratorfunction(self._unwrapped_call):
+        if self.call is None:
+            return False  # pragma: no cover
+        if inspect.isgeneratorfunction(
+            _impartial(self.call)
+        ) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
+            return True
+        dunder_call = getattr(_impartial(self.call), "__call__", None)  # noqa: B004
+        if dunder_call is None:
+            return False  # pragma: no cover
+        if inspect.isgeneratorfunction(
+            _impartial(dunder_call)
+        ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
             return True
-        dunder_call = getattr(self._unwrapped_call, "__call__", None)  # noqa: B004
-        return inspect.isgeneratorfunction(dunder_call)
+        dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None)  # noqa: B004
+        if dunder_unwrapped_call is None:
+            return False  # pragma: no cover
+        if inspect.isgeneratorfunction(
+            _impartial(dunder_unwrapped_call)
+        ) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
+            return True
+        return False
 
     @cached_property
     def is_async_gen_callable(self) -> bool:
-        if inspect.isasyncgenfunction(self._unwrapped_call):
+        if self.call is None:
+            return False  # pragma: no cover
+        if inspect.isasyncgenfunction(
+            _impartial(self.call)
+        ) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
+            return True
+        dunder_call = getattr(_impartial(self.call), "__call__", None)  # noqa: B004
+        if dunder_call is None:
+            return False  # pragma: no cover
+        if inspect.isasyncgenfunction(
+            _impartial(dunder_call)
+        ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
+            return True
+        dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None)  # noqa: B004
+        if dunder_unwrapped_call is None:
+            return False  # pragma: no cover
+        if inspect.isasyncgenfunction(
+            _impartial(dunder_unwrapped_call)
+        ) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
             return True
-        dunder_call = getattr(self._unwrapped_call, "__call__", None)  # noqa: B004
-        return inspect.isasyncgenfunction(dunder_call)
+        return False
 
     @cached_property
     def is_coroutine_callable(self) -> bool:
-        if inspect.isroutine(self._unwrapped_call):
-            return iscoroutinefunction(self._unwrapped_call)
-        if inspect.isclass(self._unwrapped_call):
-            return False
-        dunder_call = getattr(self._unwrapped_call, "__call__", None)  # noqa: B004
-        return iscoroutinefunction(dunder_call)
+        if self.call is None:
+            return False  # pragma: no cover
+        if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
+            _impartial(self.call)
+        ):
+            return True
+        if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
+            _unwrapped_call(self.call)
+        ):
+            return True
+        dunder_call = getattr(_impartial(self.call), "__call__", None)  # noqa: B004
+        if dunder_call is None:
+            return False  # pragma: no cover
+        if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
+            _unwrapped_call(dunder_call)
+        ):
+            return True
+        dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None)  # noqa: B004
+        if dunder_unwrapped_call is None:
+            return False  # pragma: no cover
+        if iscoroutinefunction(
+            _impartial(dunder_unwrapped_call)
+        ) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
+            return True
+        # if inspect.isclass(self.call): False, covered by default return
+        return False
 
     @cached_property
     def computed_scope(self) -> Union[str, None]:
index 1a493a9fd05edf9efe423c0014dab770e7fb772f..91348c8ea1a7c17e55e37312dc8667b4fc7e0db2 100644 (file)
@@ -548,10 +548,10 @@ async def _solve_generator(
     *, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
 ) -> Any:
     assert dependant.call
-    if dependant.is_gen_callable:
-        cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
-    elif dependant.is_async_gen_callable:
+    if dependant.is_async_gen_callable:
         cm = asynccontextmanager(dependant.call)(**sub_values)
+    elif dependant.is_gen_callable:
+        cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
     return await stack.enter_async_context(cm)
 
 
index f581ccba4cd726733218ba05354d5ab4d27e89a9..08356712d6932d40c10968d5ed9df61bc394bc0c 100644 (file)
@@ -1,10 +1,18 @@
+import inspect
+import sys
 from functools import wraps
 from typing import AsyncGenerator, Generator
 
 import pytest
 from fastapi import Depends, FastAPI
+from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool
 from fastapi.testclient import TestClient
 
+if sys.version_info >= (3, 13):  # pragma: no cover
+    from inspect import iscoroutinefunction
+else:  # pragma: no cover
+    from asyncio import iscoroutinefunction
+
 
 def noop_wrap(func):
     @wraps(func)
@@ -14,8 +22,163 @@ def noop_wrap(func):
     return wrapper
 
 
+def noop_wrap_async(func):
+    if inspect.isgeneratorfunction(func):
+
+        @wraps(func)
+        async def gen_wrapper(*args, **kwargs):
+            async for item in iterate_in_threadpool(func(*args, **kwargs)):
+                yield item
+
+        return gen_wrapper
+
+    elif inspect.isasyncgenfunction(func):
+
+        @wraps(func)
+        async def async_gen_wrapper(*args, **kwargs):
+            async for item in func(*args, **kwargs):
+                yield item
+
+        return async_gen_wrapper
+
+    @wraps(func)
+    async def wrapper(*args, **kwargs):
+        if inspect.isroutine(func) and iscoroutinefunction(func):
+            return await func(*args, **kwargs)
+        if inspect.isclass(func):
+            return await run_in_threadpool(func, *args, **kwargs)
+        dunder_call = getattr(func, "__call__", None)  # noqa: B004
+        if iscoroutinefunction(dunder_call):
+            return await dunder_call(*args, **kwargs)
+        return await run_in_threadpool(func, *args, **kwargs)
+
+    return wrapper
+
+
+class ClassInstanceDep:
+    def __call__(self):
+        return True
+
+
+class_instance_dep = ClassInstanceDep()
+wrapped_class_instance_dep = noop_wrap(class_instance_dep)
+wrapped_class_instance_dep_async_wrapper = noop_wrap_async(class_instance_dep)
+
+
+class ClassInstanceGenDep:
+    def __call__(self):
+        yield True
+
+
+class_instance_gen_dep = ClassInstanceGenDep()
+wrapped_class_instance_gen_dep = noop_wrap(class_instance_gen_dep)
+
+
+class ClassInstanceWrappedDep:
+    @noop_wrap
+    def __call__(self):
+        return True
+
+
+class_instance_wrapped_dep = ClassInstanceWrappedDep()
+
+
+class ClassInstanceWrappedAsyncDep:
+    @noop_wrap_async
+    def __call__(self):
+        return True
+
+
+class_instance_wrapped_async_dep = ClassInstanceWrappedAsyncDep()
+
+
+class ClassInstanceWrappedGenDep:
+    @noop_wrap
+    def __call__(self):
+        yield True
+
+
+class_instance_wrapped_gen_dep = ClassInstanceWrappedGenDep()
+
+
+class ClassInstanceWrappedAsyncGenDep:
+    @noop_wrap_async
+    def __call__(self):
+        yield True
+
+
+class_instance_wrapped_async_gen_dep = ClassInstanceWrappedAsyncGenDep()
+
+
+class ClassDep:
+    def __init__(self):
+        self.value = True
+
+
+wrapped_class_dep = noop_wrap(ClassDep)
+wrapped_class_dep_async_wrapper = noop_wrap_async(ClassDep)
+
+
+class ClassInstanceAsyncDep:
+    async def __call__(self):
+        return True
+
+
+class_instance_async_dep = ClassInstanceAsyncDep()
+wrapped_class_instance_async_dep = noop_wrap(class_instance_async_dep)
+wrapped_class_instance_async_dep_async_wrapper = noop_wrap_async(
+    class_instance_async_dep
+)
+
+
+class ClassInstanceAsyncGenDep:
+    async def __call__(self):
+        yield True
+
+
+class_instance_async_gen_dep = ClassInstanceAsyncGenDep()
+wrapped_class_instance_async_gen_dep = noop_wrap(class_instance_async_gen_dep)
+
+
+class ClassInstanceAsyncWrappedDep:
+    @noop_wrap
+    async def __call__(self):
+        return True
+
+
+class_instance_async_wrapped_dep = ClassInstanceAsyncWrappedDep()
+
+
+class ClassInstanceAsyncWrappedAsyncDep:
+    @noop_wrap_async
+    async def __call__(self):
+        return True
+
+
+class_instance_async_wrapped_async_dep = ClassInstanceAsyncWrappedAsyncDep()
+
+
+class ClassInstanceAsyncWrappedGenDep:
+    @noop_wrap
+    async def __call__(self):
+        yield True
+
+
+class_instance_async_wrapped_gen_dep = ClassInstanceAsyncWrappedGenDep()
+
+
+class ClassInstanceAsyncWrappedGenAsyncDep:
+    @noop_wrap_async
+    async def __call__(self):
+        yield True
+
+
+class_instance_async_wrapped_gen_async_dep = ClassInstanceAsyncWrappedGenAsyncDep()
+
 app = FastAPI()
 
+# Sync wrapper
+
 
 @noop_wrap
 def wrapped_dependency() -> bool:
@@ -59,16 +222,225 @@ async def get_async_wrapped_gen_dependency(
     return value
 
 
+@app.get("/wrapped-class-instance-dependency/")
+async def get_wrapped_class_instance_dependency(
+    value: bool = Depends(wrapped_class_instance_dep),
+):
+    return value
+
+
+@app.get("/wrapped-class-instance-async-dependency/")
+async def get_wrapped_class_instance_async_dependency(
+    value: bool = Depends(wrapped_class_instance_async_dep),
+):
+    return value
+
+
+@app.get("/wrapped-class-instance-gen-dependency/")
+async def get_wrapped_class_instance_gen_dependency(
+    value: bool = Depends(wrapped_class_instance_gen_dep),
+):
+    return value
+
+
+@app.get("/wrapped-class-instance-async-gen-dependency/")
+async def get_wrapped_class_instance_async_gen_dependency(
+    value: bool = Depends(wrapped_class_instance_async_gen_dep),
+):
+    return value
+
+
+@app.get("/class-instance-wrapped-dependency/")
+async def get_class_instance_wrapped_dependency(
+    value: bool = Depends(class_instance_wrapped_dep),
+):
+    return value
+
+
+@app.get("/class-instance-wrapped-async-dependency/")
+async def get_class_instance_wrapped_async_dependency(
+    value: bool = Depends(class_instance_wrapped_async_dep),
+):
+    return value
+
+
+@app.get("/class-instance-async-wrapped-dependency/")
+async def get_class_instance_async_wrapped_dependency(
+    value: bool = Depends(class_instance_async_wrapped_dep),
+):
+    return value
+
+
+@app.get("/class-instance-async-wrapped-async-dependency/")
+async def get_class_instance_async_wrapped_async_dependency(
+    value: bool = Depends(class_instance_async_wrapped_async_dep),
+):
+    return value
+
+
+@app.get("/class-instance-wrapped-gen-dependency/")
+async def get_class_instance_wrapped_gen_dependency(
+    value: bool = Depends(class_instance_wrapped_gen_dep),
+):
+    return value
+
+
+@app.get("/class-instance-wrapped-async-gen-dependency/")
+async def get_class_instance_wrapped_async_gen_dependency(
+    value: bool = Depends(class_instance_wrapped_async_gen_dep),
+):
+    return value
+
+
+@app.get("/class-instance-async-wrapped-gen-dependency/")
+async def get_class_instance_async_wrapped_gen_dependency(
+    value: bool = Depends(class_instance_async_wrapped_gen_dep),
+):
+    return value
+
+
+@app.get("/class-instance-async-wrapped-gen-async-dependency/")
+async def get_class_instance_async_wrapped_gen_async_dependency(
+    value: bool = Depends(class_instance_async_wrapped_gen_async_dep),
+):
+    return value
+
+
+@app.get("/wrapped-class-dependency/")
+async def get_wrapped_class_dependency(value: ClassDep = Depends(wrapped_class_dep)):
+    return value.value
+
+
+@app.get("/wrapped-endpoint/")
+@noop_wrap
+def get_wrapped_endpoint():
+    return True
+
+
+@app.get("/async-wrapped-endpoint/")
+@noop_wrap
+async def get_async_wrapped_endpoint():
+    return True
+
+
+# Async wrapper
+
+
+@noop_wrap_async
+def wrapped_dependency_async_wrapper() -> bool:
+    return True
+
+
+@noop_wrap_async
+def wrapped_gen_dependency_async_wrapper() -> Generator[bool, None, None]:
+    yield True
+
+
+@noop_wrap_async
+async def async_wrapped_dependency_async_wrapper() -> bool:
+    return True
+
+
+@noop_wrap_async
+async def async_wrapped_gen_dependency_async_wrapper() -> AsyncGenerator[bool, None]:
+    yield True
+
+
+@app.get("/wrapped-dependency-async-wrapper/")
+async def get_wrapped_dependency_async_wrapper(
+    value: bool = Depends(wrapped_dependency_async_wrapper),
+):
+    return value
+
+
+@app.get("/wrapped-gen-dependency-async-wrapper/")
+async def get_wrapped_gen_dependency_async_wrapper(
+    value: bool = Depends(wrapped_gen_dependency_async_wrapper),
+):
+    return value
+
+
+@app.get("/async-wrapped-dependency-async-wrapper/")
+async def get_async_wrapped_dependency_async_wrapper(
+    value: bool = Depends(async_wrapped_dependency_async_wrapper),
+):
+    return value
+
+
+@app.get("/async-wrapped-gen-dependency-async-wrapper/")
+async def get_async_wrapped_gen_dependency_async_wrapper(
+    value: bool = Depends(async_wrapped_gen_dependency_async_wrapper),
+):
+    return value
+
+
+@app.get("/wrapped-class-instance-dependency-async-wrapper/")
+async def get_wrapped_class_instance_dependency_async_wrapper(
+    value: bool = Depends(wrapped_class_instance_dep_async_wrapper),
+):
+    return value
+
+
+@app.get("/wrapped-class-instance-async-dependency-async-wrapper/")
+async def get_wrapped_class_instance_async_dependency_async_wrapper(
+    value: bool = Depends(wrapped_class_instance_async_dep_async_wrapper),
+):
+    return value
+
+
+@app.get("/wrapped-class-dependency-async-wrapper/")
+async def get_wrapped_class_dependency_async_wrapper(
+    value: ClassDep = Depends(wrapped_class_dep_async_wrapper),
+):
+    return value.value
+
+
+@app.get("/wrapped-endpoint-async-wrapper/")
+@noop_wrap_async
+def get_wrapped_endpoint_async_wrapper():
+    return True
+
+
+@app.get("/async-wrapped-endpoint-async-wrapper/")
+@noop_wrap_async
+async def get_async_wrapped_endpoint_async_wrapper():
+    return True
+
+
 client = TestClient(app)
 
 
 @pytest.mark.parametrize(
     "route",
     [
-        "/wrapped-dependency",
-        "/wrapped-gen-dependency",
-        "/async-wrapped-dependency",
-        "/async-wrapped-gen-dependency",
+        "/wrapped-dependency/",
+        "/wrapped-gen-dependency/",
+        "/async-wrapped-dependency/",
+        "/async-wrapped-gen-dependency/",
+        "/wrapped-class-instance-dependency/",
+        "/wrapped-class-instance-async-dependency/",
+        "/wrapped-class-instance-gen-dependency/",
+        "/wrapped-class-instance-async-gen-dependency/",
+        "/class-instance-wrapped-dependency/",
+        "/class-instance-wrapped-async-dependency/",
+        "/class-instance-async-wrapped-dependency/",
+        "/class-instance-async-wrapped-async-dependency/",
+        "/class-instance-wrapped-gen-dependency/",
+        "/class-instance-wrapped-async-gen-dependency/",
+        "/class-instance-async-wrapped-gen-dependency/",
+        "/class-instance-async-wrapped-gen-async-dependency/",
+        "/wrapped-class-dependency/",
+        "/wrapped-endpoint/",
+        "/async-wrapped-endpoint/",
+        "/wrapped-dependency-async-wrapper/",
+        "/wrapped-gen-dependency-async-wrapper/",
+        "/async-wrapped-dependency-async-wrapper/",
+        "/async-wrapped-gen-dependency-async-wrapper/",
+        "/wrapped-class-instance-dependency-async-wrapper/",
+        "/wrapped-class-instance-async-dependency-async-wrapper/",
+        "/wrapped-class-dependency-async-wrapper/",
+        "/wrapped-endpoint-async-wrapper/",
+        "/async-wrapped-endpoint-async-wrapper/",
     ],
 )
 def test_class_dependency(route):