]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
Add `validate_default` param to Field, add tests
authorYurii Motov <yurii.motov.monte@gmail.com>
Wed, 28 Jan 2026 15:22:58 +0000 (16:22 +0100)
committerYurii Motov <yurii.motov.monte@gmail.com>
Wed, 28 Jan 2026 15:37:53 +0000 (16:37 +0100)
sqlmodel/main.py
tests/test_pydantic/test_field.py

index 82a957103049617afcaefb5450fb11f0cdee23d2..2d17c2cac595cb40c37f43697dcd4560ffbbd9b9 100644 (file)
@@ -231,6 +231,7 @@ def Field(
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
+    validate_default: Optional[bool] = None,
     repr: bool = True,
     primary_key: Union[bool, UndefinedType] = Undefined,
     foreign_key: Any = Undefined,
@@ -275,6 +276,7 @@ def Field(
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
+    validate_default: Optional[bool] = None,
     repr: bool = True,
     primary_key: Union[bool, UndefinedType] = Undefined,
     foreign_key: str,
@@ -328,6 +330,7 @@ def Field(
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
+    validate_default: Optional[bool] = None,
     repr: bool = True,
     sa_column: Union[Column[Any], UndefinedType] = Undefined,
     schema_extra: Optional[dict[str, Any]] = None,
@@ -362,6 +365,7 @@ def Field(
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
+    validate_default: Optional[bool] = None,
     repr: bool = True,
     primary_key: Union[bool, UndefinedType] = Undefined,
     foreign_key: Any = Undefined,
@@ -377,7 +381,10 @@ def Field(
 ) -> Any:
     current_schema_extra = schema_extra or {}
 
-    for param_name in ("coerce_numbers_to_str",):
+    for param_name in (
+        "coerce_numbers_to_str",
+        "validate_default",
+    ):
         if param_name in current_schema_extra:
             msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
             warnings.warn(msg, UserWarning, stacklevel=2)
@@ -388,6 +395,9 @@ def Field(
     current_coerce_numbers_to_str = coerce_numbers_to_str or current_schema_extra.pop(
         "coerce_numbers_to_str", None
     )
+    current_validate_default = validate_default or current_schema_extra.pop(
+        "validate_default", None
+    )
     field_info_kwargs = {
         "alias": alias,
         "title": title,
@@ -396,6 +406,7 @@ def Field(
         "include": include,
         "const": const,
         "coerce_numbers_to_str": current_coerce_numbers_to_str,
+        "validate_default": current_validate_default,
         "gt": gt,
         "ge": ge,
         "lt": lt,
index 5484f3f197dcb34e9c54da6b16fc809a4e938fa8..267c37093b22f1fd63f350d797c9a8881f4b8e7f 100644 (file)
@@ -3,7 +3,7 @@ from typing import Literal, Optional, Union
 
 import pytest
 from pydantic import ValidationError
-from sqlmodel import Field, SQLModel
+from sqlmodel import Field, Session, SQLModel, create_engine
 
 
 def test_decimal():
@@ -87,3 +87,60 @@ def test_coerce_numbers_to_str_via_schema_extra():  # Current workaround. Remove
 
     assert Model.model_validate({"val": 123}).val == "123"
     assert Model.model_validate({"val": 45.67}).val == "45.67"
+
+
+def test_validate_default_true():
+    class Model(SQLModel):
+        val: int = Field(default="123", validate_default=True)
+
+    assert Model.model_validate({}).val == 123
+
+    class Model2(SQLModel):
+        val: int = Field(default=None, validate_default=True)
+
+    with pytest.raises(ValidationError):
+        Model2.model_validate({})
+
+
+def test_validate_default_table_model():
+    class Model(SQLModel):
+        id: Optional[int] = Field(default=None, primary_key=True)
+        val: int = Field(default="123", validate_default=True)
+
+    class ModelDB(Model, table=True):
+        pass
+
+    engine = create_engine("sqlite://", echo=True)
+
+    SQLModel.metadata.create_all(engine)
+
+    model = ModelDB()
+    with Session(engine) as session:
+        session.add(model)
+        session.commit()
+        session.refresh(model)
+
+    assert model.val == 123
+
+
+@pytest.mark.parametrize("validate_default", [None, False])
+def test_validate_default_false(validate_default: Optional[bool]):
+    class Model3(SQLModel):
+        val: int = Field(default="123", validate_default=validate_default)
+
+    assert Model3().val == "123"
+
+
+def test_validate_default_via_schema_extra():  # Current workaround. Remove after some time
+    with pytest.warns(
+        UserWarning,
+        match=(
+            "Pass `validate_default` parameter directly to Field instead of passing "
+            "it via `schema_extra`"
+        ),
+    ):
+
+        class Model(SQLModel):
+            val: int = Field(default="123", schema_extra={"validate_default": True})
+
+    assert Model.model_validate({}).val == 123