]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
🐛 Fix support for `Annotated` fields with Pydantic 2.12+ (#1607)
authorVictor Mota <vimota@gmail.com>
Sun, 1 Feb 2026 18:14:55 +0000 (13:14 -0500)
committerGitHub <noreply@github.com>
Sun, 1 Feb 2026 18:14:55 +0000 (19:14 +0100)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: svlandeg <svlandeg@github.com>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/main.py
tests/test_field_sa_column.py
tests/test_future_annotations.py [new file with mode: 0644]
tests/test_main.py

index 84478f24cf8ca9b6fecabc146aff45c1c5acf64f..b0e88fd04ec428203098ad9befd8f6cf33028a9c 100644 (file)
@@ -5,6 +5,7 @@ import ipaddress
 import uuid
 import weakref
 from collections.abc import Mapping, Sequence, Set
+from dataclasses import dataclass
 from datetime import date, datetime, time, timedelta
 from decimal import Decimal
 from enum import Enum
@@ -200,6 +201,38 @@ class RelationshipInfo(Representation):
         self.sa_relationship_kwargs = sa_relationship_kwargs
 
 
+@dataclass
+class FieldInfoMetadata:
+    primary_key: Union[bool, UndefinedType] = Undefined
+    nullable: Union[bool, UndefinedType] = Undefined
+    foreign_key: Any = Undefined
+    ondelete: Union[OnDeleteType, UndefinedType] = Undefined
+    unique: Union[bool, UndefinedType] = Undefined
+    index: Union[bool, UndefinedType] = Undefined
+    sa_type: Union[type[Any], UndefinedType] = Undefined
+    sa_column: Union[Column[Any], UndefinedType] = Undefined
+    sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined
+    sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined
+
+
+def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]:
+    metadata_items = getattr(field_info, "metadata", None)
+    if metadata_items:
+        for meta in metadata_items:
+            if isinstance(meta, FieldInfoMetadata):
+                return meta
+    return None
+
+
+def _get_sqlmodel_field_value(
+    field_info: Any, attribute: str, default: Any = Undefined
+) -> Any:
+    metadata = _get_sqlmodel_field_metadata(field_info)
+    if metadata is not None and hasattr(metadata, attribute):
+        return getattr(metadata, attribute)
+    return getattr(field_info, attribute, default)
+
+
 # include sa_type, sa_column_args, sa_column_kwargs
 @overload
 def Field(
@@ -423,6 +456,20 @@ def Field(
         default_factory=default_factory,
         **field_info_kwargs,
     )
+    field_metadata = FieldInfoMetadata(
+        primary_key=primary_key,
+        nullable=nullable,
+        foreign_key=foreign_key,
+        ondelete=ondelete,
+        unique=unique,
+        index=index,
+        sa_type=sa_type,
+        sa_column=sa_column,
+        sa_column_args=sa_column_args,
+        sa_column_kwargs=sa_column_kwargs,
+    )
+    if hasattr(field_info, "metadata"):
+        field_info.metadata.append(field_metadata)
     return field_info
 
 
@@ -637,7 +684,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
 
 def get_sqlalchemy_type(field: Any) -> Any:
     field_info = field
-    sa_type = getattr(field_info, "sa_type", Undefined)  # noqa: B009
+    sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined)  # noqa: B009
     if sa_type is not Undefined:
         return sa_type
 
@@ -691,39 +738,39 @@ def get_sqlalchemy_type(field: Any) -> Any:
 
 def get_column_from_field(field: Any) -> Column:  # type: ignore
     field_info = field
-    sa_column = getattr(field_info, "sa_column", Undefined)
+    sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
     if isinstance(sa_column, Column):
         return sa_column
     sa_type = get_sqlalchemy_type(field)
-    primary_key = getattr(field_info, "primary_key", Undefined)
+    primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined)
     if primary_key is Undefined:
         primary_key = False
-    index = getattr(field_info, "index", Undefined)
+    index = _get_sqlmodel_field_value(field_info, "index", Undefined)
     if index is Undefined:
         index = False
     nullable = not primary_key and is_field_noneable(field)
     # Override derived nullability if the nullable property is set explicitly
     # on the field
-    field_nullable = getattr(field_info, "nullable", Undefined)  # noqa: B009
+    field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined)
     if field_nullable is not Undefined:
         assert not isinstance(field_nullable, UndefinedType)
         nullable = field_nullable
     args = []
