]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Fix Pydantic field clone logic with validators (#899)
authorAndy Smith <apsmith2367@gmail.com>
Tue, 4 Feb 2020 03:03:51 +0000 (22:03 -0500)
committerGitHub <noreply@github.com>
Tue, 4 Feb 2020 03:03:51 +0000 (04:03 +0100)
fastapi/utils.py
tests/test_filter_pydantic_sub_model.py

index 6a0c1bfd7e19b2ec30bb5574b22f226cc9553888..e7d3891f4bb70e93ed233164dada1e9223d7765d 100644 (file)
@@ -93,12 +93,9 @@ def create_cloned_field(field: ModelField) -> ModelField:
     use_type = original_type
     if lenient_issubclass(original_type, BaseModel):
         original_type = cast(Type[BaseModel], original_type)
-        use_type = create_model(
-            original_type.__name__, __config__=original_type.__config__
-        )
+        use_type = create_model(original_type.__name__, __base__=original_type)
         for f in original_type.__fields__.values():
             use_type.__fields__[f.name] = create_cloned_field(f)
-        use_type.__validators__ = original_type.__validators__
     if PYDANTIC_1:
         new_field = ModelField(
             name=field.name,
index aef6350401beb123957fed3b11a2f45fcca1e7e2..1f7d1deed24ea879ee4ebad6c90b2c7317e6021a 100644 (file)
@@ -1,5 +1,6 @@
+import pytest
 from fastapi import Depends, FastAPI
-from pydantic import BaseModel
+from pydantic import BaseModel, ValidationError, validator
 from starlette.testclient import TestClient
 
 app = FastAPI()
@@ -18,14 +19,20 @@ class ModelA(BaseModel):
     description: str = None
     model_b: ModelB
 
+    @validator("name")
+    def lower_username(cls, name: str, values):
+        if not name.endswith("A"):
+            raise ValueError("name must end in A")
+        return name
+
 
 async def get_model_c() -> ModelC:
     return ModelC(username="test-user", password="test-password")
 
 
-@app.get("/model", response_model=ModelA)
-async def get_model_a(model_c=Depends(get_model_c)):
-    return {"name": "model-a-name", "description": "model-a-desc", "model_b": model_c}
+@app.get("/model/{name}", response_model=ModelA)
+async def get_model_a(name: str, model_c=Depends(get_model_c)):
+    return {"name": name, "description": "model-a-desc", "model_b": model_c}
 
 
 client = TestClient(app)
@@ -35,10 +42,18 @@ openapi_schema = {
     "openapi": "3.0.2",
     "info": {"title": "FastAPI", "version": "0.1.0"},
     "paths": {
-        "/model": {
+        "/model/{name}": {
             "get": {
                 "summary": "Get Model A",
-                "operationId": "get_model_a_model_get",
+                "operationId": "get_model_a_model__name__get",
+                "parameters": [
+                    {
+                        "required": True,
+                        "schema": {"title": "Name", "type": "string"},
+                        "name": "name",
+                        "in": "path",
+                    }
+                ],
                 "responses": {
                     "200": {
                         "description": "Successful Response",
@@ -47,13 +62,34 @@ openapi_schema = {
                                 "schema": {"$ref": "#/components/schemas/ModelA"}
                             }
                         },
-                    }
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
                 },
             }
         }
     },
     "components": {
         "schemas": {
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
             "ModelA": {
                 "title": "ModelA",
                 "required": ["name", "model_b"],
@@ -70,6 +106,20 @@ openapi_schema = {
                 "type": "object",
                 "properties": {"username": {"title": "Username", "type": "string"}},
             },
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
         }
     },
 }
@@ -82,10 +132,22 @@ def test_openapi_schema():
 
 
 def test_filter_sub_model():
-    response = client.get("/model")
+    response = client.get("/model/modelA")
     assert response.status_code == 200
     assert response.json() == {
-        "name": "model-a-name",
+        "name": "modelA",
         "description": "model-a-desc",
         "model_b": {"username": "test-user"},
     }
+
+
+def test_validator_is_cloned():
+    with pytest.raises(ValidationError) as err:
+        client.get("/model/modelX")
+    assert err.value.errors() == [
+        {
+            "loc": ("response", "name"),
+            "msg": "name must end in A",
+            "type": "value_error",
+        }
+    ]