]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add dataclasses callable and apply annotations more strictly
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2023 21:06:23 +0000 (16:06 -0500)
committermike bayer <mike_mp@zzzcomputing.com>
Thu, 16 Feb 2023 00:09:14 +0000 (00:09 +0000)
Added new parameter ``dataclasses_callable`` to both the
:class:`_orm.MappedAsDataclass` class as well as the
:meth:`_orm.registry.mapped_as_dataclass` method which allows an
alternative callable to Python ``dataclasses.dataclass`` to be used in
order to produce dataclasses. The use case here is to drop in Pydantic's
dataclass function instead. Adjustments have been made to the mixin support
added for :ticket:`9179` in version 2.0.1 so that the ``__annotations__``
collection of the mixin is rewritten to not include the
:class:`_orm.Mapped` container, in the same way as occurs with mapped
classes, so that the Pydantic dataclasses constructor is not exposed to
unknown types.

Fixes: #9266
Change-Id: Ia0fab6f20b93a5cb853799dcf1b70a0386837c14

doc/build/changelog/unreleased_20/9266.rst [new file with mode: 0644]
doc/build/orm/dataclasses.rst
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_dc_transforms.py

diff --git a/doc/build/changelog/unreleased_20/9266.rst b/doc/build/changelog/unreleased_20/9266.rst
new file mode 100644 (file)
index 0000000..4fd79dd
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: usecase, orm declarative
+    :tickets: 9266
+
+    Added new parameter ``dataclasses_callable`` to both the
+    :class:`_orm.MappedAsDataclass` class as well as the
+    :meth:`_orm.registry.mapped_as_dataclass` method which allows an
+    alternative callable to Python ``dataclasses.dataclass`` to be used in
+    order to produce dataclasses. The use case here is to drop in Pydantic's
+    dataclass function instead. Adjustments have been made to the mixin support
+    added for :ticket:`9179` in version 2.0.1 so that the ``__annotations__``
+    collection of the mixin is rewritten to not include the
+    :class:`_orm.Mapped` container, in the same way as occurs with mapped
+    classes, so that the Pydantic dataclasses constructor is not exposed to
+    unknown types.
+
+    .. seealso::
+
+        :ref:`dataclasses_pydantic`
+
index e98c67e68942442437d344cafe0e64a87dfbd31a..4bb74cbb221eac241b81a7f4ce8ba530a6f3e420 100644 (file)
@@ -471,6 +471,54 @@ variable may be generated::
    even though the purpose of this attribute was only to allow legacy
    ORM typed mappings to continue to function.
 
+.. _dataclasses_pydantic:
+
+Integrating with Alternate Dataclass Providers such as Pydantic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+SQLAlchemy's :class:`_orm.MappedAsDataclass` class
+and :meth:`_orm.registry.mapped_as_dataclass` method call directly into
+the Python standard library ``dataclasses.dataclass`` class decorator, after
+the declarative mapping process has been applied to the class.  This
+function call may be swapped out for alternateive dataclasses providers,
+such as that of Pydantic, using the ``dataclass_callable`` parameter
+accepted by :class:`_orm.MappedAsDataclass` as a class keyword argument
+as well as by :meth:`_orm.registry.mapped_as_dataclass`::
+
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+    from sqlalchemy.orm import MappedAsDataclass
+    from sqlalchemy.orm import registry
+
+
+    class Base(
+        MappedAsDataclass,
+        DeclarativeBase,
+        dataclass_callable=pydantic.dataclasses.dataclass,
+    ):
+        pass
+
+
+    class User(Base):
+        __tablename__ = "user"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        name: Mapped[str]
+
+The above ``User`` class will be applied as a dataclass, using Pydantic's
+``pydantic.dataclasses.dataclasses`` callable.     The process is available
+both for mapped classes as well as mixins that extend from
+:class:`_orm.MappedAsDataclass` or which have
+:meth:`_orm.registry.mapped_as_dataclass` applied directly.
+
+.. versionadded:: 2.0.4 Added the ``dataclass_callable`` class and method
+   parameters for :class:`_orm.MappedAsDataclass` and
+   :meth:`_orm.registry.mapped_as_dataclass`, and adjusted some of the
+   dataclass internals to accommodate more strict dataclass functions such as
+   that of Pydantic.
+
+
 .. _orm_declarative_dataclasses:
 
 Applying ORM Mappings to an existing dataclass
