]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix include/exclude for dicts in `jsonable_encoder` (#2016)
authorRubikoid <dimazotoff8@gmail.com>
Sun, 4 Jul 2021 18:53:40 +0000 (21:53 +0300)
committerGitHub <noreply@github.com>
Sun, 4 Jul 2021 18:53:40 +0000 (20:53 +0200)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/encoders.py
tests/test_response_model_include_exclude.py [new file with mode: 0644]

index 6a2a75dda629013e2f13ccd0a8b38e1dced77ea7..51cab419d4944b74f8b1ad602b799dcc19720201 100644 (file)
@@ -36,9 +36,9 @@ def jsonable_encoder(
     custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
     sqlalchemy_safe: bool = True,
 ) -> Any:
-    if include is not None and not isinstance(include, set):
+    if include is not None and not isinstance(include, (set, dict)):
         include = set(include)
-    if exclude is not None and not isinstance(exclude, set):
+    if exclude is not None and not isinstance(exclude, (set, dict)):
         exclude = set(exclude)
     if isinstance(obj, BaseModel):
         encoder = getattr(obj.__config__, "json_encoders", {})
diff --git a/tests/test_response_model_include_exclude.py b/tests/test_response_model_include_exclude.py
new file mode 100644 (file)
index 0000000..533f810
--- /dev/null
@@ -0,0 +1,174 @@
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+
+
+class Test(BaseModel):
+    foo: str
+    bar: str
+
+
+class Test2(BaseModel):
+    test: Test
+    baz: str
+
+
+class Test3(BaseModel):
+    name: str
+    age: int
+    test2: Test2
+
+
+app = FastAPI()
+
+
+@app.get(
+    "/simple_include",
+    response_model=Test2,
+    response_model_include={"baz": ..., "test": {"foo"}},
+)
+def simple_include():
+    return Test2(
+        test=Test(foo="simple_include test foo", bar="simple_include test bar"),
+        baz="simple_include test2 baz",
+    )
+
+
+@app.get(
+    "/simple_include_dict",
+    response_model=Test2,
+    response_model_include={"baz": ..., "test": {"foo"}},
+)
+def simple_include_dict():
+    return {
+        "test": {
+            "foo": "simple_include_dict test foo",
+            "bar": "simple_include_dict test bar",
+        },
+        "baz": "simple_include_dict test2 baz",
+    }
+
+
+@app.get(
+    "/simple_exclude",
+    response_model=Test2,
+    response_model_exclude={"test": {"bar"}},
+)
+def simple_exclude():
+    return Test2(
+        test=Test(foo="simple_exclude test foo", bar="simple_exclude test bar"),
+        baz="simple_exclude test2 baz",
+    )
+
+
+@app.get(
+    "/simple_exclude_dict",
+    response_model=Test2,
+    response_model_exclude={"test": {"bar"}},
+)
+def simple_exclude_dict():
+    return {
+        "test": {
+            "foo": "simple_exclude_dict test foo",
+            "bar": "simple_exclude_dict test bar",
+        },
+        "baz": "simple_exclude_dict test2 baz",
+    }
+
+
+@app.get(
+    "/mixed",
+    response_model=Test3,
+    response_model_include={"test2", "name"},
+    response_model_exclude={"test2": {"baz"}},
+)
+def mixed():
+    return Test3(
+        name="mixed test3 name",
+        age=3,
+        test2=Test2(
+            test=Test(foo="mixed test foo", bar="mixed test bar"), baz="mixed test2 baz"
+        ),
+    )
+
+
+@app.get(
+    "/mixed_dict",
+    response_model=Test3,
+    response_model_include={"test2", "name"},
+    response_model_exclude={"test2": {"baz"}},
+)
+def mixed_dict():
+    return {
+        "name": "mixed_dict test3 name",
+        "age": 3,
+        "test2": {
+            "test": {"foo": "mixed_dict test foo", "bar": "mixed_dict test bar"},
+            "baz": "mixed_dict test2 baz",
+        },
+    }
+
+
+client = TestClient(app)
+
+
+def test_nested_include_simple():
+    response = client.get("/simple_include")
+
+    assert response.status_code == 200, response.text
+
+    assert response.json() == {
+        "baz": "simple_include test2 baz",
+        "test": {"foo": "simple_include test foo"},
+    }
+
+
+def test_nested_include_simple_dict():
+    response = client.get("/simple_include_dict")
+
+    assert response.status_code == 200, response.text
+
+    assert response.json() == {
+        "baz": "simple_include_dict test2 baz",
+        "test": {"foo": "simple_include_dict test foo"},
+    }
+
+
+def test_nested_exclude_simple():
+    response = client.get("/simple_exclude")
+    assert response.status_code == 200, response.text
+    assert response.json() == {
+        "baz": "simple_exclude test2 baz",
+        "test": {"foo": "simple_exclude test foo"},
+    }
+
+
+def test_nested_exclude_simple_dict():
+    response = client.get("/simple_exclude_dict")
+    assert response.status_code == 200, response.text
+    assert response.json() == {
+        "baz": "simple_exclude_dict test2 baz",
+        "test": {"foo": "simple_exclude_dict test foo"},
+    }
+
+
+def test_nested_include_mixed():
+    response = client.get("/mixed")
+    assert response.status_code == 200, response.text
+    assert response.json() == {
+        "name": "mixed test3 name",
+        "test2": {
+            "test": {"foo": "mixed test foo", "bar": "mixed test bar"},
+        },
+    }
+
+
+def test_nested_include_mixed_dict():
+    response = client.get("/mixed_dict")
+    assert response.status_code == 200, response.text
+    assert response.json() == {
+        "name": "mixed_dict test3 name",
+        "test2": {
+            "test": {"foo": "mixed_dict test foo", "bar": "mixed_dict test bar"},
+        },
+    }