]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Do not allow invalid combinations of field parameters for columns and relationships...
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 28 Oct 2023 13:55:23 +0000 (17:55 +0400)
committerGitHub <noreply@github.com>
Sat, 28 Oct 2023 13:55:23 +0000 (17:55 +0400)
* ♻️ Make sa_column exclusive, do not allow incompatible arguments, sa_column_args, primary_key, etc

* ✅ Add tests for new errors when incorrectly using sa_column

* ✅ Add tests for sa_column_args and sa_column_kwargs

* ♻️ Do not allow sa_relationship with sa_relationship_args or sa_relationship_kwargs

* ✅ Add tests for relationship errors

* ✅ Fix test for sa_column_args

sqlmodel/main.py
tests/test_field_sa_args_kwargs.py [new file with mode: 0644]
tests/test_field_sa_column.py [new file with mode: 0644]
tests/test_field_sa_relationship.py [new file with mode: 0644]

index 3015aa9fbdf9f0d368f22403ae21a1aeddabeb30..f48e388e137b453ca9d8f5967b81665530cddab0 100644 (file)
@@ -22,6 +22,7 @@ from typing import (
     TypeVar,
     Union,
     cast,
+    overload,
 )
 
 from pydantic import BaseConfig, BaseModel
@@ -87,6 +88,28 @@ class FieldInfo(PydanticFieldInfo):
                     "Passing sa_column_kwargs is not supported when "
                     "also passing a sa_column"
                 )
+            if primary_key is not Undefined:
+                raise RuntimeError(
+                    "Passing primary_key is not supported when "
+                    "also passing a sa_column"
+                )
+            if nullable is not Undefined:
+                raise RuntimeError(
+                    "Passing nullable is not supported when " "also passing a sa_column"
+                )
+            if foreign_key is not Undefined:
+                raise RuntimeError(
+                    "Passing foreign_key is not supported when "
+                    "also passing a sa_column"
+                )
+            if unique is not Undefined:
+                raise RuntimeError(
+                    "Passing unique is not supported when " "also passing a sa_column"
+                )
+            if index is not Undefined:
+                raise RuntimeError(
+                    "Passing index is not supported when " "also passing a sa_column"
+                )
         super().__init__(default=default, **kwargs)
         self.primary_key = primary_key
         self.nullable = nullable
@@ -126,6 +149,86 @@ class RelationshipInfo(Representation):
         self.sa_relationship_kwargs = sa_relationship_kwargs
 
 