index d02012b86b927d7e43b294b59e66ac67193ae049..f332d296460e347f625fc6496c7029df3db23a71 100644 (file)
@@ -593,6 +593,9 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
         unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
         match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
         kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        dataclass_callable: Union[
+            _NoArg, Callable[..., Type[Any]]
+        ] = _NoArg.NO_ARG,
     ) -> None:
         apply_dc_transforms: _DataclassArguments = {
             "init": init,
@@ -602,6 +605,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
             "unsafe_hash": unsafe_hash,
             "match_args": match_args,
             "kw_only": kw_only,
+            "dataclass_callable": dataclass_callable,
         }
 
         current_transforms: _DataclassArguments
@@ -623,8 +627,11 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
         super().__init_subclass__()
 
         if not _is_mapped_class(cls):
+            new_anno = (
+                _ClassScanMapperConfig._update_annotations_for_non_mapped_class
+            )(cls)
             _ClassScanMapperConfig._apply_dataclasses_to_any_class(
-                current_transforms, cls
+                current_transforms, cls, new_anno
             )
 
 
@@ -1569,6 +1576,7 @@ class registry:
         unsafe_hash: Union[_NoArg, bool] = ...,
         match_args: Union[_NoArg, bool] = ...,
         kw_only: Union[_NoArg, bool] = ...,
+        dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ...,
     ) -> Callable[[Type[_O]], Type[_O]]:
         ...
 
