From: Yurii Motov Date: Wed, 28 Jan 2026 16:07:35 +0000 (+0100) Subject: Add `union_mode` param to Field, add tests X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b612f286e7693e96aae21855ba00dcd28697b154;p=thirdparty%2Ffastapi%2Fsqlmodel.git Add `union_mode` param to Field, add tests --- diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2d17c2ca..fe9eb4e6 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -228,6 +228,7 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, @@ -273,6 +274,7 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, @@ -327,6 +329,7 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, @@ -362,6 +365,7 @@ def Field( unique_items: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, + union_mode: Optional[Literal["smart", "left_to_right"]] = None, allow_mutation: bool = True, regex: Optional[str] = None, discriminator: Optional[str] = None, @@ -384,6 +388,7 @@ def Field( for param_name in ( "coerce_numbers_to_str", "validate_default", + "union_mode", ): if param_name in current_schema_extra: msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`" @@ -444,6 +449,10 @@ def Field( serialization_alias or schema_serialization_alias or alias ) + current_union_mode = union_mode or current_schema_extra.pop("union_mode", None) + if current_union_mode is not None: + field_info_kwargs["union_mode"] = current_union_mode + field_info = FieldInfo( default, default_factory=default_factory, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 267c3709..0635ce73 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -144,3 +144,54 @@ def test_validate_default_via_schema_extra(): # Current workaround. Remove afte val: int = Field(default="123", schema_extra={"validate_default": True}) assert Model.model_validate({}).val == 123 + + +@pytest.mark.parametrize("union_mode", [None, "smart"]) +def test_union_mode_smart(union_mode: Optional[Literal["smart"]]): + class Model(SQLModel): + val: Union[float, int] = Field(union_mode=union_mode) + + a = Model.model_validate({"val": 123}) + assert isinstance(a.val, int) # float is first, but int is more precise + + b = Model.model_validate({"val": 123.0}) + assert isinstance(b.val, float) + + c = Model.model_validate({"val": 123.1}) + assert isinstance(c.val, float) + + +def test_union_mode_left_to_right(): + class Model(SQLModel): + val: Union[float, int] = Field(union_mode="left_to_right") + + a = Model.model_validate({"val": 123}) + assert isinstance(a.val, float) + + b = Model.model_validate({"val": 123.0}) + assert isinstance(b.val, float) + + c = Model.model_validate({"val": 123.1}) + assert isinstance(c.val, float) + + +def test_union_mode_via_schema_extra(): # Current workaround. Remove after some time + with pytest.warns( + UserWarning, + match=( + "Pass `union_mode` parameter directly to Field instead of passing " + "it via `schema_extra`" + ), + ): + + class Model(SQLModel): + val: Union[float, int] = Field(schema_extra={"union_mode": "smart"}) + + a = Model.model_validate({"val": 123}) + assert isinstance(a.val, int) # float is first, but int is more precise + + b = Model.model_validate({"val": 123.0}) + assert isinstance(b.val, float) + + c = Model.model_validate({"val": 123.1}) + assert isinstance(c.val, float)