]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in query.join() which would occur
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Aug 2011 18:23:07 +0000 (14:23 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Aug 2011 18:23:07 +0000 (14:23 -0400)
    in a complex multiple-overlapping path scenario,
    where the same table could be joined to
    twice.  Thanks *much* to Dave Vitek
    for the excellent fix here.  [ticket:2247]

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/test_joins.py

diff --git a/CHANGES b/CHANGES
index d0d599b4c109a270b4bb53e3aad217973539b5a9..9b8dc256d0e5f527a324150d49c28424fc7edbc1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -15,6 +15,12 @@ CHANGES
      when the Session.is_active is True.
      [ticket:2241]
 
+  - Fixed bug in query.join() which would occur
+    in a complex multiple-overlapping path scenario,
+    where the same table could be joined to
+    twice.  Thanks *much* to Dave Vitek 
+    for the excellent fix here.  [ticket:2247]
+
 - sqlite
   - Ensured that the same ValueError is raised for
     illegal date/time/datetime string parsed from
index cf907b8795a4d0c0a59e3ba1ed1d2be6775464ca..437199ac81f18171033414db4e1de40a397e5506 100644 (file)
@@ -1507,6 +1507,18 @@ class Query(object):
                             outerjoin=True, create_aliases=aliased, 
                             from_joinpoint=from_joinpoint)
 
