]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Prefer custom encoder over defaults if specified in `jsonable_encoder` (#4467)
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 23 Jan 2022 16:32:04 +0000 (17:32 +0100)
committerGitHub <noreply@github.com>
Sun, 23 Jan 2022 16:32:04 +0000 (17:32 +0100)
Co-authored-by: Vivek Sunder <sviveksunder@gmail.com>
fastapi/encoders.py
tests/test_jsonable_encoder.py

index 3f599c9faa04533ddf29af92a3de9da7caf2ac7f..4b7ffe313fa6bf841c0ca25de577cb0716607a6a 100644 (file)
@@ -34,9 +34,17 @@ def jsonable_encoder(
     exclude_unset: bool = False,
     exclude_defaults: bool = False,
     exclude_none: bool = False,
-    custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
+    custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None,
     sqlalchemy_safe: bool = True,
 ) -> Any:
+    custom_encoder = custom_encoder or {}
+    if custom_encoder:
+        if type(obj) in custom_encoder:
+            return custom_encoder[type(obj)](obj)
+        else:
+            for encoder_type, encoder_instance in custom_encoder.items():
+                if isinstance(obj, encoder_type):
+                    return encoder_instance(obj)
     if include is not None and not isinstance(include, (set, dict)):
         include = set(include)
     if exclude is not None and not isinstance(exclude, (set, dict)):
@@ -118,14 +126,6 @@ def jsonable_encoder(
             )
         return encoded_list
 
-    if custom_encoder:
-        if type(obj) in custom_encoder:
-            return custom_encoder[type(obj)](obj)
-        else:
-            for encoder_type, encoder in custom_encoder.items():
-                if isinstance(obj, encoder_type):
-                    return encoder(obj)
-
     if type(obj) in ENCODERS_BY_TYPE:
         return ENCODERS_BY_TYPE[type(obj)](obj)
     for encoder, classes_tuple in encoders_by_class_tuples.items():
index e2aa8adf8448a6d5f331c027e4afd278d912b7c6..fa82b5ea8358575e27e7909cdacdf1669867a00e 100644 (file)
@@ -161,6 +161,21 @@ def test_custom_encoders():
     assert encoded_instance["dt_field"] == instance.dt_field.isoformat()
 
 
+def test_custom_enum_encoders():
+    def custom_enum_encoder(v: Enum):
+        return v.value.lower()
+
+    class MyEnum(Enum):
+        ENUM_VAL_1 = "ENUM_VAL_1"
+
+    instance = MyEnum.ENUM_VAL_1
+
+    encoded_instance = jsonable_encoder(
+        instance, custom_encoder={MyEnum: custom_enum_encoder}
+    )
+    assert encoded_instance == custom_enum_encoder(instance)
+
+
 def test_encode_model_with_path(model_with_path):
     if isinstance(model_with_path.path, PureWindowsPath):
         expected = "\\foo\\bar"