]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
MappedAsDataclass applies @dataclasses.dataclass unconditionally
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Jan 2023 18:28:42 +0000 (13:28 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Jan 2023 20:40:20 +0000 (15:40 -0500)
When using the :class:`.MappedAsDataclass` superclass, all classes within
the hierarchy that are subclasses of this class will now be run through the
``@dataclasses.dataclass`` function whether or not they are actually
mapped, so that non-ORM fields declared on non-mapped classes within the
hierarchy will be used when mapped subclasses are turned into dataclasses.
This behavior applies both to intermediary classes mapped with
``__abstract__ = True`` as well as to the user-defined declarative base
itself, assuming :class:`.MappedAsDataclass` is present as a superclass for
these classes.

This allows non-mapped attributes such as ``InitVar`` declarations on
superclasses to be used, without the need to run the
``@dataclasses.dataclass`` decorator explicitly on each non-mapped class.
The new behavior is considered as correct as this is what the :pep:`681`
implementation expects when using a superclass to indicate dataclass
behavior.

Fixes: #9179
Change-Id: Ia01fa9806a27f7c1121bf7eaddf2847cf6dc5313

doc/build/changelog/unreleased_20/9179.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_dc_transforms.py

diff --git a/doc/build/changelog/unreleased_20/9179.rst b/doc/build/changelog/unreleased_20/9179.rst
new file mode 100644 (file)
index 0000000..489812c
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9179
+
+    When using the :class:`.MappedAsDataclass` superclass, all classes within
+    the hierarchy that are subclasses of this class will now be run through the
+    ``@dataclasses.dataclass`` function whether or not they are actually
+    mapped, so that non-ORM fields declared on non-mapped classes within the
+    hierarchy will be used when mapped subclasses are turned into dataclasses.
+    This behavior applies both to intermediary classes mapped with
+    ``__abstract__ = True`` as well as to the user-defined declarative base
+    itself, assuming :class:`.MappedAsDataclass` is present as a superclass for
+    these classes.
+
+    This allows non-mapped attributes such as ``InitVar`` declarations on
+    superclasses to be used, without the need to run the
+    ``@dataclasses.dataclass`` decorator explicitly on each non-mapped class.
+    The new behavior is considered as correct as this is what the :pep:`681`
+    implementation expects when using a superclass to indicate dataclass
+    behavior.
index ecd81c7ed50be1bd7c2f3e322f21805bd0fef1b4..c0bc3700819b00f74ad225aa6815c219ada9ae31 100644 (file)
@@ -45,6 +45,7 @@ from ._orm_constructors import relationship
 from ._orm_constructors import synonym
 from .attributes import InstrumentedAttribute
 from .base import _inspect_mapped_class
+from .base import _is_mapped_class
 from .base import Mapped
 from .decl_base import _add_attribute
 from .decl_base import _as_declarative
@@ -596,20 +597,29 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
             "kw_only": kw_only,
         }
 
+        current_transforms: _DataclassArguments
+
         if hasattr(cls, "_sa_apply_dc_transforms"):
             current = cls._sa_apply_dc_transforms  # type: ignore[attr-defined]
 
             _ClassScanMapperConfig._assert_dc_arguments(current)
 
