]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add additional contextual path info when splicing eager joins
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Jun 2024 14:56:26 +0000 (10:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Jun 2024 21:06:07 +0000 (17:06 -0400)
Fixed very old issue involving the :paramref:`_orm.joinedload.innerjoin`
parameter where making use of this parameter mixed into a query that also
included joined eager loads along a self-referential or other cyclical
relationship, along with complicating factors like inner joins added for
secondary tables and such, would have the chance of splicing a particular
inner join to the wrong part of the query.  Additional state has been added
to the internal method that does this splice to make a better decision as
to where splicing should proceed.

Fixes: #11449
Change-Id: Ie8f0e8d9bb7958baac33c7c2231e4afae15cf5b1

doc/build/changelog/unreleased_20/11449.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/test_eager_relations.py

diff --git a/doc/build/changelog/unreleased_20/11449.rst b/doc/build/changelog/unreleased_20/11449.rst
new file mode 100644 (file)
index 0000000..f7974cf
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11449
+
+    Fixed very old issue involving the :paramref:`_orm.joinedload.innerjoin`
+    parameter where making use of this parameter mixed into a query that also
+    included joined eager loads along a self-referential or other cyclical
+    relationship, along with complicating factors like inner joins added for
+    secondary tables and such, would have the chance of splicing a particular
+    inner join to the wrong part of the query.  Additional state has been added
+    to the internal method that does this splice to make a better decision as
+    to where splicing should proceed.
index 20c3b9cc6b01369e00671a38678023c7089a91bb..e5eff56f3bfbfc928c8e1b140b8a32f00c86d895 100644 (file)
@@ -16,8 +16,10 @@ import collections
 import itertools
 from typing import Any
 from typing import Dict
+from typing import Optional
 from typing import Tuple
 from typing import TYPE_CHECKING
+from typing import Union
 
 from . import attributes
 from . import exc as orm_exc
@@ -57,8 +59,10 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..sql.selectable import Select
+from ..util.typing import Literal
 
 if TYPE_CHECKING:
+    from .mapper import Mapper
     from .relationships import RelationshipProperty
     from ..sql.elements import ColumnElement
 
@@ -2506,13 +2510,13 @@ class JoinedLoader(AbstractRelationshipLoader):
                 or query_entity.entity_zero.represents_outer_join
                 or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
                 _left_memo=self.parent,
-                _right_memo=self.mapper,
+                _right_memo=path[self.mapper],
                 _extra_criteria=extra_join_criteria,
             )
         else:
             # all other cases are innerjoin=='nested' approach
             eagerjoin = self._splice_nested_inner_join(
-                path, towrap, clauses, onclause, extra_join_criteria
+                path, path[-2], towrap, clauses, onclause, extra_join_criteria
             )
 
         compile_state.eager_joins[query_entity_key] = eagerjoin
@@ -2546,93 +2550,176 @@ class JoinedLoader(AbstractRelationshipLoader):
             )
 
     def _splice_nested_inner_join(
-        self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
+        self,
+        path,
+        entity_we_want_to_splice_onto,
+        join_obj,
+        clauses,
+        onclause,
+        extra_criteria,
+        entity_inside_join_structure: Union[
+            Mapper, None, Literal[False]
+        ] = False,
+        detected_existing_path: Optional[path_registry.PathRegistry] = None,
     ):
         # recursive fn to splice a nested join into an existing one.
-        # splicing=False means this is the outermost call, and it
-        # should return a value.  splicing=<from object> is the recursive
-        # form, where it can return None to indicate the end of the recursion
+        # entity_inside_join_structure=False means this is the outermost call,
+        # and it should return a value.  entity_inside_join_structure=<mapper>
+        # indicates we've descended into a join and are looking at a FROM
+        # clause representing this mapper; if this is not
+        # entity_we_want_to_splice_onto then return None to end the recursive
+        # branch
+
+        assert entity_we_want_to_splice_onto is path[-2]
 
-        if splicing is False:
-            # first call is always handed a join object
-            # from the outside
+        if entity_inside_join_structure is False:
             assert isinstance(join_obj, orm_util._ORMJoin)
