def get_sqlalchemy_type(field: ModelField) -> Any:
- if issubclass(field.type_, str):
- if field.field_info.max_length:
- return AutoString(length=field.field_info.max_length)
- return AutoString
- if issubclass(field.type_, float):
- return Float
- if issubclass(field.type_, bool):
- return Boolean
- if issubclass(field.type_, int):
- return Integer
- if issubclass(field.type_, datetime):
- return DateTime
- if issubclass(field.type_, date):
- return Date
- if issubclass(field.type_, timedelta):
- return Interval
- if issubclass(field.type_, time):
- return Time
- if issubclass(field.type_, Enum):
- return sa_Enum(field.type_)
- if issubclass(field.type_, bytes):
- return LargeBinary
- if issubclass(field.type_, Decimal):
- return Numeric(
- precision=getattr(field.type_, "max_digits", None),
- scale=getattr(field.type_, "decimal_places", None),
- )
- if issubclass(field.type_, ipaddress.IPv4Address):
- return AutoString
- if issubclass(field.type_, ipaddress.IPv4Network):
- return AutoString
- if issubclass(field.type_, ipaddress.IPv6Address):
- return AutoString
- if issubclass(field.type_, ipaddress.IPv6Network):
- return AutoString
- if issubclass(field.type_, Path):
- return AutoString
- if issubclass(field.type_, uuid.UUID):
- return GUID
+ if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
+ if issubclass(field.type_, str):
+ if field.field_info.max_length:
+ return AutoString(length=field.field_info.max_length)
+ return AutoString
+ if issubclass(field.type_, float):
+ return Float
+ if issubclass(field.type_, bool):
+ return Boolean
+ if issubclass(field.type_, int):
+ return Integer
+ if issubclass(field.type_, datetime):
+ return DateTime
+ if issubclass(field.type_, date):
+ return Date
+ if issubclass(field.type_, timedelta):
+ return Interval
+ if issubclass(field.type_, time):
+ return Time
+ if issubclass(field.type_, Enum):
+ return sa_Enum(field.type_)
+ if issubclass(field.type_, bytes):
+ return LargeBinary
+ if issubclass(field.type_, Decimal):
+ return Numeric(
+ precision=getattr(field.type_, "max_digits", None),
+ scale=getattr(field.type_, "decimal_places", None),
+ )
+ if issubclass(field.type_, ipaddress.IPv4Address):
+ return AutoString
+ if issubclass(field.type_, ipaddress.IPv4Network):
+ return AutoString
+ if issubclass(field.type_, ipaddress.IPv6Address):
+ return AutoString
+ if issubclass(field.type_, ipaddress.IPv6Network):
+ return AutoString
+ if issubclass(field.type_, Path):
+ return AutoString
+ if issubclass(field.type_, uuid.UUID):
+ return GUID
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
--- /dev/null
+from typing import Any, Dict, List, Optional, Union
+
+import pytest
+from sqlmodel import Field, SQLModel
+
+
+def test_type_list_breaks() -> None:
+ with pytest.raises(ValueError):
+
+ class Hero(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ tags: List[str]
+
+
+def test_type_dict_breaks() -> None:
+ with pytest.raises(ValueError):
+
+ class Hero(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ tags: Dict[str, Any]
+
+
+def test_type_union_breaks() -> None:
+ with pytest.raises(ValueError):
+
+ class Hero(SQLModel, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+ tags: Union[int, str]