-            cls._sa_apply_dc_transforms = {
+            cls._sa_apply_dc_transforms = current_transforms = {  # type: ignore  # noqa: E501
                 k: current.get(k, _NoArg.NO_ARG) if v is _NoArg.NO_ARG else v
                 for k, v in apply_dc_transforms.items()
             }
         else:
-            cls._sa_apply_dc_transforms = apply_dc_transforms
+            cls._sa_apply_dc_transforms = (
+                current_transforms
+            ) = apply_dc_transforms
 
         super().__init_subclass__()
 
+        if not _is_mapped_class(cls):
+            _ClassScanMapperConfig._apply_dataclasses_to_any_class(
+                current_transforms, cls
+            )
+
 
 class DeclarativeBase(
     inspection.Inspectable[InstanceState[Any]],
index 9e8b02359735ff0f9524c2cdfcc6ee825d6e3ef4..0462a894560be714fa7d58710332b17b03d31dd8 100644 (file)
@@ -1078,10 +1078,18 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         self.cls.__annotations__ = annotations
 
-        self._assert_dc_arguments(dataclass_setup_arguments)
+        self._apply_dataclasses_to_any_class(
+            dataclass_setup_arguments, self.cls
+        )
+
+    @classmethod
+    def _apply_dataclasses_to_any_class(
+        cls, dataclass_setup_arguments: _DataclassArguments, klass: Type[_O]
+    ) -> None:
+        cls._assert_dc_arguments(dataclass_setup_arguments)
 
         dataclasses.dataclass(
-            self.cls,
+            klass,
             **{
                 k: v
                 for k, v in dataclass_setup_arguments.items()
index 5f35d7a01cfeb26ec8083b1ade4e4e7a92bdb453..63450f4a170f0e510717f807aff26055023f7ce7 100644 (file)
@@ -823,6 +823,167 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase):
                 )
 
 
+class DataclassesForNonMappedClassesTest(fixtures.TestBase):
+    """test for cases added in #9179"""
+
+    def test_base_is_dc(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: int
+
+        class Child(Parent):
+            __tablename__ = "child"
+            b: Mapped[int] = mapped_column(primary_key=True)
+
+        eq_regex(repr(Child(5, 6)), r".*\.Child\(a=5, b=6\)")
+
+    def test_base_is_dc_plus_options(self):
+        class Parent(MappedAsDataclass, DeclarativeBase, unsafe_hash=True):
+            a: int
+
+        class Child(Parent, repr=False):
+            __tablename__ = "child"
+            b: Mapped[int] = mapped_column(primary_key=True)
+
+        c1 = Child(5, 6)
+        eq_(hash(c1), hash(Child(5, 6)))
+
+        # still reprs, because base has a repr, but b not included
+        eq_regex(repr(c1), r".*\.Child\(a=5\)")
+
+    def test_base_is_dc_init_var(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: InitVar[int]
+
+        class Child(Parent):
+            __tablename__ = "child"
+            b: Mapped[int] = mapped_column(primary_key=True)
+
+        c1 = Child(a=5, b=6)
+        eq_regex(repr(c1), r".*\.Child\(b=6\)")
+
+    def test_base_is_dc_field(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: int = dataclasses.field(default=10)
+
+        class Child(Parent):
+            __tablename__ = "child"
+            b: Mapped[int] = mapped_column(primary_key=True, default=7)
+
+        c1 = Child(a=5, b=6)
+        eq_regex(repr(c1), r".*\.Child\(a=5, b=6\)")
+
+        c1 = Child(b=6)
+        eq_regex(repr(c1), r".*\.Child\(a=10, b=6\)")
+
+        c1 = Child()
+        eq_regex(repr(c1), r".*\.Child\(a=10, b=7\)")
+
+    def test_abstract_and_base_is_dc(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: int
+
+        class Mixin(Parent):
+            __abstract__ = True
+            b: int
+
+        class Child(Mixin):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6, c=7\)")
+
+    def test_abstract_and_base_is_dc_plus_options(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: int
+
+        class Mixin(Parent, unsafe_hash=True):
+            __abstract__ = True
+            b: int
+
+        class Child(Mixin, repr=False):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        eq_(hash(Child(5, 6, 7)), hash(Child(5, 6, 7)))
+
+        eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6\)")
+
+    def test_abstract_and_base_is_dc_init_var(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: InitVar[int]
+
+        class Mixin(Parent):
+            __abstract__ = True
+            b: InitVar[int]
+
+        class Child(Mixin):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        c1 = Child(a=5, b=6, c=7)
+        eq_regex(repr(c1), r".*\.Child\(c=7\)")
+
+    def test_abstract_and_base_is_dc_field(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: int = dataclasses.field(default=10)
+
+        class Mixin(Parent):
+            __abstract__ = True
+            b: int = dataclasses.field(default=7)
+
+        class Child(Mixin):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True, default=9)
+
+        c1 = Child(b=6, c=7)
+        eq_regex(repr(c1), r".*\.Child\(a=10, b=6, c=7\)")
+
+        c1 = Child()
+        eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)")
+
+    def test_abstract_is_dc(self):
+        class Parent(DeclarativeBase):
+            a: int
+
+        class Mixin(MappedAsDataclass, Parent):
+            __abstract__ = True
+            b: int
+
+        class Child(Mixin):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)")
+
+    def test_mixin_and_base_is_dc(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: int
+
+        @dataclasses.dataclass
+        class Mixin:
+            b: int
+
+        class Child(Mixin, Parent):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6, c=7\)")
+
+    def test_mixin_and_base_is_dc_init_var(self):
+        class Parent(MappedAsDataclass, DeclarativeBase):
+            a: InitVar[int]
+
+        @dataclasses.dataclass
+        class Mixin:
+            b: InitVar[int]
+
+        class Child(Mixin, Parent):
+            __tablename__ = "child"
+            c: Mapped[int] = mapped_column(primary_key=True)
+
+        eq_regex(repr(Child(a=5, b=6, c=7)), r".*\.Child\(c=7\)")
+
+
 class DataclassArgsTest(fixtures.TestBase):
     dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
     if compat.py310:
@@ -986,12 +1147,17 @@ class DataclassArgsTest(fixtures.TestBase):
             create("g", 10) >= create("b", 7)
 
     def _assert_repr(self, cls, create, dc_arguments):
+        assert "__repr__" in cls.__dict__
         a1 = create("some data", 12)
         eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)")
 
     def _assert_not_repr(self, cls, create, dc_arguments):
-        a1 = create("some data", 12)
-        eq_regex(repr(a1), r"<.*A object at 0x.*>")
+        assert "__repr__" not in cls.__dict__
+
+        # if a superclass has __repr__, then we still get repr.
+        # so can't test this
+        # a1 = create("some data", 12)
+        # eq_regex(repr(a1), r"<.*A object at 0x.*>")
 
     def _assert_init(self, cls, create, dc_arguments):
         if not dc_arguments.get("kw_only", False):