From: David Danier Date: Mon, 23 Oct 2023 06:42:30 +0000 (+0200) Subject: ✨ Raise a more clear error when a type is not valid (#425) X-Git-Tag: 0.0.9~43 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=840fd08ab2f803d4e8fb67c7587a59621473c715;p=thirdparty%2Ffastapi%2Fsqlmodel.git ✨ Raise a more clear error when a type is not valid (#425) Co-authored-by: Sebastián Ramírez --- diff --git a/sqlmodel/main.py b/sqlmodel/main.py index caae8cf0..7dec60dd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -374,45 +374,46 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): def get_sqlalchemy_type(field: ModelField) -> Any: - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) - return AutoString - if issubclass(field.type_, float): - return Float - if issubclass(field.type_, bool): - return Boolean - if issubclass(field.type_, int): - return Integer - if issubclass(field.type_, datetime): - return DateTime - if issubclass(field.type_, date): - return Date - if issubclass(field.type_, timedelta): - return Interval - if issubclass(field.type_, time): - return Time - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, bytes): - return LargeBinary - if issubclass(field.type_, Decimal): - return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), - ) - if issubclass(field.type_, ipaddress.IPv4Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): - return AutoString - if issubclass(field.type_, Path): - return AutoString - if issubclass(field.type_, uuid.UUID): - return GUID + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + if issubclass(field.type_, str): + if field.field_info.max_length: + return AutoString(length=field.field_info.max_length) + return AutoString + if issubclass(field.type_, float): + return Float + if issubclass(field.type_, bool): + return Boolean + if issubclass(field.type_, int): + return Integer + if issubclass(field.type_, datetime): + return DateTime + if issubclass(field.type_, date): + return Date + if issubclass(field.type_, timedelta): + return Interval + if issubclass(field.type_, time): + return Time + if issubclass(field.type_, Enum): + return sa_Enum(field.type_) + if issubclass(field.type_, bytes): + return LargeBinary + if issubclass(field.type_, Decimal): + return Numeric( + precision=getattr(field.type_, "max_digits", None), + scale=getattr(field.type_, "decimal_places", None), + ) + if issubclass(field.type_, ipaddress.IPv4Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv4Network): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Network): + return AutoString + if issubclass(field.type_, Path): + return AutoString + if issubclass(field.type_, uuid.UUID): + return GUID raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/tests/test_sqlalchemy_type_errors.py b/tests/test_sqlalchemy_type_errors.py new file mode 100644 index 00000000..e211c46a --- /dev/null +++ b/tests/test_sqlalchemy_type_errors.py @@ -0,0 +1,28 @@ +from typing import Any, Dict, List, Optional, Union + +import pytest +from sqlmodel import Field, SQLModel + + +def test_type_list_breaks() -> None: + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + tags: List[str] + + +def test_type_dict_breaks() -> None: + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + tags: Dict[str, Any] + + +def test_type_union_breaks() -> None: + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + tags: Union[int, str]