]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fill-out dataclass-related attr resolution
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Jan 2021 22:59:35 +0000 (17:59 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Jan 2021 22:59:35 +0000 (17:59 -0500)
Fixed issue where mixin attribute rules were not taking
effect correctly for attributes pulled from dataclasses
using the approach added in #5745.

Fixes: #5876
Change-Id: I45099a42de1d9611791e72250fe0edc69bed684c

lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/orm/test_dataclasses_py3k.py

index db6d274c861794dd490d1bad4c0a7f1cfe5514e6..db7dfebe4a18adf573b5f7e029a9323b8483c562 100644 (file)
@@ -325,6 +325,94 @@ class _ClassScanMapperConfig(_MapperConfig):
             def before_configured():
                 self.cls.__declare_first__()
 
+    def _cls_attr_override_checker(self, cls):
+        """Produce a function that checks if a class has overridden an
+        attribute, taking SQLAlchemy-enabled dataclass fields into account.
+
+        """
+        sa_dataclass_metadata_key = _get_immediate_cls_attr(
+            cls, "__sa_dataclass_metadata_key__", None
+        )
+
+        if sa_dataclass_metadata_key is None:
+
+            def attribute_is_overridden(key, obj):
+                return getattr(cls, key) is not obj
+
+        else:
+
+            all_datacls_fields = {
+                f.name: f.metadata[sa_dataclass_metadata_key]
+                for f in util.dataclass_fields(cls)
+                if sa_dataclass_metadata_key in f.metadata
+            }
+            local_datacls_fields = {
+                f.name: f.metadata[sa_dataclass_metadata_key]
+                for f in util.local_dataclass_fields(cls)
+                if sa_dataclass_metadata_key in f.metadata
+            }
+
+            absent = object()
+
+            def attribute_is_overridden(key, obj):
+                # this function likely has some failure modes still if
+                # someone is doing a deep mixing of the same attribute
+                # name as plain Python attribute vs. dataclass field.
+
+                ret = local_datacls_fields.get(key, absent)
+
+                if ret is obj:
+                    return False
+                elif ret is not absent:
+                    return True
+
+                ret = getattr(cls, key, obj)
+
+                if ret is obj:
+                    return False
+                elif ret is not absent:
+                    return True
+
+                ret = all_datacls_fields.get(key, absent)
+
+                if ret is obj:
+                    return False
+                elif ret is not absent:
+                    return True
+
+                # can't find another attribute
+                return False
+
+        return attribute_is_overridden
+
+    def _cls_attr_resolver(self, cls):
+        """produce a function to iterate the "attributes" of a class,
+        adjusting for SQLAlchemy fields embedded in dataclass fields.
+
+        """
+        sa_dataclass_metadata_key = _get_immediate_cls_attr(
+            cls, "__sa_dataclass_metadata_key__", None
+        )
+
+        if sa_dataclass_metadata_key is None:
+
+            def local_attributes_for_class():
+                for name, obj in vars(cls).items():
+                    yield name, obj
+
+        else:
+
+            def local_attributes_for_class():
+                for name, obj in vars(cls).items():
+                    yield name, obj
+                for field in util.local_dataclass_fields(cls):
+                    if sa_dataclass_metadata_key in field.metadata:
+                        yield field.name, field.metadata[
+                            sa_dataclass_metadata_key
+                        ]
+
+        return local_attributes_for_class
+
     def _scan_attributes(self):
         cls = self.cls
         dict_ = self.dict_
@@ -333,9 +421,9 @@ class _ClassScanMapperConfig(_MapperConfig):
         table_args = inherited_table_args = None
         tablename = None
 
-        for base in cls.__mro__:
+        attribute_is_overridden = self._cls_attr_override_checker(self.cls)
 
-            sa_dataclass_metadata_key = None
+        for base in cls.__mro__:
 
             class_mapped = (
                 base is not cls
@@ -345,25 +433,14 @@ class _ClassScanMapperConfig(_MapperConfig):
                 )
             )
 
