]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Implement dependency value cache per request (#292)
authorSebastián Ramírez <tiangolo@gmail.com>
Wed, 5 Jun 2019 17:00:54 +0000 (21:00 +0400)
committerGitHub <noreply@github.com>
Wed, 5 Jun 2019 17:00:54 +0000 (21:00 +0400)
* :sparkles: Add dependency cache, with support for disabling it

* :white_check_mark: Add tests for dependency cache

* :memo: Add docs about dependency value caching

docs/tutorial/dependencies/first-steps.md
docs/tutorial/dependencies/sub-dependencies.md
fastapi/dependencies/models.py
fastapi/dependencies/utils.py
fastapi/param_functions.py
fastapi/params.py
fastapi/routing.py
tests/test_dependency_cache.py [new file with mode: 0644]

index 7a19618a337988b795087156e93b0540da8caf82..601fa6245b11ab12d45e77e626fd7c579bb7b650 100644 (file)
@@ -17,14 +17,12 @@ This is very useful when you need to:
 
 All these, while minimizing code repetition.
 
-
 ## First Steps
 
 Let's see a very simple example. It will be so simple that it is not very useful, for now.
 
 But this way we can focus on how the **Dependency Injection** system works.
 
-
 ### Create a dependency, or "dependable"
 
 Let's first focus on the dependency.
@@ -151,7 +149,6 @@ The simplicity of the dependency injection system makes **FastAPI** compatible w
 * response data injection systems
 * etc.
 
-
 ## Simple and Powerful
 
 Although the hierarchical dependency injection system is very simple to define and use, it's still very powerful.
index 7f96674f3ab886f7e64b80b7678898249383ef72..e55dd14e4f41aa8ec1a46feb0fb87682aa6d18e3 100644 (file)
@@ -11,6 +11,7 @@ You could create a first dependency ("dependable") like:
 ```Python hl_lines="6 7"
 {!./src/dependencies/tutorial005.py!}
 ```
+
 It declares an optional query parameter `q` as a `str`, and then it just returns it.
 
 This is quite simple (not very useful), but will help us focus on how the sub-dependencies work.
@@ -43,6 +44,18 @@ Then we can use the dependency with:
 
     But **FastAPI** will know that it has to solve `query_extractor` first, to pass the results of that to `query_or_cookie_extractor` while calling it.
 
+## Using the same dependency multiple times
+
+If one of your dependencies is declared multiple times for the same *path operation*, for example, multiple dependencies have a common sub-dependency, **FastAPI** will know to call that sub-dependency only once per request.
+
+And it will save the returned value in a <abbr title="A utility/system to store computed/generated values, to re-use them instead of computing them again.">"cache"</abbr> and pass it to all the "dependants" that need it in that specific request, instead of calling the dependency multiple times for the same request.
+
+In an advanced scenario where you know you need the dependency to be called at every step (possibly multiple times) in the same request instead of using the "cached" value, you can set the parameter `use_cache=False` when using `Depends`:
+
+```Python hl_lines="1"
+async def needy_dependency(fresh_value: str = Depends(get_value, use_cache=False)):
+    return {"fresh_value": fresh_value}
+```
 
 ## Recap
 
@@ -54,7 +67,7 @@ But still, it is very powerful, and allows you to declare arbitrarily deeply nes
 
 !!! tip
     All this might not seem as useful with these simple examples.
-    
+
     But you will see how useful it is in the chapters about **security**.
 
     And you will also see the amounts of code it will save you.
index 33644d7641b2c4f13176199bd1027eb4928750c1..29fdd0e22af2c08d885c0aa1fc37e06bfd92356f 100644 (file)
@@ -30,6 +30,7 @@ class Dependant:
         background_tasks_param_name: str = None,
         security_scopes_param_name: str = None,
         security_scopes: List[str] = None,
+        use_cache: bool = True,
         path: str = None,
     ) -> None:
         self.path_params = path_params or []
@@ -46,5 +47,8 @@ class Dependant:
         self.security_scopes_param_name = security_scopes_param_name
         self.name = name
         self.call = call
+        self.use_cache = use_cache
         # Store the path to be able to re-generate a dependable from it in overrides
         self.path = path
+        # Save the cache key at creation to optimize performance
+        self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))
index 2a64172efd03aaf65f642b92c4d9de97a05ceffa..e79a9a6a0ef1cc1cf4ebde0d851d981847972650 100644 (file)
@@ -95,7 +95,11 @@ def get_sub_dependant(
             security_scheme=dependency, scopes=use_scopes
         )
     sub_dependant = get_dependant(
-        path=path, call=dependency, name=name, security_scopes=security_scopes
+        path=path,
+        call=dependency,
+        name=name,
+        security_scopes=security_scopes,
+        use_cache=depends.use_cache,
     )
     if security_requirement:
         sub_dependant.security_requirements.append(security_requirement)
