]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
Add `strict` param to Field, add tests
authorYurii Motov <yurii.motov.monte@gmail.com>
Wed, 28 Jan 2026 12:05:23 +0000 (13:05 +0100)
committerYurii Motov <yurii.motov.monte@gmail.com>
Wed, 28 Jan 2026 17:01:56 +0000 (18:01 +0100)
sqlmodel/main.py
tests/test_pydantic/test_field.py

index 84478f24cf8ca9b6fecabc146aff45c1c5acf64f..f76ae0263699fc346a1ea40cafd90204d5443ed0 100644 (file)
@@ -3,6 +3,7 @@ from __future__ import annotations
 import builtins
 import ipaddress
 import uuid
+import warnings
 import weakref
 from collections.abc import Mapping, Sequence, Set
 from datetime import date, datetime, time, timedelta
@@ -228,6 +229,7 @@ def Field(
     max_length: Optional[int] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
+    strict: Optional[bool] = None,
     discriminator: Optional[str] = None,
     repr: bool = True,
     primary_key: Union[bool, UndefinedType] = Undefined,
@@ -271,6 +273,7 @@ def Field(
     max_length: Optional[int] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
+    strict: Optional[bool] = None,
     discriminator: Optional[str] = None,
     repr: bool = True,
     primary_key: Union[bool, UndefinedType] = Undefined,
@@ -323,6 +326,7 @@ def Field(
     max_length: Optional[int] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
+    strict: Optional[bool] = None,
     discriminator: Optional[str] = None,
     repr: bool = True,
     sa_column: Union[Column[Any], UndefinedType] = Undefined,
@@ -356,6 +360,7 @@ def Field(
     max_length: Optional[int] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
+    strict: Optional[bool] = None,
     discriminator: Optional[str] = None,
     repr: bool = True,
     primary_key: Union[bool, UndefinedType] = Undefined,
@@ -371,9 +376,16 @@ def Field(
     schema_extra: Optional[dict[str, Any]] = None,
 ) -> Any:
     current_schema_extra = schema_extra or {}
+
+    for param_name in ("strict",):
+        if param_name in current_schema_extra:
+            msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
+            warnings.warn(msg, DeprecationWarning, stacklevel=2)
+
     # Extract possible alias settings from schema_extra so we can control precedence
     schema_validation_alias = current_schema_extra.pop("validation_alias", None)
     schema_serialization_alias = current_schema_extra.pop("serialization_alias", None)
+    current_strict = strict or current_schema_extra.pop("strict", None)
     field_info_kwargs = {
         "alias": alias,
         "title": title,
@@ -395,6 +407,7 @@ def Field(
         "max_length": max_length,
         "allow_mutation": allow_mutation,
         "regex": regex,
+        "strict": current_strict,
         "discriminator": discriminator,
         "repr": repr,
         "primary_key": primary_key,
index 140b02fd9b1845d4c515df03828f50e41584a7d5..c196061c6e20a7f3040c857f6084d277a1a9166e 100644 (file)
@@ -3,7 +3,7 @@ from typing import Literal, Optional, Union
 
 import pytest
 from pydantic import ValidationError
-from sqlmodel import Field, SQLModel
+from sqlmodel import Field, Session, SQLModel, create_engine
 
 
 def test_decimal():
@@ -54,3 +54,76 @@ def test_repr():
 
     instance = Model(id=123, foo="bar")
     assert "foo=" not in repr(instance)
+
+
+def test_strict_true():
+    class Model(SQLModel):
+        id: Optional[int] = Field(default=None, primary_key=True)
+        val: int
+        val_strict: int = Field(strict=True)
+
+    class ModelDB(Model, table=True):
+        pass
+
+    Model(val=123, val_strict=456)
+    Model(val="123", val_strict=456)
+
+    with pytest.raises(ValidationError):
+        Model(val=123, val_strict="456")
+
+    engine = create_engine("sqlite://", echo=True)
+
+    SQLModel.metadata.create_all(engine)
+
+    model = ModelDB(val=123, val_strict=456)
+    with Session(engine) as session:
+        session.add(model)
+        session.commit()
+        session.refresh(model)
+
+    assert model.val == 123
+    assert model.val_strict == 456
+
+
+def test_strict_table_model():
+    class Model(SQLModel, table=True):
+        id: Optional[int] = Field(default=None, primary_key=True)
+        val_strict: int = Field(strict=True)
+
+    engine = create_engine("sqlite://", echo=True)
+
+    SQLModel.metadata.create_all(engine)
+
+    model = Model(val_strict=456)
+    with Session(engine) as session:
+        session.add(model)
+        session.commit()
+        session.refresh(model)
+
+    assert model.val_strict == 456
+
+
+@pytest.mark.parametrize("strict", [None, False])
+def test_strict_false(strict: Optional[bool]):
+    class Model(SQLModel):
+        val: int = Field(strict=strict)
+
+    Model(val=123)
+    Model(val="123")
+
+
+def test_strict_via_schema_extra():  # Current workaround. Remove after some time
+    with pytest.warns(
+        DeprecationWarning,
+        match="Pass `strict` parameter directly to Field instead of passing it via `schema_extra`",
+    ):
+
+        class Model(SQLModel):
+            val: int
+            val_strict: int = Field(schema_extra={"strict": True})
+
+    Model(val=123, val_strict=456)
+    Model(val="123", val_strict=456)
+
+    with pytest.raises(ValidationError):
+        Model(val=123, val_strict="456")