import uuid
import weakref
from collections.abc import Mapping, Sequence, Set
+from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum
self.sa_relationship_kwargs = sa_relationship_kwargs
+@dataclass
+class FieldInfoMetadata:
+ primary_key: Union[bool, UndefinedType] = Undefined
+ nullable: Union[bool, UndefinedType] = Undefined
+ foreign_key: Any = Undefined
+ ondelete: Union[OnDeleteType, UndefinedType] = Undefined
+ unique: Union[bool, UndefinedType] = Undefined
+ index: Union[bool, UndefinedType] = Undefined
+ sa_type: Union[type[Any], UndefinedType] = Undefined
+ sa_column: Union[Column[Any], UndefinedType] = Undefined
+ sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined
+ sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined
+
+
+def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]:
+ metadata_items = getattr(field_info, "metadata", None)
+ if metadata_items:
+ for meta in metadata_items:
+ if isinstance(meta, FieldInfoMetadata):
+ return meta
+ return None
+
+
+def _get_sqlmodel_field_value(
+ field_info: Any, attribute: str, default: Any = Undefined
+) -> Any:
+ metadata = _get_sqlmodel_field_metadata(field_info)
+ if metadata is not None and hasattr(metadata, attribute):
+ return getattr(metadata, attribute)
+ return getattr(field_info, attribute, default)
+
+
# include sa_type, sa_column_args, sa_column_kwargs
@overload
def Field(
default_factory=default_factory,
**field_info_kwargs,
)
+ field_metadata = FieldInfoMetadata(
+ primary_key=primary_key,
+ nullable=nullable,
+ foreign_key=foreign_key,
+ ondelete=ondelete,
+ unique=unique,
+ index=index,
+ sa_type=sa_type,
+ sa_column=sa_column,
+ sa_column_args=sa_column_args,
+ sa_column_kwargs=sa_column_kwargs,
+ )
+ if hasattr(field_info, "metadata"):
+ field_info.metadata.append(field_metadata)
return field_info
def get_sqlalchemy_type(field: Any) -> Any:
field_info = field
- sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
+ sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined) # noqa: B009
if sa_type is not Undefined:
return sa_type
def get_column_from_field(field: Any) -> Column: # type: ignore
field_info = field
- sa_column = getattr(field_info, "sa_column", Undefined)
+ sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
- primary_key = getattr(field_info, "primary_key", Undefined)
+ primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined)
if primary_key is Undefined:
primary_key = False
- index = getattr(field_info, "index", Undefined)
+ index = _get_sqlmodel_field_value(field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
- field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
+ field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined)
if field_nullable is not Undefined:
assert not isinstance(field_nullable, UndefinedType)
nullable = field_nullable
args = []
- foreign_key = getattr(field_info, "foreign_key", Undefined)
+ foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined)
if foreign_key is Undefined:
foreign_key = None
- unique = getattr(field_info, "unique", Undefined)
+ unique = _get_sqlmodel_field_value(field_info, "unique", Undefined)
if unique is Undefined:
unique = False
if foreign_key:
- if field_info.ondelete == "SET NULL" and not nullable:
+ ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined)
+ if ondelete_value is Undefined:
+ ondelete_value = None
+ if ondelete_value == "SET NULL" and not nullable:
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
assert isinstance(foreign_key, str)
- ondelete = getattr(field_info, "ondelete", Undefined)
- if ondelete is Undefined:
- ondelete = None
- assert isinstance(ondelete, (str, type(None))) # for typing
- args.append(ForeignKey(foreign_key, ondelete=ondelete))
+ assert isinstance(ondelete_value, (str, type(None))) # for typing
+ args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
sa_default = field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default
- sa_column_args = getattr(field_info, "sa_column_args", Undefined)
+ sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence[Any], sa_column_args)))
- sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
+ sa_column_kwargs = _get_sqlmodel_field_value(
+ field_info, "sa_column_kwargs", Undefined
+ )
if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) # type: ignore
--- /dev/null
+from __future__ import annotations
+
+from typing import Annotated, Optional
+
+from sqlmodel import Field, Session, SQLModel, create_engine, select
+
+
+def test_model_with_future_annotations(clear_sqlmodel):
+ class Hero(SQLModel, table=True):
+ id: Annotated[Optional[int], Field(primary_key=True)] = None
+ name: str
+ secret_name: str
+ age: Optional[int] = None
+
+ hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
+
+ engine = create_engine("sqlite://")
+ SQLModel.metadata.create_all(engine)
+
+ with Session(engine) as session:
+ session.add(hero)
+ session.commit()
+ session.refresh(hero)
+
+ assert hero.id is not None
+ assert hero.name == "Deadpond"
+ assert hero.secret_name == "Dive Wilson"
+ assert hero.age == 25
+
+ with Session(engine) as session:
+ heroes = session.exec(select(Hero)).all()
+ assert len(heroes) == 1
+ assert heroes[0].name == "Deadpond"
+
+
+def test_model_with_string_annotations(clear_sqlmodel):
+ class Team(SQLModel, table=True):
+ id: Annotated[Optional[int], Field(primary_key=True)] = None
+ name: str
+
+ class Player(SQLModel, table=True):
+ id: Annotated[Optional[int], Field(primary_key=True)] = None
+ name: str
+ team_id: Annotated[Optional[int], Field(foreign_key="team.id")] = None
+
+ engine = create_engine("sqlite://")
+ SQLModel.metadata.create_all(engine)
+
+ team = Team(name="Champions")
+ player = Player(name="Alice", team_id=None)
+
+ with Session(engine) as session:
+ session.add(team)
+ session.commit()
+ session.refresh(team)
+
+ player.team_id = team.id
+ session.add(player)
+ session.commit()
+ session.refresh(player)
+
+ assert team.id is not None
+ assert player.team_id == team.id
-from typing import Optional\r
+from typing import Annotated, Optional\r
\r
import pytest\r
from sqlalchemy.exc import IntegrityError\r
# The next statement should not raise an AttributeError\r
assert hero_rusty_man.team\r
assert hero_rusty_man.team.name == "Preventers"\r
+\r
+\r
+def test_composite_primary_key(clear_sqlmodel):\r
+ class UserPermission(SQLModel, table=True):\r
+ user_id: int = Field(primary_key=True)\r
+ resource_id: int = Field(primary_key=True)\r
+ permission: str\r
+\r
+ engine = create_engine("sqlite://")\r
+ SQLModel.metadata.create_all(engine)\r
+\r
+ pk_column_names = {column.name for column in UserPermission.__table__.primary_key}\r
+ assert pk_column_names == {"user_id", "resource_id"}\r
+\r
+ with Session(engine) as session:\r
+ perm1 = UserPermission(user_id=1, resource_id=1, permission="read")\r
+ perm2 = UserPermission(user_id=1, resource_id=2, permission="write")\r
+ session.add(perm1)\r
+ session.add(perm2)\r
+ session.commit()\r
+\r
+ with pytest.raises(IntegrityError):\r
+ with Session(engine) as session:\r
+ perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")\r
+ session.add(perm3)\r
+ session.commit()\r
+\r
+\r
+def test_composite_primary_key_and_validator(clear_sqlmodel):\r
+ from pydantic import AfterValidator\r
+\r
+ def validate_resource_id(value: int) -> int:\r
+ if value < 1:\r
+ raise ValueError("Resource ID must be positive")\r
+ return value\r
+\r
+ class UserPermission(SQLModel, table=True):\r
+ user_id: int = Field(primary_key=True)\r
+ resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(\r
+ primary_key=True\r
+ )\r
+ permission: str\r
+\r
+ engine = create_engine("sqlite://")\r
+ SQLModel.metadata.create_all(engine)\r
+\r
+ pk_column_names = {column.name for column in UserPermission.__table__.primary_key}\r
+ assert pk_column_names == {"user_id", "resource_id"}\r
+\r
+ with Session(engine) as session:\r
+ perm1 = UserPermission(user_id=1, resource_id=1, permission="read")\r
+ perm2 = UserPermission(user_id=1, resource_id=2, permission="write")\r
+ session.add(perm1)\r
+ session.add(perm2)\r
+ session.commit()\r
+\r
+ with pytest.raises(IntegrityError):\r
+ with Session(engine) as session:\r
+ perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")\r
+ session.add(perm3)\r
+ session.commit()\r
+\r
+\r
+def test_foreign_key_ondelete_with_annotated(clear_sqlmodel):\r
+ from pydantic import AfterValidator\r
+\r
+ def ensure_positive(value: int) -> int:\r
+ if value < 0:\r
+ raise ValueError("Team ID must be positive")\r
+ return value\r
+\r
+ class Team(SQLModel, table=True):\r
+ id: int = Field(primary_key=True)\r
+ name: str\r
+\r
+ class Hero(SQLModel, table=True):\r
+ id: int = Field(primary_key=True)\r
+ team_id: Annotated[int, AfterValidator(ensure_positive)] = Field(\r
+ foreign_key="team.id",\r
+ ondelete="CASCADE",\r
+ )\r
+ name: str\r
+\r
+ engine = create_engine("sqlite://")\r
+ SQLModel.metadata.create_all(engine)\r
+\r
+ team_id_column = Hero.__table__.c.team_id # type: ignore[attr-defined]\r
+ foreign_keys = list(team_id_column.foreign_keys)\r
+ assert len(foreign_keys) == 1\r
+ assert foreign_keys[0].ondelete == "CASCADE"\r
+ assert team_id_column.nullable is False\r