-            if sa_dataclass_metadata_key is None:
-                sa_dataclass_metadata_key = _get_immediate_cls_attr(
-                    base, "__sa_dataclass_metadata_key__", None
-                )
-
-            def attributes_for_class(cls):
-                for name, obj in vars(cls).items():
-                    yield name, obj
-                if sa_dataclass_metadata_key:
-                    for field in util.dataclass_fields(cls):
-                        if sa_dataclass_metadata_key in field.metadata:
-                            yield field.name, field.metadata[
-                                sa_dataclass_metadata_key
-                            ]
+            local_attributes_for_class = self._cls_attr_resolver(base)
 
             if not class_mapped and base is not cls:
-                self._produce_column_copies(attributes_for_class, base)
+                self._produce_column_copies(
+                    local_attributes_for_class, attribute_is_overridden
+                )
 
-            for name, obj in attributes_for_class(base):
+            for name, obj in local_attributes_for_class():
                 if name == "__mapper_args__":
                     check_decl = _check_declared_props_nocascade(
                         obj, name, cls
@@ -471,6 +548,15 @@ class _ClassScanMapperConfig(_MapperConfig):
                     else:
                         self._warn_for_decl_attributes(base, name, obj)
                 elif name not in dict_ or dict_[name] is not obj:
+                    # here, we are definitely looking at the target class
+                    # and not a superclass.   this is currently a
+                    # dataclass-only path.  if the name is only
+                    # a dataclass field and isn't in local cls.__dict__,
+                    # put the object there.
+
+                    # assert that the dataclass-enabled resolver agrees
+                    # with what we are seeing
+                    assert not attribute_is_overridden(name, obj)
                     dict_[name] = obj
 
         if inherited_table_args and not tablename:
@@ -489,14 +575,17 @@ class _ClassScanMapperConfig(_MapperConfig):
                 % (key, cls)
             )
 
-    def _produce_column_copies(self, attributes_for_class, base):
+    def _produce_column_copies(
+        self, attributes_for_class, attribute_is_overridden
+    ):
         cls = self.cls
         dict_ = self.dict_
         column_copies = self.column_copies
         # copy mixin columns to the mapped class
-        for name, obj in attributes_for_class(base):
+
+        for name, obj in attributes_for_class():
             if isinstance(obj, Column):
-                if getattr(cls, name) is not obj:
+                if attribute_is_overridden(name, obj):
                     # if column has been overridden
                     # (like by the InstrumentedAttribute of the
                     # superclass), skip
index 62dd9040ee78a8193d7108e07e620bf49f83a58a..4b76e6d88f3b41074d28d98220a897f6f45391c5 100644 (file)
@@ -552,6 +552,7 @@ class DeclarativeMappedTest(MappedTest):
             metaclass=FindFixtureDeclarative,
             cls=DeclarativeBasic,
         )
+
         cls.DeclarativeBasic = _DeclBase
 
         # sets up cls.Basic which is helpful for things like composite
index 5f8788a6e8b8d95ca93ca6020af2f159bfa6ece9..2d86b8b633024459f4a52bba7b4b4ca7fc16610c 100644 (file)
@@ -66,6 +66,7 @@ from .compat import int_types  # noqa
 from .compat import iterbytes  # noqa
 from .compat import itertools_filter  # noqa
 from .compat import itertools_filterfalse  # noqa
+from .compat import local_dataclass_fields  # noqa
 from .compat import namedtuple  # noqa
 from .compat import next  # noqa
 from .compat import nullcontext  # noqa
index 1eed2c3afeb9ea73711e8ba863d9e4f4f6d726d4..5b7a3eb9fb71019bc768a4a4b0f83bfa267ca71d 100644 (file)
@@ -425,17 +425,37 @@ if py37:
     import dataclasses
 
     def dataclass_fields(cls):
+        """Return a sequence of all dataclasses.Field objects associated
+        with a class."""
+
         if dataclasses.is_dataclass(cls):
             return dataclasses.fields(cls)
         else:
             return []
 
