]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Add support for all `Field` parameters from Pydantic `1.9.0` and above, make Pydant...
authorDaniil Fajnberg <60156134+daniil-berg@users.noreply.github.com>
Thu, 26 Oct 2023 10:18:05 +0000 (12:18 +0200)
committerGitHub <noreply@github.com>
Thu, 26 Oct 2023 10:18:05 +0000 (14:18 +0400)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
pyproject.toml
sqlmodel/main.py
tests/test_pydantic/__init__.py [new file with mode: 0644]
tests/test_pydantic/test_field.py [new file with mode: 0644]

index c7956daaa9ef40cdf23cdab0e354b5186de66843..181064e4eaae9a9b0fbe48d207101928d6c8b3f7 100644 (file)
@@ -32,7 +32,7 @@ classifiers = [
 [tool.poetry.dependencies]
 python = "^3.7"
 SQLAlchemy = ">=1.4.36,<2.0.0"
-pydantic = "^1.8.2"
+pydantic = "^1.9.0"
 sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
 
 [tool.poetry.group.dev.dependencies]
index 07e600e4d45b1ce262d2acaea23a565fc6b9f8d8..3015aa9fbdf9f0d368f22403ae21a1aeddabeb30 100644 (file)
@@ -145,12 +145,17 @@ def Field(
     lt: Optional[float] = None,
     le: Optional[float] = None,
     multiple_of: Optional[float] = None,
+    max_digits: Optional[int] = None,
+    decimal_places: Optional[int] = None,
     min_items: Optional[int] = None,
     max_items: Optional[int] = None,
+    unique_items: Optional[bool] = None,
     min_length: Optional[int] = None,
     max_length: Optional[int] = None,
     allow_mutation: bool = True,
     regex: Optional[str] = None,
+    discriminator: Optional[str] = None,
+    repr: bool = True,
     primary_key: bool = False,
     foreign_key: Optional[Any] = None,
     unique: bool = False,
@@ -176,12 +181,17 @@ def Field(
         lt=lt,
         le=le,
         multiple_of=multiple_of,
+        max_digits=max_digits,
+        decimal_places=decimal_places,
         min_items=min_items,
         max_items=max_items,
+        unique_items=unique_items,
         min_length=min_length,
         max_length=max_length,
         allow_mutation=allow_mutation,
         regex=regex,
+        discriminator=discriminator,
+        repr=repr,
         primary_key=primary_key,
         foreign_key=foreign_key,
         unique=unique,
@@ -587,7 +597,11 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
 
     def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
         # Don't show SQLAlchemy private attributes
-        return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
+        return [
+            (k, v)
+            for k, v in super().__repr_args__()
+            if not (isinstance(k, str) and k.startswith("_sa_"))
+        ]
 
     # From Pydantic, override to enforce validation with dict
     @classmethod
diff --git a/tests/test_pydantic/__init__.py b/tests/test_pydantic/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py
new file mode 100644 (file)
index 0000000..9d7bc77
--- /dev/null
@@ -0,0 +1,57 @@
+from decimal import Decimal
+from typing import Optional, Union
+
+import pytest
+from pydantic import ValidationError
+from sqlmodel import Field, SQLModel
+from typing_extensions import Literal
+
+
+def test_decimal():
+    class Model(SQLModel):
+        dec: Decimal = Field(max_digits=4, decimal_places=2)
+
+    Model(dec=Decimal("3.14"))
+    Model(dec=Decimal("69.42"))
+
+    with pytest.raises(ValidationError):
+        Model(dec=Decimal("3.142"))
+    with pytest.raises(ValidationError):
+        Model(dec=Decimal("0.069"))
+    with pytest.raises(ValidationError):
+        Model(dec=Decimal("420"))
+
+
+def test_discriminator():
+    # Example adapted from
+    # [Pydantic docs](https://pydantic-docs.helpmanual.io/usage/types/#discriminated-unions-aka-tagged-unions):
+
+    class Cat(SQLModel):
+        pet_type: Literal["cat"]
+        meows: int
+
+    class Dog(SQLModel):
+        pet_type: Literal["dog"]
+        barks: float
+
+    class Lizard(SQLModel):
+        pet_type: Literal["reptile", "lizard"]
+        scales: bool
+
+    class Model(SQLModel):
+        pet: Union[Cat, Dog, Lizard] = Field(..., discriminator="pet_type")
+        n: int
+
+    Model(pet={"pet_type": "dog", "barks": 3.14}, n=1)  # type: ignore[arg-type]
+
+    with pytest.raises(ValidationError):
+        Model(pet={"pet_type": "dog"}, n=1)  # type: ignore[arg-type]
+
+
+def test_repr():
+    class Model(SQLModel):
+        id: Optional[int] = Field(primary_key=True)
+        foo: str = Field(repr=False)
+
+    instance = Model(id=123, foo="bar")
+    assert "foo=" not in repr(instance)