-        elif isinstance(join_obj, sql.selectable.FromGrouping):
+
+        if isinstance(join_obj, sql.selectable.FromGrouping):
+            # FromGrouping - continue descending into the structure
             return self._splice_nested_inner_join(
                 path,
+                entity_we_want_to_splice_onto,
                 join_obj.element,
                 clauses,
                 onclause,
                 extra_criteria,
-                splicing,
+                entity_inside_join_structure,
             )
-        elif not isinstance(join_obj, orm_util._ORMJoin):
-            if path[-2].isa(splicing):
-                return orm_util._ORMJoin(
-                    join_obj,
-                    clauses.aliased_insp,
-                    onclause,
-                    isouter=False,
-                    _left_memo=splicing,
-                    _right_memo=path[-1].mapper,
-                    _extra_criteria=extra_criteria,
-                )
-            else:
-                return None
+        elif isinstance(join_obj, orm_util._ORMJoin):
+            # _ORMJoin - continue descending into the structure
 
-        target_join = self._splice_nested_inner_join(
-            path,
-            join_obj.right,
-            clauses,
-            onclause,
-            extra_criteria,
-            join_obj._right_memo,
-        )
-        if target_join is None:
-            right_splice = False
+            join_right_path = join_obj._right_memo
+
+            # see if right side of join is viable
             target_join = self._splice_nested_inner_join(
                 path,
-                join_obj.left,
+                entity_we_want_to_splice_onto,
+                join_obj.right,
                 clauses,
                 onclause,
                 extra_criteria,
-                join_obj._left_memo,
+                entity_inside_join_structure=(
+                    join_right_path[-1].mapper
+                    if join_right_path is not None
+                    else None
+                ),
             )
-            if target_join is None:
-                # should only return None when recursively called,
-                # e.g. splicing refers to a from obj
-                assert (
-                    splicing is not False
-                ), "assertion failed attempting to produce joined eager loads"
-                return None
-        else:
-            right_splice = True
-
-        if right_splice:
-            # for a right splice, attempt to flatten out
-            # a JOIN b JOIN c JOIN .. to avoid needless
-            # parenthesis nesting
-            if not join_obj.isouter and not target_join.isouter:
-                eagerjoin = join_obj._splice_into_center(target_join)
+
+            if target_join is not None:
+                # for a right splice, attempt to flatten out
+                # a JOIN b JOIN c JOIN .. to avoid needless
+                # parenthesis nesting
+                if not join_obj.isouter and not target_join.isouter:
+                    eagerjoin = join_obj._splice_into_center(target_join)
+                else:
+                    eagerjoin = orm_util._ORMJoin(
+                        join_obj.left,
+                        target_join,
+                        join_obj.onclause,
+                        isouter=join_obj.isouter,
+                        _left_memo=join_obj._left_memo,
+                    )
+
+                eagerjoin._target_adapter = target_join._target_adapter
+                return eagerjoin
+
             else:
