]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
🐛 Fix Enum handling in SQLAlchemy (#165)
authorChris White <chriswhite199@gmail.com>
Sat, 27 Aug 2022 22:48:44 +0000 (15:48 -0700)
committerGitHub <noreply@github.com>
Sat, 27 Aug 2022 22:48:44 +0000 (00:48 +0200)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/main.py
tests/test_enums.py [new file with mode: 0644]

index d85976db474e53b8533c2394b6101e5bec325fc4..86e28b33334e46778f2762b9cd9145b362a6c7d7 100644 (file)
@@ -31,18 +31,9 @@ from pydantic.fields import ModelField, Undefined, UndefinedType
 from pydantic.main import ModelMetaclass, validate_model
 from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations
 from pydantic.utils import ROOT_KEY, Representation
-from sqlalchemy import (
-    Boolean,
-    Column,
-    Date,
-    DateTime,
-    Float,
-    ForeignKey,
-    Integer,
-    Interval,
-    Numeric,
-    inspect,
-)
+from sqlalchemy import Boolean, Column, Date, DateTime
+from sqlalchemy import Enum as sa_Enum
+from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
 from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
 from sqlalchemy.orm.attributes import set_attribute
 from sqlalchemy.orm.decl_api import DeclarativeMeta
@@ -396,7 +387,7 @@ def get_sqlachemy_type(field: ModelField) -> Any:
     if issubclass(field.type_, time):
         return Time
     if issubclass(field.type_, Enum):
-        return Enum
+        return sa_Enum(field.type_)
     if issubclass(field.type_, bytes):
         return LargeBinary
     if issubclass(field.type_, Decimal):
diff --git a/tests/test_enums.py b/tests/test_enums.py
new file mode 100644 (file)
index 0000000..aeec645
--- /dev/null
@@ -0,0 +1,72 @@
+import enum
+import uuid
+
+from sqlalchemy import create_mock_engine
+from sqlalchemy.sql.type_api import TypeEngine
+from sqlmodel import Field, SQLModel
+
+"""
+Tests related to Enums
+
+Associated issues:
+* https://github.com/tiangolo/sqlmodel/issues/96
+* https://github.com/tiangolo/sqlmodel/issues/164
+"""
+
+
+class MyEnum1(enum.Enum):
+    A = "A"
+    B = "B"
+
+
+class MyEnum2(enum.Enum):
+    C = "C"
+    D = "D"
+
+
+class BaseModel(SQLModel):
+    id: uuid.UUID = Field(primary_key=True)
+    enum_field: MyEnum2
+
+
+class FlatModel(SQLModel, table=True):
+    id: uuid.UUID = Field(primary_key=True)
+    enum_field: MyEnum1
+
+
+class InheritModel(BaseModel, table=True):
+    pass
+
+
+def pg_dump(sql: TypeEngine, *args, **kwargs):
+    dialect = sql.compile(dialect=postgres_engine.dialect)
+    sql_str = str(dialect).rstrip()
+    if sql_str:
+        print(sql_str + ";")
+
+
+def sqlite_dump(sql: TypeEngine, *args, **kwargs):
+    dialect = sql.compile(dialect=sqlite_engine.dialect)
+    sql_str = str(dialect).rstrip()
+    if sql_str:
+        print(sql_str + ";")
+
+
+postgres_engine = create_mock_engine("postgresql://", pg_dump)
+sqlite_engine = create_mock_engine("sqlite://", sqlite_dump)
+
+
+def test_postgres_ddl_sql(capsys):
+    SQLModel.metadata.create_all(bind=postgres_engine, checkfirst=False)
+
+    captured = capsys.readouterr()
+    assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out
+    assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out
+
+
+def test_sqlite_ddl_sql(capsys):
+    SQLModel.metadata.create_all(bind=sqlite_engine, checkfirst=False)
+
+    captured = capsys.readouterr()
+    assert "enum_field VARCHAR(1) NOT NULL" in captured.out
+    assert "CREATE TYPE" not in captured.out