from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
-from typing_extensions import get_args, get_origin
+from typing_extensions import Annotated, get_args, get_origin
# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
return False
return False
- def get_type_from_field(field: Any) -> Any:
- type_: Any = field.annotation
+ def get_sa_type_from_type_annotation(annotation: Any) -> Any:
# Resolve Optional fields
- if type_ is None:
+ if annotation is None:
raise ValueError("Missing field type")
- origin = get_origin(type_)
+ origin = get_origin(annotation)
if origin is None:
- return type_
+ return annotation
+ elif origin is Annotated:
+ return get_sa_type_from_type_annotation(get_args(annotation)[0])
if _is_union_type(origin):
- bases = get_args(type_)
+ bases = get_args(annotation)
if len(bases) > 2:
raise ValueError(
"Cannot have a (non-optional) union as a SQLAlchemy field"
"Cannot have a (non-optional) union as a SQLAlchemy field"
)
# Optional unions are allowed
- return bases[0] if bases[0] is not NoneType else bases[1]
+ use_type = bases[0] if bases[0] is not NoneType else bases[1]
+ return get_sa_type_from_type_annotation(use_type)
return origin
+ def get_sa_type_from_field(field: Any) -> Any:
+ type_: Any = field.annotation
+ return get_sa_type_from_type_annotation(type_)
+
def get_field_metadata(field: Any) -> Any:
for meta in field.metadata:
if isinstance(meta, (PydanticMetadata, MaxLen)):
)
return field.allow_none # type: ignore[no-any-return, attr-defined]
- def get_type_from_field(field: Any) -> Any:
+ def get_sa_type_from_field(field: Any) -> Any:
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
return field.type_
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
get_field_metadata,
get_model_fields,
get_relationship_to,
- get_type_from_field,
+ get_sa_type_from_field,
init_pydantic_private_attrs,
is_field_noneable,
is_table_model_class,
if sa_type is not Undefined:
return sa_type
- type_ = get_type_from_field(field)
+ type_ = get_sa_type_from_field(field)
metadata = get_field_metadata(field)
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
--- /dev/null
+import uuid
+from typing import Optional
+
+from sqlmodel import Field, Session, SQLModel, create_engine, select
+
+from tests.conftest import needs_pydanticv2
+
+
+@needs_pydanticv2
+def test_annotated_optional_types(clear_sqlmodel) -> None:
+ from pydantic import UUID4
+
+ class Hero(SQLModel, table=True):
+ # Pydantic UUID4 is: Annotated[UUID, UuidVersion(4)]
+ id: Optional[UUID4] = Field(default_factory=uuid.uuid4, primary_key=True)
+
+ engine = create_engine("sqlite:///:memory:")
+ SQLModel.metadata.create_all(engine)
+ with Session(engine) as db:
+ hero = Hero()
+ db.add(hero)
+ db.commit()
+ statement = select(Hero)
+ result = db.exec(statement).all()
+ assert len(result) == 1
+ assert isinstance(hero.id, uuid.UUID)