]> git.ipfire.org Git - thirdparty/fastapi/sqlmodel.git/commitdiff
♻️ Refactor types to properly support Pydantic 2.7 (#913)
authorSebastián Ramírez <tiangolo@gmail.com>
Mon, 29 Apr 2024 22:11:02 +0000 (15:11 -0700)
committerGitHub <noreply@github.com>
Mon, 29 Apr 2024 22:11:02 +0000 (15:11 -0700)
sqlmodel/_compat.py
sqlmodel/main.py

index 072d2b0f58e357519134c7e333919e99e880f416..72ec8330fdd7849e11f31ef6f3f3748c73b22c4b 100644 (file)
@@ -18,11 +18,13 @@ from typing import (
     Union,
 )
 
-from pydantic import VERSION as PYDANTIC_VERSION
+from pydantic import VERSION as P_VERSION
 from pydantic import BaseModel
 from pydantic.fields import FieldInfo
 from typing_extensions import get_args, get_origin
 
+# Reassign variable to make it reexported for mypy
+PYDANTIC_VERSION = P_VERSION
 IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
 
 
index 9e8330d69d5db793f2360e2424d832c701e2a646..a16428b1921a628e194b0a1d4f9694da4088effa 100644 (file)
@@ -6,6 +6,7 @@ from decimal import Decimal
 from enum import Enum
 from pathlib import Path
 from typing import (
+    TYPE_CHECKING,
     AbstractSet,
     Any,
     Callable,
@@ -55,6 +56,7 @@ from typing_extensions import Literal, deprecated, get_origin
 
 from ._compat import (  # type: ignore[attr-defined]
     IS_PYDANTIC_V2,
+    PYDANTIC_VERSION,
     BaseConfig,
     ModelField,
     ModelMetaclass,
@@ -80,6 +82,12 @@ from ._compat import (  # type: ignore[attr-defined]
 )
 from .sql.sqltypes import GUID, AutoString
 
+if TYPE_CHECKING:
+    from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
+    from pydantic._internal._repr import Representation as Representation
+    from pydantic_core import PydanticUndefined as Undefined
+    from pydantic_core import PydanticUndefinedType as UndefinedType
+
 _T = TypeVar("_T")
 NoArgAnyCallable = Callable[[], Any]
 IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None]
@@ -764,13 +772,22 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
         mode: Union[Literal["json", "python"], str] = "python",
         include: IncEx = None,
         exclude: IncEx = None,
+        context: Union[Dict[str, Any], None] = None,
         by_alias: bool = False,
         exclude_unset: bool = False,
         exclude_defaults: bool = False,
         exclude_none: bool = False,
         round_trip: bool = False,
-        warnings: bool = True,
+        warnings: Union[bool, Literal["none", "warn", "error"]] = True,
+        serialize_as_any: bool = False,
     ) -> Dict[str, Any]:
+        if PYDANTIC_VERSION >= "2.7.0":
+            extra_kwargs: Dict[str, Any] = {
+                "context": context,
+                "serialize_as_any": serialize_as_any,
+            }
+        else:
+            extra_kwargs = {}
         if IS_PYDANTIC_V2:
             return super().model_dump(
                 mode=mode,
@@ -782,6 +799,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
                 exclude_none=exclude_none,
                 round_trip=round_trip,
                 warnings=warnings,
+                **extra_kwargs,
             )
         else:
             return super().dict(