]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
✨ Allow setting `unique` in `Field()` for a column (#83)
authorRaphael Gibson <42935757+raphaelgibson@users.noreply.github.com>
Sat, 27 Aug 2022 23:49:29 +0000 (20:49 -0300)
committerGitHub <noreply@github.com>
Sat, 27 Aug 2022 23:49:29 +0000 (01:49 +0200)
Co-authored-by: Raphael Gibson <raphael.araujo@estantemagica.com.br>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/main.py
tests/test_main.py [new file with mode: 0644]

index bdfe6dfc1c9361f8ff3acfba7fc0bd64c4d87223..7c79edd2e3cf4235348efb79fdd4f303fc87530d 100644 (file)
@@ -61,6 +61,7 @@ class FieldInfo(PydanticFieldInfo):
         primary_key = kwargs.pop("primary_key", False)
         nullable = kwargs.pop("nullable", Undefined)
         foreign_key = kwargs.pop("foreign_key", Undefined)
+        unique = kwargs.pop("unique", False)
         index = kwargs.pop("index", Undefined)
         sa_column = kwargs.pop("sa_column", Undefined)
         sa_column_args = kwargs.pop("sa_column_args", Undefined)
@@ -80,6 +81,7 @@ class FieldInfo(PydanticFieldInfo):
         self.primary_key = primary_key
         self.nullable = nullable
         self.foreign_key = foreign_key
+        self.unique = unique
         self.index = index
         self.sa_column = sa_column
         self.sa_column_args = sa_column_args
@@ -141,6 +143,7 @@ def Field(
     regex: Optional[str] = None,
     primary_key: bool = False,
     foreign_key: Optional[Any] = None,
+    unique: bool = False,
     nullable: Union[bool, UndefinedType] = Undefined,
     index: Union[bool, UndefinedType] = Undefined,
     sa_column: Union[Column, UndefinedType] = Undefined,  # type: ignore
@@ -171,6 +174,7 @@ def Field(
         regex=regex,
         primary_key=primary_key,
         foreign_key=foreign_key,
+        unique=unique,
         nullable=nullable,
         index=index,
         sa_column=sa_column,
@@ -426,12 +430,14 @@ def get_column_from_field(field: ModelField) -> Column:  # type: ignore
     nullable = not primary_key and _is_field_nullable(field)
     args = []
     foreign_key = getattr(field.field_info, "foreign_key", None)
+    unique = getattr(field.field_info, "unique", False)
     if foreign_key:
         args.append(ForeignKey(foreign_key))
     kwargs = {
         "primary_key": primary_key,
         "nullable": nullable,
         "index": index,
+        "unique": unique,
     }
     sa_default = Undefined
     if field.field_info.default_factory:
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644 (file)
index 0000000..22c6232
--- /dev/null
@@ -0,0 +1,93 @@
+from typing import Optional\r
+\r
+import pytest\r
+from sqlalchemy.exc import IntegrityError\r
+from sqlmodel import Field, Session, SQLModel, create_engine\r
+\r
+\r
+def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel):\r
+    class Hero(SQLModel, table=True):\r
+        id: Optional[int] = Field(default=None, primary_key=True)\r
+        name: str\r
+        secret_name: str\r
+        age: Optional[int] = None\r
+\r
+    hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")\r
+    hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")\r
+\r
+    engine = create_engine("sqlite://")\r
+\r
+    SQLModel.metadata.create_all(engine)\r
+\r
+    with Session(engine) as session:\r
+        session.add(hero_1)\r
+        session.commit()\r
+        session.refresh(hero_1)\r
+\r
+    with Session(engine) as session:\r
+        session.add(hero_2)\r
+        session.commit()\r
+        session.refresh(hero_2)\r
+\r
+    with Session(engine) as session:\r
+        heroes = session.query(Hero).all()\r
+        assert len(heroes) == 2\r
+        assert heroes[0].name == heroes[1].name\r
+\r
+\r
+def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel):\r
+    class Hero(SQLModel, table=True):\r
+        id: Optional[int] = Field(default=None, primary_key=True)\r
+        name: str\r
+        secret_name: str = Field(unique=False)\r
+        age: Optional[int] = None\r
+\r
+    hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")\r
+    hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")\r
+\r
+    engine = create_engine("sqlite://")\r
+\r
+    SQLModel.metadata.create_all(engine)\r
+\r
+    with Session(engine) as session:\r
+        session.add(hero_1)\r
+        session.commit()\r
+        session.refresh(hero_1)\r
+\r
+    with Session(engine) as session:\r
+        session.add(hero_2)\r
+        session.commit()\r
+        session.refresh(hero_2)\r
+\r
+    with Session(engine) as session:\r
+        heroes = session.query(Hero).all()\r
+        assert len(heroes) == 2\r
+        assert heroes[0].name == heroes[1].name\r
+\r
+\r
+def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(\r
+    clear_sqlmodel,\r
+):\r
+    class Hero(SQLModel, table=True):\r
+        id: Optional[int] = Field(default=None, primary_key=True)\r
+        name: str\r
+        secret_name: str = Field(unique=True)\r
+        age: Optional[int] = None\r
+\r
+    hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")\r
+    hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson")\r
+\r
+    engine = create_engine("sqlite://")\r
+\r
+    SQLModel.metadata.create_all(engine)\r
+\r
+    with Session(engine) as session:\r
+        session.add(hero_1)\r
+        session.commit()\r
+        session.refresh(hero_1)\r
+\r
+    with pytest.raises(IntegrityError):\r
+        with Session(engine) as session:\r
+            session.add(hero_2)\r
+            session.commit()\r
+            session.refresh(hero_2)\r