]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed a bug related to "nested" inner join eager loading, which
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Mar 2015 21:49:39 +0000 (17:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Mar 2015 21:49:39 +0000 (17:49 -0400)
exists in 0.9 as well but is more of a regression in 1.0 due to
:ticket:`3008` which turns on "nested" by default, such that
a joined eager load that travels across sibling paths from a common
ancestor using innerjoin=True will correctly splice each "innerjoin"
sibling into the appropriate part of the join, when a series of
inner/outer joins are mixed together.
fixes #3347

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/orm/strategies.py
test/orm/test_eager_relations.py

index 96dd1d5c08bdd9726231170594c25fd8cc20c386..b8d61de2e89ed21a8f94439d54af40e9ce4e73d0 100644 (file)
 .. changelog::
     :version: 1.0.0b5
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3347
+
+        Fixed a bug related to "nested" inner join eager loading, which
+        exists in 0.9 as well but is more of a regression in 1.0 due to
+        :ticket:`3008` which turns on "nested" by default, such that
+        a joined eager load that travels across sibling paths from a common
+        ancestor using innerjoin=True will correctly splice each "innerjoin"
+        sibling into the appropriate part of the join, when a series of
+        inner/outer joins are mixed together.
+
 .. changelog::
     :version: 1.0.0b4
     :released: March 29, 2015
index 0b2672d66249f377ad989b074796624488d0c607..9aae8e5c827219941c6e3ec00ae9f5a4dc96ad11 100644 (file)
@@ -1332,34 +1332,24 @@ class JoinedLoader(AbstractRelationshipLoader):
 
         assert clauses.aliased_class is not None
 
-        join_to_outer = innerjoin and isinstance(towrap, sql.Join) and \
-            towrap.isouter
-
-        if chained_from_outerjoin and \
-                join_to_outer and innerjoin != 'unnested':
-            inner = orm_util.join(
-                towrap.right,
-                clauses.aliased_class,
-                onclause,
-                isouter=False
-            )
+        attach_on_outside = (
+            not chained_from_outerjoin or
+            not innerjoin or innerjoin == 'unnested')
 
-            eagerjoin = orm_util.join(
-                towrap.left,
-                inner,
-                towrap.onclause,
-                isouter=True
-            )
-            eagerjoin._target_adapter = inner._target_adapter
-        else:
-            if chained_from_outerjoin:
-                innerjoin = False
+        if attach_on_outside:
+            # this is the "classic" eager join case.
             eagerjoin = orm_util.join(
                 towrap,
                 clauses.aliased_class,
                 onclause,
-                isouter=not innerjoin
+                isouter=not innerjoin or (
+                    chained_from_outerjoin and isinstance(towrap, sql.Join)
+                )
             )
+        else:
+            # all other cases are innerjoin=='nested' approach
+            eagerjoin = self._splice_nested_inner_join(
+                path, towrap, clauses, onclause)
         context.eager_joins[entity_key] = eagerjoin
 
         # send a hint to the Query as to where it may "splice" this join
@@ -1389,6 +1379,64 @@ class JoinedLoader(AbstractRelationshipLoader):
                     )
                 )
 
