]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:recycle: Refactor jsonable_encoder and test it
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 22 Dec 2018 13:15:04 +0000 (17:15 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Sat, 22 Dec 2018 13:15:04 +0000 (17:15 +0400)
with nested arbitrary classes

fastapi/encoders.py
tests/test_jsonable_encoder.py [new file with mode: 0644]

index 25ef6dc124573c222cf6ab8a1abc046bafe90b0d..6b6e5d0f7f278a847e56e90e64d50e6769ef6d61 100644 (file)
@@ -3,7 +3,7 @@ from types import GeneratorType
 from typing import Any, Set
 
 from pydantic import BaseModel
-from pydantic.json import pydantic_encoder
+from pydantic.json import ENCODERS_BY_TYPE
 
 
 def jsonable_encoder(
@@ -12,64 +12,11 @@ def jsonable_encoder(
     exclude: Set[str] = set(),
     by_alias: bool = False,
     include_none: bool = True,
-    root_encoder: bool = True,
-) -> Any:
-    errors = []
-    try:
-        return known_data_encoder(
-            obj,
-            include=include,
-            exclude=exclude,
-            by_alias=by_alias,
-            include_none=include_none,
-        )
-    except Exception as e:
-        if not root_encoder:
-            raise e
-        errors.append(e)
-    try:
-        data = dict(obj)
-        return jsonable_encoder(
-            data,
-            include=include,
-            exclude=exclude,
-            by_alias=by_alias,
-            include_none=include_none,
-            root_encoder=False,
-        )
-    except Exception as e:
-        if not root_encoder:
-            raise e
-        errors.append(e)
-    try:
-        data = vars(obj)
-        return jsonable_encoder(
-            data,
-            include=include,
-            exclude=exclude,
-            by_alias=by_alias,
-            include_none=include_none,
-            root_encoder=False,
-        )
-    except Exception as e:
-        if not root_encoder:
-            raise e
-        errors.append(e)
-        raise ValueError(errors)
-
-
-def known_data_encoder(
-    obj: Any,
-    include: Set[str] = None,
-    exclude: Set[str] = set(),
-    by_alias: bool = False,
-    include_none: bool = True,
 ) -> Any:
     if isinstance(obj, BaseModel):
         return jsonable_encoder(
             obj.dict(include=include, exclude=exclude, by_alias=by_alias),
             include_none=include_none,
-            root_encoder=False,
         )
     if isinstance(obj, Enum):
         return obj.value
@@ -78,10 +25,8 @@ def known_data_encoder(
     if isinstance(obj, dict):
         return {
             jsonable_encoder(
-                key, by_alias=by_alias, include_none=include_none, root_encoder=False
-            ): jsonable_encoder(
-                value, by_alias=by_alias, include_none=include_none, root_encoder=False
-            )
+                key, by_alias=by_alias, include_none=include_none
+            ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none)
             for key, value in obj.items()
             if value is not None or include_none
         }
@@ -93,8 +38,22 @@ def known_data_encoder(
                 exclude=exclude,
                 by_alias=by_alias,
                 include_none=include_none,
-                root_encoder=False,
             )
             for item in obj
         ]
-    return pydantic_encoder(obj)
+    errors = []
+    try:
+        encoder = ENCODERS_BY_TYPE[type(obj)]
+        return encoder(obj)
+    except KeyError as e:
+        errors.append(e)
+        try:
+            data = dict(obj)
+        except Exception as e:
+            errors.append(e)
+            try:
+                data = vars(obj)
+            except Exception as e:
+                errors.append(e)
+                raise ValueError(errors)
+    return jsonable_encoder(data, by_alias=by_alias, include_none=include_none)
diff --git a/tests/test_jsonable_encoder.py b/tests/test_jsonable_encoder.py
new file mode 100644 (file)
index 0000000..9108df9
--- /dev/null
@@ -0,0 +1,50 @@
+import pytest
+from fastapi.encoders import jsonable_encoder
+
+
+class Person:
+    def __init__(self, name: str):
+        self.name = name
+
+
+class Pet:
+    def __init__(self, owner: Person, name: str):
+        self.owner = owner
+        self.name = name
+
+
+class DictablePerson(Person):
+    def __iter__(self):
+        return ((k, v) for k, v in self.__dict__.items())
+
+
+class DictablePet(Pet):
+    def __iter__(self):
+        return ((k, v) for k, v in self.__dict__.items())
+
+
+class Unserializable:
+    def __iter__(self):
+        raise NotImplementedError()
+
+    @property
+    def __dict__(self):
+        raise NotImplementedError()
+
+
+def test_encode_class():
+    person = Person(name="Foo")
+    pet = Pet(owner=person, name="Firulais")
+    assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}}
+
+
+def test_encode_dictable():
+    person = DictablePerson(name="Foo")
+    pet = DictablePet(owner=person, name="Firulais")
+    assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}}
+
+
+def test_encode_unsupported():
+    unserializable = Unserializable()
+    with pytest.raises(ValueError):
+        jsonable_encoder(unserializable)