]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
Add `union_mode` param to Field, add tests
authorYurii Motov <yurii.motov.monte@gmail.com>
Wed, 28 Jan 2026 16:07:35 +0000 (17:07 +0100)
committerYurii Motov <yurii.motov.monte@gmail.com>
Wed, 28 Jan 2026 16:07:35 +0000 (17:07 +0100)
sqlmodel/main.py
tests/test_pydantic/test_field.py

index 2d17c2cac595cb40c37f43697dcd4560ffbbd9b9..fe9eb4e6f715ad0b3fee9187f03ba447bc437813 100644 (file)
@@ -228,6 +228,7 @@ def Field(
     unique_items: Optional[bool] = None,
     min_length: Optional[int] = None,
     max_length: Optional[int] = None,
+    union_mode: Optional[Literal["smart", "left_to_right"]] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
@@ -273,6 +274,7 @@ def Field(
     unique_items: Optional[bool] = None,
     min_length: Optional[int] = None,
     max_length: Optional[int] = None,
+    union_mode: Optional[Literal["smart", "left_to_right"]] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
@@ -327,6 +329,7 @@ def Field(
     unique_items: Optional[bool] = None,
     min_length: Optional[int] = None,
     max_length: Optional[int] = None,
+    union_mode: Optional[Literal["smart", "left_to_right"]] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
@@ -362,6 +365,7 @@ def Field(
     unique_items: Optional[bool] = None,
     min_length: Optional[int] = None,
     max_length: Optional[int] = None,
+    union_mode: Optional[Literal["smart", "left_to_right"]] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
     discriminator: Optional[str] = None,
@@ -384,6 +388,7 @@ def Field(
     for param_name in (
         "coerce_numbers_to_str",
         "validate_default",
+        "union_mode",
     ):
         if param_name in current_schema_extra:
             msg = f"Pass `{param_name}` parameter directly to Field instead of passing it via `schema_extra`"
@@ -444,6 +449,10 @@ def Field(
         serialization_alias or schema_serialization_alias or alias
     )
 
+    current_union_mode = union_mode or current_schema_extra.pop("union_mode", None)
+    if current_union_mode is not None:
+        field_info_kwargs["union_mode"] = current_union_mode
+
     field_info = FieldInfo(
         default,
         default_factory=default_factory,
index 267c37093b22f1fd63f350d797c9a8881f4b8e7f..0635ce73dbe5d43969a70bf5157c6583744f1a8c 100644 (file)
@@ -144,3 +144,54 @@ def test_validate_default_via_schema_extra():  # Current workaround. Remove afte
             val: int = Field(default="123", schema_extra={"validate_default": True})
 
     assert Model.model_validate({}).val == 123
+
+
+@pytest.mark.parametrize("union_mode", [None, "smart"])
+def test_union_mode_smart(union_mode: Optional[Literal["smart"]]):
+    class Model(SQLModel):
+        val: Union[float, int] = Field(union_mode=union_mode)
+
+    a = Model.model_validate({"val": 123})
+    assert isinstance(a.val, int)  # float is first, but int is more precise
+
+    b = Model.model_validate({"val": 123.0})
+    assert isinstance(b.val, float)
+
+    c = Model.model_validate({"val": 123.1})
+    assert isinstance(c.val, float)
+
+
+def test_union_mode_left_to_right():
+    class Model(SQLModel):
+        val: Union[float, int] = Field(union_mode="left_to_right")
+
+    a = Model.model_validate({"val": 123})
+    assert isinstance(a.val, float)
+
+    b = Model.model_validate({"val": 123.0})
+    assert isinstance(b.val, float)
+
+    c = Model.model_validate({"val": 123.1})
+    assert isinstance(c.val, float)
+
+
+def test_union_mode_via_schema_extra():  # Current workaround. Remove after some time
+    with pytest.warns(
+        UserWarning,
+        match=(
+            "Pass `union_mode` parameter directly to Field instead of passing "
+            "it via `schema_extra`"
+        ),
+    ):
+
+        class Model(SQLModel):
+            val: Union[float, int] = Field(schema_extra={"union_mode": "smart"})
+
+    a = Model.model_validate({"val": 123})
+    assert isinstance(a.val, int)  # float is first, but int is more precise
+
+    b = Model.model_validate({"val": 123.0})
+    assert isinstance(b.val, float)
+
+    c = Model.model_validate({"val": 123.1})
+    assert isinstance(c.val, float)