]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
🐛 Fix setting nullable property of Fields that don't accept `None` (#79)
authorEvangelos Anagnostopoulos <anagnostopoulos@workable.com>
Sat, 27 Aug 2022 22:18:57 +0000 (01:18 +0300)
committerGitHub <noreply@github.com>
Sat, 27 Aug 2022 22:18:57 +0000 (00:18 +0200)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
sqlmodel/main.py
tests/test_tutorial/test_create_db_and_table/test_tutorial001.py

index 9efdafeca36eea829cfcfb1d6ade19a740b5016b..d85976db474e53b8533c2394b6101e5bec325fc4 100644 (file)
@@ -25,6 +25,7 @@ from typing import (
 
 from pydantic import BaseConfig, BaseModel
 from pydantic.errors import ConfigError, DictError
+from pydantic.fields import SHAPE_SINGLETON
 from pydantic.fields import FieldInfo as PydanticFieldInfo
 from pydantic.fields import ModelField, Undefined, UndefinedType
 from pydantic.main import ModelMetaclass, validate_model
@@ -424,7 +425,6 @@ def get_column_from_field(field: ModelField) -> Column:  # type: ignore
         return sa_column
     sa_type = get_sqlachemy_type(field)
     primary_key = getattr(field.field_info, "primary_key", False)
-    nullable = not field.required
     index = getattr(field.field_info, "index", Undefined)
     if index is Undefined:
         index = False
@@ -432,6 +432,7 @@ def get_column_from_field(field: ModelField) -> Column:  # type: ignore
         field_nullable = getattr(field.field_info, "nullable")
         if field_nullable != Undefined:
             nullable = field_nullable
+    nullable = not primary_key and _is_field_nullable(field)
     args = []
     foreign_key = getattr(field.field_info, "foreign_key", None)
     if foreign_key:
@@ -646,3 +647,13 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
     @declared_attr  # type: ignore
     def __tablename__(cls) -> str:
         return cls.__name__.lower()
+
+
+def _is_field_nullable(field: ModelField) -> bool:
+    if not field.required:
+        # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
+        is_optional = field.allow_none and (
+            field.shape != SHAPE_SINGLETON or not field.sub_fields
+        )
+        return is_optional and field.default is None and field.default_factory is None
+    return False
index 591a51cc22a30c0edf17195ced11000458b2c3ea..b6a2e72628703c042d050724a1be29843a02199d 100644 (file)
@@ -9,7 +9,7 @@ def test_create_db_and_table(cov_tmp_path: Path):
     assert "BEGIN" in result.stdout
     assert 'PRAGMA main.table_info("hero")' in result.stdout
     assert "CREATE TABLE hero (" in result.stdout
-    assert "id INTEGER," in result.stdout
+    assert "id INTEGER NOT NULL," in result.stdout
     assert "name VARCHAR NOT NULL," in result.stdout
     assert "secret_name VARCHAR NOT NULL," in result.stdout
     assert "age INTEGER," in result.stdout