From d3261cab591051759ad966eac8f61e6ee39ebaa1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 23 Oct 2023 13:22:44 +0400 Subject: [PATCH] =?utf8?q?=F0=9F=90=9B=20Fix=20enum=20type=20checks=20orde?= =?utf8?q?ring=20in=20`get=5Fsqlalchemy=5Ftype`=20(#669)?= MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Co-authored-by: Pierre Cheynier --- sqlmodel/main.py | 5 +++-- tests/test_enums.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d5a73024..a32be42c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -384,6 +384,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: ModelField) -> Any: if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI + if issubclass(field.type_, Enum): + return sa_Enum(field.type_) if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) @@ -402,8 +405,6 @@ def get_sqlalchemy_type(field: ModelField) -> Any: return Interval if issubclass(field.type_, time): return Time - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) if issubclass(field.type_, bytes): return LargeBinary if issubclass(field.type_, Decimal): diff --git a/tests/test_enums.py b/tests/test_enums.py index aeec6456..194bdefe 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -14,12 +14,12 @@ Associated issues: """ -class MyEnum1(enum.Enum): +class MyEnum1(str, enum.Enum): A = "A" B = "B" -class MyEnum2(enum.Enum): +class MyEnum2(str, enum.Enum): C = "C" D = "D" @@ -70,3 +70,43 @@ def test_sqlite_ddl_sql(capsys): captured = capsys.readouterr() assert "enum_field VARCHAR(1) NOT NULL" in captured.out assert "CREATE TYPE" not in captured.out + + +def test_json_schema_flat_model(): + assert FlatModel.schema() == { + "title": "FlatModel", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "string", "format": "uuid"}, + "enum_field": {"$ref": "#/definitions/MyEnum1"}, + }, + "required": ["id", "enum_field"], + "definitions": { + "MyEnum1": { + "title": "MyEnum1", + "description": "An enumeration.", + "enum": ["A", "B"], + "type": "string", + } + }, + } + + +def test_json_schema_inherit_model(): + assert InheritModel.schema() == { + "title": "InheritModel", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "string", "format": "uuid"}, + "enum_field": {"$ref": "#/definitions/MyEnum2"}, + }, + "required": ["id", "enum_field"], + "definitions": { + "MyEnum2": { + "title": "MyEnum2", + "description": "An enumeration.", + "enum": ["C", "D"], + "type": "string", + } + }, + } -- 2.47.2