@@ -1583,6 +1591,9 @@ class registry:
         unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
         match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
         kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        dataclass_callable: Union[
+            _NoArg, Callable[..., Type[Any]]
+        ] = _NoArg.NO_ARG,
     ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
         """Class decorator that will apply the Declarative mapping process
         to a given class, and additionally convert the class to be a
@@ -1608,6 +1619,7 @@ class registry:
                 "unsafe_hash": unsafe_hash,
                 "match_args": match_args,
                 "kw_only": kw_only,
+                "dataclass_callable": dataclass_callable,
             }
             _as_declarative(self, cls, cls.__dict__)
             return cls
index aeed9b4395160b8985ee0409e90c1194a9d63ff9..f0be55b8923da1c40a67bac391318c5b183132cf 100644 (file)
@@ -126,6 +126,7 @@ class _DataclassArguments(TypedDict):
     unsafe_hash: Union[_NoArg, bool]
     match_args: Union[_NoArg, bool]
     kw_only: Union[_NoArg, bool]
+    dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]]
 
 
 def _declared_mapping_info(
@@ -1099,26 +1100,81 @@ class _ClassScanMapperConfig(_MapperConfig):
         for k, v in defaults.items():
             setattr(self.cls, k, v)
 
-        self.cls.__annotations__ = annotations
-
         self._apply_dataclasses_to_any_class(
-            dataclass_setup_arguments, self.cls
+            dataclass_setup_arguments, self.cls, annotations
         )
 
+    @classmethod
+    def _update_annotations_for_non_mapped_class(
+        cls, klass: Type[_O]
+    ) -> Mapping[str, _AnnotationScanType]:
+        cls_annotations = util.get_annotations(klass)
+
+        new_anno = {}
+        for name, annotation in cls_annotations.items():
+            if _is_mapped_annotation(annotation, klass, klass):
+
+                extracted = _extract_mapped_subtype(
+                    annotation,
+                    klass,
+                    klass.__module__,
+                    name,
+                    type(None),
+                    required=False,
+                    is_dataclass_field=False,
+                    expect_mapped=False,
+                )
+                if extracted:
+                    inner, _ = extracted
+                    new_anno[name] = inner
+            else:
+                new_anno[name] = annotation
+        return new_anno
+
     @classmethod
     def _apply_dataclasses_to_any_class(
-        cls, dataclass_setup_arguments: _DataclassArguments, klass: Type[_O]
+        cls,
+        dataclass_setup_arguments: _DataclassArguments,
+        klass: Type[_O],
+        use_annotations: Mapping[str, _AnnotationScanType],
     ) -> None:
         cls._assert_dc_arguments(dataclass_setup_arguments)
 
-        dataclasses.dataclass(
-            klass,
-            **{
-                k: v
-                for k, v in dataclass_setup_arguments.items()
-                if v is not _NoArg.NO_ARG
-            },
-        )
+        dataclass_callable = dataclass_setup_arguments["dataclass_callable"]
+        if dataclass_callable is _NoArg.NO_ARG:
+            dataclass_callable = dataclasses.dataclass
+
+        restored: Optional[Any]
+
+        if use_annotations:
+            # apply constructed annotations that should look "normal" to a
+            # dataclasses callable, based on the fields present.  This
+            # means remove the Mapped[] container and ensure all Field
+            # entries have an annotation
+            restored = getattr(klass, "__annotations__", None)
+            klass.__annotations__ = cast("Dict[str, Any]", use_annotations)
+        else:
+            restored = None
+
+        try:
+            dataclass_callable(
+                klass,
+                **{
+                    k: v
+                    for k, v in dataclass_setup_arguments.items()
+                    if v is not _NoArg.NO_ARG and k != "dataclass_callable"
+                },
+            )
+        finally:
+            # restore original annotations outside of the dataclasses
+            # process; for mixins and __abstract__ superclasses, SQLAlchemy
+            # Declarative will need to see the Mapped[] container inside the
+            # annotations in order to map subclasses
+            if use_annotations:
+                if restored is None:
+                    del klass.__annotations__
+                else:
+                    klass.__annotations__ = restored
 
     @classmethod
     def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None:
@@ -1130,6 +1186,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             "unsafe_hash",
             "kw_only",
             "match_args",
+            "dataclass_callable",
         }
         disallowed_args = set(arguments).difference(allowed)
         if disallowed_args:
index 9b7c72778b3a0cde4f2812018b9f278fd1d07090..5abcaa46efc1816391ee7396f92cca31d3fdd722 100644 (file)
@@ -37,6 +37,7 @@ from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import MappedColumn
+from sqlalchemy.orm import registry
 from sqlalchemy.orm import registry as _RegistryType
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
@@ -51,6 +52,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing import Variation
 from sqlalchemy.util import compat
 
 
@@ -261,6 +263,45 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         a3 = A("data")
         eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
 
+    @testing.variation("dc_type", ["decorator", "superclass"])
+    def test_dataclass_fn(self, dc_type: Variation):
+        annotations = {}
+
+        def dc_callable(kls, **kw) -> Type[Any]:
+            annotations[kls] = kls.__annotations__
+            return dataclasses.dataclass(kls, **kw)  # type: ignore
+
+        if dc_type.decorator:
+            reg = registry()
+
+            @reg.mapped_as_dataclass(dataclass_callable=dc_callable)
+            class MappedClass:
+                __tablename__ = "mapped_class"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                name: Mapped[str]
+
+            eq_(annotations, {MappedClass: {"id": int, "name": str}})
+
+        elif dc_type.superclass:
+
+            class Base(DeclarativeBase):
+                pass
+
+            class Mixin(MappedAsDataclass, dataclass_callable=dc_callable):
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+            class MappedClass(Mixin, Base):
+                __tablename__ = "mapped_class"
+                name: Mapped[str]
+
+            eq_(
+                annotations,
+                {Mixin: {"id": int}, MappedClass: {"id": int, "name": str}},
+            )
+        else:
+            dc_type.fail()
+
     def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]):
         class A(dc_decl_base):
             __tablename__ = "a"
@@ -978,10 +1019,16 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
         eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)")
 
     def test_abstract_is_dc(self):
+        collected_annotations = {}
+
+        def check_args(cls, **kw):
+            collected_annotations[cls] = cls.__annotations__
+            return dataclasses.dataclass(cls, **kw)
+
         class Parent(DeclarativeBase):
             a: int
 
-        class Mixin(MappedAsDataclass, Parent):
+        class Mixin(MappedAsDataclass, Parent, dataclass_callable=check_args):
             __abstract__ = True
             b: int
 
@@ -989,6 +1036,42 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
             __tablename__ = "child"
             c: Mapped[int] = mapped_column(primary_key=True)
 
+        eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}})
+        eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)")
+
+    @testing.variation("check_annotations", [True, False])
+    def test_abstract_is_dc_w_mapped(self, check_annotations):
+        if check_annotations:
+            collected_annotations = {}
+
+            def check_args(cls, **kw):
+                collected_annotations[cls] = cls.__annotations__
+                return dataclasses.dataclass(cls, **kw)
+
+            class_kw = {"dataclass_callable": check_args}
+        else:
+            class_kw = {}
+
+        class Parent(DeclarativeBase):
+            a: int
+
+        class Mixin(MappedAsDataclass, Parent, **class_kw):
+            __abstract__ = True
+            b: Mapped[int] = mapped_column()
+
+        class Child(Mixin):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        if check_annotations:
+            # note: current dataclasses process adds Field() object to Child
+            # based on attributes which include those from Mixin.  This means
+            # the annotations of Child are also augmented while we do
+            # dataclasses collection.
+            eq_(
+                collected_annotations,
+                {Mixin: {"b": int}, Child: {"b": int, "c": int}},
+            )
         eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)")
 
     def test_mixin_and_base_is_dc(self):
@@ -1023,14 +1106,36 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
         "dataclass_scope",
         ["on_base", "on_mixin", "on_base_class", "on_sub_class"],
     )
-    def test_mixin_w_inheritance(self, dataclass_scope):
+    @testing.variation(
+        "test_alternative_callable",
+        [True, False],
+    )
+    def test_mixin_w_inheritance(
+        self, dataclass_scope, test_alternative_callable
+    ):
         """test #9226"""
 
