]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement declared_attr superclass assignment check for dataclasses
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Apr 2021 01:45:10 +0000 (21:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Apr 2021 15:02:48 +0000 (11:02 -0400)
Adjusted the declarative scan for dataclasses so that the inheritance
behavior of :func:`_orm.declared_attr` established on a mixin, when using
the new form of having it inside of a ``dataclasses.field()`` construct and
not actually a descriptor attribute on the class, correctly accommodates
the case when the target class to be mapped is a subclass of an existing
mapped class which has already mapped that :func:`_orm.declared_attr`, and
therefore should not be re-applied to this class.

Also, as changed in ed3f2c617239668d we now have an "is_dataclass"
boolean set as we iterate through attrs so we can remove this
from declared_attr.

Fixes: #6346
Change-Id: Iec75bdefd3bff7d8a9a157c8dd744ac14ff15ea8

doc/build/changelog/unreleased_14/6346.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
test/orm/test_dataclasses_py3k.py

diff --git a/doc/build/changelog/unreleased_14/6346.rst b/doc/build/changelog/unreleased_14/6346.rst
new file mode 100644 (file)
index 0000000..4ca26b7
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm, dataclasses
+    :tickets: 6346
+
+      Adjusted the declarative scan for dataclasses so that the inheritance
+      behavior of :func:`_orm.declared_attr` established on a mixin, when using
+      the new form of having it inside of a ``dataclasses.field()`` construct and
+      not actually a descriptor attribute on the class, correctly accommodates
+      the case when the target class to be mapped is a subclass of an existing
+      mapped class which has already mapped that :func:`_orm.declared_attr`, and
+      therefore should not be re-applied to this class.
+
index d9c464815b7c17e987a0251da8b4676eba0fb25d..4e2c3a88601690bee904b8e921950ea73e83019d 100644 (file)
@@ -205,11 +205,10 @@ class declared_attr(interfaces._MappedAttribute, property):
 
     """  # noqa E501
 
-    def __init__(self, fget, cascading=False, _is_dataclass=False):
+    def __init__(self, fget, cascading=False):
         super(declared_attr, self).__init__(fget)
         self.__doc__ = fget.__doc__
         self._cascading = cascading
-        self._is_dataclass = _is_dataclass
 
     def __get__(desc, self, cls):
         # the declared_attr needs to make use of a cache that exists
index b3444f26f442e133d13223a5555fd6c2232bcf34..f52827ad13df1072f267425a20ef8c038619bb9c 100644 (file)
@@ -536,11 +536,24 @@ class _ClassScanMapperConfig(_MapperConfig):
                             ] = ret = obj.__get__(obj, cls)
                             setattr(cls, name, ret)
                         else:
-                            if obj._is_dataclass:
-                                ret = obj.fget()
-                            else:
-
+                            if is_dataclass:
                                 # access attribute using normal class access
+                                # first, to see if it's been mapped on a
+                                # superclass.   note if the dataclasses.field()
+                                # has "default", this value can be anything.
+                                ret = getattr(cls, name, None)
+
+                                # so, if it's anything that's not ORM
+                                # mapped, assume we should invoke the
+                                # declared_attr
+                                if not isinstance(ret, InspectionAttr):
+                                    ret = obj.fget()
+                            else:
+                                # access attribute using normal class access.
+                                # if the declared attr already took place
+                                # on a superclass that is mapped, then
+                                # this is no longer a declared_attr, it will
+                                # be the InstrumentedAttribute
                                 ret = getattr(cls, name)
 
                             # correct for proxies created from hybrid_property
@@ -988,11 +1001,9 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
     decl_api = util.preloaded.orm_decl_api
     obj = field_metadata[sa_dataclass_metadata_key]
     if callable(obj) and not isinstance(obj, decl_api.declared_attr):
-        return decl_api.declared_attr(obj, _is_dataclass=True)
-    elif isinstance(obj, decl_api.declared_attr):
-        obj._is_dataclass = True
+        return decl_api.declared_attr(obj)
+    else:
         return obj
-    return obj
 
 
 class _DeferredMapperConfig(_ClassScanMapperConfig):
index 56091505c3ff40b11696474149bc92b977573dda..2debc3ddf421663dc2875b348b0ab3dc5f12c271 100644 (file)
@@ -29,9 +29,6 @@ except ImportError:
 class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
     __requires__ = ("dataclasses",)
 
-    run_setup_classes = "each"
-    run_setup_mappers = "each"
-
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -525,9 +522,6 @@ class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
 class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest):
     __requires__ = ("dataclasses",)
 
-    run_setup_classes = "each"
-    run_setup_mappers = "each"
-
     @classmethod
     def setup_classes(cls):
         declarative = cls.DeclarativeBasic.registry.mapped
@@ -554,6 +548,11 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest):
                 },
             )
 
+            has_a_default: str = dataclasses.field(
+                default="some default",
+                metadata={"sa": lambda: Column(String(50))},
+            )
+
         @declarative
         @dataclasses.dataclass
         class Widget(WidgetDC):
@@ -566,11 +565,35 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest):
                 default=None,
                 metadata={"sa": Column(String(30), nullable=False)},
             )
+
             __mapper_args__ = dict(
                 polymorphic_on="type",
                 polymorphic_identity="normal",
             )
 
+        @declarative
+        @dataclasses.dataclass
+        class SpecialWidget(Widget):
+            __tablename__ = "special_widgets"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            special_widget_id: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": Column(
+                        ForeignKey("widgets.widget_id"), primary_key=True
+                    )
+                },
+            )
+
+            magic: bool = dataclasses.field(
+                default=False, metadata={"sa": Column(Boolean)}
+            )
+
+            __mapper_args__ = dict(
+                polymorphic_identity="special",
+            )
+
         @dataclasses.dataclass
         class AccountDC:
 
@@ -631,9 +654,12 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest):
         cls.classes["Account"] = Account
         cls.classes["Widget"] = Widget
         cls.classes["User"] = User
+        cls.classes["SpecialWidget"] = SpecialWidget
 
     def test_setup(self):
-        Account, Widget, User = self.classes("Account", "Widget", "User")
+        Account, Widget, User, SpecialWidget = self.classes(
+            "Account", "Widget", "User", "SpecialWidget"
+        )
 
         assert "account_id" in Widget.__table__.c
         assert list(Widget.__table__.c.account_id.foreign_keys)[0].references(
@@ -641,11 +667,35 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest):
         )
         assert inspect(Account).relationships.widgets.mapper is inspect(Widget)
 
+        assert "account_id" not in SpecialWidget.__table__.c
+
+        assert "has_a_default" in Widget.__table__.c
+        assert "has_a_default" not in SpecialWidget.__table__.c
+
         assert "account_id" in User.__table__.c
         assert list(User.__table__.c.account_id.foreign_keys)[0].references(
             Account.__table__
         )
 
+    def test_asdict_and_astuple_special_widget(self):
+        SpecialWidget = self.classes.SpecialWidget
+        widget = SpecialWidget(magic=True)
+        eq_(
+            dataclasses.asdict(widget),
+            {
+                "widget_id": None,
+                "account_id": None,
+                "has_a_default": "some default",
+                "name": None,
+                "special_widget_id": None,
+                "magic": True,
+            },
+        )
+        eq_(
+            dataclasses.astuple(widget),
+            (None, None, "some default", None, None, True),
+        )
+
 
 class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest):
     __requires__ = ("dataclasses",)
@@ -678,6 +728,11 @@ class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest):
                 },
             )
 
+            has_a_default: str = dataclasses.field(
+                default="some default",
+                metadata={"sa": declared_attr(lambda: Column(String(50)))},
+            )
+
         @declarative
         @dataclasses.dataclass
         class Widget(WidgetDC):
@@ -695,6 +750,29 @@ class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest):
                 polymorphic_identity="normal",
             )
 
+        @declarative
+        @dataclasses.dataclass
+        class SpecialWidget(Widget):
+            __tablename__ = "special_widgets"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            special_widget_id: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": Column(
+                        ForeignKey("widgets.widget_id"), primary_key=True
+                    )
+                },
+            )
+
+            magic: bool = dataclasses.field(
+                default=False, metadata={"sa": Column(Boolean)}
+            )
+
+            __mapper_args__ = dict(
+                polymorphic_identity="special",
+            )
+
         @dataclasses.dataclass
         class AccountDC:
 
@@ -757,14 +835,12 @@ class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest):
         cls.classes["Account"] = Account
         cls.classes["Widget"] = Widget
         cls.classes["User"] = User
+        cls.classes["SpecialWidget"] = SpecialWidget
 
 
 class PropagationFromMixinTest(fixtures.TestBase):
     __requires__ = ("dataclasses",)
 
-    run_setup_classes = "each"
-    run_setup_mappers = "each"
-
     def test_propagate_w_plain_mixin_col(self, run_test):
         @dataclasses.dataclass
         class CommonMixin:
@@ -865,9 +941,6 @@ class PropagationFromMixinTest(fixtures.TestBase):
 class PropagationFromAbstractTest(fixtures.TestBase):
     __requires__ = ("dataclasses",)
 
-    run_setup_classes = "each"
-    run_setup_mappers = "each"
-
     def test_propagate_w_plain_mixin_col(self, run_test):
         @dataclasses.dataclass
         class BaseType: