]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Accommodate for query._current_path in subq eager load join_depth
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Apr 2017 16:02:18 +0000 (12:02 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Apr 2017 16:02:18 +0000 (12:02 -0400)
Fixed bug in subquery eager loading where the "join_depth" parameter
for self-referential relationships would not be correctly honored,
loading all available levels deep rather than correctly counting
the specified number of levels for eager loading.

Change-Id: Ifa54085cbab3b41c2196f3ee519f485c63e4cb8d
Fixes: #3967
doc/build/changelog/changelog_12.rst
lib/sqlalchemy/orm/strategies.py
test/orm/test_subquery_relations.py

index 7b47149f8abfcd3d65de9022adb0b0cb51990068..57262ef9b517b5159ce0782ba011ee664e929913 100644 (file)
 .. changelog::
     :version: 1.2.0b1
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3967
+
+        Fixed bug in subquery eager loading where the "join_depth" parameter
+        for self-referential relationships would not be correctly honored,
+        loading all available levels deep rather than correctly counting
+        the specified number of levels for eager loading.
+
     .. change::
         :tags: bug, orm
 
index fdcb54953ae0bc481b0f8c7db51c70f34a825a4a..10131c80d3e8fca5adc352766627bbd37fb85a1a 100644 (file)
@@ -828,7 +828,11 @@ class SubqueryLoader(AbstractRelationshipLoader):
         # a cycle
         if not path.contains(context.attributes, "loader"):
             if self.join_depth:
-                if path.length / 2 > self.join_depth:
+                if (
+                    (context.query._current_path.length
+                     if context.query._current_path else 0) +
+                    path.length
+                ) / 2 > self.join_depth:
                     return
             elif subq_path.contains_mapper(self.mapper):
                 return
index 139628165a5b8cf2ca9c86fcb41783792a158ece..5d0aa13287de3bd4902b0c766d30e91d12e98e91 100644 (file)
@@ -517,6 +517,29 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         eq_(self.static.user_address_result,
             sess.query(User).order_by(User.id).all())
 
+    def test_cyclical_explicit_join_depth(self):
+        """A circular eager relationship breaks the cycle with a lazy loader"""
+
+        Address, addresses, users, User = (self.classes.Address,
+                                           self.tables.addresses,
+                                           self.tables.users,
+                                           self.classes.User)
+
+        mapper(Address, addresses)
+        mapper(User, users, properties=dict(
+            addresses=relationship(Address, lazy='subquery', join_depth=1,
+                                   backref=sa.orm.backref(
+                                       'user', lazy='subquery', join_depth=1),
+                                   order_by=Address.id)
+        ))
+        is_(sa.orm.class_mapper(User).get_property('addresses').lazy,
+            'subquery')
+        is_(sa.orm.class_mapper(Address).get_property('user').lazy, 'subquery')
+
+        sess = create_session()
+        eq_(self.static.user_address_result,
+            sess.query(User).order_by(User.id).all())
+
     def test_double(self):
         """Eager loading with two relationships simultaneously,
             from the same table, using aliases."""
@@ -1512,6 +1535,8 @@ class SelfReferentialTest(fixtures.MappedTest):
         n1.append(Node(data='n11'))
         n1.append(Node(data='n12'))
         n1.append(Node(data='n13'))
+        n1.children[0].append(Node(data='n111'))
+        n1.children[0].append(Node(data='n112'))
         n1.children[1].append(Node(data='n121'))
         n1.children[1].append(Node(data='n122'))
         n1.children[1].append(Node(data='n123'))
@@ -1521,14 +1546,22 @@ class SelfReferentialTest(fixtures.MappedTest):
 
         def go():
             allnodes = sess.query(Node).order_by(Node.data).all()
-            n12 = allnodes[2]
+
+            n11 = allnodes[1]
+            eq_(n11.data, 'n11')
+            eq_([
+                Node(data='n111'),
+                Node(data='n112'),
+            ], list(n11.children))
+
+            n12 = allnodes[4]
             eq_(n12.data, 'n12')
             eq_([
                 Node(data='n121'),
                 Node(data='n122'),
                 Node(data='n123')
             ], list(n12.children))
-        self.assert_sql_count(testing.db, go, 4)
+        self.assert_sql_count(testing.db, go, 2)
 
     def test_with_deferred(self):
         nodes = self.tables.nodes