]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Allow using dependables with `functools.partial()` (#9753)
authorLie Ryan <lie.1296@gmail.com>
Tue, 2 Dec 2025 20:58:30 +0000 (07:58 +1100)
committerGitHub <noreply@github.com>
Tue, 2 Dec 2025 20:58:30 +0000 (20:58 +0000)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Motov Yurii <109919500+YuriiMotov@users.noreply.github.com>
Co-authored-by: Yurii Motov <yurii.motov.monte@gmail.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/dependencies/models.py
tests/test_dependency_partial.py [new file with mode: 0644]

index 13486dd1890bc96a58102740d8a4636fd085466a..2a4d9a01027be8aa3844030ef94cffd3605ba7a8 100644 (file)
@@ -1,7 +1,7 @@
 import inspect
 import sys
 from dataclasses import dataclass, field
-from functools import cached_property
+from functools import cached_property, partial
 from typing import Any, Callable, List, Optional, Sequence, Union
 
 from fastapi._compat import ModelField
@@ -79,7 +79,10 @@ class Dependant:
     def _unwrapped_call(self) -> Any:
         if self.call is None:
             return self.call  # pragma: no cover
-        return inspect.unwrap(self.call)
+        unwrapped = inspect.unwrap(self.call)
+        if isinstance(unwrapped, partial):
+            unwrapped = unwrapped.func
+        return unwrapped
 
     @cached_property
     def is_gen_callable(self) -> bool:
diff --git a/tests/test_dependency_partial.py b/tests/test_dependency_partial.py
new file mode 100644 (file)
index 0000000..61a7623
--- /dev/null
@@ -0,0 +1,251 @@
+from functools import partial
+from typing import AsyncGenerator, Generator
+
+import pytest
+from fastapi import Depends, FastAPI
+from fastapi.testclient import TestClient
+from typing_extensions import Annotated
+
+app = FastAPI()
+
+
+def function_dependency(value: str) -> str:
+    return value
+
+
+async def async_function_dependency(value: str) -> str:
+    return value
+
+
+def gen_dependency(value: str) -> Generator[str, None, None]:
+    yield value
+
+
+async def async_gen_dependency(value: str) -> AsyncGenerator[str, None]:
+    yield value
+
+
+class CallableDependency:
+    def __call__(self, value: str) -> str:
+        return value
+
+
+class CallableGenDependency:
+    def __call__(self, value: str) -> Generator[str, None, None]:
+        yield value
+
+
+class AsyncCallableDependency:
+    async def __call__(self, value: str) -> str:
+        return value
+
+
+class AsyncCallableGenDependency:
+    async def __call__(self, value: str) -> AsyncGenerator[str, None]:
+        yield value
+
+
+class MethodsDependency:
+    def synchronous(self, value: str) -> str:
+        return value
+
+    async def asynchronous(self, value: str) -> str:
+        return value
+
+    def synchronous_gen(self, value: str) -> Generator[str, None, None]:
+        yield value
+
+    async def asynchronous_gen(self, value: str) -> AsyncGenerator[str, None]:
+        yield value
+
+
+callable_dependency = CallableDependency()
+callable_gen_dependency = CallableGenDependency()
+async_callable_dependency = AsyncCallableDependency()
+async_callable_gen_dependency = AsyncCallableGenDependency()
+methods_dependency = MethodsDependency()
+
+
+@app.get("/partial-function-dependency")
+async def get_partial_function_dependency(
+    value: Annotated[
+        str, Depends(partial(function_dependency, "partial-function-dependency"))
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-async-function-dependency")
+async def get_partial_async_function_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(async_function_dependency, "partial-async-function-dependency")
+        ),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-gen-dependency")
+async def get_partial_gen_dependency(
+    value: Annotated[str, Depends(partial(gen_dependency, "partial-gen-dependency"))],
+) -> str:
+    return value
+
+
+@app.get("/partial-async-gen-dependency")
+async def get_partial_async_gen_dependency(
+    value: Annotated[
+        str, Depends(partial(async_gen_dependency, "partial-async-gen-dependency"))
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-callable-dependency")
+async def get_partial_callable_dependency(
+    value: Annotated[
+        str, Depends(partial(callable_dependency, "partial-callable-dependency"))
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-callable-gen-dependency")
+async def get_partial_callable_gen_dependency(
+    value: Annotated[
+        str,
+        Depends(partial(callable_gen_dependency, "partial-callable-gen-dependency")),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-async-callable-dependency")
+async def get_partial_async_callable_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(async_callable_dependency, "partial-async-callable-dependency")
+        ),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-async-callable-gen-dependency")
+async def get_partial_async_callable_gen_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(
+                async_callable_gen_dependency, "partial-async-callable-gen-dependency"
+            )
+        ),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-synchronous-method-dependency")
+async def get_partial_synchronous_method_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(
+                methods_dependency.synchronous, "partial-synchronous-method-dependency"
+            )
+        ),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-synchronous-method-gen-dependency")
+async def get_partial_synchronous_method_gen_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(
+                methods_dependency.synchronous_gen,
+                "partial-synchronous-method-gen-dependency",
+            )
+        ),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-asynchronous-method-dependency")
+async def get_partial_asynchronous_method_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(
+                methods_dependency.asynchronous,
+                "partial-asynchronous-method-dependency",
+            )
+        ),
+    ],
+) -> str:
+    return value
+
+
+@app.get("/partial-asynchronous-method-gen-dependency")
+async def get_partial_asynchronous_method_gen_dependency(
+    value: Annotated[
+        str,
+        Depends(
+            partial(
+                methods_dependency.asynchronous_gen,
+                "partial-asynchronous-method-gen-dependency",
+            )
+        ),
+    ],
+) -> str:
+    return value
+
+
+client = TestClient(app)
+
+
+@pytest.mark.parametrize(
+    "route,value",
+    [
+        ("/partial-function-dependency", "partial-function-dependency"),
+        (
+            "/partial-async-function-dependency",
+            "partial-async-function-dependency",
+        ),
+        ("/partial-gen-dependency", "partial-gen-dependency"),
+        ("/partial-async-gen-dependency", "partial-async-gen-dependency"),
+        ("/partial-callable-dependency", "partial-callable-dependency"),
+        ("/partial-callable-gen-dependency", "partial-callable-gen-dependency"),
+        ("/partial-async-callable-dependency", "partial-async-callable-dependency"),
+        (
+            "/partial-async-callable-gen-dependency",
+            "partial-async-callable-gen-dependency",
+        ),
+        (
+            "/partial-synchronous-method-dependency",
+            "partial-synchronous-method-dependency",
+        ),
+        (
+            "/partial-synchronous-method-gen-dependency",
+            "partial-synchronous-method-gen-dependency",
+        ),
+        (
+            "/partial-asynchronous-method-dependency",
+            "partial-asynchronous-method-dependency",
+        ),
+        (
+            "/partial-asynchronous-method-gen-dependency",
+            "partial-asynchronous-method-gen-dependency",
+        ),
+    ],
+)
+def test_dependency_types_with_partial(route: str, value: str) -> None:
+    response = client.get(route)
+    assert response.status_code == 200, response.text
+    assert response.json() == value