]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
Merge branch 'main' into main
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 29 Oct 2023 07:55:17 +0000 (11:55 +0400)
committerGitHub <noreply@github.com>
Sun, 29 Oct 2023 07:55:17 +0000 (11:55 +0400)
1  2 
sqlmodel/main.py

index 46f3f0ee289f11f009c02595dcbd3d43cfd5e018,f48e388e137b453ca9d8f5967b81665530cddab0..8801730bca51dd1c30492af3dab8bc2c4d20495f
@@@ -143,12 -178,92 +180,93 @@@ def Field
      max_length: Optional[int] = None,
      allow_mutation: bool = True,
      regex: Optional[str] = None,
-     primary_key: bool = False,
-     foreign_key: Optional[Any] = None,
-     unique: bool = False,
+     discriminator: Optional[str] = None,
+     repr: bool = True,
+     primary_key: Union[bool, UndefinedType] = Undefined,
+     foreign_key: Any = Undefined,
+     unique: Union[bool, UndefinedType] = Undefined,
+     nullable: Union[bool, UndefinedType] = Undefined,
+     index: Union[bool, UndefinedType] = Undefined,
+     sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
+     sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
+     schema_extra: Optional[Dict[str, Any]] = None,
+ ) -> Any:
+     ...
+ @overload
+ def Field(
+     default: Any = Undefined,
+     *,
+     default_factory: Optional[NoArgAnyCallable] = None,
+     alias: Optional[str] = None,
+     title: Optional[str] = None,
+     description: Optional[str] = None,
+     exclude: Union[
+         AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+     ] = None,
+     include: Union[
+         AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+     ] = None,
+     const: Optional[bool] = None,
+     gt: Optional[float] = None,
+     ge: Optional[float] = None,
+     lt: Optional[float] = None,
+     le: Optional[float] = None,
+     multiple_of: Optional[float] = None,
+     max_digits: Optional[int] = None,
+     decimal_places: Optional[int] = None,
+     min_items: Optional[int] = None,
+     max_items: Optional[int] = None,
+     unique_items: Optional[bool] = None,
+     min_length: Optional[int] = None,
+     max_length: Optional[int] = None,
+     allow_mutation: bool = True,
+     regex: Optional[str] = None,
+     discriminator: Optional[str] = None,
+     repr: bool = True,
+     sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
+     schema_extra: Optional[Dict[str, Any]] = None,
+ ) -> Any:
+     ...
+ def Field(
+     default: Any = Undefined,
+     *,
+     default_factory: Optional[NoArgAnyCallable] = None,
+     alias: Optional[str] = None,
+     title: Optional[str] = None,
+     description: Optional[str] = None,
+     exclude: Union[
+         AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+     ] = None,
+     include: Union[
+         AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+     ] = None,
+     const: Optional[bool] = None,
+     gt: Optional[float] = None,
+     ge: Optional[float] = None,
+     lt: Optional[float] = None,
+     le: Optional[float] = None,
+     multiple_of: Optional[float] = None,
+     max_digits: Optional[int] = None,
+     decimal_places: Optional[int] = None,
+     min_items: Optional[int] = None,
+     max_items: Optional[int] = None,
+     unique_items: Optional[bool] = None,
+     min_length: Optional[int] = None,
+     max_length: Optional[int] = None,
+     allow_mutation: bool = True,
+     regex: Optional[str] = None,
+     discriminator: Optional[str] = None,
+     repr: bool = True,
+     primary_key: Union[bool, UndefinedType] = Undefined,
+     foreign_key: Any = Undefined,
+     unique: Union[bool, UndefinedType] = Undefined,
      nullable: Union[bool, UndefinedType] = Undefined,
      index: Union[bool, UndefinedType] = Undefined,
 +    sa_type: Type[Any] = Undefined,
      sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
      sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
      sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
@@@ -376,48 -515,47 +519,50 @@@ class SQLModelMetaclass(ModelMetaclass
  
  
  def get_sqlalchemy_type(field: ModelField) -> Any:
-     if issubclass(field.type_, str):
-         if field.field_info.max_length:
-             return AutoString(length=field.field_info.max_length)
-         return AutoString
-     if issubclass(field.type_, float):
-         return Float
-     if issubclass(field.type_, bool):
-         return Boolean
-     if issubclass(field.type_, int):
-         return Integer
-     if issubclass(field.type_, datetime):
-         return DateTime
-     if issubclass(field.type_, date):
-         return Date
-     if issubclass(field.type_, timedelta):
-         return Interval
-     if issubclass(field.type_, time):
-         return Time
-     if issubclass(field.type_, Enum):
-         return sa_Enum(field.type_)
-     if issubclass(field.type_, bytes):
-         return LargeBinary
-     if issubclass(field.type_, Decimal):
-         return Numeric(
-             precision=getattr(field.type_, "max_digits", None),
-             scale=getattr(field.type_, "decimal_places", None),
-         )
-     if issubclass(field.type_, ipaddress.IPv4Address):
-         return AutoString
-     if issubclass(field.type_, ipaddress.IPv4Network):
-         return AutoString
-     if issubclass(field.type_, ipaddress.IPv6Address):
-         return AutoString
-     if issubclass(field.type_, ipaddress.IPv6Network):
-         return AutoString
-     if issubclass(field.type_, Path):
-         return AutoString
-     if issubclass(field.type_, uuid.UUID):
-         return GUID
 +    if hasattr(field.field_info, "sa_type"):
 +        if not issubclass(type(field.field_info.sa_type), type(Undefined)):
 +            return field.field_info.sa_type
+     if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
+         # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
+         if issubclass(field.type_, Enum):
+             return sa_Enum(field.type_)
+         if issubclass(field.type_, str):
+             if field.field_info.max_length:
+                 return AutoString(length=field.field_info.max_length)
+             return AutoString
+         if issubclass(field.type_, float):
+             return Float
+         if issubclass(field.type_, bool):
+             return Boolean
+         if issubclass(field.type_, int):
+             return Integer
+         if issubclass(field.type_, datetime):
+             return DateTime
+         if issubclass(field.type_, date):
+             return Date
+         if issubclass(field.type_, timedelta):
+             return Interval
+         if issubclass(field.type_, time):
+             return Time
+         if issubclass(field.type_, bytes):
+             return LargeBinary
+         if issubclass(field.type_, Decimal):
+             return Numeric(
+                 precision=getattr(field.type_, "max_digits", None),
+                 scale=getattr(field.type_, "decimal_places", None),
+             )
+         if issubclass(field.type_, ipaddress.IPv4Address):
+             return AutoString
+         if issubclass(field.type_, ipaddress.IPv4Network):
+             return AutoString
+         if issubclass(field.type_, ipaddress.IPv6Address):
+             return AutoString
+         if issubclass(field.type_, ipaddress.IPv6Network):
+             return AutoString
+         if issubclass(field.type_, Path):
+             return AutoString
+         if issubclass(field.type_, uuid.UUID):
+             return GUID
      raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")