]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add support and tests for Pydantic dataclasses in response_model (#454)
authordconathan <dconathan@gmail.com>
Fri, 30 Aug 2019 23:12:15 +0000 (18:12 -0500)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 30 Aug 2019 23:12:15 +0000 (18:12 -0500)
fastapi/utils.py
tests/test_serialize_response.py
tests/test_serialize_response_dataclass.py [new file with mode: 0644]
tests/test_validate_response.py [new file with mode: 0644]
tests/test_validate_response_dataclass.py [new file with mode: 0644]

index de0260615ee009ded042b288441140d887c68da0..17a16b52276872ba6e8d80e74ca65107e44575bd 100644 (file)
@@ -1,4 +1,5 @@
 import re
+from dataclasses import is_dataclass
 from typing import Any, Dict, List, Sequence, Set, Type, cast
 
 from fastapi import routing
@@ -52,6 +53,8 @@ def get_path_param_names(path: str) -> Set[str]:
 
 def create_cloned_field(field: Field) -> Field:
     original_type = field.type_
+    if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
+        original_type = original_type.__pydantic_model__  # type: ignore
     use_type = original_type
     if lenient_issubclass(original_type, BaseModel):
         original_type = cast(Type[BaseModel], original_type)
index c0382b899438a8298fe32651662f6b06fa7e69bb..5fff871f0f0ca0fd4e42b17462f2de0b315f30fe 100644 (file)
@@ -1,8 +1,7 @@
 from typing import List
 
-import pytest
 from fastapi import FastAPI
-from pydantic import BaseModel, ValidationError
+from pydantic import BaseModel
 from starlette.testclient import TestClient
 
 app = FastAPI()
@@ -14,38 +13,45 @@ class Item(BaseModel):
     owner_ids: List[int] = None
 
 
-@app.get("/items/invalid", response_model=Item)
-def get_invalid():
-    return {"name": "invalid", "price": "foo"}
+@app.get("/items/valid", response_model=Item)
+def get_valid():
+    return {"name": "valid", "price": 1.0}
 
 
-@app.get("/items/innerinvalid", response_model=Item)
-def get_innerinvalid():
-    return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
+@app.get("/items/coerce", response_model=Item)
+def get_coerce():
+    return {"name": "coerce", "price": "1.0"}
 
 
-@app.get("/items/invalidlist", response_model=List[Item])
-def get_invalidlist():
+@app.get("/items/validlist", response_model=List[Item])
+def get_validlist():
     return [
         {"name": "foo"},
-        {"name": "bar", "price": "bar"},
-        {"name": "baz", "price": "baz"},
+        {"name": "bar", "price": 1.0},
+        {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
     ]
 
 
 client = TestClient(app)
 
 
-def test_invalid():
-    with pytest.raises(ValidationError):
-        client.get("/items/invalid")
+def test_valid():
+    response = client.get("/items/valid")
+    response.raise_for_status()
+    assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
 
 
-def test_double_invalid():
-    with pytest.raises(ValidationError):
-        client.get("/items/innerinvalid")
+def test_coerce():
+    response = client.get("/items/coerce")
+    response.raise_for_status()
+    assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
 
 
-def test_invalid_list():
-    with pytest.raises(ValidationError):
-        client.get("/items/invalidlist")
+def test_validlist():
+    response = client.get("/items/validlist")
+    response.raise_for_status()
+    assert response.json() == [
+        {"name": "foo", "price": None, "owner_ids": None},
+        {"name": "bar", "price": 1.0, "owner_ids": None},
+        {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    ]
diff --git a/tests/test_serialize_response_dataclass.py b/tests/test_serialize_response_dataclass.py
new file mode 100644 (file)
index 0000000..ee701f9
--- /dev/null
@@ -0,0 +1,58 @@
+from typing import List
+
+from fastapi import FastAPI
+from pydantic.dataclasses import dataclass
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+@dataclass
+class Item:
+    name: str
+    price: float = None
+    owner_ids: List[int] = None
+
+
+@app.get("/items/valid", response_model=Item)
+def get_valid():
+    return {"name": "valid", "price": 1.0}
+
+
+@app.get("/items/coerce", response_model=Item)
+def get_coerce():
+    return {"name": "coerce", "price": "1.0"}
+
+
+@app.get("/items/validlist", response_model=List[Item])
+def get_validlist():
+    return [
+        {"name": "foo"},
+        {"name": "bar", "price": 1.0},
+        {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    ]
+
+
+client = TestClient(app)
+
+
+def test_valid():
+    response = client.get("/items/valid")
+    response.raise_for_status()
+    assert response.json() == {"name": "valid", "price": 1.0, "owner_ids": None}
+
+
+def test_coerce():
+    response = client.get("/items/coerce")
+    response.raise_for_status()
+    assert response.json() == {"name": "coerce", "price": 1.0, "owner_ids": None}
+
+
+def test_validlist():
+    response = client.get("/items/validlist")
+    response.raise_for_status()
+    assert response.json() == [
+        {"name": "foo", "price": None, "owner_ids": None},
+        {"name": "bar", "price": 1.0, "owner_ids": None},
+        {"name": "baz", "price": 2.0, "owner_ids": [1, 2, 3]},
+    ]
diff --git a/tests/test_validate_response.py b/tests/test_validate_response.py
new file mode 100644 (file)
index 0000000..c0382b8
--- /dev/null
@@ -0,0 +1,51 @@
+from typing import List
+
+import pytest
+from fastapi import FastAPI
+from pydantic import BaseModel, ValidationError
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class Item(BaseModel):
+    name: str
+    price: float = None
+    owner_ids: List[int] = None
+
+
+@app.get("/items/invalid", response_model=Item)
+def get_invalid():
+    return {"name": "invalid", "price": "foo"}
+
+
+@app.get("/items/innerinvalid", response_model=Item)
+def get_innerinvalid():
+    return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
+
+
+@app.get("/items/invalidlist", response_model=List[Item])
+def get_invalidlist():
+    return [
+        {"name": "foo"},
+        {"name": "bar", "price": "bar"},
+        {"name": "baz", "price": "baz"},
+    ]
+
+
+client = TestClient(app)
+
+
+def test_invalid():
+    with pytest.raises(ValidationError):
+        client.get("/items/invalid")
+
+
+def test_double_invalid():
+    with pytest.raises(ValidationError):
+        client.get("/items/innerinvalid")
+
+
+def test_invalid_list():
+    with pytest.raises(ValidationError):
+        client.get("/items/invalidlist")
diff --git a/tests/test_validate_response_dataclass.py b/tests/test_validate_response_dataclass.py
new file mode 100644 (file)
index 0000000..4a06641
--- /dev/null
@@ -0,0 +1,53 @@
+from typing import List
+
+import pytest
+from fastapi import FastAPI
+from pydantic import ValidationError
+from pydantic.dataclasses import dataclass
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+@dataclass
+class Item:
+    name: str
+    price: float = None
+    owner_ids: List[int] = None
+
+
+@app.get("/items/invalid", response_model=Item)
+def get_invalid():
+    return {"name": "invalid", "price": "foo"}
+
+
+@app.get("/items/innerinvalid", response_model=Item)
+def get_innerinvalid():
+    return {"name": "double invalid", "price": "foo", "owner_ids": ["foo", "bar"]}
+
+
+@app.get("/items/invalidlist", response_model=List[Item])
+def get_invalidlist():
+    return [
+        {"name": "foo"},
+        {"name": "bar", "price": "bar"},
+        {"name": "baz", "price": "baz"},
+    ]
+
+
+client = TestClient(app)
+
+
+def test_invalid():
+    with pytest.raises(ValidationError):
+        client.get("/items/invalid")
+
+
+def test_double_invalid():
+    with pytest.raises(ValidationError):
+        client.get("/items/innerinvalid")
+
+
+def test_invalid_list():
+    with pytest.raises(ValidationError):
+        client.get("/items/invalidlist")