From: Mike Bayer Date: Fri, 10 Feb 2023 21:06:23 +0000 (-0500) Subject: add dataclasses callable and apply annotations more strictly X-Git-Tag: rel_2_0_4~7^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=18fd19e60d55b35408d94b892e0a2051bcb7ec88;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add dataclasses callable and apply annotations more strictly Added new parameter ``dataclasses_callable`` to both the :class:`_orm.MappedAsDataclass` class as well as the :meth:`_orm.registry.mapped_as_dataclass` method which allows an alternative callable to Python ``dataclasses.dataclass`` to be used in order to produce dataclasses. The use case here is to drop in Pydantic's dataclass function instead. Adjustments have been made to the mixin support added for :ticket:`9179` in version 2.0.1 so that the ``__annotations__`` collection of the mixin is rewritten to not include the :class:`_orm.Mapped` container, in the same way as occurs with mapped classes, so that the Pydantic dataclasses constructor is not exposed to unknown types. Fixes: #9266 Change-Id: Ia0fab6f20b93a5cb853799dcf1b70a0386837c14 --- diff --git a/doc/build/changelog/unreleased_20/9266.rst b/doc/build/changelog/unreleased_20/9266.rst new file mode 100644 index 0000000000..4fd79dd380 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9266.rst @@ -0,0 +1,20 @@ +.. change:: + :tags: usecase, orm declarative + :tickets: 9266 + + Added new parameter ``dataclasses_callable`` to both the + :class:`_orm.MappedAsDataclass` class as well as the + :meth:`_orm.registry.mapped_as_dataclass` method which allows an + alternative callable to Python ``dataclasses.dataclass`` to be used in + order to produce dataclasses. The use case here is to drop in Pydantic's + dataclass function instead. Adjustments have been made to the mixin support + added for :ticket:`9179` in version 2.0.1 so that the ``__annotations__`` + collection of the mixin is rewritten to not include the + :class:`_orm.Mapped` container, in the same way as occurs with mapped + classes, so that the Pydantic dataclasses constructor is not exposed to + unknown types. + + .. seealso:: + + :ref:`dataclasses_pydantic` + diff --git a/doc/build/orm/dataclasses.rst b/doc/build/orm/dataclasses.rst index e98c67e689..4bb74cbb22 100644 --- a/doc/build/orm/dataclasses.rst +++ b/doc/build/orm/dataclasses.rst @@ -471,6 +471,54 @@ variable may be generated:: even though the purpose of this attribute was only to allow legacy ORM typed mappings to continue to function. +.. _dataclasses_pydantic: + +Integrating with Alternate Dataclass Providers such as Pydantic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQLAlchemy's :class:`_orm.MappedAsDataclass` class +and :meth:`_orm.registry.mapped_as_dataclass` method call directly into +the Python standard library ``dataclasses.dataclass`` class decorator, after +the declarative mapping process has been applied to the class. This +function call may be swapped out for alternateive dataclasses providers, +such as that of Pydantic, using the ``dataclass_callable`` parameter +accepted by :class:`_orm.MappedAsDataclass` as a class keyword argument +as well as by :meth:`_orm.registry.mapped_as_dataclass`:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import MappedAsDataclass + from sqlalchemy.orm import registry + + + class Base( + MappedAsDataclass, + DeclarativeBase, + dataclass_callable=pydantic.dataclasses.dataclass, + ): + pass + + + class User(Base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + +The above ``User`` class will be applied as a dataclass, using Pydantic's +``pydantic.dataclasses.dataclasses`` callable. The process is available +both for mapped classes as well as mixins that extend from +:class:`_orm.MappedAsDataclass` or which have +:meth:`_orm.registry.mapped_as_dataclass` applied directly. + +.. versionadded:: 2.0.4 Added the ``dataclass_callable`` class and method + parameters for :class:`_orm.MappedAsDataclass` and + :meth:`_orm.registry.mapped_as_dataclass`, and adjusted some of the + dataclass internals to accommodate more strict dataclass functions such as + that of Pydantic. + + .. _orm_declarative_dataclasses: Applying ORM Mappings to an existing dataclass diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index d02012b86b..f332d29646 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -593,6 +593,9 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, ) -> None: apply_dc_transforms: _DataclassArguments = { "init": init, @@ -602,6 +605,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): "unsafe_hash": unsafe_hash, "match_args": match_args, "kw_only": kw_only, + "dataclass_callable": dataclass_callable, } current_transforms: _DataclassArguments @@ -623,8 +627,11 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): super().__init_subclass__() if not _is_mapped_class(cls): + new_anno = ( + _ClassScanMapperConfig._update_annotations_for_non_mapped_class + )(cls) _ClassScanMapperConfig._apply_dataclasses_to_any_class( - current_transforms, cls + current_transforms, cls, new_anno ) @@ -1569,6 +1576,7 @@ class registry: unsafe_hash: Union[_NoArg, bool] = ..., match_args: Union[_NoArg, bool] = ..., kw_only: Union[_NoArg, bool] = ..., + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., ) -> Callable[[Type[_O]], Type[_O]]: ... @@ -1583,6 +1591,9 @@ class registry: unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: """Class decorator that will apply the Declarative mapping process to a given class, and additionally convert the class to be a @@ -1608,6 +1619,7 @@ class registry: "unsafe_hash": unsafe_hash, "match_args": match_args, "kw_only": kw_only, + "dataclass_callable": dataclass_callable, } _as_declarative(self, cls, cls.__dict__) return cls diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index aeed9b4395..f0be55b892 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -126,6 +126,7 @@ class _DataclassArguments(TypedDict): unsafe_hash: Union[_NoArg, bool] match_args: Union[_NoArg, bool] kw_only: Union[_NoArg, bool] + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] def _declared_mapping_info( @@ -1099,26 +1100,81 @@ class _ClassScanMapperConfig(_MapperConfig): for k, v in defaults.items(): setattr(self.cls, k, v) - self.cls.__annotations__ = annotations - self._apply_dataclasses_to_any_class( - dataclass_setup_arguments, self.cls + dataclass_setup_arguments, self.cls, annotations ) + @classmethod + def _update_annotations_for_non_mapped_class( + cls, klass: Type[_O] + ) -> Mapping[str, _AnnotationScanType]: + cls_annotations = util.get_annotations(klass) + + new_anno = {} + for name, annotation in cls_annotations.items(): + if _is_mapped_annotation(annotation, klass, klass): + + extracted = _extract_mapped_subtype( + annotation, + klass, + klass.__module__, + name, + type(None), + required=False, + is_dataclass_field=False, + expect_mapped=False, + ) + if extracted: + inner, _ = extracted + new_anno[name] = inner + else: + new_anno[name] = annotation + return new_anno + @classmethod def _apply_dataclasses_to_any_class( - cls, dataclass_setup_arguments: _DataclassArguments, klass: Type[_O] + cls, + dataclass_setup_arguments: _DataclassArguments, + klass: Type[_O], + use_annotations: Mapping[str, _AnnotationScanType], ) -> None: cls._assert_dc_arguments(dataclass_setup_arguments) - dataclasses.dataclass( - klass, - **{ - k: v - for k, v in dataclass_setup_arguments.items() - if v is not _NoArg.NO_ARG - }, - ) + dataclass_callable = dataclass_setup_arguments["dataclass_callable"] + if dataclass_callable is _NoArg.NO_ARG: + dataclass_callable = dataclasses.dataclass + + restored: Optional[Any] + + if use_annotations: + # apply constructed annotations that should look "normal" to a + # dataclasses callable, based on the fields present. This + # means remove the Mapped[] container and ensure all Field + # entries have an annotation + restored = getattr(klass, "__annotations__", None) + klass.__annotations__ = cast("Dict[str, Any]", use_annotations) + else: + restored = None + + try: + dataclass_callable( + klass, + **{ + k: v + for k, v in dataclass_setup_arguments.items() + if v is not _NoArg.NO_ARG and k != "dataclass_callable" + }, + ) + finally: + # restore original annotations outside of the dataclasses + # process; for mixins and __abstract__ superclasses, SQLAlchemy + # Declarative will need to see the Mapped[] container inside the + # annotations in order to map subclasses + if use_annotations: + if restored is None: + del klass.__annotations__ + else: + klass.__annotations__ = restored @classmethod def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None: @@ -1130,6 +1186,7 @@ class _ClassScanMapperConfig(_MapperConfig): "unsafe_hash", "kw_only", "match_args", + "dataclass_callable", } disallowed_args = set(arguments).difference(allowed) if disallowed_args: diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index 9b7c72778b..5abcaa46ef 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -37,6 +37,7 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import MappedColumn +from sqlalchemy.orm import registry from sqlalchemy.orm import registry as _RegistryType from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -51,6 +52,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ +from sqlalchemy.testing import Variation from sqlalchemy.util import compat @@ -261,6 +263,45 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + @testing.variation("dc_type", ["decorator", "superclass"]) + def test_dataclass_fn(self, dc_type: Variation): + annotations = {} + + def dc_callable(kls, **kw) -> Type[Any]: + annotations[kls] = kls.__annotations__ + return dataclasses.dataclass(kls, **kw) # type: ignore + + if dc_type.decorator: + reg = registry() + + @reg.mapped_as_dataclass(dataclass_callable=dc_callable) + class MappedClass: + __tablename__ = "mapped_class" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + eq_(annotations, {MappedClass: {"id": int, "name": str}}) + + elif dc_type.superclass: + + class Base(DeclarativeBase): + pass + + class Mixin(MappedAsDataclass, dataclass_callable=dc_callable): + id: Mapped[int] = mapped_column(primary_key=True) + + class MappedClass(Mixin, Base): + __tablename__ = "mapped_class" + name: Mapped[str] + + eq_( + annotations, + {Mixin: {"id": int}, MappedClass: {"id": int, "name": str}}, + ) + else: + dc_type.fail() + def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]): class A(dc_decl_base): __tablename__ = "a" @@ -978,10 +1019,16 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)") def test_abstract_is_dc(self): + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = cls.__annotations__ + return dataclasses.dataclass(cls, **kw) + class Parent(DeclarativeBase): a: int - class Mixin(MappedAsDataclass, Parent): + class Mixin(MappedAsDataclass, Parent, dataclass_callable=check_args): __abstract__ = True b: int @@ -989,6 +1036,42 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): __tablename__ = "child" c: Mapped[int] = mapped_column(primary_key=True) + eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}}) + eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + + @testing.variation("check_annotations", [True, False]) + def test_abstract_is_dc_w_mapped(self, check_annotations): + if check_annotations: + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = cls.__annotations__ + return dataclasses.dataclass(cls, **kw) + + class_kw = {"dataclass_callable": check_args} + else: + class_kw = {} + + class Parent(DeclarativeBase): + a: int + + class Mixin(MappedAsDataclass, Parent, **class_kw): + __abstract__ = True + b: Mapped[int] = mapped_column() + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + if check_annotations: + # note: current dataclasses process adds Field() object to Child + # based on attributes which include those from Mixin. This means + # the annotations of Child are also augmented while we do + # dataclasses collection. + eq_( + collected_annotations, + {Mixin: {"b": int}, Child: {"b": int, "c": int}}, + ) eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") def test_mixin_and_base_is_dc(self): @@ -1023,14 +1106,36 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): "dataclass_scope", ["on_base", "on_mixin", "on_base_class", "on_sub_class"], ) - def test_mixin_w_inheritance(self, dataclass_scope): + @testing.variation( + "test_alternative_callable", + [True, False], + ) + def test_mixin_w_inheritance( + self, dataclass_scope, test_alternative_callable + ): """test #9226""" + expected_annotations = {} + + if test_alternative_callable: + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = getattr( + cls, "__annotations__", {} + ) + return dataclasses.dataclass(cls, **kw) + + klass_kw = {"dataclass_callable": check_args} + else: + klass_kw = {} + if dataclass_scope.on_base: - class Base(DeclarativeBase, MappedAsDataclass): + class Base(MappedAsDataclass, DeclarativeBase, **klass_kw): pass + expected_annotations[Base] = {} else: class Base(DeclarativeBase): @@ -1038,7 +1143,7 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): if dataclass_scope.on_mixin: - class Mixin(MappedAsDataclass): + class Mixin(MappedAsDataclass, **klass_kw): @declared_attr.directive @classmethod def __tablename__(cls) -> str: @@ -1061,6 +1166,7 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): init=False, ) + expected_annotations[Mixin] = {} else: class Mixin: @@ -1100,7 +1206,7 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): if dataclass_scope.on_base_class: - class Book(Mixin, MappedAsDataclass, Base): + class Book(Mixin, MappedAsDataclass, Base, **klass_kw): id: Mapped[int] = mapped_column( Integer, primary_key=True, @@ -1120,9 +1226,12 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): primary_key=True, ) + if MappedAsDataclass in Book.__mro__: + expected_annotations[Book] = {"id": int, "polymorphic_type": str} + if dataclass_scope.on_sub_class: - class Novel(MappedAsDataclass, Book): + class Novel(MappedAsDataclass, Book, **klass_kw): id: Mapped[int] = mapped_column( # noqa: A001 ForeignKey("book.id"), primary_key=True, @@ -1140,6 +1249,11 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): ) description: Mapped[Optional[str]] + expected_annotations[Novel] = {"id": int, "description": Optional[str]} + + if test_alternative_callable: + eq_(collected_annotations, expected_annotations) + n1 = Novel("the description") eq_(n1.description, "the description") @@ -1210,7 +1324,6 @@ class DataclassArgsTest(fixtures.TestBase): x: Mapped[int] = mapped_expr_constructor def _assert_cls(self, cls, dc_arguments): - if dc_arguments["init"]: def create(data, x): @@ -1335,7 +1448,6 @@ class DataclassArgsTest(fixtures.TestBase): eq_(a3.x, 7) def _assert_not_init(self, cls, create, dc_arguments): - with expect_raises(TypeError): cls("Some data", 5) @@ -1579,18 +1691,20 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): @testing.fixture def model(self): - def go(use_mixin, use_inherits, mad_setup): - + def go(use_mixin, use_inherits, mad_setup, dataclass_kw): if use_mixin: - if mad_setup == "dc, mad": - class BaseEntity(DeclarativeBase, MappedAsDataclass): + class BaseEntity( + DeclarativeBase, MappedAsDataclass, **dataclass_kw + ): pass elif mad_setup == "mad, dc": - class BaseEntity(MappedAsDataclass, DeclarativeBase): + class BaseEntity( + MappedAsDataclass, DeclarativeBase, **dataclass_kw + ): pass elif mad_setup == "subclass": @@ -1605,7 +1719,9 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if mad_setup == "subclass": - class A(IdMixin, MappedAsDataclass, BaseEntity): + class A( + IdMixin, MappedAsDataclass, BaseEntity, **dataclass_kw + ): __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "a", @@ -1628,17 +1744,20 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): data: Mapped[str] = mapped_column(String, init=False) else: - if mad_setup == "dc, mad": - class BaseEntity(DeclarativeBase, MappedAsDataclass): + class BaseEntity( + DeclarativeBase, MappedAsDataclass, **dataclass_kw + ): id: Mapped[int] = mapped_column( primary_key=True, init=False ) elif mad_setup == "mad, dc": - class BaseEntity(MappedAsDataclass, DeclarativeBase): + class BaseEntity( + MappedAsDataclass, DeclarativeBase, **dataclass_kw + ): id: Mapped[int] = mapped_column( primary_key=True, init=False ) @@ -1652,7 +1771,7 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if mad_setup == "subclass": - class A(MappedAsDataclass, BaseEntity): + class A(MappedAsDataclass, BaseEntity, **dataclass_kw): __mapper_args__ = { "polymorphic_on": "type", "polymorphic_identity": "a", @@ -1698,6 +1817,7 @@ class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): use_inherits=use_inherits == "inherits", use_mixin=use_mixin == "mixin", mad_setup=mad_setup, + dataclass_kw={}, ) obj = target_cls()