]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
trust user PK argument as given; don't reduce
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Mar 2022 00:14:04 +0000 (20:14 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 Mar 2022 16:12:34 +0000 (12:12 -0400)
Fixed issue where the :class:`_orm.Mapper` would reduce a user-defined
:paramref:`_orm.Mapper.primary_key` argument too aggressively, in the case
of mapping to a ``UNION`` where for some of the SELECT entries, two columns
are essentially equivalent, but in another, they are not, such as in a
recursive CTE. The logic here has been changed to accept a given
user-defined PK as given, where columns will be related to the mapped
selectable but no longer "reduced" as this heuristic can't accommodate for
all situations.

Fixes: #7842
Change-Id: Ie46f0a3d42cae0501641fa213da0a9d5ca26c3ad

doc/build/changelog/unreleased_14/7842.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
test/orm/test_mapper.py

diff --git a/doc/build/changelog/unreleased_14/7842.rst b/doc/build/changelog/unreleased_14/7842.rst
new file mode 100644 (file)
index 0000000..c165ed4
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 7842
+
+    Fixed issue where the :class:`_orm.Mapper` would reduce a user-defined
+    :paramref:`_orm.Mapper.primary_key` argument too aggressively, in the case
+    of mapping to a ``UNION`` where for some of the SELECT entries, two columns
+    are essentially equivalent, but in another, they are not, such as in a
+    recursive CTE. The logic here has been changed to accept a given
+    user-defined PK as given, where columns will be related to the mapped
+    selectable but no longer "reduced" as this heuristic can't accommodate for
+    all situations.
index 011e7d2efc2b641f56e49081d13e872df9c965e4..7d1fc76436bbf4b6dbcfbbaab3591d73492ee0b7 100644 (file)
@@ -1368,17 +1368,17 @@ class Mapper(
             # that of the inheriting (unless concrete or explicit)
             self.primary_key = self.inherits.primary_key
         else:
-            # determine primary key from argument or persist_selectable pks -
-            # reduce to the minimal set of columns
+            # determine primary key from argument or persist_selectable pks
             if self._primary_key_argument:
-                primary_key = sql_util.reduce_columns(
-                    [
-                        self.persist_selectable.corresponding_column(c)
-                        for c in self._primary_key_argument
-                    ],
-                    ignore_nonexistent_tables=True,
-                )
+                primary_key = [
+                    self.persist_selectable.corresponding_column(c)
+                    for c in self._primary_key_argument
+                ]
             else:
+                # if heuristically determined PKs, reduce to the minimal set
+                # of columns by eliminating FK->PK pairs for a multi-table
+                # expression.   May over-reduce for some kinds of UNIONs
+                # / CTEs; use explicit PK argument for these special cases
                 primary_key = sql_util.reduce_columns(
                     self._pks_by_table[self.persist_selectable],
                     ignore_nonexistent_tables=True,
index 1fad974b92934cae73013fa3ab4b462bf716881e..980c82fbe2aff6824264f7316da9184326a820f8 100644 (file)
@@ -5,6 +5,7 @@ import sqlalchemy as sa
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import String
@@ -41,6 +42,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import ComparableEntity
 from sqlalchemy.testing.fixtures import ComparableMixin
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -1403,6 +1405,65 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             ],
         )
 
+    @testing.requires.ctes
+    def test_mapping_to_union_dont_overlimit_pk(self, registry, connection):
+        """test #7842"""
+        Base = registry.generate_base()
+
+        class Node(Base):
+            __tablename__ = "cte_nodes"
+
+            id = Column(Integer, primary_key=True)
+            parent = Column(Integer, ForeignKey("cte_nodes.id"))
+
+            #  so we dont have to deal with NULLS FIRST
+            sort_key = Column(Integer)
+
+        class NodeRel(ComparableEntity, Base):
+            table = select(
+                Node.id, Node.parent, Node.sort_key, literal(0).label("depth")
+            ).cte(recursive=True)
+            __table__ = table.union_all(
+                select(
+                    Node.id,
+                    table.c.parent,
+                    table.c.sort_key,
+                    table.c.depth + literal(1),
+                )
+                .select_from(Node)
+                .join(table, Node.parent == table.c.id)
+            )
+
+            __mapper_args__ = {
+                "primary_key": (__table__.c.id, __table__.c.parent)
+            }
+
+        nt = NodeRel.__table__
+
+        eq_(NodeRel.__mapper__.primary_key, (nt.c.id, nt.c.parent))
+
+        registry.metadata.create_all(connection)
+        with Session(connection) as session:
+            n1, n2, n3, n4 = (
+                Node(id=1, sort_key=1),
+                Node(id=2, parent=1, sort_key=2),
+                Node(id=3, parent=2, sort_key=3),
+                Node(id=4, parent=3, sort_key=4),
+            )
+            session.add_all([n1, n2, n3, n4])
+            session.commit()
+
+            q_rel = select(NodeRel).filter_by(id=4).order_by(NodeRel.sort_key)
+            eq_(
+                session.scalars(q_rel).all(),
+                [
+                    NodeRel(id=4, parent=None),
+                    NodeRel(id=4, parent=1),
+                    NodeRel(id=4, parent=2),
+                    NodeRel(id=4, parent=3),
+                ],
+            )
+
     def test_scalar_pk_arg(self):
         users, Keyword, items, Item, User, keywords = (
             self.tables.users,