]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
⚡️ Improve performance in request body parsing with a cache for internal model fields...
authorSebastián Ramírez <tiangolo@gmail.com>
Wed, 11 Sep 2024 07:45:30 +0000 (09:45 +0200)
committerGitHub <noreply@github.com>
Wed, 11 Sep 2024 07:45:30 +0000 (09:45 +0200)
fastapi/_compat.py
fastapi/dependencies/utils.py
tests/test_compat.py

index f940d6597322325270a9a32944e312fd86a4c8f2..4b07b44fa582936607b9260ca19f99c6a350cadc 100644 (file)
@@ -2,6 +2,7 @@ from collections import deque
 from copy import copy
 from dataclasses import dataclass, is_dataclass
 from enum import Enum
+from functools import lru_cache
 from typing import (
     Any,
     Callable,
@@ -649,3 +650,8 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
         is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
         for sub_annotation in get_args(annotation)
     )
+
+
+@lru_cache
+def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
+    return get_model_fields(model)
index 6083b73195fc5d80d15234d769047f0283c30808..f18eace9d430e397a3dac5125ba8ccbea4cdec39 100644 (file)
@@ -32,8 +32,8 @@ from fastapi._compat import (
     evaluate_forwardref,
     field_annotation_is_scalar,
     get_annotation_from_field_info,
+    get_cached_model_fields,
     get_missing_field_error,
-    get_model_fields,
     is_bytes_field,
     is_bytes_sequence_field,
     is_scalar_field,
@@ -810,7 +810,7 @@ async def request_body_to_args(
     fields_to_extract: List[ModelField] = body_fields
 
     if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
-        fields_to_extract = get_model_fields(first_field.type_)
+        fields_to_extract = get_cached_model_fields(first_field.type_)
 
     if isinstance(received_body, FormData):
         body_to_process = await _extract_form_body(fields_to_extract, received_body)
index 270475bf3a42a74f4de7fce6bfd1d9a4d554c9bb..f4a3093c5ee4f7a5ce25d299a318cf4b67c7a297 100644 (file)
@@ -5,6 +5,7 @@ from fastapi._compat import (
     ModelField,
     Undefined,
     _get_model_config,
+    get_cached_model_fields,
     get_model_fields,
     is_bytes_sequence_annotation,
     is_scalar_field,
@@ -102,3 +103,18 @@ def test_is_pv1_scalar_field():
 
     fields = get_model_fields(Model)
     assert not is_scalar_field(fields[0])
+
+
+def test_get_model_fields_cached():
+    class Model(BaseModel):
+        foo: str
+
+    non_cached_fields = get_model_fields(Model)
+    non_cached_fields2 = get_model_fields(Model)
+    cached_fields = get_cached_model_fields(Model)
+    cached_fields2 = get_cached_model_fields(Model)
+    for f1, f2 in zip(cached_fields, cached_fields2):
+        assert f1 is f2
+
+    assert non_cached_fields is not non_cached_fields2
+    assert cached_fields is cached_fields2