]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Add support for passing a custom SQLAlchemy type to `Field()` with `sa_type` (...
authorMaruo.S <raspi-maru2004@outlook.jp>
Sun, 29 Oct 2023 08:10:39 +0000 (17:10 +0900)
committerGitHub <noreply@github.com>
Sun, 29 Oct 2023 08:10:39 +0000 (12:10 +0400)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/main.py
tests/test_field_sa_column.py

index f48e388e137b453ca9d8f5967b81665530cddab0..2b69dd2a75929e3dd24bc0c866c654c837e257e1 100644 (file)
@@ -74,6 +74,7 @@ class FieldInfo(PydanticFieldInfo):
         foreign_key = kwargs.pop("foreign_key", Undefined)
         unique = kwargs.pop("unique", False)
         index = kwargs.pop("index", Undefined)
+        sa_type = kwargs.pop("sa_type", Undefined)
         sa_column = kwargs.pop("sa_column", Undefined)
         sa_column_args = kwargs.pop("sa_column_args", Undefined)
         sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
@@ -104,11 +105,15 @@ class FieldInfo(PydanticFieldInfo):
                 )
             if unique is not Undefined:
                 raise RuntimeError(
-                    "Passing unique is not supported when " "also passing a sa_column"
+                    "Passing unique is not supported when also passing a sa_column"
                 )
             if index is not Undefined:
                 raise RuntimeError(
-                    "Passing index is not supported when " "also passing a sa_column"
+                    "Passing index is not supported when also passing a sa_column"
+                )
+            if sa_type is not Undefined:
+                raise RuntimeError(
+                    "Passing sa_type is not supported when also passing a sa_column"
                 )
         super().__init__(default=default, **kwargs)
         self.primary_key = primary_key
@@ -116,6 +121,7 @@ class FieldInfo(PydanticFieldInfo):
         self.foreign_key = foreign_key
         self.unique = unique
         self.index = index
+        self.sa_type = sa_type
         self.sa_column = sa_column
         self.sa_column_args = sa_column_args
         self.sa_column_kwargs = sa_column_kwargs
@@ -185,6 +191,7 @@ def Field(
     unique: Union[bool, UndefinedType] = Undefined,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
+    sa_type: Union[Type[Any], UndefinedType] = Undefined,
     sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
     sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
     schema_extra: Optional[Dict[str, Any]] = None,
@@ -264,6 +271,7 @@ def Field(
     unique: Union[bool, UndefinedType] = Undefined,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
+    sa_type: Union[Type[Any], UndefinedType] = Undefined,
     sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
     sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
     sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
@@ -300,6 +308,7 @@ def Field(
         unique=unique,
         nullable=nullable,
         index=index,
+        sa_type=sa_type,
         sa_column=sa_column,
         sa_column_args=sa_column_args,
         sa_column_kwargs=sa_column_kwargs,
@@ -515,6 +524,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
 
 
 def get_sqlalchemy_type(field: ModelField) -> Any:
+    sa_type = getattr(field.field_info, "sa_type", Undefined)  # noqa: B009
+    if sa_type is not Undefined:
+        return sa_type
     if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
         # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
         if issubclass(field.type_, Enum):
index 51cfdfa7973c656ed54fa3b6e799432e85adc8c5..7384f1fabcf93f1b9e9856a678630255307b4f76 100644 (file)
@@ -39,6 +39,17 @@ def test_sa_column_no_sa_kargs() -> None:
             )
 
 
+def test_sa_column_no_type() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                sa_type=Integer,
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
 def test_sa_column_no_primary_key() -> None:
     with pytest.raises(RuntimeError):