+    def _splice_nested_inner_join(
+            self, path, join_obj, clauses, onclause, splicing=False):
+
+        if not splicing:
+            # first call is always handed a join object
+            # from the outside
+            assert isinstance(join_obj, sql.Join)
+        elif isinstance(join_obj, sql.selectable.FromGrouping):
+            return self._splice_nested_inner_join(
+                path, join_obj.element, clauses, onclause, True
+            )
+        elif not isinstance(join_obj, sql.Join):
+            if join_obj.is_derived_from(path[-2].selectable):
+                return orm_util.join(
+                    join_obj, clauses.aliased_class,
+                    onclause, isouter=False
+                )
+            else:
+                # only here if splicing == True
+                return None
+
+        target_join = self._splice_nested_inner_join(
+            path, join_obj.right, clauses, onclause, True)
+        if target_join is None:
+            right_splice = False
+            target_join = self._splice_nested_inner_join(
+                path, join_obj.left, clauses, onclause, True)
+            if target_join is None:
+                # should only return None when recursively called,
+                # e.g. splicing==True
+                assert splicing, \
+                    "assertion failed attempting to produce joined eager loads"
+                return None
+        else:
+            right_splice = True
+
+        if right_splice:
+            # for a right splice, attempt to flatten out
+            # a JOIN b JOIN c JOIN .. to avoid needless
+            # parenthesis nesting
+            if not join_obj.isouter and not target_join.isouter:
+                eagerjoin = orm_util.join(
+                    join_obj.left, target_join.left,
+                    join_obj.onclause, isouter=False,
+                ).join(target_join.right,
+                       target_join.onclause, isouter=False)
+            else:
+                eagerjoin = orm_util.join(
+                    join_obj.left, target_join,
+                    join_obj.onclause, isouter=join_obj.isouter)
+        else:
+            eagerjoin = orm_util.join(
+                target_join, join_obj.right,
+                join_obj.onclause, isouter=join_obj.isouter)
+
+        eagerjoin._target_adapter = target_join._target_adapter
+        return eagerjoin
+
     def _create_eager_adapter(self, context, result, adapter, path, loadopt):
         user_defined_adapter = self._init_user_defined_eager_proc(
             loadopt, context) if loadopt else False
index 3688773c256895b0f26d8caa614e22478ced44a9..d701cdbfcfcac84fea374dd4ffecdcd184e3bc44 100644 (file)
@@ -1699,6 +1699,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             "ON users.id = addresses_1.user_id"
         )
 
+
     def test_catch_the_right_target(self):
         # test eager join chaining to the "nested" join on the left,
         # a new feature as of [ticket:2369]
