]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
🐛 Fix support for types with `Optional[Annoated[x, f()]]`, e.g. `id: Optional[pydanti...
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 31 Aug 2024 09:38:19 +0000 (11:38 +0200)
committerGitHub <noreply@github.com>
Sat, 31 Aug 2024 09:38:19 +0000 (11:38 +0200)
sqlmodel/_compat.py
sqlmodel/main.py
tests/test_annotated_uuid.py [new file with mode: 0644]

index 4018d1bb396eb1d963b6951ed20317f09c1b6dcc..4e80cdc37418af423fe8e22a2d754c667e05e5d5 100644 (file)
@@ -21,7 +21,7 @@ from typing import (
 from pydantic import VERSION as P_VERSION
 from pydantic import BaseModel
 from pydantic.fields import FieldInfo
-from typing_extensions import get_args, get_origin
+from typing_extensions import Annotated, get_args, get_origin
 
 # Reassign variable to make it reexported for mypy
 PYDANTIC_VERSION = P_VERSION
@@ -177,16 +177,17 @@ if IS_PYDANTIC_V2:
             return False
         return False
 
-    def get_type_from_field(field: Any) -> Any:
-        type_: Any = field.annotation
+    def get_sa_type_from_type_annotation(annotation: Any) -> Any:
         # Resolve Optional fields
-        if type_ is None:
+        if annotation is None:
             raise ValueError("Missing field type")
-        origin = get_origin(type_)
+        origin = get_origin(annotation)
         if origin is None:
-            return type_
+            return annotation
+        elif origin is Annotated:
+            return get_sa_type_from_type_annotation(get_args(annotation)[0])
         if _is_union_type(origin):
-            bases = get_args(type_)
+            bases = get_args(annotation)
             if len(bases) > 2:
                 raise ValueError(
                     "Cannot have a (non-optional) union as a SQLAlchemy field"
@@ -197,9 +198,14 @@ if IS_PYDANTIC_V2:
                     "Cannot have a (non-optional) union as a SQLAlchemy field"
                 )
             # Optional unions are allowed
-            return bases[0] if bases[0] is not NoneType else bases[1]
+            use_type = bases[0] if bases[0] is not NoneType else bases[1]
+            return get_sa_type_from_type_annotation(use_type)
         return origin
 
+    def get_sa_type_from_field(field: Any) -> Any:
+        type_: Any = field.annotation
+        return get_sa_type_from_type_annotation(type_)
+
     def get_field_metadata(field: Any) -> Any:
         for meta in field.metadata:
             if isinstance(meta, (PydanticMetadata, MaxLen)):
@@ -444,7 +450,7 @@ else:
             )
         return field.allow_none  # type: ignore[no-any-return, attr-defined]
 
-    def get_type_from_field(field: Any) -> Any:
+    def get_sa_type_from_field(field: Any) -> Any:
         if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
             return field.type_
         raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
index d8fced51fafeda3a30d4e7dd0f9e44fbde4a5d06..1597e4e04f8a1d03fafbdc62aaaa88375cbac37b 100644 (file)
@@ -71,7 +71,7 @@ from ._compat import (  # type: ignore[attr-defined]
     get_field_metadata,
     get_model_fields,
     get_relationship_to,
-    get_type_from_field,
+    get_sa_type_from_field,
     init_pydantic_private_attrs,
     is_field_noneable,
     is_table_model_class,
@@ -649,7 +649,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
     if sa_type is not Undefined:
         return sa_type
 
-    type_ = get_type_from_field(field)
+    type_ = get_sa_type_from_field(field)
     metadata = get_field_metadata(field)
 
     # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
diff --git a/tests/test_annotated_uuid.py b/tests/test_annotated_uuid.py
new file mode 100644 (file)
index 0000000..b0e25ab
--- /dev/null
@@ -0,0 +1,26 @@
+import uuid
+from typing import Optional
+
+from sqlmodel import Field, Session, SQLModel, create_engine, select
+
+from tests.conftest import needs_pydanticv2
+
+
+@needs_pydanticv2
+def test_annotated_optional_types(clear_sqlmodel) -> None:
+    from pydantic import UUID4
+
+    class Hero(SQLModel, table=True):
+        # Pydantic UUID4 is: Annotated[UUID, UuidVersion(4)]
+        id: Optional[UUID4] = Field(default_factory=uuid.uuid4, primary_key=True)
+
+    engine = create_engine("sqlite:///:memory:")
+    SQLModel.metadata.create_all(engine)
+    with Session(engine) as db:
+        hero = Hero()
+        db.add(hero)
+        db.commit()
+        statement = select(Hero)
+        result = db.exec(statement).all()
+    assert len(result) == 1
+    assert isinstance(hero.id, uuid.UUID)