exclude: Set[str] = set(),
by_alias: bool = False,
include_none: bool = True,
+ custom_encoder: dict = {},
) -> Any:
if isinstance(obj, BaseModel):
- return jsonable_encoder(
- obj.dict(include=include, exclude=exclude, by_alias=by_alias),
- include_none=include_none,
- )
+ if not obj.Config.json_encoders:
+ return jsonable_encoder(
+ obj.dict(include=include, exclude=exclude, by_alias=by_alias),
+ include_none=include_none,
+ )
+ else:
+ return jsonable_encoder(
+ obj.dict(include=include, exclude=exclude, by_alias=by_alias),
+ include_none=include_none,
+ custom_encoder=obj.Config.json_encoders,
+ )
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, (str, int, float, type(None))):
if isinstance(obj, dict):
return {
jsonable_encoder(
- key, by_alias=by_alias, include_none=include_none
- ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none)
+ key,
+ by_alias=by_alias,
+ include_none=include_none,
+ custom_encoder=custom_encoder,
+ ): jsonable_encoder(
+ value,
+ by_alias=by_alias,
+ include_none=include_none,
+ custom_encoder=custom_encoder,
+ )
for key, value in obj.items()
if value is not None or include_none
}
exclude=exclude,
by_alias=by_alias,
include_none=include_none,
+ custom_encoder=custom_encoder,
)
for item in obj
]
errors = []
try:
- encoder = ENCODERS_BY_TYPE[type(obj)]
+ if custom_encoder and type(obj) in custom_encoder:
+ encoder = custom_encoder[type(obj)]
+ else:
+ encoder = ENCODERS_BY_TYPE[type(obj)]
return encoder(obj)
except KeyError as e:
errors.append(e)
--- /dev/null
+import json
+from datetime import datetime, timezone
+
+from fastapi import FastAPI
+from pydantic import BaseModel
+from starlette.testclient import TestClient
+
+
+class ModelWithDatetimeField(BaseModel):
+ dt_field: datetime
+
+ class Config:
+ json_encoders = {
+ datetime: lambda dt: dt.replace(
+ microsecond=0, tzinfo=timezone.utc
+ ).isoformat()
+ }
+
+
+app = FastAPI()
+model = ModelWithDatetimeField(dt_field=datetime.utcnow())
+
+
+@app.get("/model", response_model=ModelWithDatetimeField)
+def get_model():
+ return model
+
+
+client = TestClient(app)
+
+
+def test_dt():
+ with client:
+ response = client.get("/model")
+ assert json.loads(model.json()) == response.json()