-                eagerjoin = orm_util._ORMJoin(
+                # see if left side of join is viable
+                target_join = self._splice_nested_inner_join(
+                    path,
+                    entity_we_want_to_splice_onto,
                     join_obj.left,
-                    target_join,
-                    join_obj.onclause,
-                    isouter=join_obj.isouter,
-                    _left_memo=join_obj._left_memo,
+                    clauses,
+                    onclause,
+                    extra_criteria,
+                    entity_inside_join_structure=join_obj._left_memo,
+                    detected_existing_path=join_right_path,
                 )
-        else:
-            eagerjoin = orm_util._ORMJoin(
-                target_join,
-                join_obj.right,
-                join_obj.onclause,
-                isouter=join_obj.isouter,
-                _right_memo=join_obj._right_memo,
-            )
 
-        eagerjoin._target_adapter = target_join._target_adapter
-        return eagerjoin
+                if target_join is not None:
+                    eagerjoin = orm_util._ORMJoin(
+                        target_join,
+                        join_obj.right,
+                        join_obj.onclause,
+                        isouter=join_obj.isouter,
+                        _right_memo=join_obj._right_memo,
+                    )
+                    eagerjoin._target_adapter = target_join._target_adapter
+                    return eagerjoin
+
+            # neither side viable, return None, or fail if this was the top
+            # most call
+            if entity_inside_join_structure is False:
+                assert (
+                    False
+                ), "assertion failed attempting to produce joined eager loads"
+            return None
+
+        # reached an endpoint (e.g. a table that's mapped, or an alias of that
+        # table).  determine if we can use this endpoint to splice onto
+
+        # is this the entity we want to splice onto in the first place?
+        if not entity_we_want_to_splice_onto.isa(entity_inside_join_structure):
+            return None
+
+        # path check.  if we know the path how this join endpoint got here,
+        # lets look at our path we are satisfying and see if we're in the
+        # wrong place.  This is specifically for when our entity may
+        # appear more than once in the path, issue #11449
+        if detected_existing_path:
+            # this assertion is currently based on how this call is made,
+            # where given a join_obj, the call will have these parameters as
+            # entity_inside_join_structure=join_obj._left_memo
+            # and entity_inside_join_structure=join_obj._right_memo.mapper
+            assert detected_existing_path[-3] is entity_inside_join_structure
+
+            # from that, see if the path we are targeting matches the
+            # "existing" path of this join all the way up to the midpoint
+            # of this join object (e.g. the relationship).
+            # if not, then this is not our target
+            #
+            # a test condition where this test is false looks like:
+            #
+            # desired splice:         Node->kind->Kind
+            # path of desired splice: NodeGroup->nodes->Node->kind
+            # path we've located:     NodeGroup->nodes->Node->common_node->Node
+            #
+            # above, because we want to splice kind->Kind onto
+            # NodeGroup->nodes->Node, this is not our path because it actually
+            # goes more steps than we want into self-referential
+            # ->common_node->Node
+            #
+            # a test condition where this test is true looks like:
+            #
+            # desired splice:         B->c2s->C2
+            # path of desired splice: A->bs->B->c2s
+            # path we've located:     A->bs->B->c1s->C1
+            #
+            # above, we want to splice c2s->C2 onto B, and the located path
+            # shows that the join ends with B->c1s->C1.  so we will
+            # add another join onto that, which would create a "branch" that
+            # we might represent in a pseudopath as:
+            #
+            # B->c1s->C1
+            #  ->c2s->C2
+            #
+            # i.e. A JOIN B ON <bs> JOIN C1 ON <c1s>
+            #                       JOIN C2 ON <c2s>
+            #
+
+            if detected_existing_path[0:-2] != path.path[0:-1]:
+                return None
+
+        return orm_util._ORMJoin(
+            join_obj,
+            clauses.aliased_insp,
+            onclause,
+            isouter=False,
+            _left_memo=entity_inside_join_structure,
+            _right_memo=path[path[-1].mapper],
+            _extra_criteria=extra_criteria,
+        )
 
     def _create_eager_adapter(self, context, result, adapter, path, loadopt):
         compile_state = context.compile_state
index d1dbf22639dbb02f76bbc2577a072ef44376353f..1e4d3713975a4ec7de8ea73ce2ad249f214d3af2 100644 (file)
@@ -1945,7 +1945,7 @@ class _ORMJoin(expression.Join):
             self.onclause,
             isouter=self.isouter,
             _left_memo=self._left_memo,
-            _right_memo=other._left_memo,
+            _right_memo=None,
         )
 
         return _ORMJoin(
index 2e762c2d3cb2462762209c1b7e6ea544dd3ba0a1..bc3d8f10c2c609dbdef091e77d93d59fd018c786 100644 (file)
@@ -41,6 +41,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.assertsql import RegexSQL
 from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -3696,6 +3697,179 @@ class InnerJoinSplicingWSecondaryTest(
         self._assert_result(q)
 
 
+class InnerJoinSplicingWSecondarySelfRefTest(
+    fixtures.MappedTest, testing.AssertsCompiledSQL
+):
+    """test for issue 11449"""
+
+    __dialect__ = "default"
+    __backend__ = True  # exercise hardcore join nesting on backends
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "kind",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+        )
+
+        Table(
+            "node",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column(
+                "common_node_id", Integer, ForeignKey("node.id"), nullable=True
+            ),
+            Column("kind_id", Integer, ForeignKey("kind.id"), nullable=False),
+        )
+        Table(
+            "node_group",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+        )
+        Table(
+            "node_group_node",
+            metadata,
+            Column(
+                "node_group_id",
+                Integer,
+                ForeignKey("node_group.id"),
+                primary_key=True,
+            ),
+            Column(
+                "node_id", Integer, ForeignKey("node.id"), primary_key=True
+            ),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Kind(cls.Comparable):
+            pass
+
+        class Node(cls.Comparable):
+            pass
+
+        class NodeGroup(cls.Comparable):
+            pass
+
+        class NodeGroupNode(cls.Comparable):
+            pass
+
+    @classmethod
+    def insert_data(cls, connection):
+        kind = cls.tables.kind
+        connection.execute(
+            kind.insert(), [{"id": 1, "name": "a"}, {"id": 2, "name": "c"}]
+        )
+        node = cls.tables.node
+        connection.execute(
+            node.insert(),
+            {"id": 1, "name": "nc", "kind_id": 2},
+        )
+
+        connection.execute(
+            node.insert(),
+            {"id": 2, "name": "na", "kind_id": 1, "common_node_id": 1},
+        )
+
+        node_group = cls.tables.node_group
+        node_group_node = cls.tables.node_group_node
+
+        connection.execute(node_group.insert(), {"id": 1, "name": "group"})
+        connection.execute(
+            node_group_node.insert(),
+            {"id": 1, "node_group_id": 1, "node_id": 2},
+        )
+        connection.commit()
+
+    @testing.fixture(params=["common_nodes,kind", "kind,common_nodes"])
+    def node_fixture(self, request):
+        Kind, Node, NodeGroup, NodeGroupNode = self.classes(
+            "Kind", "Node", "NodeGroup", "NodeGroupNode"
+        )
+        kind, node, node_group, node_group_node = self.tables(
+            "kind", "node", "node_group", "node_group_node"
+        )
+        self.mapper_registry.map_imperatively(Kind, kind)
+
+        if request.param == "common_nodes,kind":
+            self.mapper_registry.map_imperatively(
+                Node,
+                node,
+                properties=dict(
+                    common_node=relationship(
+                        "Node",
+                        remote_side=[node.c.id],
+                    ),
+                    kind=relationship(Kind, innerjoin=True, lazy="joined"),
+                ),
+            )
+        elif request.param == "kind,common_nodes":
+            self.mapper_registry.map_imperatively(
+                Node,
+                node,
+                properties=dict(
+                    kind=relationship(Kind, innerjoin=True, lazy="joined"),
+                    common_node=relationship(
+                        "Node",
+                        remote_side=[node.c.id],
+                    ),
+                ),
+            )
+
+        self.mapper_registry.map_imperatively(
+            NodeGroup,
+            node_group,
+            properties=dict(
+                nodes=relationship(Node, secondary="node_group_node")
+            ),
+        )
+        self.mapper_registry.map_imperatively(NodeGroupNode, node_group_node)
+
+    def test_select(self, node_fixture):
+        Kind, Node, NodeGroup, NodeGroupNode = self.classes(
+            "Kind", "Node", "NodeGroup", "NodeGroupNode"
+        )
+
+        session = fixture_session()
+        with self.sql_execution_asserter(testing.db) as asserter:
+            group = (
+                session.scalars(
+                    select(NodeGroup)
+                    .where(NodeGroup.name == "group")
+                    .options(
+                        joinedload(NodeGroup.nodes).joinedload(
+                            Node.common_node
+                        )
+                    )
+                )
+                .unique()
+                .one_or_none()
+            )
+
+            eq_(group.nodes[0].common_node.kind.name, "c")
+            eq_(group.nodes[0].kind.name, "a")
+
+        asserter.assert_(
+            RegexSQL(
+                r"SELECT .* FROM node_group "
+                r"LEFT OUTER JOIN \(node_group_node AS node_group_node_1 "
+                r"JOIN node AS node_2 "
+                r"ON node_2.id = node_group_node_1.node_id "
+                r"JOIN kind AS kind_\d ON kind_\d.id = node_2.kind_id\) "
+                r"ON node_group.id = node_group_node_1.node_group_id "
+                r"LEFT OUTER JOIN "
+                r"\(node AS node_1 JOIN kind AS kind_\d "
+                r"ON kind_\d.id = node_1.kind_id\) "
+                r"ON node_1.id = node_2.common_node_id "
+                r"WHERE node_group.name = :name_5"
+            )
+        )
+
+
 class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
     """test #2188"""