@@ -2006,6 +2007,257 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         ])
 
 
+class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+    __dialect__ = 'default'
+    __backend__ = True  # exercise hardcore join nesting on backends
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('a', metadata,
+              Column('id', Integer, primary_key=True)
+              )
+
+        Table('b', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('a_id', Integer, ForeignKey('a.id')),
+              Column('value', String(10)),
+              )
+        Table('c1', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('b_id', Integer, ForeignKey('b.id')),
+              Column('value', String(10)),
+              )
+        Table('c2', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('b_id', Integer, ForeignKey('b.id')),
+              Column('value', String(10)),
+              )
+        Table('d1', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('c1_id', Integer, ForeignKey('c1.id')),
+              Column('value', String(10)),
+              )
+        Table('d2', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('c2_id', Integer, ForeignKey('c2.id')),
+              Column('value', String(10)),
+              )
+        Table('e1', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('d1_id', Integer, ForeignKey('d1.id')),
+              Column('value', String(10)),
+              )
+
+    @classmethod
+    def setup_classes(cls):
+
+        class A(cls.Comparable):
+            pass
+
+        class B(cls.Comparable):
+            pass
+
+        class C1(cls.Comparable):
+            pass
+
+        class C2(cls.Comparable):
+            pass
+
+        class D1(cls.Comparable):
+            pass
+
+        class D2(cls.Comparable):
+            pass
+
+        class E1(cls.Comparable):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        A, B, C1, C2, D1, D2, E1 = (
+            cls.classes.A, cls.classes.B, cls.classes.C1,
+            cls.classes.C2, cls.classes.D1, cls.classes.D2, cls.classes.E1)
+        mapper(A, cls.tables.a, properties={
+            'bs': relationship(B)
+        })
+        mapper(B, cls.tables.b, properties={
+            'c1s': relationship(C1, order_by=cls.tables.c1.c.id),
+            'c2s': relationship(C2, order_by=cls.tables.c2.c.id)
+        })
+        mapper(C1, cls.tables.c1, properties={
+            'd1s': relationship(D1, order_by=cls.tables.d1.c.id)
+        })
+        mapper(C2, cls.tables.c2, properties={
+            'd2s': relationship(D2, order_by=cls.tables.d2.c.id)
+        })
+        mapper(D1, cls.tables.d1, properties={
+            'e1s': relationship(E1, order_by=cls.tables.e1.c.id)
+        })
+        mapper(D2, cls.tables.d2)
+        mapper(E1, cls.tables.e1)
+
+    @classmethod
+    def _fixture_data(cls):
+        A, B, C1, C2, D1, D2, E1 = (
+            cls.classes.A, cls.classes.B, cls.classes.C1,
+            cls.classes.C2, cls.classes.D1, cls.classes.D2, cls.classes.E1)
+        return [
+            A(id=1, bs=[
+                B(
+                    id=1,
+                    c1s=[C1(
+                        id=1, value='C11',
+                        d1s=[
+                            D1(id=1, e1s=[E1(id=1)]), D1(id=2, e1s=[E1(id=2)])
+                        ]
+                    )
+                    ],
+                    c2s=[C2(id=1, value='C21', d2s=[D2(id=3)]),
+                         C2(id=2, value='C22', d2s=[D2(id=4)])]
+                ),
+                B(
+                    id=2,
+                    c1s=[
+                        C1(
+                            id=4, value='C14',
+                            d1s=[D1(
+                                id=3, e1s=[
+                                    E1(id=3, value='E13'),
+                                    E1(id=4, value="E14")
+                                ]),
+                                D1(id=4, e1s=[E1(id=5)])
+                            ]
+                        )
+                    ],
+                    c2s=[C2(id=4, value='C24', d2s=[])]
+                ),
+            ]),
+            A(id=2, bs=[
+                B(
+                    id=3,
+                    c1s=[
+                        C1(
+                            id=8,
+                            d1s=[D1(id=5, value='D15', e1s=[E1(id=6)])]
+                        )
+                    ],
+                    c2s=[C2(id=8, d2s=[D2(id=6, value='D26')])]
+                )
+            ])
+        ]
+
+    @classmethod
+    def insert_data(cls):
+        s = Session(testing.db)
+        s.add_all(cls._fixture_data())
+        s.commit()
+
+    def _assert_result(self, query):
+        eq_(
+            query.all(),
+            self._fixture_data()
+        )
+
+    def test_nested_innerjoin_propagation_multiple_paths_one(self):
+        A, B, C1, C2 = (
+            self.classes.A, self.classes.B, self.classes.C1,
+            self.classes.C2)
+
+        s = Session()
+
+        q = s.query(A).options(
+            joinedload(A.bs, innerjoin=False).
+            joinedload(B.c1s, innerjoin=True).
+            joinedload(C1.d1s, innerjoin=True),
+            defaultload(A.bs).joinedload(B.c2s, innerjoin=True).
+            joinedload(C2.d2s, innerjoin=False)
+        )
+        self.assert_compile(
+            q,
+            "SELECT a.id AS a_id, d1_1.id AS d1_1_id, "
+            "d1_1.c1_id AS d1_1_c1_id, d1_1.value AS d1_1_value, "
+            "c1_1.id AS c1_1_id, c1_1.b_id AS c1_1_b_id, "
+            "c1_1.value AS c1_1_value, d2_1.id AS d2_1_id, "
+            "d2_1.c2_id AS d2_1_c2_id, d2_1.value AS d2_1_value, "
+            "c2_1.id AS c2_1_id, c2_1.b_id AS c2_1_b_id, "
+            "c2_1.value AS c2_1_value, b_1.id AS b_1_id, "
+            "b_1.a_id AS b_1_a_id, b_1.value AS b_1_value "
+            "FROM a "
+            "LEFT OUTER JOIN "
+            "(b AS b_1 JOIN c2 AS c2_1 ON b_1.id = c2_1.b_id "
+            "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id "
+            "JOIN d1 AS d1_1 ON c1_1.id = d1_1.c1_id) ON a.id = b_1.a_id "
+            "LEFT OUTER JOIN d2 AS d2_1 ON c2_1.id = d2_1.c2_id "
+            "ORDER BY c1_1.id, d1_1.id, c2_1.id, d2_1.id"
+        )
+        self._assert_result(q)
+
+    def test_nested_innerjoin_propagation_multiple_paths_two(self):
+        # test #3447
+        A = self.classes.A
+
+        s = Session()
+
+        q = s.query(A).options(
+            joinedload('bs'),
+            joinedload('bs.c2s', innerjoin=True),
+            joinedload('bs.c1s', innerjoin=True),
+            joinedload('bs.c1s.d1s')
+        )
+        self.assert_compile(
+            q,
+            "SELECT a.id AS a_id, d1_1.id AS d1_1_id, "
+            "d1_1.c1_id AS d1_1_c1_id, d1_1.value AS d1_1_value, "
+            "c1_1.id AS c1_1_id, c1_1.b_id AS c1_1_b_id, "
+            "c1_1.value AS c1_1_value, c2_1.id AS c2_1_id, "
+            "c2_1.b_id AS c2_1_b_id, c2_1.value AS c2_1_value, "
+            "b_1.id AS b_1_id, b_1.a_id AS b_1_a_id, "
+            "b_1.value AS b_1_value "
+            "FROM a LEFT OUTER JOIN "
+            "(b AS b_1 JOIN c2 AS c2_1 ON b_1.id = c2_1.b_id "
+            "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id) ON a.id = b_1.a_id "
+            "LEFT OUTER JOIN d1 AS d1_1 ON c1_1.id = d1_1.c1_id "
+            "ORDER BY c1_1.id, d1_1.id, c2_1.id"
+        )
+        self._assert_result(q)
+
+    def test_multiple_splice_points(self):
+        A = self.classes.A
+
+        s = Session()
+
+        q = s.query(A).options(
+            joinedload('bs', innerjoin=False),
+            joinedload('bs.c1s', innerjoin=True),
+            joinedload('bs.c2s', innerjoin=True),
+            joinedload('bs.c1s.d1s', innerjoin=False),
+            joinedload('bs.c2s.d2s'),
+            joinedload('bs.c1s.d1s.e1s', innerjoin=True)
+        )
+
+        self.assert_compile(
+            q,
+            "SELECT a.id AS a_id, e1_1.id AS e1_1_id, "
+            "e1_1.d1_id AS e1_1_d1_id, e1_1.value AS e1_1_value, "
+            "d1_1.id AS d1_1_id, d1_1.c1_id AS d1_1_c1_id, "
+            "d1_1.value AS d1_1_value, c1_1.id AS c1_1_id, "
+            "c1_1.b_id AS c1_1_b_id, c1_1.value AS c1_1_value, "
+            "d2_1.id AS d2_1_id, d2_1.c2_id AS d2_1_c2_id, "
+            "d2_1.value AS d2_1_value, c2_1.id AS c2_1_id, "
+            "c2_1.b_id AS c2_1_b_id, c2_1.value AS c2_1_value, "
+            "b_1.id AS b_1_id, b_1.a_id AS b_1_a_id, b_1.value AS b_1_value "
+            "FROM a LEFT OUTER JOIN "
+            "(b AS b_1 JOIN c2 AS c2_1 ON b_1.id = c2_1.b_id "
+            "JOIN c1 AS c1_1 ON b_1.id = c1_1.b_id) ON a.id = b_1.a_id "
+            "LEFT OUTER JOIN ("
+            "d1 AS d1_1 JOIN e1 AS e1_1 ON d1_1.id = e1_1.d1_id) "
+            "ON c1_1.id = d1_1.c1_id "
+            "LEFT OUTER JOIN d2 AS d2_1 ON c2_1.id = d2_1.c2_id "
+            "ORDER BY c1_1.id, d1_1.id, e1_1.id, c2_1.id, d2_1.id"
+        )
+        self._assert_result(q)
+
+
 class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
     """test #2188"""