From 9a74a282d0cb5a924322b9ad4b07e6196b55612a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 4 Jun 2024 10:56:26 -0400 Subject: [PATCH] add additional contextual path info when splicing eager joins Fixed very old issue involving the :paramref:`_orm.joinedload.innerjoin` parameter where making use of this parameter mixed into a query that also included joined eager loads along a self-referential or other cyclical relationship, along with complicating factors like inner joins added for secondary tables and such, would have the chance of splicing a particular inner join to the wrong part of the query. Additional state has been added to the internal method that does this splice to make a better decision as to where splicing should proceed. Fixes: #11449 Change-Id: Ie8f0e8d9bb7958baac33c7c2231e4afae15cf5b1 (cherry picked from commit c4c57237b76f3992a62c6eb5c23fd4e1919f1e4a) --- doc/build/changelog/unreleased_20/11449.rst | 12 ++ lib/sqlalchemy/orm/strategies.py | 32 +++- lib/sqlalchemy/orm/util.py | 2 +- test/orm/test_eager_relations.py | 174 ++++++++++++++++++++ 4 files changed, 215 insertions(+), 5 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11449.rst diff --git a/doc/build/changelog/unreleased_20/11449.rst b/doc/build/changelog/unreleased_20/11449.rst new file mode 100644 index 0000000000..f7974cfd76 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11449.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 11449 + + Fixed very old issue involving the :paramref:`_orm.joinedload.innerjoin` + parameter where making use of this parameter mixed into a query that also + included joined eager loads along a self-referential or other cyclical + relationship, along with complicating factors like inner joins added for + secondary tables and such, would have the chance of splicing a particular + inner join to the wrong part of the query. Additional state has been added + to the internal method that does this splice to make a better decision as + to where splicing should proceed. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 20c3b9cc6b..00c6fcb6c1 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2506,7 +2506,7 @@ class JoinedLoader(AbstractRelationshipLoader): or query_entity.entity_zero.represents_outer_join or (chained_from_outerjoin and isinstance(towrap, sql.Join)), _left_memo=self.parent, - _right_memo=self.mapper, + _right_memo=path[self.mapper], _extra_criteria=extra_join_criteria, ) else: @@ -2546,7 +2546,14 @@ class JoinedLoader(AbstractRelationshipLoader): ) def _splice_nested_inner_join( - self, path, join_obj, clauses, onclause, extra_criteria, splicing=False + self, + path, + join_obj, + clauses, + onclause, + extra_criteria, + splicing=False, + detected_existing_path=None, ): # recursive fn to splice a nested join into an existing one. # splicing=False means this is the outermost call, and it @@ -2568,13 +2575,23 @@ class JoinedLoader(AbstractRelationshipLoader): ) elif not isinstance(join_obj, orm_util._ORMJoin): if path[-2].isa(splicing): + + if detected_existing_path: + # TODO: refine this into a more efficient method + if not detected_existing_path.contains_mapper(splicing): + return None + elif path_registry.PathRegistry.coerce( + detected_existing_path[len(path) :] + ).contains_mapper(splicing): + return None + return orm_util._ORMJoin( join_obj, clauses.aliased_insp, onclause, isouter=False, _left_memo=splicing, - _right_memo=path[-1].mapper, + _right_memo=path[path[-1].mapper], _extra_criteria=extra_criteria, ) else: @@ -2586,7 +2603,12 @@ class JoinedLoader(AbstractRelationshipLoader): clauses, onclause, extra_criteria, - join_obj._right_memo, + # NOTE: this is the one place _right_memo is consumed + splicing=( + join_obj._right_memo[-1].mapper + if join_obj._right_memo is not None + else None + ), ) if target_join is None: right_splice = False @@ -2597,7 +2619,9 @@ class JoinedLoader(AbstractRelationshipLoader): onclause, extra_criteria, join_obj._left_memo, + detected_existing_path=join_obj._right_memo, ) + if target_join is None: # should only return None when recursively called, # e.g. splicing refers to a from obj diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0ab3536ddd..9835f82447 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1943,7 +1943,7 @@ class _ORMJoin(expression.Join): self.onclause, isouter=self.isouter, _left_memo=self._left_memo, - _right_memo=other._left_memo, + _right_memo=None, ) return _ORMJoin( diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 2e762c2d3c..bc3d8f10c2 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -41,6 +41,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -3696,6 +3697,179 @@ class InnerJoinSplicingWSecondaryTest( self._assert_result(q) +class InnerJoinSplicingWSecondarySelfRefTest( + fixtures.MappedTest, testing.AssertsCompiledSQL +): + """test for issue 11449""" + + __dialect__ = "default" + __backend__ = True # exercise hardcore join nesting on backends + + @classmethod + def define_tables(cls, metadata): + Table( + "kind", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + + Table( + "node", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column( + "common_node_id", Integer, ForeignKey("node.id"), nullable=True + ), + Column("kind_id", Integer, ForeignKey("kind.id"), nullable=False), + ) + Table( + "node_group", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + Table( + "node_group_node", + metadata, + Column( + "node_group_id", + Integer, + ForeignKey("node_group.id"), + primary_key=True, + ), + Column( + "node_id", Integer, ForeignKey("node.id"), primary_key=True + ), + ) + + @classmethod + def setup_classes(cls): + class Kind(cls.Comparable): + pass + + class Node(cls.Comparable): + pass + + class NodeGroup(cls.Comparable): + pass + + class NodeGroupNode(cls.Comparable): + pass + + @classmethod + def insert_data(cls, connection): + kind = cls.tables.kind + connection.execute( + kind.insert(), [{"id": 1, "name": "a"}, {"id": 2, "name": "c"}] + ) + node = cls.tables.node + connection.execute( + node.insert(), + {"id": 1, "name": "nc", "kind_id": 2}, + ) + + connection.execute( + node.insert(), + {"id": 2, "name": "na", "kind_id": 1, "common_node_id": 1}, + ) + + node_group = cls.tables.node_group + node_group_node = cls.tables.node_group_node + + connection.execute(node_group.insert(), {"id": 1, "name": "group"}) + connection.execute( + node_group_node.insert(), + {"id": 1, "node_group_id": 1, "node_id": 2}, + ) + connection.commit() + + @testing.fixture(params=["common_nodes,kind", "kind,common_nodes"]) + def node_fixture(self, request): + Kind, Node, NodeGroup, NodeGroupNode = self.classes( + "Kind", "Node", "NodeGroup", "NodeGroupNode" + ) + kind, node, node_group, node_group_node = self.tables( + "kind", "node", "node_group", "node_group_node" + ) + self.mapper_registry.map_imperatively(Kind, kind) + + if request.param == "common_nodes,kind": + self.mapper_registry.map_imperatively( + Node, + node, + properties=dict( + common_node=relationship( + "Node", + remote_side=[node.c.id], + ), + kind=relationship(Kind, innerjoin=True, lazy="joined"), + ), + ) + elif request.param == "kind,common_nodes": + self.mapper_registry.map_imperatively( + Node, + node, + properties=dict( + kind=relationship(Kind, innerjoin=True, lazy="joined"), + common_node=relationship( + "Node", + remote_side=[node.c.id], + ), + ), + ) + + self.mapper_registry.map_imperatively( + NodeGroup, + node_group, + properties=dict( + nodes=relationship(Node, secondary="node_group_node") + ), + ) + self.mapper_registry.map_imperatively(NodeGroupNode, node_group_node) + + def test_select(self, node_fixture): + Kind, Node, NodeGroup, NodeGroupNode = self.classes( + "Kind", "Node", "NodeGroup", "NodeGroupNode" + ) + + session = fixture_session() + with self.sql_execution_asserter(testing.db) as asserter: + group = ( + session.scalars( + select(NodeGroup) + .where(NodeGroup.name == "group") + .options( + joinedload(NodeGroup.nodes).joinedload( + Node.common_node + ) + ) + ) + .unique() + .one_or_none() + ) + + eq_(group.nodes[0].common_node.kind.name, "c") + eq_(group.nodes[0].kind.name, "a") + + asserter.assert_( + RegexSQL( + r"SELECT .* FROM node_group " + r"LEFT OUTER JOIN \(node_group_node AS node_group_node_1 " + r"JOIN node AS node_2 " + r"ON node_2.id = node_group_node_1.node_id " + r"JOIN kind AS kind_\d ON kind_\d.id = node_2.kind_id\) " + r"ON node_group.id = node_group_node_1.node_group_id " + r"LEFT OUTER JOIN " + r"\(node AS node_1 JOIN kind AS kind_\d " + r"ON kind_\d.id = node_1.kind_id\) " + r"ON node_1.id = node_2.common_node_id " + r"WHERE node_group.name = :name_5" + ) + ) + + class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): """test #2188""" -- 2.47.2