@@ -111,6 +115,7 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
         cookie_params=dependant.cookie_params.copy(),
         body_params=dependant.body_params.copy(),
         security_schemes=dependant.security_requirements.copy(),
+        use_cache=dependant.use_cache,
         path=dependant.path,
     )
     for sub_dependant in dependant.dependencies:
@@ -148,12 +153,17 @@ def is_scalar_sequence_field(field: Field) -> bool:
 
 
 def get_dependant(
-    *, path: str, call: Callable, name: str = None, security_scopes: List[str] = None
+    *,
+    path: str,
+    call: Callable,
+    name: str = None,
+    security_scopes: List[str] = None,
+    use_cache: bool = True,
 ) -> Dependant:
     path_param_names = get_path_param_names(path)
     endpoint_signature = inspect.signature(call)
     signature_params = endpoint_signature.parameters
-    dependant = Dependant(call=call, name=name, path=path)
+    dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
     for param_name, param in signature_params.items():
         if isinstance(param.default, params.Depends):
             sub_dependant = get_param_sub_dependant(
@@ -286,18 +296,29 @@ async def solve_dependencies(
     body: Dict[str, Any] = None,
     background_tasks: BackgroundTasks = None,
     dependency_overrides_provider: Any = None,
-) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks]]:
+    dependency_cache: Dict[Tuple[Callable, Tuple[str]], Any] = None,
+) -> Tuple[
+    Dict[str, Any],
+    List[ErrorWrapper],
+    Optional[BackgroundTasks],
+    Dict[Tuple[Callable, Tuple[str]], Any],
+]:
     values: Dict[str, Any] = {}
     errors: List[ErrorWrapper] = []
+    dependency_cache = dependency_cache or {}
     sub_dependant: Dependant
     for sub_dependant in dependant.dependencies:
-        call: Callable = sub_dependant.call  # type: ignore
+        sub_dependant.call = cast(Callable, sub_dependant.call)
+        sub_dependant.cache_key = cast(
+            Tuple[Callable, Tuple[str]], sub_dependant.cache_key
+        )
+        call = sub_dependant.call
         use_sub_dependant = sub_dependant
         if (
             dependency_overrides_provider
             and dependency_overrides_provider.dependency_overrides
         ):
-            original_call: Callable = sub_dependant.call  # type: ignore
+            original_call = sub_dependant.call
             call = getattr(
                 dependency_overrides_provider, "dependency_overrides", {}
             ).get(original_call, original_call)
@@ -309,22 +330,28 @@ async def solve_dependencies(
                 security_scopes=sub_dependant.security_scopes,
             )
 
