From: Mike Bayer Date: Sat, 6 Aug 2011 18:23:07 +0000 (-0400) Subject: - Fixed bug in query.join() which would occur X-Git-Tag: rel_0_7_3~96 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b9a4eacfa3675f2eb9a141499d83cf532c263e01;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed bug in query.join() which would occur in a complex multiple-overlapping path scenario, where the same table could be joined to twice. Thanks *much* to Dave Vitek for the excellent fix here. [ticket:2247] --- diff --git a/CHANGES b/CHANGES index d0d599b4c1..9b8dc256d0 100644 --- a/CHANGES +++ b/CHANGES @@ -15,6 +15,12 @@ CHANGES when the Session.is_active is True. [ticket:2241] + - Fixed bug in query.join() which would occur + in a complex multiple-overlapping path scenario, + where the same table could be joined to + twice. Thanks *much* to Dave Vitek + for the excellent fix here. [ticket:2247] + - sqlite - Ensured that the same ValueError is raised for illegal date/time/datetime string parsed from diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index cf907b8795..437199ac81 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1507,6 +1507,18 @@ class Query(object): outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) + def _update_joinpoint(self, jp): + self._joinpoint = jp + # copy backwards to the root of the _joinpath + # dict, so that no existing dict in the path is mutated + while 'prev' in jp: + f, prev = jp['prev'] + prev = prev.copy() + prev[f] = jp + jp['prev'] = (f, prev) + jp = prev + self._joinpath = jp + @_generative(_no_statement_condition, _no_limit_offset) def _join(self, keys, outerjoin, create_aliases, from_joinpoint): """consumes arguments from join() or outerjoin(), places them into a @@ -1586,11 +1598,18 @@ class Query(object): if not create_aliases: # check for this path already present. # don't render in that case. - if (left_entity, right_entity, prop.key) in \ - self._joinpoint: - self._joinpoint = \ - self._joinpoint[ - (left_entity, right_entity, prop.key)] + edge = (left_entity, right_entity, prop.key) + if edge in self._joinpoint: + # The child's prev reference might be stale -- + # it could point to a parent older than the + # current joinpoint. If this is the case, + # then we need to update it and then fix the + # tree's spine with _update_joinpoint. Copy + # and then mutate the child, which might be + # shared by a different query object. + jp = self._joinpoint[edge].copy() + jp['prev'] = (edge, self._joinpoint) + self._update_joinpoint(jp) continue elif onclause is not None and right_entity is None: @@ -1661,23 +1680,10 @@ class Query(object): # if joining on a MapperProperty path, # track the path to prevent redundant joins if not create_aliases and prop: - - self._joinpoint = jp = { + self._update_joinpoint({ '_joinpoint_entity':right, 'prev':((left, right, prop.key), self._joinpoint) - } - - # copy backwards to the root of the _joinpath - # dict, so that no existing dict in the path is mutated - while 'prev' in jp: - f, prev = jp['prev'] - prev = prev.copy() - prev[f] = jp - jp['prev'] = (f, prev) - jp = prev - - self._joinpath = jp - + }) else: self._joinpoint = { '_joinpoint_entity':right diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index b2f03cad95..7e47c55bf8 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -352,6 +352,7 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL): class JoinTest(QueryTest, AssertsCompiledSQL): + __dialect__ = 'default' def test_single_name(self): User = self.classes.User @@ -362,7 +363,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(User).join("orders"), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders ON users.id = orders.user_id" - , use_default_dialect = True ) assert_raises( @@ -375,7 +375,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name FROM users " "JOIN orders ON users.id = orders.user_id JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id JOIN items ON items.id = order_items_1.item_id" - , use_default_dialect=True ) # test overlapping paths. User->orders is used by both joins, but rendered once. @@ -385,7 +384,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "ON users.id = orders.user_id JOIN order_items AS order_items_1 ON orders.id = " "order_items_1.order_id JOIN items ON items.id = order_items_1.item_id JOIN addresses " "ON addresses.id = orders.address_id" - , use_default_dialect=True ) def test_multi_tuple_form(self): @@ -411,7 +409,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(User).join((Order, User.id==Order.user_id)), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders ON users.id = orders.user_id", - use_default_dialect=True ) self.assert_compile( @@ -423,7 +420,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN order_items AS order_items_1 ON orders.id = " "order_items_1.order_id JOIN items ON items.id = " "order_items_1.item_id", - use_default_dialect=True ) # the old "backwards" form @@ -431,7 +427,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(User).join(("orders", Order)), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders ON users.id = orders.user_id", - use_default_dialect=True ) def test_single_prop(self): @@ -445,14 +440,12 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(User).join(User.orders), "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN orders ON users.id = orders.user_id" - , use_default_dialect=True ) self.assert_compile( sess.query(User).join(Order.user), "SELECT users.id AS users_id, users.name AS users_name " "FROM orders JOIN users ON users.id = orders.user_id" - , use_default_dialect=True ) oalias1 = aliased(Order) @@ -462,7 +455,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(User).join(oalias1.user), "SELECT users.id AS users_id, users.name AS users_name " "FROM orders AS orders_1 JOIN users ON users.id = orders_1.user_id" - , use_default_dialect=True ) # another nonsensical query. (from [ticket:1537]). @@ -472,7 +464,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name " "FROM orders AS orders_1 JOIN users ON users.id = orders_1.user_id, " "orders AS orders_2 JOIN users ON users.id = orders_2.user_id" - , use_default_dialect=True ) self.assert_compile( @@ -480,7 +471,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name FROM users " "JOIN orders ON users.id = orders.user_id JOIN order_items AS order_items_1 " "ON orders.id = order_items_1.order_id JOIN items ON items.id = order_items_1.item_id" - , use_default_dialect=True ) ualias = aliased(User) @@ -488,7 +478,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(ualias).join(ualias.orders), "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " "FROM users AS users_1 JOIN orders ON users_1.id = orders.user_id" - , use_default_dialect=True ) # this query is somewhat nonsensical. the old system didn't render a correct @@ -502,7 +491,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN orders ON users.id = orders.user_id, " "orders AS orders_1 JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id " "JOIN items ON items.id = order_items_1.item_id" - , use_default_dialect=True ) # same as before using an aliased() for User as well @@ -513,7 +501,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN orders ON users_1.id = orders.user_id, " "orders AS orders_1 JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id " "JOIN items ON items.id = order_items_1.item_id" - , use_default_dialect=True ) self.assert_compile( @@ -522,7 +509,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "FROM (SELECT users.id AS users_id, users.name AS users_name " "FROM users " "WHERE users.name = :name_1) AS anon_1 JOIN orders ON anon_1.users_id = orders.user_id" - , use_default_dialect=True ) self.assert_compile( @@ -530,7 +516,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name " "FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id " "WHERE addresses_1.email_address = :email_address_1" - , use_default_dialect=True ) self.assert_compile( @@ -540,7 +525,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id " "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " "WHERE items_1.id = :id_1" - , use_default_dialect=True ) # test #1 for [ticket:1706] @@ -553,7 +537,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "users_1_name FROM users AS users_1 JOIN orders AS orders_1 " "ON users_1.id = orders_1.user_id JOIN addresses ON users_1.id " "= addresses.user_id" - , use_default_dialect=True ) # test #2 for [ticket:1706] @@ -566,7 +549,6 @@ class JoinTest(QueryTest, AssertsCompiledSQL): "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name FROM users " "AS users_1 JOIN addresses ON users_1.id = addresses.user_id JOIN users AS users_2 " "ON users_2.id = addresses.user_id JOIN orders ON users_1.id = orders.user_id" - , use_default_dialect=True ) def test_overlapping_paths(self): @@ -578,6 +560,27 @@ class JoinTest(QueryTest, AssertsCompiledSQL): filter_by(id=3).join('orders','address', aliased=aliased).filter_by(id=1).all() assert [User(id=7, name='jack')] == result + def test_overlapping_paths_multilevel(self): + User = self.classes.User + + s = Session() + q = s.query(User).\ + join('orders').\ + join('addresses').\ + join('orders', 'items').\ + join('addresses', 'dingaling') + self.assert_compile( + q, + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users JOIN orders ON users.id = orders.user_id " + "JOIN addresses ON users.id = addresses.user_id " + "JOIN order_items AS order_items_1 ON orders.id = " + "order_items_1.order_id " + "JOIN items ON items.id = order_items_1.item_id " + "JOIN dingalings ON addresses.id = dingalings.address_id" + + ) + def test_overlapping_paths_outerjoin(self): User = self.classes.User @@ -640,11 +643,8 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess.query(User.id, literal_column('foo')).join(Order.user), "SELECT users.id AS users_id, foo FROM " "orders JOIN users ON users.id = orders.user_id" - , use_default_dialect=True ) - - def test_backwards_join(self): User, Address = self.classes.User, self.classes.Address