)
-def create_cloned_field(field: ModelField) -> ModelField:
+def create_cloned_field(
+ field: ModelField, *, cloned_types: Dict[Type[BaseModel], Type[BaseModel]] = None,
+) -> ModelField:
+ # _cloned_types has already cloned types, to support recursive models
+ if cloned_types is None:
+ cloned_types = dict()
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__ # type: ignore
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
- use_type = create_model(original_type.__name__, __base__=original_type)
- for f in original_type.__fields__.values():
- use_type.__fields__[f.name] = create_cloned_field(f)
-
+ use_type = cloned_types.get(original_type)
+ if use_type is None:
+ use_type = create_model(original_type.__name__, __base__=original_type)
+ cloned_types[original_type] = use_type
+ for f in original_type.__fields__.values():
+ use_type.__fields__[f.name] = create_cloned_field(
+ f, cloned_types=cloned_types
+ )
new_field = create_response_field(name=field.name, type_=use_type)
new_field.has_alias = field.has_alias
new_field.alias = field.alias
new_field.validate_always = field.validate_always
if field.sub_fields:
new_field.sub_fields = [
- create_cloned_field(sub_field) for sub_field in field.sub_fields
+ create_cloned_field(sub_field, cloned_types=cloned_types)
+ for sub_field in field.sub_fields
]
if field.key_field:
- new_field.key_field = create_cloned_field(field.key_field)
+ new_field.key_field = create_cloned_field(
+ field.key_field, cloned_types=cloned_types
+ )
new_field.validators = field.validators
if PYDANTIC_1:
new_field.pre_validators = field.pre_validators
--- /dev/null
+from typing import List
+
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class RecursiveItem(BaseModel):
+ sub_items: List["RecursiveItem"] = []
+ name: str
+
+
+RecursiveItem.update_forward_refs()
+
+
+class RecursiveSubitemInSubmodel(BaseModel):
+ sub_items2: List["RecursiveItemViaSubmodel"] = []
+ name: str
+
+
+class RecursiveItemViaSubmodel(BaseModel):
+ sub_items1: List[RecursiveSubitemInSubmodel] = []
+ name: str
+
+
+RecursiveSubitemInSubmodel.update_forward_refs()
+
+
+@app.get("/items/recursive", response_model=RecursiveItem)
+def get_recursive():
+ return {"name": "item", "sub_items": [{"name": "subitem", "sub_items": []}]}
+
+
+@app.get("/items/recursive-submodel", response_model=RecursiveItemViaSubmodel)
+def get_recursive_submodel():
+ return {
+ "name": "item",
+ "sub_items1": [
+ {
+ "name": "subitem",
+ "sub_items2": [
+ {
+ "name": "subsubitem",
+ "sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
+ }
+ ],
+ }
+ ],
+ }
+
+
+client = TestClient(app)
+
+
+def test_recursive():
+ response = client.get("/items/recursive")
+ assert response.status_code == 200
+ assert response.json() == {
+ "sub_items": [{"name": "subitem", "sub_items": []}],
+ "name": "item",
+ }
+
+ response = client.get("/items/recursive-submodel")
+ assert response.status_code == 200
+ assert response.json() == {
+ "name": "item",
+ "sub_items1": [
+ {
+ "name": "subitem",
+ "sub_items2": [
+ {
+ "name": "subsubitem",
+ "sub_items1": [{"name": "subsubsubitem", "sub_items2": []}],
+ }
+ ],
+ }
+ ],
+ }