+    def _update_joinpoint(self, jp):
+        self._joinpoint = jp
+        # copy backwards to the root of the _joinpath
+        # dict, so that no existing dict in the path is mutated
+        while 'prev' in jp:
+            f, prev = jp['prev']
+            prev = prev.copy()
+            prev[f] = jp
+            jp['prev'] = (f, prev)
+            jp = prev
+        self._joinpath = jp
+
     @_generative(_no_statement_condition, _no_limit_offset)
     def _join(self, keys, outerjoin, create_aliases, from_joinpoint):
         """consumes arguments from join() or outerjoin(), places them into a
@@ -1586,11 +1598,18 @@ class Query(object):
                 if not create_aliases:
                     # check for this path already present.
                     # don't render in that case.
-                    if (left_entity, right_entity, prop.key) in \
-                                    self._joinpoint:
-                        self._joinpoint = \
-                                    self._joinpoint[
-                                    (left_entity, right_entity, prop.key)]
+                    edge = (left_entity, right_entity, prop.key)
+                    if edge in self._joinpoint:
+                        # The child's prev reference might be stale --
+                        # it could point to a parent older than the
+                        # current joinpoint.  If this is the case,
+                        # then we need to update it and then fix the
+                        # tree's spine with _update_joinpoint.  Copy
+                        # and then mutate the child, which might be
+                        # shared by a different query object.
+                        jp = self._joinpoint[edge].copy()
+                        jp['prev'] = (edge, self._joinpoint)
+                        self._update_joinpoint(jp)
                         continue
 
             elif onclause is not None and right_entity is None:
@@ -1661,23 +1680,10 @@ class Query(object):
         # if joining on a MapperProperty path,
         # track the path to prevent redundant joins
         if not create_aliases and prop:
-
-            self._joinpoint = jp = {
+            self._update_joinpoint({
                 '_joinpoint_entity':right,
                 'prev':((left, right, prop.key), self._joinpoint)
-            }
-
-            # copy backwards to the root of the _joinpath
-            # dict, so that no existing dict in the path is mutated
-            while 'prev' in jp:
-                f, prev = jp['prev']
-                prev = prev.copy()
-                prev[f] = jp
-                jp['prev'] = (f, prev)
-                jp = prev
-
-            self._joinpath = jp
-
+            })
         else:
             self._joinpoint = {
                 '_joinpoint_entity':right
index b2f03cad9551e64b733e856d20cef485645c58af..7e47c55bf86b55babc2c7514cfbcc5afd73eee9e 100644 (file)
@@ -352,6 +352,7 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
 
 
 class JoinTest(QueryTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
 
     def test_single_name(self):
         User = self.classes.User
@@ -362,7 +363,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(User).join("orders"),
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM users JOIN orders ON users.id = orders.user_id"
-            , use_default_dialect = True
         )
 
         assert_raises(
@@ -375,7 +375,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "SELECT users.id AS users_id, users.name AS users_name FROM users "
             "JOIN orders ON users.id = orders.user_id JOIN order_items AS order_items_1 "
             "ON orders.id = order_items_1.order_id JOIN items ON items.id = order_items_1.item_id"
-            , use_default_dialect=True
         )
 
         # test overlapping paths.   User->orders is used by both joins, but rendered once.
@@ -385,7 +384,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "ON users.id = orders.user_id JOIN order_items AS order_items_1 ON orders.id = "
             "order_items_1.order_id JOIN items ON items.id = order_items_1.item_id JOIN addresses "
             "ON addresses.id = orders.address_id"
-            , use_default_dialect=True
         )
 
     def test_multi_tuple_form(self):
@@ -411,7 +409,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(User).join((Order, User.id==Order.user_id)),
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM users JOIN orders ON users.id = orders.user_id",
-            use_default_dialect=True
         )
 
         self.assert_compile(
@@ -423,7 +420,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "JOIN order_items AS order_items_1 ON orders.id = "
             "order_items_1.order_id JOIN items ON items.id = "
             "order_items_1.item_id",
-            use_default_dialect=True
         )
 
         # the old "backwards" form
@@ -431,7 +427,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(User).join(("orders", Order)),
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM users JOIN orders ON users.id = orders.user_id",
-            use_default_dialect=True
         )
 
     def test_single_prop(self):
@@ -445,14 +440,12 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(User).join(User.orders),
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM users JOIN orders ON users.id = orders.user_id"
-            , use_default_dialect=True
         )
 
         self.assert_compile(
             sess.query(User).join(Order.user),
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM orders JOIN users ON users.id = orders.user_id"
-            , use_default_dialect=True
         )
 
         oalias1 = aliased(Order)
@@ -462,7 +455,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(User).join(oalias1.user),
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM orders AS orders_1 JOIN users ON users.id = orders_1.user_id"
-            , use_default_dialect=True
         )
 
         # another nonsensical query.  (from [ticket:1537]).
@@ -472,7 +464,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM orders AS orders_1 JOIN users ON users.id = orders_1.user_id, "
             "orders AS orders_2 JOIN users ON users.id = orders_2.user_id"
-            , use_default_dialect=True
         )
 
         self.assert_compile(
@@ -480,7 +471,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "SELECT users.id AS users_id, users.name AS users_name FROM users "
             "JOIN orders ON users.id = orders.user_id JOIN order_items AS order_items_1 "
             "ON orders.id = order_items_1.order_id JOIN items ON items.id = order_items_1.item_id"
-            , use_default_dialect=True
         )
 
         ualias = aliased(User)
@@ -488,7 +478,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(ualias).join(ualias.orders),
             "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name "
             "FROM users AS users_1 JOIN orders ON users_1.id = orders.user_id"
-            , use_default_dialect=True
         )
 
         # this query is somewhat nonsensical.  the old system didn't render a correct
@@ -502,7 +491,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "JOIN orders ON users.id = orders.user_id, "
             "orders AS orders_1 JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id "
             "JOIN items ON items.id = order_items_1.item_id"
-            , use_default_dialect=True
         )
 
         # same as before using an aliased() for User as well
@@ -513,7 +501,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "JOIN orders ON users_1.id = orders.user_id, "
             "orders AS orders_1 JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id "
             "JOIN items ON items.id = order_items_1.item_id"
-            , use_default_dialect=True
         )
 
         self.assert_compile(
@@ -522,7 +509,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "FROM (SELECT users.id AS users_id, users.name AS users_name "
             "FROM users "
             "WHERE users.name = :name_1) AS anon_1 JOIN orders ON anon_1.users_id = orders.user_id"
-            , use_default_dialect=True
         )
 
         self.assert_compile(
@@ -530,7 +516,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "SELECT users.id AS users_id, users.name AS users_name "
             "FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id "
             "WHERE addresses_1.email_address = :email_address_1"
-            , use_default_dialect=True
         )
 
         self.assert_compile(
@@ -540,7 +525,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id "
             "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
             "WHERE items_1.id = :id_1"
-            , use_default_dialect=True
         )
 
         # test #1 for [ticket:1706]
@@ -553,7 +537,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "users_1_name FROM users AS users_1 JOIN orders AS orders_1 "
             "ON users_1.id = orders_1.user_id JOIN addresses ON users_1.id "
             "= addresses.user_id"
-            , use_default_dialect=True
         )
 
         # test #2 for [ticket:1706]
@@ -566,7 +549,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name FROM users "
             "AS users_1 JOIN addresses ON users_1.id = addresses.user_id JOIN users AS users_2 "
             "ON users_2.id = addresses.user_id JOIN orders ON users_1.id = orders.user_id"
-            , use_default_dialect=True
         )
 
     def test_overlapping_paths(self):
@@ -578,6 +560,27 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
                     filter_by(id=3).join('orders','address', aliased=aliased).filter_by(id=1).all()
             assert [User(id=7, name='jack')] == result
 
+    def test_overlapping_paths_multilevel(self):
+        User = self.classes.User
+
+        s = Session()
+        q = s.query(User).\
+                    join('orders').\
+                    join('addresses').\
+                    join('orders', 'items').\
+                    join('addresses', 'dingaling')
+        self.assert_compile(
+            q,
+            "SELECT users.id AS users_id, users.name AS users_name "
+            "FROM users JOIN orders ON users.id = orders.user_id "
+            "JOIN addresses ON users.id = addresses.user_id "
+            "JOIN order_items AS order_items_1 ON orders.id = "
+            "order_items_1.order_id "
+            "JOIN items ON items.id = order_items_1.item_id "
+            "JOIN dingalings ON addresses.id = dingalings.address_id"
+
+        )
+
     def test_overlapping_paths_outerjoin(self):
         User = self.classes.User
 
@@ -640,11 +643,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             sess.query(User.id, literal_column('foo')).join(Order.user),
             "SELECT users.id AS users_id, foo FROM "
             "orders JOIN users ON users.id = orders.user_id"
-            , use_default_dialect=True
         )
 
-
-
     def test_backwards_join(self):
         User, Address = self.classes.User, self.classes.Address