-        sub_values, sub_errors, background_tasks = await solve_dependencies(
+        sub_values, sub_errors, background_tasks, sub_dependency_cache = await solve_dependencies(
             request=request,
             dependant=use_sub_dependant,
             body=body,
             background_tasks=background_tasks,
             dependency_overrides_provider=dependency_overrides_provider,
+            dependency_cache=dependency_cache,
         )
+        dependency_cache.update(sub_dependency_cache)
         if sub_errors:
             errors.extend(sub_errors)
             continue
-        if is_coroutine_callable(call):
+        if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
+            solved = dependency_cache[sub_dependant.cache_key]
+        elif is_coroutine_callable(call):
             solved = await call(**sub_values)
         else:
             solved = await run_in_threadpool(call, **sub_values)
-        if use_sub_dependant.name is not None:
-            values[use_sub_dependant.name] = solved
+        if sub_dependant.name is not None:
+            values[sub_dependant.name] = solved
+        if sub_dependant.cache_key not in dependency_cache:
+            dependency_cache[sub_dependant.cache_key] = solved
     path_values, path_errors = request_params_to_args(
         dependant.path_params, request.path_params
     )
@@ -360,7 +387,7 @@ async def solve_dependencies(
         values[dependant.security_scopes_param_name] = SecurityScopes(
             scopes=dependant.security_scopes
         )
-    return values, errors, background_tasks
+    return values, errors, background_tasks, dependency_cache
 
 
 def request_params_to_args(
index 92c83ba9a88229d2520bee835e976b9e5b2f584f..abd95609c129dd8ab4317f70784f241bc215985f 100644 (file)
@@ -238,11 +238,13 @@ def File(  # noqa: N802
     )
 
 
-def Depends(dependency: Callable = None) -> Any:  # noqa: N802
-    return params.Depends(dependency=dependency)
+def Depends(  # noqa: N802
+    dependency: Callable = None, *, use_cache: bool = True
+) -> Any:
+    return params.Depends(dependency=dependency, use_cache=use_cache)
 
 
 def Security(  # noqa: N802
-    dependency: Callable = None, scopes: Sequence[str] = None
+    dependency: Callable = None, *, scopes: Sequence[str] = None, use_cache: bool = True
 ) -> Any:
-    return params.Security(dependency=dependency, scopes=scopes)
+    return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)
index 3d9afec786c57b92990f11fffd890e0374565c0d..0541a3695839c4e49a6fe08a1c0b045deff796e7 100644 (file)
@@ -308,11 +308,18 @@ class File(Form):
 
 
 class Depends:
-    def __init__(self, dependency: Callable = None):
+    def __init__(self, dependency: Callable = None, *, use_cache: bool = True):
         self.dependency = dependency
+        self.use_cache = use_cache
 
 
 class Security(Depends):
-    def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
+    def __init__(
+        self,
+        dependency: Callable = None,
+        *,
+        scopes: Sequence[str] = None,
+        use_cache: bool = True,
+    ):
+        super().__init__(dependency=dependency, use_cache=use_cache)
         self.scopes = scopes or []
-        super().__init__(dependency=dependency)
index 8526d8c0455e9bc449c67d4c5a1c8d059ba3c1e0..4ae8bb586d50dc0ca77860940561816838c9ee44 100644 (file)
@@ -102,7 +102,7 @@ def get_app(
             raise HTTPException(
                 status_code=400, detail="There was an error parsing the body"
             ) from e
-        values, errors, background_tasks = await solve_dependencies(
+        values, errors, background_tasks, _ = await solve_dependencies(
             request=request,
             dependant=dependant,
             body=body,
@@ -141,7 +141,7 @@ def get_websocket_app(
     dependant: Dependant, dependency_overrides_provider: Any = None
 ) -> Callable:
     async def app(websocket: WebSocket) -> None:
-        values, errors, _ = await solve_dependencies(
+        values, errors, _, _2 = await solve_dependencies(
             request=websocket,
             dependant=dependant,
             dependency_overrides_provider=dependency_overrides_provider,
diff --git a/tests/test_dependency_cache.py b/tests/test_dependency_cache.py
new file mode 100644 (file)
index 0000000..e9d027b
--- /dev/null
@@ -0,0 +1,68 @@
+from fastapi import Depends, FastAPI
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+counter_holder = {"counter": 0}
+
+
+async def dep_counter():
+    counter_holder["counter"] += 1
+    return counter_holder["counter"]
+
+
+async def super_dep(count: int = Depends(dep_counter)):
+    return count
+
+
+@app.get("/counter/")
+async def get_counter(count: int = Depends(dep_counter)):
+    return {"counter": count}
+
+
+@app.get("/sub-counter/")
+async def get_sub_counter(
+    subcount: int = Depends(super_dep), count: int = Depends(dep_counter)
+):
+    return {"counter": count, "subcounter": subcount}
+
+
+@app.get("/sub-counter-no-cache/")
+async def get_sub_counter_no_cache(
+    subcount: int = Depends(super_dep),
+    count: int = Depends(dep_counter, use_cache=False),
+):
+    return {"counter": count, "subcounter": subcount}
+
+
+client = TestClient(app)
+
+
+def test_normal_counter():
+    counter_holder["counter"] = 0
+    response = client.get("/counter/")
+    assert response.status_code == 200
+    assert response.json() == {"counter": 1}
+    response = client.get("/counter/")
+    assert response.status_code == 200
+    assert response.json() == {"counter": 2}
+
+
+def test_sub_counter():
+    counter_holder["counter"] = 0
+    response = client.get("/sub-counter/")
+    assert response.status_code == 200
+    assert response.json() == {"counter": 1, "subcounter": 1}
+    response = client.get("/sub-counter/")
+    assert response.status_code == 200
+    assert response.json() == {"counter": 2, "subcounter": 2}
+
+
+def test_sub_counter_no_cache():
+    counter_holder["counter"] = 0
+    response = client.get("/sub-counter-no-cache/")
+    assert response.status_code == 200
+    assert response.json() == {"counter": 2, "subcounter": 1}
+    response = client.get("/sub-counter-no-cache/")
+    assert response.status_code == 200
+    assert response.json() == {"counter": 4, "subcounter": 3}