+        expected_annotations = {}
+
+        if test_alternative_callable:
+            collected_annotations = {}
+
+            def check_args(cls, **kw):
+                collected_annotations[cls] = getattr(
+                    cls, "__annotations__", {}
+                )
+                return dataclasses.dataclass(cls, **kw)
+
+            klass_kw = {"dataclass_callable": check_args}
+        else:
+            klass_kw = {}
+
         if dataclass_scope.on_base:
 
-            class Base(DeclarativeBase, MappedAsDataclass):
+            class Base(MappedAsDataclass, DeclarativeBase, **klass_kw):
                 pass
 
+            expected_annotations[Base] = {}
         else:
 
             class Base(DeclarativeBase):
@@ -1038,7 +1143,7 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
 
         if dataclass_scope.on_mixin:
 
-            class Mixin(MappedAsDataclass):
+            class Mixin(MappedAsDataclass, **klass_kw):
                 @declared_attr.directive
                 @classmethod
                 def __tablename__(cls) -> str:
@@ -1061,6 +1166,7 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
                         init=False,
                     )
 
+            expected_annotations[Mixin] = {}
         else:
 
             class Mixin:
@@ -1100,7 +1206,7 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
 
         if dataclass_scope.on_base_class:
 
-            class Book(Mixin, MappedAsDataclass, Base):
+            class Book(Mixin, MappedAsDataclass, Base, **klass_kw):
                 id: Mapped[int] = mapped_column(
                     Integer,
                     primary_key=True,
@@ -1120,9 +1226,12 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
                         primary_key=True,
                     )
 