+@overload
+def Field(
+    default: Any = Undefined,
+    *,
+    default_factory: Optional[NoArgAnyCallable] = None,
+    alias: Optional[str] = None,
+    title: Optional[str] = None,
+    description: Optional[str] = None,
+    exclude: Union[
+        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+    ] = None,
+    include: Union[
+        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+    ] = None,
+    const: Optional[bool] = None,
+    gt: Optional[float] = None,
+    ge: Optional[float] = None,
+    lt: Optional[float] = None,
+    le: Optional[float] = None,
+    multiple_of: Optional[float] = None,
+    max_digits: Optional[int] = None,
+    decimal_places: Optional[int] = None,
+    min_items: Optional[int] = None,
+    max_items: Optional[int] = None,
+    unique_items: Optional[bool] = None,
+    min_length: Optional[int] = None,
+    max_length: Optional[int] = None,
+    allow_mutation: bool = True,
+    regex: Optional[str] = None,
+    discriminator: Optional[str] = None,
+    repr: bool = True,
+    primary_key: Union[bool, UndefinedType] = Undefined,
+    foreign_key: Any = Undefined,
+    unique: Union[bool, UndefinedType] = Undefined,
+    nullable: Union[bool, UndefinedType] = Undefined,
+    index: Union[bool, UndefinedType] = Undefined,
+    sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
+    sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
+    schema_extra: Optional[Dict[str, Any]] = None,
+) -> Any:
+    ...
+
+
+@overload
+def Field(
+    default: Any = Undefined,
+    *,
+    default_factory: Optional[NoArgAnyCallable] = None,
+    alias: Optional[str] = None,
+    title: Optional[str] = None,
+    description: Optional[str] = None,
+    exclude: Union[
+        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+    ] = None,
+    include: Union[
+        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
+    ] = None,
+    const: Optional[bool] = None,
+    gt: Optional[float] = None,
+    ge: Optional[float] = None,
+    lt: Optional[float] = None,
+    le: Optional[float] = None,
+    multiple_of: Optional[float] = None,
+    max_digits: Optional[int] = None,
+    decimal_places: Optional[int] = None,
+    min_items: Optional[int] = None,
+    max_items: Optional[int] = None,
+    unique_items: Optional[bool] = None,
+    min_length: Optional[int] = None,
+    max_length: Optional[int] = None,
+    allow_mutation: bool = True,
+    regex: Optional[str] = None,
+    discriminator: Optional[str] = None,
+    repr: bool = True,
+    sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
+    schema_extra: Optional[Dict[str, Any]] = None,
+) -> Any:
+    ...
+
+
 def Field(
     default: Any = Undefined,
     *,
@@ -156,9 +259,9 @@ def Field(
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
     repr: bool = True,
-    primary_key: bool = False,
-    foreign_key: Optional[Any] = None,
-    unique: bool = False,
+    primary_key: Union[bool, UndefinedType] = Undefined,
+    foreign_key: Any = Undefined,
+    unique: Union[bool, UndefinedType] = Undefined,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
     sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
@@ -206,6 +309,27 @@ def Field(
     return field_info
 
 
+@overload
+def Relationship(
+    *,
+    back_populates: Optional[str] = None,
+    link_model: Optional[Any] = None,
+    sa_relationship_args: Optional[Sequence[Any]] = None,
+    sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
+) -> Any:
+    ...
+
+
+@overload
+def Relationship(
+    *,
+    back_populates: Optional[str] = None,
+    link_model: Optional[Any] = None,
+    sa_relationship: Optional[RelationshipProperty] = None,  # type: ignore
+) -> Any:
+    ...
+
+
 def Relationship(
     *,
     back_populates: Optional[str] = None,
@@ -440,21 +564,28 @@ def get_column_from_field(field: ModelField) -> Column:  # type: ignore
     if isinstance(sa_column, Column):
         return sa_column
     sa_type = get_sqlalchemy_type(field)
-    primary_key = getattr(field.field_info, "primary_key", False)
+    primary_key = getattr(field.field_info, "primary_key", Undefined)
+    if primary_key is Undefined:
+        primary_key = False
     index = getattr(field.field_info, "index", Undefined)
     if index is Undefined:
         index = False
     nullable = not primary_key and _is_field_noneable(field)
     # Override derived nullability if the nullable property is set explicitly
     # on the field
-    if hasattr(field.field_info, "nullable"):
-        field_nullable = getattr(field.field_info, "nullable")  # noqa: B009
-        if field_nullable != Undefined:
-            nullable = field_nullable
+    field_nullable = getattr(field.field_info, "nullable", Undefined)  # noqa: B009
+    if field_nullable != Undefined:
+        assert not isinstance(field_nullable, UndefinedType)
+        nullable = field_nullable
     args = []
-    foreign_key = getattr(field.field_info, "foreign_key", None)
-    unique = getattr(field.field_info, "unique", False)
+    foreign_key = getattr(field.field_info, "foreign_key", Undefined)
+    if foreign_key is Undefined:
+        foreign_key = None
+    unique = getattr(field.field_info, "unique", Undefined)
+    if unique is Undefined:
+        unique = False
     if foreign_key:
+        assert isinstance(foreign_key, str)
         args.append(ForeignKey(foreign_key))
     kwargs = {
         "primary_key": primary_key,
diff --git a/tests/test_field_sa_args_kwargs.py b/tests/test_field_sa_args_kwargs.py
new file mode 100644 (file)
index 0000000..94a1a13
--- /dev/null
@@ -0,0 +1,39 @@
+from typing import Optional
+
+from sqlalchemy import ForeignKey
+from sqlmodel import Field, SQLModel, create_engine
+
+
+def test_sa_column_args(clear_sqlmodel, caplog) -> None:
+    class Team(SQLModel, table=True):
+        id: Optional[int] = Field(default=None, primary_key=True)
+        name: str
+
+    class Hero(SQLModel, table=True):
+        id: Optional[int] = Field(default=None, primary_key=True)
+        team_id: Optional[int] = Field(
+            default=None,
+            sa_column_args=[ForeignKey("team.id")],
+        )
+
+    engine = create_engine("sqlite://", echo=True)
+    SQLModel.metadata.create_all(engine)
+    create_table_log = [
+        message for message in caplog.messages if "CREATE TABLE hero" in message
+    ][0]
+    assert "FOREIGN KEY(team_id) REFERENCES team (id)" in create_table_log
+
+
+def test_sa_column_kargs(clear_sqlmodel, caplog) -> None:
+    class Item(SQLModel, table=True):
+        id: Optional[int] = Field(
+            default=None,
+            sa_column_kwargs={"primary_key": True},
+        )
+
+    engine = create_engine("sqlite://", echo=True)
+    SQLModel.metadata.create_all(engine)
+    create_table_log = [
+        message for message in caplog.messages if "CREATE TABLE item" in message
+    ][0]
+    assert "PRIMARY KEY (id)" in create_table_log
diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py
new file mode 100644 (file)
index 0000000..51cfdfa
--- /dev/null
@@ -0,0 +1,99 @@
+from typing import Optional
+
+import pytest
+from sqlalchemy import Column, Integer, String
+from sqlmodel import Field, SQLModel
+
+
+def test_sa_column_takes_precedence() -> None:
+    class Item(SQLModel, table=True):
+        id: Optional[int] = Field(
+            default=None,
+            sa_column=Column(String, primary_key=True, nullable=False),
+        )
+
+    # It would have been nullable with no sa_column
+    assert Item.id.nullable is False  # type: ignore
+    assert isinstance(Item.id.type, String)  # type: ignore
+
+
+def test_sa_column_no_sa_args() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                sa_column_args=[Integer],
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
+def test_sa_column_no_sa_kargs() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                sa_column_kwargs={"primary_key": True},
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
+def test_sa_column_no_primary_key() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                primary_key=True,
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
+def test_sa_column_no_nullable() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                nullable=True,
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
+def test_sa_column_no_foreign_key() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Team(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            name: str
+
+        class Hero(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            team_id: Optional[int] = Field(
+                default=None,
+                foreign_key="team.id",
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
+def test_sa_column_no_unique() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                unique=True,
+                sa_column=Column(Integer, primary_key=True),
+            )
+
+
+def test_sa_column_no_index() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Item(SQLModel, table=True):
+            id: Optional[int] = Field(
+                default=None,
+                index=True,
+                sa_column=Column(Integer, primary_key=True),
+            )
diff --git a/tests/test_field_sa_relationship.py b/tests/test_field_sa_relationship.py
new file mode 100644 (file)
index 0000000..7606fd8
--- /dev/null
@@ -0,0 +1,53 @@
+from typing import List, Optional
+
+import pytest
+from sqlalchemy.orm import relationship
+from sqlmodel import Field, Relationship, SQLModel
+
+
+def test_sa_relationship_no_args() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Team(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            name: str = Field(index=True)
+            headquarters: str
+
+            heroes: List["Hero"] = Relationship(
+                back_populates="team",
+                sa_relationship_args=["Hero"],
+                sa_relationship=relationship("Hero", back_populates="team"),
+            )
+
+        class Hero(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            name: str = Field(index=True)
+            secret_name: str
+            age: Optional[int] = Field(default=None, index=True)
+
+            team_id: Optional[int] = Field(default=None, foreign_key="team.id")
+            team: Optional[Team] = Relationship(back_populates="heroes")
+
+
+def test_sa_relationship_no_kwargs() -> None:
+    with pytest.raises(RuntimeError):
+
+        class Team(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            name: str = Field(index=True)
+            headquarters: str
+
+            heroes: List["Hero"] = Relationship(
+                back_populates="team",
+                sa_relationship_kwargs={"lazy": "selectin"},
+                sa_relationship=relationship("Hero", back_populates="team"),
+            )
+
+        class Hero(SQLModel, table=True):
+            id: Optional[int] = Field(default=None, primary_key=True)
+            name: str = Field(index=True)
+            secret_name: str
+            age: Optional[int] = Field(default=None, index=True)
+
+            team_id: Optional[int] = Field(default=None, foreign_key="team.id")
+            team: Optional[Team] = Relationship(back_populates="heroes")