]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
♻️ Refactor checking for sa_type
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 29 Oct 2023 08:00:37 +0000 (12:00 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Sun, 29 Oct 2023 08:00:37 +0000 (12:00 +0400)
sqlmodel/main.py

index 8801730bca51dd1c30492af3dab8bc2c4d20495f..266bb6ccd1415e0d41e6920b0e9869c93708a774 100644 (file)
@@ -105,11 +105,15 @@ class FieldInfo(PydanticFieldInfo):
                 )
             if unique is not Undefined:
                 raise RuntimeError(
-                    "Passing unique is not supported when " "also passing a sa_column"
+                    "Passing unique is not supported when also passing a sa_column"
                 )
             if index is not Undefined:
                 raise RuntimeError(
-                    "Passing index is not supported when " "also passing a sa_column"
+                    "Passing index is not supported when also passing a sa_column"
+                )
+            if sa_type is not Undefined:
+                raise RuntimeError(
+                    "Passing sa_type is not supported when also passing a sa_column"
                 )
         super().__init__(default=default, **kwargs)
         self.primary_key = primary_key
@@ -187,6 +191,7 @@ def Field(
     unique: Union[bool, UndefinedType] = Undefined,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
+    sa_type: Union[Type[Any], UndefinedType] = Undefined,
     sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
     sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
     schema_extra: Optional[Dict[str, Any]] = None,
@@ -266,7 +271,7 @@ def Field(
     unique: Union[bool, UndefinedType] = Undefined,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
-    sa_type: Type[Any] = Undefined,
+    sa_type: Union[Type[Any], UndefinedType] = Undefined,
     sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
     sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
     sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
@@ -519,9 +524,9 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
 
 
 def get_sqlalchemy_type(field: ModelField) -> Any:
-    if hasattr(field.field_info, "sa_type"):
-        if not issubclass(type(field.field_info.sa_type), type(Undefined)):
-            return field.field_info.sa_type
+    sa_type = getattr(field.field_info, "sa_type")  # noqa: B009
+    if sa_type is not Undefined:
+        return sa_type
     if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
         # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
         if issubclass(field.type_, Enum):