+    def local_dataclass_fields(cls):
+        """Return a sequence of all dataclasses.Field objects associated with
+        a class, excluding those that originate from a superclass."""
+
+        if dataclasses.is_dataclass(cls):
+            super_fields = set()
+            for sup in cls.__bases__:
+                super_fields.update(dataclass_fields(sup))
+            return [
+                f for f in dataclasses.fields(cls) if f not in super_fields
+            ]
+        else:
+            return []
+
 
 else:
 
     def dataclass_fields(cls):
         return []
 
+    def local_dataclass_fields(cls):
+        return []
+
 
 def raise_from_cause(exception, exc_info=None):
     r"""legacy.  use raise\_()"""
index a4b5e4c83d5f0410bec218285d734627c09fca57..51c76c125a4760191c67bd234e1b43ab4eae92b3 100644 (file)
@@ -6,12 +6,16 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import registry as declarative_registry
+from sqlalchemy.orm import registry
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
@@ -171,14 +175,14 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         assert Widget("Foo") != Widget("Bar")
         assert Widget("Foo") != SpecialWidget("Foo")
 
-    def test_asdict_and_astuple(self):
+    def test_asdict_and_astuple_widget(self):
         Widget = self.classes.Widget
-        SpecialWidget = self.classes.SpecialWidget
-
         widget = Widget("Foo")
         eq_(dataclasses.asdict(widget), {"name": "Foo"})
         eq_(dataclasses.astuple(widget), ("Foo",))
 
+    def test_asdict_and_astuple_special_widget(self):
+        SpecialWidget = self.classes.SpecialWidget
         widget = SpecialWidget("Bar", magic=True)
         eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
         eq_(dataclasses.astuple(widget), ("Bar", True))
@@ -187,11 +191,11 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         Account = self.classes.Account
         account = self.data_fixture()
 
-        with Session(testing.db) as session:
+        with fixture_session() as session:
             session.add(account)
             session.commit()
 
-        with Session(testing.db) as session:
+        with fixture_session() as session:
             a = session.query(Account).get(42)
             self.check_data_fixture(a)
 
