]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add additional contextual path info when splicing eager joins
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2024 14:56:26 +0000 (10:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2024 21:26:28 +0000 (17:26 -0400)
Fixed very old issue involving the :paramref:`_orm.joinedload.innerjoin`
parameter where making use of this parameter mixed into a query that also
included joined eager loads along a self-referential or other cyclical
relationship, along with complicating factors like inner joins added for
secondary tables and such, would have the chance of splicing a particular
inner join to the wrong part of the query.  Additional state has been added
to the internal method that does this splice to make a better decision as
to where splicing should proceed.

Fixes: #11449
Change-Id: Ie8f0e8d9bb7958baac33c7c2231e4afae15cf5b1
(cherry picked from commit c4c57237b76f3992a62c6eb5c23fd4e1919f1e4a)

doc/build/changelog/unreleased_20/11449.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/test_eager_relations.py

diff --git a/doc/build/changelog/unreleased_20/11449.rst b/doc/build/changelog/unreleased_20/11449.rst
new file mode 100644 (file)
index 0000000..f7974cf
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11449
+
+    Fixed very old issue involving the :paramref:`_orm.joinedload.innerjoin`
+    parameter where making use of this parameter mixed into a query that also
+    included joined eager loads along a self-referential or other cyclical
+    relationship, along with complicating factors like inner joins added for
+    secondary tables and such, would have the chance of splicing a particular
+    inner join to the wrong part of the query.  Additional state has been added
+    to the internal method that does this splice to make a better decision as
+    to where splicing should proceed.
index 20c3b9cc6b01369e00671a38678023c7089a91bb..00c6fcb6c1aa12ec89aee02350aa5597c8b66269 100644 (file)
@@ -2506,7 +2506,7 @@ class JoinedLoader(AbstractRelationshipLoader):
                 or query_entity.entity_zero.represents_outer_join
                 or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
                 _left_memo=self.parent,
-                _right_memo=self.mapper,
+                _right_memo=path[self.mapper],
                 _extra_criteria=extra_join_criteria,
             )
         else:
@@ -2546,7 +2546,14 @@ class JoinedLoader(AbstractRelationshipLoader):
             )
 
     def _splice_nested_inner_join(
-        self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
+        self,
+        path,
+        join_obj,
+        clauses,
+        onclause,
+        extra_criteria,
+        splicing=False,
+        detected_existing_path=None,
     ):
         # recursive fn to splice a nested join into an existing one.
         # splicing=False means this is the outermost call, and it
@@ -2568,13 +2575,23 @@ class JoinedLoader(AbstractRelationshipLoader):
             )
         elif not isinstance(join_obj, orm_util._ORMJoin):
             if path[-2].isa(splicing):
+
+                if detected_existing_path:
+                    # TODO: refine this into a more efficient method
+                    if not detected_existing_path.contains_mapper(splicing):
+                        return None
+                    elif path_registry.PathRegistry.coerce(
+                        detected_existing_path[len(path) :]
+                    ).contains_mapper(splicing):
+                        return None
+
                 return orm_util._ORMJoin(
                     join_obj,
                     clauses.aliased_insp,
                     onclause,
                     isouter=False,
                     _left_memo=splicing,
-                    _right_memo=path[-1].mapper,
+                    _right_memo=path[path[-1].mapper],
                     _extra_criteria=extra_criteria,
                 )
             else:
@@ -2586,7 +2603,12 @@ class JoinedLoader(AbstractRelationshipLoader):
             clauses,
             onclause,
             extra_criteria,
-            join_obj._right_memo,
+            # NOTE: this is the one place _right_memo is consumed
+            splicing=(
+                join_obj._right_memo[-1].mapper
+                if join_obj._right_memo is not None
+                else None
+            ),
         )
         if target_join is None:
             right_splice = False
@@ -2597,7 +2619,9 @@ class JoinedLoader(AbstractRelationshipLoader):
                 onclause,
                 extra_criteria,
                 join_obj._left_memo,
+                detected_existing_path=join_obj._right_memo,
             )
+
             if target_join is None:
                 # should only return None when recursively called,
                 # e.g. splicing refers to a from obj
index 0ab3536dddc34a44da4c7afc1fb124ca240cb8aa..9835f824470247ba6b906d6fe4119275b1455f54 100644 (file)
@@ -1943,7 +1943,7 @@ class _ORMJoin(expression.Join):
             self.onclause,
             isouter=self.isouter,
             _left_memo=self._left_memo,
