]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Raise a more clear error when a type is not valid (#425)
authorDavid Danier <david.danier@gmail.com>
Mon, 23 Oct 2023 06:42:30 +0000 (08:42 +0200)
committerGitHub <noreply@github.com>
Mon, 23 Oct 2023 06:42:30 +0000 (10:42 +0400)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/main.py
tests/test_sqlalchemy_type_errors.py [new file with mode: 0644]

index caae8cf08db606b5ce946f69e6008c9a3b944b15..7dec60ddace5d10f6e0a6c1b11f024eaaabb3315 100644 (file)
@@ -374,45 +374,46 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
 
 
 def get_sqlalchemy_type(field: ModelField) -> Any:
-    if issubclass(field.type_, str):
-        if field.field_info.max_length:
-            return AutoString(length=field.field_info.max_length)
-        return AutoString
-    if issubclass(field.type_, float):
-        return Float
-    if issubclass(field.type_, bool):
-        return Boolean
-    if issubclass(field.type_, int):
-        return Integer
-    if issubclass(field.type_, datetime):
-        return DateTime
-    if issubclass(field.type_, date):
-        return Date
-    if issubclass(field.type_, timedelta):
-        return Interval
-    if issubclass(field.type_, time):
-        return Time
-    if issubclass(field.type_, Enum):
-        return sa_Enum(field.type_)
-    if issubclass(field.type_, bytes):
-        return LargeBinary
-    if issubclass(field.type_, Decimal):
-        return Numeric(
-            precision=getattr(field.type_, "max_digits", None),
-            scale=getattr(field.type_, "decimal_places", None),
-        )
-    if issubclass(field.type_, ipaddress.IPv4Address):
-        return AutoString
-    if issubclass(field.type_, ipaddress.IPv4Network):
-        return AutoString
-    if issubclass(field.type_, ipaddress.IPv6Address):
-        return AutoString
-    if issubclass(field.type_, ipaddress.IPv6Network):
-        return AutoString
-    if issubclass(field.type_, Path):
-        return AutoString
-    if issubclass(field.type_, uuid.UUID):
-        return GUID
+    if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
+        if issubclass(field.type_, str):
+            if field.field_info.max_length:
+                return AutoString(length=field.field_info.max_length)
+            return AutoString
+        if issubclass(field.type_, float):
+            return Float
+        if issubclass(field.type_, bool):
+            return Boolean
+        if issubclass(field.type_, int):
+            return Integer
+        if issubclass(field.type_, datetime):
+            return DateTime
+        if issubclass(field.type_, date):
+            return Date
+        if issubclass(field.type_, timedelta):
+            return Interval
+        if issubclass(field.type_, time):
+            return Time
+        if issubclass(field.type_, Enum):
+            return sa_Enum(field.type_)
+        if issubclass(field.type_, bytes):
+            return LargeBinary
+        if issubclass(field.type_, Decimal):
+            return Numeric(
+                precision=getattr(field.type_, "max_digits", None),
+                scale=getattr(field.type_, "decimal_places", None),
+            )
+        if issubclass(field.type_, ipaddress.IPv4Address):
+            return AutoString
+        if issubclass(field.type_, ipaddress.IPv4Network):
+            return AutoString
+        if issubclass(field.type_, ipaddress.IPv6Address):
+            return AutoString
+        if issubclass(field.type_, ipaddress.IPv6Network):
+            return AutoString
+        if issubclass(field.type_, Path):
+            return AutoString
+        if issubclass(field.type_, uuid.UUID):
+            return GUID
     raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
 
 
diff --git a/tests/test_sqlalchemy_type_errors.py b/tests/test_sqlalchemy_type_errors.py
new file mode 100644 (file)
index 0000000..e211c46
--- /dev/null
@@ -0,0 +1,28 @@
+from typing import Any, Dict, List, Optional, Union
+
+import pytest
+from sqlmodel import Field, SQLModel
+
+
+def test_type_list_breaks() -> None:
+    with pytest.raises(ValueError):
+
+        class Hero(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            tags: List[str]
+
+
+def test_type_dict_breaks() -> None:
+    with pytest.raises(ValueError):
+
+        class Hero(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            tags: Dict[str, Any]
+
+
+def test_type_union_breaks() -> None:
+    with pytest.raises(ValueError):
+
+        class Hero(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            tags: Union[int, str]