-    foreign_key = getattr(field_info, "foreign_key", Undefined)
+    foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined)
     if foreign_key is Undefined:
         foreign_key = None
-    unique = getattr(field_info, "unique", Undefined)
+    unique = _get_sqlmodel_field_value(field_info, "unique", Undefined)
     if unique is Undefined:
         unique = False
     if foreign_key:
-        if field_info.ondelete == "SET NULL" and not nullable:
+        ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined)
+        if ondelete_value is Undefined:
+            ondelete_value = None
+        if ondelete_value == "SET NULL" and not nullable:
             raise RuntimeError('ondelete="SET NULL" requires nullable=True')
         assert isinstance(foreign_key, str)
-        ondelete = getattr(field_info, "ondelete", Undefined)
-        if ondelete is Undefined:
-            ondelete = None
-        assert isinstance(ondelete, (str, type(None)))  # for typing
-        args.append(ForeignKey(foreign_key, ondelete=ondelete))
+        assert isinstance(ondelete_value, (str, type(None)))  # for typing
+        args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
     kwargs = {
         "primary_key": primary_key,
         "nullable": nullable,
@@ -737,10 +784,12 @@ def get_column_from_field(field: Any) -> Column:  # type: ignore
         sa_default = field_info.default
     if sa_default is not Undefined:
         kwargs["default"] = sa_default
-    sa_column_args = getattr(field_info, "sa_column_args", Undefined)
+    sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined)
     if sa_column_args is not Undefined:
         args.extend(list(cast(Sequence[Any], sa_column_args)))
-    sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
+    sa_column_kwargs = _get_sqlmodel_field_value(
+        field_info, "sa_column_kwargs", Undefined
+    )
     if sa_column_kwargs is not Undefined:
         kwargs.update(cast(dict[Any, Any], sa_column_kwargs))
     return Column(sa_type, *args, **kwargs)  # type: ignore
index e2ccc6d7efd2ed9d3a5f0d189d646e4a2ea2c314..48001aecec495bedf7f7afcf80462877b58ecb03 100644 (file)
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Annotated, Optional
 
 import pytest
 from sqlalchemy import Column, Integer, String
@@ -17,6 +17,17 @@ def test_sa_column_takes_precedence() -> None:
     assert isinstance(Item.id.type, String)  # type: ignore
 
 
+def test_sa_column_with_annotated_metadata() -> None:
+    class Item(SQLModel, table=True):
+        id: Annotated[Optional[int], "meta"] = Field(
+            default=None,
+            sa_column=Column(String, primary_key=True, nullable=False),
+        )
+
+    assert Item.id.nullable is False  # type: ignore
+    assert isinstance(Item.id.type, String)  # type: ignore
+
+
 def test_sa_column_no_sa_args() -> None:
     with pytest.raises(RuntimeError):
 
diff --git a/tests/test_future_annotations.py b/tests/test_future_annotations.py
new file mode 100644 (file)
index 0000000..6caf0fa
--- /dev/null
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import Annotated, Optional
+
+from sqlmodel import Field, Session, SQLModel, create_engine, select
+
+
+def test_model_with_future_annotations(clear_sqlmodel):
+    class Hero(SQLModel, table=True):
+        id: Annotated[Optional[int], Field(primary_key=True)] = None
+        name: str
+        secret_name: str
+        age: Optional[int] = None
+
+    hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
+
+    engine = create_engine("sqlite://")
+    SQLModel.metadata.create_all(engine)
+
+    with Session(engine) as session:
+        session.add(hero)
+        session.commit()
+        session.refresh(hero)
+
+        assert hero.id is not None
+        assert hero.name == "Deadpond"
+        assert hero.secret_name == "Dive Wilson"
+        assert hero.age == 25
+
+    with Session(engine) as session:
+        heroes = session.exec(select(Hero)).all()
+        assert len(heroes) == 1
+        assert heroes[0].name == "Deadpond"
+
+
+def test_model_with_string_annotations(clear_sqlmodel):
+    class Team(SQLModel, table=True):
+        id: Annotated[Optional[int], Field(primary_key=True)] = None
+        name: str
+
+    class Player(SQLModel, table=True):
+        id: Annotated[Optional[int], Field(primary_key=True)] = None
+        name: str
+        team_id: Annotated[Optional[int], Field(foreign_key="team.id")] = None
+
+    engine = create_engine("sqlite://")
+    SQLModel.metadata.create_all(engine)
+
+    team = Team(name="Champions")
+    player = Player(name="Alice", team_id=None)
+
+    with Session(engine) as session:
+        session.add(team)
+        session.commit()
+        session.refresh(team)
+
+        player.team_id = team.id
+        session.add(player)
+        session.commit()
+        session.refresh(player)
+
+        assert team.id is not None
+        assert player.team_id == team.id
index c1508d181fb12c0b87cb8d77feae22baf4fc26a6..2a862d14b5e672bb7fccdccb7283d3d114b4daba 100644 (file)
@@ -1,4 +1,4 @@
-from typing import Optional\r
+from typing import Annotated, Optional\r
 \r
 import pytest\r
 from sqlalchemy.exc import IntegrityError\r
