]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add support for subtypes of main types in jsonable_encoder
authorRoald Storm <RmStorm@users.noreply.github.com>
Wed, 8 Jan 2020 20:50:21 +0000 (21:50 +0100)
committerSebastián Ramírez <tiangolo@gmail.com>
Wed, 8 Jan 2020 20:50:21 +0000 (21:50 +0100)
fastapi/encoders.py
tests/test_inherited_custom_class.py [new file with mode: 0644]

index d75664debebe61907155abc178156ea165d591f5..cb94a91a73d8ed8300209824fac88d83b1520433 100644 (file)
@@ -1,6 +1,6 @@
 from enum import Enum
 from types import GeneratorType
-from typing import Any, Dict, List, Set, Union
+from typing import Any, Callable, Dict, List, Set, Tuple, Union
 
 from fastapi.logger import logger
 from fastapi.utils import PYDANTIC_1
@@ -11,6 +11,21 @@ SetIntStr = Set[Union[int, str]]
 DictIntStrAny = Dict[Union[int, str], Any]
 
 
+def generate_encoders_by_class_tuples(
+    type_encoder_map: Dict[Any, Callable]
+) -> Dict[Callable, Tuple]:
+    encoders_by_classes: Dict[Callable, List] = {}
+    for type_, encoder in type_encoder_map.items():
+        encoders_by_classes.setdefault(encoder, []).append(type_)
+    encoders_by_class_tuples: Dict[Callable, Tuple] = {}
+    for encoder, classes in encoders_by_classes.items():
+        encoders_by_class_tuples[encoder] = tuple(classes)
+    return encoders_by_class_tuples
+
+
+encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
+
+
 def jsonable_encoder(
     obj: Any,
     include: Union[SetIntStr, DictIntStrAny] = None,
@@ -106,24 +121,31 @@ def jsonable_encoder(
                 )
             )
         return encoded_list
+
+    if custom_encoder:
+        if type(obj) in custom_encoder:
+            return custom_encoder[type(obj)](obj)
+        else:
+            for encoder_type, encoder in custom_encoder.items():
+                if isinstance(obj, encoder_type):
+                    return encoder(obj)
+
+    if type(obj) in ENCODERS_BY_TYPE:
+        return ENCODERS_BY_TYPE[type(obj)](obj)
+    for encoder, classes_tuple in encoders_by_class_tuples.items():
+        if isinstance(obj, classes_tuple):
+            return encoder(obj)
+
     errors: List[Exception] = []
     try:
-        if custom_encoder and type(obj) in custom_encoder:
-            encoder = custom_encoder[type(obj)]
-        else:
-            encoder = ENCODERS_BY_TYPE[type(obj)]
-        return encoder(obj)
-    except KeyError as e:
+        data = dict(obj)
+    except Exception as e:
         errors.append(e)
         try:
-            data = dict(obj)
+            data = vars(obj)
         except Exception as e:
             errors.append(e)
-            try:
-                data = vars(obj)
-            except Exception as e:
-                errors.append(e)
-                raise ValueError(errors)
+            raise ValueError(errors)
     return jsonable_encoder(
         data,
         by_alias=by_alias,
diff --git a/tests/test_inherited_custom_class.py b/tests/test_inherited_custom_class.py
new file mode 100644 (file)
index 0000000..a9f6738
--- /dev/null
@@ -0,0 +1,73 @@
+import uuid
+
+import pytest
+from fastapi import FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class MyUuid:
+    def __init__(self, uuid_string: str):
+        self.uuid = uuid_string
+
+    def __str__(self):
+        return self.uuid
+
+    @property
+    def __class__(self):
+        return uuid.UUID
+
+    @property
+    def __dict__(self):
+        """Spoof a missing __dict__ by raising TypeError, this is how
+        asyncpg.pgroto.pgproto.UUID behaves"""
+        raise TypeError("vars() argument must have __dict__ attribute")
+
+
+@app.get("/fast_uuid")
+def return_fast_uuid():
+    # I don't want to import asyncpg for this test so I made my own UUID
+    # Import asyncpg and uncomment the two lines below for the actual bug
+
+    # from asyncpg.pgproto import pgproto
+    # asyncpg_uuid = pgproto.UUID("a10ff360-3b1e-4984-a26f-d3ab460bdb51")
+
+    asyncpg_uuid = MyUuid("a10ff360-3b1e-4984-a26f-d3ab460bdb51")
+    assert isinstance(asyncpg_uuid, uuid.UUID)
+    assert type(asyncpg_uuid) != uuid.UUID
+    with pytest.raises(TypeError):
+        vars(asyncpg_uuid)
+    return {"fast_uuid": asyncpg_uuid}
+
+
+class SomeCustomClass(BaseModel):
+    class Config:
+        arbitrary_types_allowed = True
+        json_encoders = {uuid.UUID: str}
+
+    a_uuid: MyUuid
+
+
+@app.get("/get_custom_class")
+def return_some_user():
+    # Test that the fix also works for custom pydantic classes
+    return SomeCustomClass(a_uuid=MyUuid("b8799909-f914-42de-91bc-95c819218d01"))
+
+
+client = TestClient(app)
+
+
+def test_dt():
+    with client:
+        response_simple = client.get("/fast_uuid")
+        response_pydantic = client.get("/get_custom_class")
+
+    assert response_simple.json() == {
+        "fast_uuid": "a10ff360-3b1e-4984-a26f-d3ab460bdb51"
+    }
+
+    assert response_pydantic.json() == {
+        "a_uuid": "b8799909-f914-42de-91bc-95c819218d01"
+    }