From: Sebastián Ramírez Date: Sun, 29 Oct 2023 07:55:17 +0000 (+0400) Subject: Merge branch 'main' into main X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=a1833901f59a233bc44b28e8df80a72d987b8c7b;p=thirdparty%2Ffastapi%2Fsqlmodel.git Merge branch 'main' into main --- a1833901f59a233bc44b28e8df80a72d987b8c7b diff --cc sqlmodel/main.py index 46f3f0ee,f48e388e..8801730b --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@@ -143,12 -178,92 +180,93 @@@ def Field max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, - primary_key: bool = False, - foreign_key: Optional[Any] = None, - unique: bool = False, + discriminator: Optional[str] = None, + repr: bool = True, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, + ) -> Any: + ... + + + @overload + def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + schema_extra: Optional[Dict[str, Any]] = None, + ) -> Any: + ... + + + def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, + sa_type: Type[Any] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, @@@ -376,48 -515,47 +519,50 @@@ class SQLModelMetaclass(ModelMetaclass def get_sqlalchemy_type(field: ModelField) -> Any: + if hasattr(field.field_info, "sa_type"): + if not issubclass(type(field.field_info.sa_type), type(Undefined)): + return field.field_info.sa_type - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) - return AutoString - if issubclass(field.type_, float): - return Float - if issubclass(field.type_, bool): - return Boolean - if issubclass(field.type_, int): - return Integer - if issubclass(field.type_, datetime): - return DateTime - if issubclass(field.type_, date): - return Date - if issubclass(field.type_, timedelta): - return Interval - if issubclass(field.type_, time): - return Time - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, bytes): - return LargeBinary - if issubclass(field.type_, Decimal): - return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), - ) - if issubclass(field.type_, ipaddress.IPv4Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): - return AutoString - if issubclass(field.type_, Path): - return AutoString - if issubclass(field.type_, uuid.UUID): - return GUID + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI + if issubclass(field.type_, Enum): + return sa_Enum(field.type_) + if issubclass(field.type_, str): + if field.field_info.max_length: + return AutoString(length=field.field_info.max_length) + return AutoString + if issubclass(field.type_, float): + return Float + if issubclass(field.type_, bool): + return Boolean + if issubclass(field.type_, int): + return Integer + if issubclass(field.type_, datetime): + return DateTime + if issubclass(field.type_, date): + return Date + if issubclass(field.type_, timedelta): + return Interval + if issubclass(field.type_, time): + return Time + if issubclass(field.type_, bytes): + return LargeBinary + if issubclass(field.type_, Decimal): + return Numeric( + precision=getattr(field.type_, "max_digits", None), + scale=getattr(field.type_, "decimal_places", None), + ) + if issubclass(field.type_, ipaddress.IPv4Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv4Network): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Network): + return AutoString + if issubclass(field.type_, Path): + return AutoString + if issubclass(field.type_, uuid.UUID): + return GUID raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")