@@ -125,3 +125,94 @@ def test_sa_relationship_property(clear_sqlmodel):
         # The next statement should not raise an AttributeError\r
         assert hero_rusty_man.team\r
         assert hero_rusty_man.team.name == "Preventers"\r
+\r
+\r
+def test_composite_primary_key(clear_sqlmodel):\r
+    class UserPermission(SQLModel, table=True):\r
+        user_id: int = Field(primary_key=True)\r
+        resource_id: int = Field(primary_key=True)\r
+        permission: str\r
+\r
+    engine = create_engine("sqlite://")\r
+    SQLModel.metadata.create_all(engine)\r
+\r
+    pk_column_names = {column.name for column in UserPermission.__table__.primary_key}\r
+    assert pk_column_names == {"user_id", "resource_id"}\r
+\r
+    with Session(engine) as session:\r
+        perm1 = UserPermission(user_id=1, resource_id=1, permission="read")\r
+        perm2 = UserPermission(user_id=1, resource_id=2, permission="write")\r
+        session.add(perm1)\r
+        session.add(perm2)\r
+        session.commit()\r
+\r
+    with pytest.raises(IntegrityError):\r
+        with Session(engine) as session:\r
+            perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")\r
+            session.add(perm3)\r
+            session.commit()\r
+\r
+\r
+def test_composite_primary_key_and_validator(clear_sqlmodel):\r
+    from pydantic import AfterValidator\r
+\r
+    def validate_resource_id(value: int) -> int:\r
+        if value < 1:\r
+            raise ValueError("Resource ID must be positive")\r
+        return value\r
+\r
+    class UserPermission(SQLModel, table=True):\r
+        user_id: int = Field(primary_key=True)\r
+        resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(\r
+            primary_key=True\r
+        )\r
+        permission: str\r
+\r
+    engine = create_engine("sqlite://")\r
+    SQLModel.metadata.create_all(engine)\r
+\r
+    pk_column_names = {column.name for column in UserPermission.__table__.primary_key}\r
+    assert pk_column_names == {"user_id", "resource_id"}\r
+\r
+    with Session(engine) as session:\r
+        perm1 = UserPermission(user_id=1, resource_id=1, permission="read")\r
+        perm2 = UserPermission(user_id=1, resource_id=2, permission="write")\r
+        session.add(perm1)\r
+        session.add(perm2)\r
+        session.commit()\r
+\r
+    with pytest.raises(IntegrityError):\r
+        with Session(engine) as session:\r
+            perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")\r
+            session.add(perm3)\r
+            session.commit()\r
+\r
+\r
+def test_foreign_key_ondelete_with_annotated(clear_sqlmodel):\r
+    from pydantic import AfterValidator\r
+\r
+    def ensure_positive(value: int) -> int:\r
+        if value < 0:\r
+            raise ValueError("Team ID must be positive")\r
+        return value\r
+\r
+    class Team(SQLModel, table=True):\r
+        id: int = Field(primary_key=True)\r
+        name: str\r
+\r
+    class Hero(SQLModel, table=True):\r
+        id: int = Field(primary_key=True)\r
+        team_id: Annotated[int, AfterValidator(ensure_positive)] = Field(\r
+            foreign_key="team.id",\r
+            ondelete="CASCADE",\r
+        )\r
+        name: str\r
+\r
+    engine = create_engine("sqlite://")\r
+    SQLModel.metadata.create_all(engine)\r
+\r
+    team_id_column = Hero.__table__.c.team_id  # type: ignore[attr-defined]\r
+    foreign_keys = list(team_id_column.foreign_keys)\r
+    assert len(foreign_keys) == 1\r
+    assert foreign_keys[0].ondelete == "CASCADE"\r
+    assert team_id_column.nullable is False\r