]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Handle wrapped dependencies (#9555)
authorMatthew Martin <phy1729@gmail.com>
Tue, 2 Dec 2025 13:34:19 +0000 (07:34 -0600)
committerGitHub <noreply@github.com>
Tue, 2 Dec 2025 13:34:19 +0000 (14:34 +0100)
Co-authored-by: Motov Yurii <109919500+YuriiMotov@users.noreply.github.com>
Co-authored-by: Yurii Motov <yurii.motov.monte@gmail.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/dependencies/models.py
tests/test_dependency_wrapped.py [new file with mode: 0644]

index fbb666a7daea25ab7689d26dc9ae3c6234a02855..13486dd1890bc96a58102740d8a4636fd085466a 100644 (file)
@@ -75,27 +75,33 @@ class Dependant:
                 return True
         return False
 
+    @cached_property
+    def _unwrapped_call(self) -> Any:
+        if self.call is None:
+            return self.call  # pragma: no cover
+        return inspect.unwrap(self.call)
+
     @cached_property
     def is_gen_callable(self) -> bool:
-        if inspect.isgeneratorfunction(self.call):
+        if inspect.isgeneratorfunction(self._unwrapped_call):
             return True
-        dunder_call = getattr(self.call, "__call__", None)  # noqa: B004
+        dunder_call = getattr(self._unwrapped_call, "__call__", None)  # noqa: B004
         return inspect.isgeneratorfunction(dunder_call)
 
     @cached_property
     def is_async_gen_callable(self) -> bool:
-        if inspect.isasyncgenfunction(self.call):
+        if inspect.isasyncgenfunction(self._unwrapped_call):
             return True
-        dunder_call = getattr(self.call, "__call__", None)  # noqa: B004
+        dunder_call = getattr(self._unwrapped_call, "__call__", None)  # noqa: B004
         return inspect.isasyncgenfunction(dunder_call)
 
     @cached_property
     def is_coroutine_callable(self) -> bool:
-        if inspect.isroutine(self.call):
-            return iscoroutinefunction(self.call)
-        if inspect.isclass(self.call):
+        if inspect.isroutine(self._unwrapped_call):
+            return iscoroutinefunction(self._unwrapped_call)
+        if inspect.isclass(self._unwrapped_call):
             return False
-        dunder_call = getattr(self.call, "__call__", None)  # noqa: B004
+        dunder_call = getattr(self._unwrapped_call, "__call__", None)  # noqa: B004
         return iscoroutinefunction(dunder_call)
 
     @cached_property
diff --git a/tests/test_dependency_wrapped.py b/tests/test_dependency_wrapped.py
new file mode 100644 (file)
index 0000000..f581ccb
--- /dev/null
@@ -0,0 +1,77 @@
+from functools import wraps
+from typing import AsyncGenerator, Generator
+
+import pytest
+from fastapi import Depends, FastAPI
+from fastapi.testclient import TestClient
+
+
+def noop_wrap(func):
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+app = FastAPI()
+
+
+@noop_wrap
+def wrapped_dependency() -> bool:
+    return True
+
+
+@noop_wrap
+def wrapped_gen_dependency() -> Generator[bool, None, None]:
+    yield True
+
+
+@noop_wrap
+async def async_wrapped_dependency() -> bool:
+    return True
+
+
+@noop_wrap
+async def async_wrapped_gen_dependency() -> AsyncGenerator[bool, None]:
+    yield True
+
+
+@app.get("/wrapped-dependency/")
+async def get_wrapped_dependency(value: bool = Depends(wrapped_dependency)):
+    return value
+
+
+@app.get("/wrapped-gen-dependency/")
+async def get_wrapped_gen_dependency(value: bool = Depends(wrapped_gen_dependency)):
+    return value
+
+
+@app.get("/async-wrapped-dependency/")
+async def get_async_wrapped_dependency(value: bool = Depends(async_wrapped_dependency)):
+    return value
+
+
+@app.get("/async-wrapped-gen-dependency/")
+async def get_async_wrapped_gen_dependency(
+    value: bool = Depends(async_wrapped_gen_dependency),
+):
+    return value
+
+
+client = TestClient(app)
+
+
+@pytest.mark.parametrize(
+    "route",
+    [
+        "/wrapped-dependency",
+        "/wrapped-gen-dependency",
+        "/async-wrapped-dependency",
+        "/async-wrapped-gen-dependency",
+    ],
+)
+def test_class_dependency(route):
+    response = client.get(route)
+    assert response.status_code == 200, response.text
+    assert response.json() is True