@@ -373,14 +377,229 @@ class FieldEmbeddedDeclarativeDataclassesTest(
     def define_tables(cls, metadata):
         pass
 
-    def test_asdict_and_astuple(self):
+    def test_asdict_and_astuple_widget(self):
         Widget = self.classes.Widget
-        SpecialWidget = self.classes.SpecialWidget
 
         widget = Widget("Foo")
         eq_(dataclasses.asdict(widget), {"name": "Foo"})
         eq_(dataclasses.astuple(widget), ("Foo",))
 
+    def test_asdict_and_astuple_special_widget(self):
+        SpecialWidget = self.classes.SpecialWidget
         widget = SpecialWidget("Bar", magic=True)
         eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
         eq_(dataclasses.astuple(widget), ("Bar", True))
+
+
+class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
+    __requires__ = ("dataclasses",)
+
+    @classmethod
+    def setup_classes(cls):
+        declarative = cls.DeclarativeBasic.registry.mapped
+
+        @dataclasses.dataclass
+        class SurrogateWidgetPK:
+
+            __sa_dataclass_metadata_key__ = "sa"
+
+            widget_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class Widget(SurrogateWidgetPK):
+            __tablename__ = "widgets"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            account_id = Column(
+                Integer,
+                ForeignKey("accounts.account_id"),
+                nullable=False,
+            )
+            type = Column(String(30), nullable=False)
+
+            name: Optional[str] = dataclasses.field(
+                default=None,
+                metadata={"sa": Column(String(30), nullable=False)},
+            )
+            __mapper_args__ = dict(
+                polymorphic_on="type",
+                polymorphic_identity="normal",
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class SpecialWidget(Widget):
+            __sa_dataclass_metadata_key__ = "sa"
+
+            magic: bool = dataclasses.field(
+                default=False, metadata={"sa": Column(Boolean)}
+            )
+
+            __mapper_args__ = dict(
+                polymorphic_identity="special",
+            )
+
+        @dataclasses.dataclass
+        class SurrogateAccountPK:
+
+            __sa_dataclass_metadata_key__ = "sa"
+
+            account_id = Column(
+                "we_dont_want_to_use_this", Integer, primary_key=True
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class Account(SurrogateAccountPK):
+            __tablename__ = "accounts"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            account_id: int = dataclasses.field(
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+            widgets: List[Widget] = dataclasses.field(
+                default_factory=list, metadata={"sa": relationship("Widget")}
+            )
+            widget_count: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": Column("widget_count", Integer, nullable=False)
+                },
+            )
+
+            def __post_init__(self):
+                self.widget_count = len(self.widgets)
+
+            def add_widget(self, widget: Widget):
+                self.widgets.append(widget)
+                self.widget_count += 1
+
+        cls.classes.Account = Account
+        cls.classes.Widget = Widget
+        cls.classes.SpecialWidget = SpecialWidget
+
+    def check_widget_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        (
+            id_,
+            name,
+        ) = dataclasses.fields(obj)
+        eq_(name.name, "name")
+        eq_(id_.name, "widget_id")
+
+    def check_special_widget_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        id_, name, magic = dataclasses.fields(obj)
+        eq_(id_.name, "widget_id")
+        eq_(name.name, "name")
+        eq_(magic.name, "magic")
+
+    def test_asdict_and_astuple_widget(self):
+        Widget = self.classes.Widget
+
+        widget = Widget("Foo")
+        eq_(dataclasses.asdict(widget), {"name": "Foo", "widget_id": None})
+        eq_(
+            dataclasses.astuple(widget),
+            (
+                None,
+                "Foo",
+            ),
+        )
+
+    def test_asdict_and_astuple_special_widget(self):
+        SpecialWidget = self.classes.SpecialWidget
+        widget = SpecialWidget("Bar", magic=True)
+        eq_(
+            dataclasses.asdict(widget),
+            {"name": "Bar", "magic": True, "widget_id": None},
+        )
+        eq_(dataclasses.astuple(widget), (None, "Bar", True))
+
+
+class PropagationBlockTest(fixtures.TestBase):
+    __requires__ = ("dataclasses",)
+
+    run_setup_classes = "each"
+    run_setup_mappers = "each"
+
+    def test_propagate_w_plain_mixin_col(self, run_test):
+        @dataclasses.dataclass
+        class CommonMixin:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            @declared_attr
+            def __tablename__(cls):
+                return cls.__name__.lower()
+
+            __table_args__ = {"mysql_engine": "InnoDB"}
+            timestamp = Column(Integer)
+
+        run_test(CommonMixin)
+
+    def test_propagate_w_field_mixin_col(self, run_test):
+        @dataclasses.dataclass
+        class CommonMixin:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            @declared_attr
+            def __tablename__(cls):
+                return cls.__name__.lower()
+
+            __table_args__ = {"mysql_engine": "InnoDB"}
+
+            timestamp: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, nullable=False)},
+            )
+
+        run_test(CommonMixin)
+
+    @testing.fixture()
+    def run_test(self):
+        def go(CommonMixin):
+            declarative = registry().mapped
+
+            @declarative
+            @dataclasses.dataclass
+            class BaseType(CommonMixin):
+
+                discriminator = Column("type", String(50))
+                __mapper_args__ = dict(polymorphic_on=discriminator)
+                id = Column(Integer, primary_key=True)
+                value = Column(Integer())
+
+            @declarative
+            @dataclasses.dataclass
+            class Single(BaseType):
+
+                __tablename__ = None
+                __mapper_args__ = dict(polymorphic_identity="type1")
+
+            @declarative
+            @dataclasses.dataclass
+            class Joined(BaseType):
+
+                __mapper_args__ = dict(polymorphic_identity="type2")
+                id = Column(
+                    Integer, ForeignKey("basetype.id"), primary_key=True
+                )
+
+            eq_(BaseType.__table__.name, "basetype")
+            eq_(
+                list(BaseType.__table__.c.keys()),
+                ["timestamp", "type", "id", "value"],
+            )
+            eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"})
+            assert Single.__table__ is BaseType.__table__
+            eq_(Joined.__table__.name, "joined")
+            eq_(list(Joined.__table__.c.keys()), ["id"])
+            eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"})
+
+        yield go
+
+        clear_mappers()