]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Fix exclude_unset and aliases in response model validation (#1074)
authorjuhovh-aiven <juhovh@aiven.io>
Fri, 27 Mar 2020 15:19:17 +0000 (02:19 +1100)
committerGitHub <noreply@github.com>
Fri, 27 Mar 2020 15:19:17 +0000 (16:19 +0100)
* Fix exclude_unset and aliases in response model validation.

* :sparkles: Use by_alias from param :shrug:

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/routing.py
tests/test_serialize_response_model.py [new file with mode: 0644]

index b36104869648a53520304f9d6cfdc0b3aa5531da..b90935e15ffe76ce3903f4b4bd66e842ab7a2d4c 100644 (file)
@@ -48,6 +48,28 @@ except ImportError:  # pragma: nocover
     from pydantic.fields import Field as ModelField  # type: ignore
 
 
+def _prepare_response_content(
+    res: Any, *, by_alias: bool = True, exclude_unset: bool
+) -> Any:
+    if isinstance(res, BaseModel):
+        if PYDANTIC_1:
+            return res.dict(by_alias=by_alias, exclude_unset=exclude_unset)
+        else:
+            return res.dict(
+                by_alias=by_alias, skip_defaults=exclude_unset
+            )  # pragma: nocover
+    elif isinstance(res, list):
+        return [
+            _prepare_response_content(item, exclude_unset=exclude_unset) for item in res
+        ]
+    elif isinstance(res, dict):
+        return {
+            k: _prepare_response_content(v, exclude_unset=exclude_unset)
+            for k, v in res.items()
+        }
+    return res
+
+
 async def serialize_response(
     *,
     field: ModelField = None,
@@ -60,13 +82,9 @@ async def serialize_response(
 ) -> Any:
     if field:
         errors = []
-        if exclude_unset and isinstance(response_content, BaseModel):
-            if PYDANTIC_1:
-                response_content = response_content.dict(exclude_unset=exclude_unset)
-            else:
-                response_content = response_content.dict(
-                    skip_defaults=exclude_unset
-                )  # pragma: nocover
+        response_content = _prepare_response_content(
+            response_content, by_alias=by_alias, exclude_unset=exclude_unset
+        )
         if is_coroutine:
             value, errors_ = field.validate(response_content, {}, loc=("response",))
         else:
diff --git a/tests/test_serialize_response_model.py b/tests/test_serialize_response_model.py
new file mode 100644 (file)
index 0000000..adb7fda
--- /dev/null
@@ -0,0 +1,154 @@
+from typing import Dict, List
+
+from fastapi import FastAPI
+from pydantic import BaseModel, Field
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class Item(BaseModel):
+    name: str = Field(..., alias="aliased_name")
+    price: float = None
+    owner_ids: List[int] = None
+
+
+@app.get("/items/valid", response_model=Item)
+def get_valid():
+    return Item(aliased_name="valid", price=1.0)
+
+
+@app.get("/items/coerce", response_model=Item)
+def get_coerce():
+    return Item(aliased_name="coerce", price="1.0")
+
+
+@app.get("/items/validlist", response_model=List[Item])
+def get_validlist():
+    return [
+        Item(aliased_name="foo"),
+        Item(aliased_name="bar", price=1.0),
+        Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]),
+    ]
+
+
+@app.get("/items/validdict", response_model=Dict[str, Item])
+def get_validdict():
+    return {
+        "k1": Item(aliased_name="foo"),
+        "k2": Item(aliased_name="bar", price=1.0),
+        "k3": Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]),
+    }
+
+
+@app.get(
+    "/items/valid-exclude-unset", response_model=Item, response_model_exclude_unset=True
+)
+def get_valid_exclude_unset():
+    return Item(aliased_name="valid", price=1.0)
+
+
+@app.get(
+    "/items/coerce-exclude-unset",
+    response_model=Item,
+    response_model_exclude_unset=True,
+)
+def get_coerce_exclude_unset():
+    return Item(aliased_name="coerce", price="1.0")
+
+
+@app.get(
+    "/items/validlist-exclude-unset",
+    response_model=List[Item],
+    response_model_exclude_unset=True,
+)
+def get_validlist_exclude_unset():
+    return [
+        Item(aliased_name="foo"),
+        Item(aliased_name="bar", price=1.0),
+        Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]),
+    ]
+
+
+@app.get(
+    "/items/validdict-exclude-unset",
+    response_model=Dict[str, Item],
+    response_model_exclude_unset=True,
+)
+def get_validdict_exclude_unset():
+    return {
+        "k1": Item(aliased_name="foo"),
+        "k2": Item(aliased_name="bar", price=1.0),
+        "k3": Item(aliased_name="baz", price=2.0, owner_ids=[1, 2, 3]),
+    }
+
+
+client = TestClient(app)
+
+
+def test_valid():
+    response = client.get("/items/valid")
+    response.raise_for_status()
+    assert response.json() == {"aliased_name": "valid", "price": 1.0, "owner_ids": None}
+
+
+def test_coerce():
+    response = client.get("/items/coerce")
+    response.raise_for_status()
+    assert response.json() == {
+        "aliased_name": "coerce",
+        "price": 1.0,
+        "owner_ids": None,
+    }
+
+
+def test_validlist():
+    response = client.get("/items/validlist")
+    response.raise_for_status()
+    assert response.json() == [
+        {"aliased_name": "foo", "price": None, "owner_ids": None},
+        {"aliased_name": "bar", "price": 1.0, "owner_ids": None},
+        {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    ]
+
+
+def test_validdict():
+    response = client.get("/items/validdict")
+    response.raise_for_status()
+    assert response.json() == {
+        "k1": {"aliased_name": "foo", "price": None, "owner_ids": None},
+        "k2": {"aliased_name": "bar", "price": 1.0, "owner_ids": None},
+        "k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    }
+
+
+def test_valid_exclude_unset():
+    response = client.get("/items/valid-exclude-unset")
+    response.raise_for_status()
+    assert response.json() == {"aliased_name": "valid", "price": 1.0}
+
+
+def test_coerce_exclude_unset():
+    response = client.get("/items/coerce-exclude-unset")
+    response.raise_for_status()
+    assert response.json() == {"aliased_name": "coerce", "price": 1.0}
+
+
+def test_validlist_exclude_unset():
+    response = client.get("/items/validlist-exclude-unset")
+    response.raise_for_status()
+    assert response.json() == [
+        {"aliased_name": "foo"},
+        {"aliased_name": "bar", "price": 1.0},
+        {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    ]
+
+
+def test_validdict_exclude_unset():
+    response = client.get("/items/validdict-exclude-unset")
+    response.raise_for_status()
+    assert response.json() == {
+        "k1": {"aliased_name": "foo"},
+        "k2": {"aliased_name": "bar", "price": 1.0},
+        "k3": {"aliased_name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    }