]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
produce column copies up the whole hierarchy first
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 28 Jun 2022 22:55:19 +0000 (18:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Jun 2022 15:22:59 +0000 (11:22 -0400)
Fixed issue where a hierarchy of classes set up as an abstract or mixin
declarative classes could not declare standalone columns on a superclass
that would then be copied correctly to a :class:`_orm.declared_attr`
callable that wanted to make use of them on a descendant class.

Originally it looked like this would produce an ordering change,
however an adjustment to the flow for produce_column_copies
has avoided that for now.

Fixes: #8190
Change-Id: I4e2ee74edb110793eb42691c3e4a0e0535fba7e9

doc/build/changelog/unreleased_14/8190.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_mixin.py

diff --git a/doc/build/changelog/unreleased_14/8190.rst b/doc/build/changelog/unreleased_14/8190.rst
new file mode 100644 (file)
index 0000000..934e44c
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm, declarative
+    :tickets: 8190
+
+    Fixed issue where a hierarchy of classes set up as an abstract or mixin
+    declarative classes could not declare standalone columns on a superclass
+    that would then be copied correctly to a :class:`_orm.declared_attr`
+    callable that wanted to make use of them on a descendant class.
index 1366bedf2499b5663aea9d4677ee6257061d3484..62251fa2b458c68f943515dcd3b47507312dd311 100644 (file)
@@ -714,7 +714,14 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         attribute_is_overridden = self._cls_attr_override_checker(self.cls)
 
+        bases = []
+
         for base in cls.__mro__:
+            # collect bases and make sure standalone columns are copied
+            # to be the column they will ultimately be on the class,
+            # so that declared_attr functions use the right columns.
+            # need to do this all the way up the hierarchy first
+            # (see #8190)
 
             class_mapped = (
                 base is not cls
@@ -727,10 +734,34 @@ class _ClassScanMapperConfig(_MapperConfig):
             local_attributes_for_class = self._cls_attr_resolver(base)
 
             if not class_mapped and base is not cls:
-                self._produce_column_copies(
+                locally_collected_columns = self._produce_column_copies(
                     local_attributes_for_class,
                     attribute_is_overridden,
                 )
+            else:
+                locally_collected_columns = {}
+
+            bases.append(
+                (
+                    base,
+                    class_mapped,
+                    local_attributes_for_class,
+                    locally_collected_columns,
+                )
+            )
+
+        for (
+            base,
+            class_mapped,
+            local_attributes_for_class,
+            locally_collected_columns,
+        ) in bases:
+
+            # this transfer can also take place as we scan each name
+            # for finer-grained control of how collected_attributes is
+            # populated, as this is what impacts column ordering.
+            # however it's simpler to get it out of the way here.
+            collected_attributes.update(locally_collected_columns)
 
             for (
                 name,
@@ -738,6 +769,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                 annotation,
                 is_dataclass_field,
             ) in local_attributes_for_class():
+
                 if re.match(r"^__.+__$", name):
                     if name == "__mapper_args__":
                         check_decl = _check_declared_props_nocascade(
@@ -1096,10 +1128,10 @@ class _ClassScanMapperConfig(_MapperConfig):
             [], Iterable[Tuple[str, Any, Any, bool]]
         ],
         attribute_is_overridden: Callable[[str, Any], bool],
-    ) -> None:
+    ) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]:
         cls = self.cls
         dict_ = self.clsdict_view
-        collected_attributes = self.collected_attributes
+        locally_collected_attributes = {}
         column_copies = self.column_copies
         # copy mixin columns to the mapped class
 
@@ -1132,9 +1164,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                                 )
 
                     column_copies[obj] = copy_ = obj._copy()
-                    collected_attributes[name] = copy_
 
+                    locally_collected_attributes[name] = copy_
                     setattr(cls, name, copy_)
+        return locally_collected_attributes
 
     def _extract_mappable_attributes(self) -> None:
         cls = self.cls
index 36840b2d7a788aa1987c096a497b07f73b75207e..d509b6e992f615155f2738c50cd901dc2d1fd390 100644 (file)
@@ -2216,6 +2216,53 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL):
             "> :param_1",
         )
 
+    def test_multilevel_mixin_attr_refers_to_column_copies(self):
+        """test #8190.
+
+        This test is the same idea as test_mixin_attr_refers_to_column_copies
+        but tests the column copies from superclasses.
+
+        """
+        counter = mock.Mock()
+
+        class SomeOtherMixin:
+            status = Column(String)
+
+        class HasAddressCount(SomeOtherMixin):
+            id = Column(Integer, primary_key=True)
+
+            @declared_attr
+            def address_count(cls):
+                counter(cls.id)
+                counter(cls.status)
+                return column_property(
+                    select(func.count(Address.id))
+                    .where(Address.user_id == cls.id)
+                    .where(cls.status == "some status")
+                    .scalar_subquery()
+                )
+
+        class Address(Base):
+            __tablename__ = "address"
+            id = Column(Integer, primary_key=True)
+            user_id = Column(ForeignKey("user.id"))
+
+        class User(Base, HasAddressCount):
+            __tablename__ = "user"
+
+        eq_(counter.mock_calls, [mock.call(User.id), mock.call(User.status)])
+
+        sess = fixture_session()
+        self.assert_compile(
+            sess.query(User).having(User.address_count > 5),
+            "SELECT (SELECT count(address.id) AS count_1 FROM address "
+            'WHERE address.user_id = "user".id AND "user".status = :param_1) '
+            'AS anon_1, "user".id AS user_id, "user".status AS user_status '
+            'FROM "user" HAVING (SELECT count(address.id) AS count_1 '
+            'FROM address WHERE address.user_id = "user".id '
+            'AND "user".status = :param_1) > :param_2',
+        )
+
 
 class AbstractTest(DeclarativeTestBase):
     def test_abstract_boolean(self):