]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix issue with unbaking subqueries
authorMark Hahnenberg <mark@nylas.com>
Tue, 12 Jul 2016 18:07:52 +0000 (14:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Jul 2016 18:09:14 +0000 (14:09 -0400)
Fix improper capture of a loop variable inside a lambda during unbaking
of subquery eager loaders, which would cause the incorrect query
to be invoked.

Fixes: #3743
Change-Id: I995110deb8ee2dae8540486729e1ae64578d28fc
Pull-request: https://github.com/zzzeek/sqlalchemy/pull/290

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/ext/baked.py
test/ext/test_baked.py

index f75956bcf39182f52cb392781782b6bc98cb5e2b..c977c7366d0e67d9d28375cc5d8f179406717d60 100644 (file)
 .. changelog::
     :version: 1.0.15
 
+    .. change::
+        :tags: bug, ext
+        :tickets: 3743
+
+        Fixed bug in ``sqlalchemy.ext.baked`` where the unbaking of a
+        subquery eager loader query would fail due to a variable scoping
+        issue, when multiple subquery loaders were involved.  Pull request
+        courtesy Mark Hahnenberg.
+
 .. changelog::
     :version: 1.0.14
     :released: July 6, 2016
index bfdc1e1a0f00291ba8f398a9ef4d08d007f4d937..3ca94925eba69875f7586f7bfbcabb10b5645559 100644 (file)
@@ -194,7 +194,8 @@ class BakedQuery(object):
 
         """
         for k, cache_key, query in context.attributes["baked_queries"]:
-            bk = BakedQuery(self._bakery, lambda sess: query.with_session(sess))
+            bk = BakedQuery(self._bakery,
+                            lambda sess, q=query: q.with_session(sess))
             bk._cache_key = cache_key
             context.attributes[k] = bk.for_session(session).params(**params)
 
index 8bfa58403e22bea9439ed6426308ef9db884903e..4250e363b12f2f67c0d0bc53d0065975ce9f52f8 100644 (file)
@@ -305,12 +305,16 @@ class ResultTest(BakedTest):
     def setup_mappers(cls):
         User = cls.classes.User
         Address = cls.classes.Address
+        Order = cls.classes.Order
 
         mapper(User, cls.tables.users, properties={
             "addresses": relationship(
-                Address, order_by=cls.tables.addresses.c.id)
+                Address, order_by=cls.tables.addresses.c.id),
+            "orders": relationship(
+                Order, order_by=cls.tables.orders.c.id)
         })
         mapper(Address, cls.tables.addresses)
+        mapper(Order, cls.tables.orders)
 
     def test_cachekeys_on_constructor(self):
         User = self.classes.User
@@ -551,24 +555,29 @@ class ResultTest(BakedTest):
     def test_subquery_eagerloading(self):
         User = self.classes.User
         Address = self.classes.Address
+        Order = self.classes.Order
 
-        base_bq = self.bakery(
-            lambda s: s.query(User))
+        # Override the default bakery for one with a smaller size. This used to
+        # trigger a bug when unbaking subqueries.
+        self.bakery = baked.bakery(size=3)
+        base_bq = self.bakery(lambda s: s.query(User))
 
-        base_bq += lambda q: q.options(subqueryload(User.addresses))
+        base_bq += lambda q: q.options(subqueryload(User.addresses),
+                                       subqueryload(User.orders))
         base_bq += lambda q: q.order_by(User.id)
 
         assert_result = [
-            User(id=7, addresses=[
-                Address(id=1, email_address='jack@bean.com')]),
+            User(id=7,
+                addresses=[Address(id=1, email_address='jack@bean.com')],
+                orders=[Order(id=1), Order(id=3), Order(id=5)]),
             User(id=8, addresses=[
                 Address(id=2, email_address='ed@wood.com'),
                 Address(id=3, email_address='ed@bettyboop.com'),
                 Address(id=4, email_address='ed@lala.com'),
             ]),
-            User(id=9, addresses=[
-                Address(id=5)
-            ]),
+            User(id=9,
+                addresses=[Address(id=5)], 
+                orders=[Order(id=2), Order(id=4)]),
             User(id=10, addresses=[])
         ]
 
@@ -603,18 +612,18 @@ class ResultTest(BakedTest):
                         def go():
                             result = bq(sess).all()
                             eq_(assert_result[1:2], result)
-                        self.assert_sql_count(testing.db, go, 2)
+                        self.assert_sql_count(testing.db, go, 3)
                 else:
                     if cond1:
                         def go():
                             result = bq(sess).all()
                             eq_(assert_result[0:1], result)
-                        self.assert_sql_count(testing.db, go, 2)
+                        self.assert_sql_count(testing.db, go, 3)
                     else:
                         def go():
                             result = bq(sess).all()
                             eq_(assert_result[1:3], result)
-                        self.assert_sql_count(testing.db, go, 2)
+                        self.assert_sql_count(testing.db, go, 3)
 
                 sess.close()