From d910853f12afe7035c99f52d34baf1e9066a1914 Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: Wed, 28 Jan 2026 13:05:23 +0100 Subject: [PATCH] Add `strict` param to Field, add tests --- sqlmodel/main.py | 13 ++++++ tests/test_pydantic/test_field.py | 75 ++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 84478f24..f76ae026 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -3,6 +3,7 @@ from __future__ import annotations import builtins import ipaddress import uuid +import warnings import weakref from collections.abc import Mapping, Sequence, Set from datetime import date, datetime, time, timedelta @@ -228,6 +229,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + strict: Optional[bool] = None, discriminator: Optional[str] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, @@ -271,6 +273,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + strict: Optional[bool] = None, discriminator: Optional[str] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, @@ -323,6 +326,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + strict: Optional[bool] = None, discriminator: Optional[str] = None, repr: bool = True, sa_column: Union[Column[Any], UndefinedType] = Undefined, @@ -356,6 +360,7 @@ def Field( max_length: Optional[int] = None, allow_mutation: bool = True, regex: Optional[str] = None, + strict: Optional[bool] = None, discriminator: Optional[str] = None, repr: bool = True, primary_key: Union[bool, UndefinedType] = Undefined, @@ -371,9 +376,16 @@ def Field( schema_extra: Optional[dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} + + for param_name in ("strict",): + if param_name in current_schema_extra: + msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`" + warnings.warn(msg, DeprecationWarning, stacklevel=2) + # Extract possible alias settings from schema_extra so we can control precedence schema_validation_alias = current_schema_extra.pop("validation_alias", None) schema_serialization_alias = current_schema_extra.pop("serialization_alias", None) + current_strict = strict or current_schema_extra.pop("strict", None) field_info_kwargs = { "alias": alias, "title": title, @@ -395,6 +407,7 @@ def Field( "max_length": max_length, "allow_mutation": allow_mutation, "regex": regex, + "strict": current_strict, "discriminator": discriminator, "repr": repr, "primary_key": primary_key, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 140b02fd..c196061c 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -3,7 +3,7 @@ from typing import Literal, Optional, Union import pytest from pydantic import ValidationError -from sqlmodel import Field, SQLModel +from sqlmodel import Field, Session, SQLModel, create_engine def test_decimal(): @@ -54,3 +54,76 @@ def test_repr(): instance = Model(id=123, foo="bar") assert "foo=" not in repr(instance) + + +def test_strict_true(): + class Model(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + val: int + val_strict: int = Field(strict=True) + + class ModelDB(Model, table=True): + pass + + Model(val=123, val_strict=456) + Model(val="123", val_strict=456) + + with pytest.raises(ValidationError): + Model(val=123, val_strict="456") + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + model = ModelDB(val=123, val_strict=456) + with Session(engine) as session: + session.add(model) + session.commit() + session.refresh(model) + + assert model.val == 123 + assert model.val_strict == 456 + + +def test_strict_table_model(): + class Model(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + val_strict: int = Field(strict=True) + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + model = Model(val_strict=456) + with Session(engine) as session: + session.add(model) + session.commit() + session.refresh(model) + + assert model.val_strict == 456 + + +@pytest.mark.parametrize("strict", [None, False]) +def test_strict_false(strict: Optional[bool]): + class Model(SQLModel): + val: int = Field(strict=strict) + + Model(val=123) + Model(val="123") + + +def test_strict_via_schema_extra(): # Current workaround. Remove after some time + with pytest.warns( + DeprecationWarning, + match="Pass `strict` parameter directly to Field instead of passing it via `schema_extra`", + ): + + class Model(SQLModel): + val: int + val_strict: int = Field(schema_extra={"strict": True}) + + Model(val=123, val_strict=456) + Model(val="123", val_strict=456) + + with pytest.raises(ValidationError): + Model(val=123, val_strict="456") -- 2.47.3