]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Check already cloned fields in create_cloned_field to support recursive models...
authorvoegtlel <5764745+voegtlel@users.noreply.github.com>
Sun, 29 Mar 2020 17:26:29 +0000 (19:26 +0200)
committerGitHub <noreply@github.com>
Sun, 29 Mar 2020 17:26:29 +0000 (19:26 +0200)
* FIX: #894
Include recursion check for create_cloned_field.
Added test for recursive model.

* :recycle: Refactor and format create_cloned_field()

Co-authored-by: Lukas Voegtle <lukas.voegtle@sick.de>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/utils.py
tests/test_validate_response_recursive.py [new file with mode: 0644]

index f24f280739d19be2a872f04b4a6e32093e511d84..154dd9aa180e3e9e0a1dbf247fb10d1a8d559581 100644 (file)
@@ -131,17 +131,26 @@ def create_response_field(
         )
 
 
-def create_cloned_field(field: ModelField) -> ModelField:
+def create_cloned_field(
+    field: ModelField, *, cloned_types: Dict[Type[BaseModel], Type[BaseModel]] = None,
+) -> ModelField:
+    # _cloned_types has already cloned types, to support recursive models
+    if cloned_types is None:
+        cloned_types = dict()
     original_type = field.type_
     if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
         original_type = original_type.__pydantic_model__  # type: ignore
     use_type = original_type
     if lenient_issubclass(original_type, BaseModel):
         original_type = cast(Type[BaseModel], original_type)
-        use_type = create_model(original_type.__name__, __base__=original_type)
-        for f in original_type.__fields__.values():
-            use_type.__fields__[f.name] = create_cloned_field(f)
-
+        use_type = cloned_types.get(original_type)
+        if use_type is None:
+            use_type = create_model(original_type.__name__, __base__=original_type)
+            cloned_types[original_type] = use_type
+            for f in original_type.__fields__.values():
+                use_type.__fields__[f.name] = create_cloned_field(
+                    f, cloned_types=cloned_types
+                )
     new_field = create_response_field(name=field.name, type_=use_type)
     new_field.has_alias = field.has_alias
     new_field.alias = field.alias
@@ -157,10 +166,13 @@ def create_cloned_field(field: ModelField) -> ModelField:
     new_field.validate_always = field.validate_always
     if field.sub_fields:
         new_field.sub_fields = [
-            create_cloned_field(sub_field) for sub_field in field.sub_fields
+            create_cloned_field(sub_field, cloned_types=cloned_types)
+            for sub_field in field.sub_fields
         ]
     if field.key_field:
-        new_field.key_field = create_cloned_field(field.key_field)
+        new_field.key_field = create_cloned_field(
+            field.key_field, cloned_types=cloned_types
+        )
     new_field.validators = field.validators
     if PYDANTIC_1:
         new_field.pre_validators = field.pre_validators
diff --git a/tests/test_validate_response_recursive.py b/tests/test_validate_response_recursive.py
new file mode 100644 (file)
index 0000000..8b77ed1
--- /dev/null
@@ -0,0 +1,80 @@
+from typing import List
+
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class RecursiveItem(BaseModel):
+    sub_items: List["RecursiveItem"] = []
+    name: str
+
+
+RecursiveItem.update_forward_refs()
+
+
+class RecursiveSubitemInSubmodel(BaseModel):
+    sub_items2: List["RecursiveItemViaSubmodel"] = []
+    name: str
+
+
+class RecursiveItemViaSubmodel(BaseModel):
+    sub_items1: List[RecursiveSubitemInSubmodel] = []
+    name: str
+
+
+RecursiveSubitemInSubmodel.update_forward_refs()
+
+
+@app.get("/items/recursive", response_model=RecursiveItem)
+def get_recursive():
+    return {"name": "item", "sub_items": [{"name": "subitem", "sub_items": []}]}
+
+
+@app.get("/items/recursive-submodel", response_model=RecursiveItemViaSubmodel)
+def get_recursive_submodel():
+    return {
+        "name": "item",
+        "sub_items1": [
+            {
+                "name": "subitem",
+                "sub_items2": [
+                    {
+                        "name": "subsubitem",
+                        "sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
+                    }
+                ],
+            }
+        ],
+    }
+
+
+client = TestClient(app)
+
+
+def test_recursive():
+    response = client.get("/items/recursive")
+    assert response.status_code == 200
+    assert response.json() == {
+        "sub_items": [{"name": "subitem", "sub_items": []}],
+        "name": "item",
+    }
+
+    response = client.get("/items/recursive-submodel")
+    assert response.status_code == 200
+    assert response.json() == {
+        "name": "item",
+        "sub_items1": [
+            {
+                "name": "subitem",
+                "sub_items2": [
+                    {
+                        "name": "subsubitem",
+                        "sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
+                    }
+                ],
+            }
+        ],
+    }