-            _right_memo=other._left_memo,
+            _right_memo=None,
         )
 
         return _ORMJoin(
index 2e762c2d3cb2462762209c1b7e6ea544dd3ba0a1..bc3d8f10c2c609dbdef091e77d93d59fd018c786 100644 (file)
@@ -41,6 +41,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.assertsql import RegexSQL
 from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -3696,6 +3697,179 @@ class InnerJoinSplicingWSecondaryTest(
         self._assert_result(q)
 
 
+class InnerJoinSplicingWSecondarySelfRefTest(
+    fixtures.MappedTest, testing.AssertsCompiledSQL
+):
+    """test for issue 11449"""
+
+    __dialect__ = "default"
+    __backend__ = True  # exercise hardcore join nesting on backends
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "kind",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+        )
+
+        Table(
+            "node",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column(
+                "common_node_id", Integer, ForeignKey("node.id"), nullable=True
+            ),
+            Column("kind_id", Integer, ForeignKey("kind.id"), nullable=False),
+        )
+        Table(
+            "node_group",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+        )
+        Table(
+            "node_group_node",
+            metadata,
+            Column(
+                "node_group_id",
+                Integer,
+                ForeignKey("node_group.id"),
+                primary_key=True,
+            ),
+            Column(
+                "node_id", Integer, ForeignKey("node.id"), primary_key=True
+            ),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Kind(cls.Comparable):
+            pass
+
+        class Node(cls.Comparable):
+            pass
+
+        class NodeGroup(cls.Comparable):
+            pass
+
+        class NodeGroupNode(cls.Comparable):
+            pass
+
+    @classmethod
+    def insert_data(cls, connection):
+        kind = cls.tables.kind
+        connection.execute(
+            kind.insert(), [{"id": 1, "name": "a"}, {"id": 2, "name": "c"}]
+        )
+        node = cls.tables.node
+        connection.execute(
+            node.insert(),
+            {"id": 1, "name": "nc", "kind_id": 2},
+        )
+
+        connection.execute(
+            node.insert(),
+            {"id": 2, "name": "na", "kind_id": 1, "common_node_id": 1},
+        )
+
+        node_group = cls.tables.node_group
+        node_group_node = cls.tables.node_group_node
+
+        connection.execute(node_group.insert(), {"id": 1, "name": "group"})
+        connection.execute(
+            node_group_node.insert(),
+            {"id": 1, "node_group_id": 1, "node_id": 2},
+        )
+        connection.commit()
+
+    @testing.fixture(params=["common_nodes,kind", "kind,common_nodes"])
+    def node_fixture(self, request):
+        Kind, Node, NodeGroup, NodeGroupNode = self.classes(
+            "Kind", "Node", "NodeGroup", "NodeGroupNode"
+        )
+        kind, node, node_group, node_group_node = self.tables(
+            "kind", "node", "node_group", "node_group_node"
+        )
+        self.mapper_registry.map_imperatively(Kind, kind)
+
+        if request.param == "common_nodes,kind":
+            self.mapper_registry.map_imperatively(
+                Node,
+                node,
+                properties=dict(
+                    common_node=relationship(
+                        "Node",
+                        remote_side=[node.c.id],
+                    ),
+                    kind=relationship(Kind, innerjoin=True, lazy="joined"),
+                ),
+            )
+        elif request.param == "kind,common_nodes":
+            self.mapper_registry.map_imperatively(
+                Node,
+                node,
+                properties=dict(
+                    kind=relationship(Kind, innerjoin=True, lazy="joined"),
+                    common_node=relationship(
+                        "Node",
+                        remote_side=[node.c.id],
+                    ),
+                ),
+            )
+
+        self.mapper_registry.map_imperatively(
+            NodeGroup,
+            node_group,
+            properties=dict(
+                nodes=relationship(Node, secondary="node_group_node")
+            ),
+        )
+        self.mapper_registry.map_imperatively(NodeGroupNode, node_group_node)
+
+    def test_select(self, node_fixture):
+        Kind, Node, NodeGroup, NodeGroupNode = self.classes(
+            "Kind", "Node", "NodeGroup", "NodeGroupNode"
+        )
+
+        session = fixture_session()
+        with self.sql_execution_asserter(testing.db) as asserter:
+            group = (
+                session.scalars(
+                    select(NodeGroup)
+                    .where(NodeGroup.name == "group")
+                    .options(
+                        joinedload(NodeGroup.nodes).joinedload(
+                            Node.common_node
+                        )
+                    )
+                )
+                .unique()
+                .one_or_none()
+            )
+
+            eq_(group.nodes[0].common_node.kind.name, "c")
+            eq_(group.nodes[0].kind.name, "a")
+
+        asserter.assert_(
+            RegexSQL(
+                r"SELECT .* FROM node_group "
+                r"LEFT OUTER JOIN \(node_group_node AS node_group_node_1 "
+                r"JOIN node AS node_2 "
+                r"ON node_2.id = node_group_node_1.node_id "
+                r"JOIN kind AS kind_\d ON kind_\d.id = node_2.kind_id\) "
+                r"ON node_group.id = node_group_node_1.node_group_id "
+                r"LEFT OUTER JOIN "
+                r"\(node AS node_1 JOIN kind AS kind_\d "
+                r"ON kind_\d.id = node_1.kind_id\) "
+                r"ON node_1.id = node_2.common_node_id "
+                r"WHERE node_group.name = :name_5"
+            )
+        )
+
+
 class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
     """test #2188"""