[tool.poetry.dependencies]
python = "^3.7"
SQLAlchemy = ">=1.4.36,<2.0.0"
-pydantic = "^1.8.2"
+pydantic = "^1.9.0"
sqlalchemy2-stubs = {version = "*", allow-prereleases = true}
[tool.poetry.group.dev.dependencies]
lt: Optional[float] = None,
le: Optional[float] = None,
multiple_of: Optional[float] = None,
+ max_digits: Optional[int] = None,
+ decimal_places: Optional[int] = None,
min_items: Optional[int] = None,
max_items: Optional[int] = None,
+ unique_items: Optional[bool] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
+ discriminator: Optional[str] = None,
+ repr: bool = True,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
unique: bool = False,
lt=lt,
le=le,
multiple_of=multiple_of,
+ max_digits=max_digits,
+ decimal_places=decimal_places,
min_items=min_items,
max_items=max_items,
+ unique_items=unique_items,
min_length=min_length,
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
+ discriminator=discriminator,
+ repr=repr,
primary_key=primary_key,
foreign_key=foreign_key,
unique=unique,
def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
# Don't show SQLAlchemy private attributes
- return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
+ return [
+ (k, v)
+ for k, v in super().__repr_args__()
+ if not (isinstance(k, str) and k.startswith("_sa_"))
+ ]
# From Pydantic, override to enforce validation with dict
@classmethod
--- /dev/null
+from decimal import Decimal
+from typing import Optional, Union
+
+import pytest
+from pydantic import ValidationError
+from sqlmodel import Field, SQLModel
+from typing_extensions import Literal
+
+
+def test_decimal():
+ class Model(SQLModel):
+ dec: Decimal = Field(max_digits=4, decimal_places=2)
+
+ Model(dec=Decimal("3.14"))
+ Model(dec=Decimal("69.42"))
+
+ with pytest.raises(ValidationError):
+ Model(dec=Decimal("3.142"))
+ with pytest.raises(ValidationError):
+ Model(dec=Decimal("0.069"))
+ with pytest.raises(ValidationError):
+ Model(dec=Decimal("420"))
+
+
+def test_discriminator():
+ # Example adapted from
+ # [Pydantic docs](https://pydantic-docs.helpmanual.io/usage/types/#discriminated-unions-aka-tagged-unions):
+
+ class Cat(SQLModel):
+ pet_type: Literal["cat"]
+ meows: int
+
+ class Dog(SQLModel):
+ pet_type: Literal["dog"]
+ barks: float
+
+ class Lizard(SQLModel):
+ pet_type: Literal["reptile", "lizard"]
+ scales: bool
+
+ class Model(SQLModel):
+ pet: Union[Cat, Dog, Lizard] = Field(..., discriminator="pet_type")
+ n: int
+
+ Model(pet={"pet_type": "dog", "barks": 3.14}, n=1) # type: ignore[arg-type]
+
+ with pytest.raises(ValidationError):
+ Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type]
+
+
+def test_repr():
+ class Model(SQLModel):
+ id: Optional[int] = Field(primary_key=True)
+ foo: str = Field(repr=False)
+
+ instance = Model(id=123, foo="bar")
+ assert "foo=" not in repr(instance)