+        if MappedAsDataclass in Book.__mro__:
+            expected_annotations[Book] = {"id": int, "polymorphic_type": str}
+
         if dataclass_scope.on_sub_class:
 
-            class Novel(MappedAsDataclass, Book):
+            class Novel(MappedAsDataclass, Book, **klass_kw):
                 id: Mapped[int] = mapped_column(  # noqa: A001
                     ForeignKey("book.id"),
                     primary_key=True,
@@ -1140,6 +1249,11 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase):
                 )
                 description: Mapped[Optional[str]]
 
+        expected_annotations[Novel] = {"id": int, "description": Optional[str]}
+
+        if test_alternative_callable:
+            eq_(collected_annotations, expected_annotations)
+
         n1 = Novel("the description")
         eq_(n1.description, "the description")
 
@@ -1210,7 +1324,6 @@ class DataclassArgsTest(fixtures.TestBase):
                 x: Mapped[int] = mapped_expr_constructor
 
     def _assert_cls(self, cls, dc_arguments):
-
         if dc_arguments["init"]:
 
             def create(data, x):
@@ -1335,7 +1448,6 @@ class DataclassArgsTest(fixtures.TestBase):
         eq_(a3.x, 7)
 
     def _assert_not_init(self, cls, create, dc_arguments):
-
         with expect_raises(TypeError):
             cls("Some data", 5)
 
@@ -1579,18 +1691,20 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
     @testing.fixture
     def model(self):
-        def go(use_mixin, use_inherits, mad_setup):
-
+        def go(use_mixin, use_inherits, mad_setup, dataclass_kw):
             if use_mixin:
-
                 if mad_setup == "dc, mad":
 
-                    class BaseEntity(DeclarativeBase, MappedAsDataclass):
+                    class BaseEntity(
+                        DeclarativeBase, MappedAsDataclass, **dataclass_kw
+                    ):
                         pass
 
                 elif mad_setup == "mad, dc":
 
-                    class BaseEntity(MappedAsDataclass, DeclarativeBase):
+                    class BaseEntity(
+                        MappedAsDataclass, DeclarativeBase, **dataclass_kw
+                    ):
                         pass
 
                 elif mad_setup == "subclass":
@@ -1605,7 +1719,9 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
                 if mad_setup == "subclass":
 
-                    class A(IdMixin, MappedAsDataclass, BaseEntity):
+                    class A(
+                        IdMixin, MappedAsDataclass, BaseEntity, **dataclass_kw
+                    ):
                         __mapper_args__ = {
                             "polymorphic_on": "type",
                             "polymorphic_identity": "a",
@@ -1628,17 +1744,20 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                         data: Mapped[str] = mapped_column(String, init=False)
 
             else:
-
                 if mad_setup == "dc, mad":
 
-                    class BaseEntity(DeclarativeBase, MappedAsDataclass):
+                    class BaseEntity(
+                        DeclarativeBase, MappedAsDataclass, **dataclass_kw
+                    ):
                         id: Mapped[int] = mapped_column(
                             primary_key=True, init=False
                         )
 
                 elif mad_setup == "mad, dc":
 
-                    class BaseEntity(MappedAsDataclass, DeclarativeBase):
+                    class BaseEntity(
+                        MappedAsDataclass, DeclarativeBase, **dataclass_kw
+                    ):
                         id: Mapped[int] = mapped_column(
                             primary_key=True, init=False
                         )
@@ -1652,7 +1771,7 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
                 if mad_setup == "subclass":
 
-                    class A(MappedAsDataclass, BaseEntity):
+                    class A(MappedAsDataclass, BaseEntity, **dataclass_kw):
                         __mapper_args__ = {
                             "polymorphic_on": "type",
                             "polymorphic_identity": "a",
@@ -1698,6 +1817,7 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             use_inherits=use_inherits == "inherits",
             use_mixin=use_mixin == "mixin",
             mad_setup=mad_setup,
+            dataclass_kw={},
         )
 
         obj = target_cls()