From: Mike Bayer Date: Wed, 12 Oct 2011 19:15:28 +0000 (-0400) Subject: - Improved query.join() such that the "left" side X-Git-Tag: rel_0_7_3~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=df02cc0854068a93aa3a49a312a91780de236f5e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Improved query.join() such that the "left" side can more flexibly be a non-ORM selectable, such as a subquery. A selectable placed in select_from() will now be used as the left side, favored over implicit usage of a mapped entity. If the join still fails based on lack of foreign keys, the error message includes this detail. Thanks to brianrhude on IRC for the test case. [ticket:2298] --- diff --git a/CHANGES b/CHANGES index 7d3dd04ec1..d072eba6ec 100644 --- a/CHANGES +++ b/CHANGES @@ -15,6 +15,17 @@ CHANGES fixes [ticket:2279]. Also in 0.6.9. - orm + - Improved query.join() such that the "left" side + can more flexibly be a non-ORM selectable, + such as a subquery. A selectable placed + in select_from() will now be used as the left + side, favored over implicit usage + of a mapped entity. + If the join still fails based on lack of + foreign keys, the error message includes + this detail. Thanks to brianrhude + on IRC for the test case. [ticket:2298] + - Added after_soft_rollback() Session event. This event fires unconditionally whenever rollback() is called, regardless of if an actual DBAPI diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6ef6c3b573..88532d9096 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1626,7 +1626,10 @@ class Query(object): """append a JOIN to the query's from clause.""" if left is None: - left = self._joinpoint_zero() + if self._from_obj: + left = self._from_obj[0] + elif self._entities: + left = self._entities[0].entity_zero_or_selectable if left is right and \ not create_aliases: @@ -1742,10 +1745,15 @@ class Query(object): sql_util.clause_is_present(left_selectable, clause): join_to_left = False - clause = orm_join(clause, + try: + clause = orm_join(clause, right, onclause, isouter=outerjoin, join_to_left=join_to_left) + except sa_exc.ArgumentError, ae: + raise sa_exc.InvalidRequestError( + "Could not find a FROM clause to join from. " + "Tried joining to %s, but got: %s" % (right, ae)) self._from_obj = \ self._from_obj[:replace_clause_index] + \ @@ -1760,6 +1768,8 @@ class Query(object): break else: clause = left + elif left_selectable is not None: + clause = left_selectable else: clause = None @@ -1767,8 +1777,13 @@ class Query(object): raise sa_exc.InvalidRequestError( "Could not find a FROM clause to join from") - clause = orm_join(clause, right, onclause, + try: + clause = orm_join(clause, right, onclause, isouter=outerjoin, join_to_left=join_to_left) + except sa_exc.ArgumentError, ae: + raise sa_exc.InvalidRequestError( + "Could not find a FROM clause to join from. " + "Tried joining to %s, but got: %s" % (right, ae)) self._from_obj = self._from_obj + (clause,) @@ -2890,6 +2905,10 @@ class _MapperEntity(_QueryEntity): def type(self): return self.mapper.class_ + @property + def entity_zero_or_selectable(self): + return self.entity_zero + def corresponds_to(self, entity): if _is_aliased_class(entity) or self.is_aliased_class: return entity is self.path_entity @@ -3055,7 +3074,7 @@ class _ColumnEntity(_QueryEntity): # of FROMs for the overall expression - this helps # subqueries which were built from ORM constructs from # leaking out their entities into the main select construct - actual_froms = set(column._from_objects) + self.actual_froms = actual_froms = set(column._from_objects) self.entities = util.OrderedSet( elem._annotations['parententity'] @@ -3069,6 +3088,15 @@ class _ColumnEntity(_QueryEntity): else: self.entity_zero = None + @property + def entity_zero_or_selectable(self): + if self.entity_zero: + return self.entity_zero + elif self.actual_froms: + return list(self.actual_froms)[0] + else: + return None + @property def type(self): return self.column.type diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index e2fb55129d..c8fc0af79b 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -1610,6 +1610,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): class SelectFromTest(QueryTest, AssertsCompiledSQL): run_setup_mappers = None + __dialect__ = 'default' def test_replace_with_select(self): users, Address, addresses, User = (self.tables.users, @@ -1676,7 +1677,6 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name AS users_name FROM " "users JOIN (SELECT users.id AS id, users.name AS name FROM " "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 ON users.id > anon_1.id", - use_default_dialect=True ) self.assert_compile( @@ -1684,18 +1684,22 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name FROM " "users AS users_1, (SELECT users.id AS id, users.name AS name FROM " "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 WHERE users_1.id > anon_1.id", - use_default_dialect=True ) - # these two are essentially saying, "join ualias to ualias", so an - # error is raised. join() deals with entities, not what's in - # select_from(). - assert_raises(sa_exc.InvalidRequestError, - sess.query(ualias).select_from(sel).join, ualias, ualias.id>sel.c.id + self.assert_compile( + sess.query(ualias).select_from(sel).join(ualias, ualias.id>sel.c.id), + "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " + "FROM (SELECT users.id AS id, users.name AS name " + "FROM users WHERE users.id IN (:id_1, :id_2)) AS anon_1 " + "JOIN users AS users_1 ON users_1.id > anon_1.id" ) - assert_raises(sa_exc.InvalidRequestError, - sess.query(ualias).select_from(sel).join, ualias, ualias.id>User.id + self.assert_compile( + sess.query(ualias).select_from(sel).join(ualias, ualias.id>User.id), + "SELECT users_1.id AS users_1_id, users_1.name AS users_1_name " + "FROM (SELECT users.id AS id, users.name AS name FROM " + "users WHERE users.id IN (:id_1, :id_2)) AS anon_1 " + "JOIN users AS users_1 ON anon_1.id < users_1.id" ) salias = aliased(User, sel) @@ -1704,7 +1708,6 @@ class SelectFromTest(QueryTest, AssertsCompiledSQL): "SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name FROM " "(SELECT users.id AS id, users.name AS name FROM users WHERE users.id " "IN (:id_1, :id_2)) AS anon_1 JOIN users AS users_1 ON users_1.id > anon_1.id", - use_default_dialect=True ) diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py index ccd42c10bd..db7c78cdd6 100644 --- a/test/orm/test_joins.py +++ b/test/orm/test_joins.py @@ -881,7 +881,11 @@ class JoinTest(QueryTest, AssertsCompiledSQL): sess = create_session() - assert_raises_message(sa.exc.InvalidRequestError, "Could not find a FROM clause to join from", sess.query(users).join, addresses) + self.assert_compile( + sess.query(users).join(addresses), + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users JOIN addresses ON users.id = addresses.user_id" + ) def test_orderby_arg_bug(self): @@ -1261,13 +1265,19 @@ class JoinTest(QueryTest, AssertsCompiledSQL): assert_raises_message( sa_exc.InvalidRequestError, - "Could not find a FROM", + "Could not find a FROM clause to join from. Tried joining " + "to .*?, but got: " + "Can't find any foreign key relationships " + "between 'users' and 'users'.", sess.query(users.c.id).join, User ) assert_raises_message( sa_exc.InvalidRequestError, - "Could not find a FROM", + "Could not find a FROM clause to join from. Tried joining " + "to .*?, but got: " + "Can't find any foreign key relationships " + "between 'users' and 'users'.", sess.query(users.c.id).select_from(users).join, User ) @@ -1322,6 +1332,151 @@ class JoinTest(QueryTest, AssertsCompiledSQL): use_default_dialect=True ) +class JoinFromSelectableTest(fixtures.MappedTest, AssertsCompiledSQL): + __dialect__ = 'default' + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + Table('table1', metadata, + Column('id', Integer, primary_key=True) + ) + Table('table2', metadata, + Column('id', Integer, primary_key=True), + Column('t1_id', Integer) + ) + + @classmethod + def setup_classes(cls): + table1, table2 = cls.tables.table1, cls.tables.table2 + class T1(cls.Comparable): + pass + + class T2(cls.Comparable): + pass + + mapper(T1, table1) + mapper(T2, table2) + + def test_select_mapped_to_mapped_explicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + self.assert_compile( + sess.query(subq.c.count, T1.id).select_from(subq).join(T1, subq.c.t1_id==T1.id), + "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id " + "FROM (SELECT table2.t1_id AS t1_id, " + "count(table2.id) AS count FROM table2 " + "GROUP BY table2.t1_id) AS anon_1 JOIN table1 ON anon_1.t1_id = table1.id" + ) + + def test_select_mapped_to_mapped_implicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + self.assert_compile( + sess.query(subq.c.count, T1.id).join(T1, subq.c.t1_id==T1.id), + "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id " + "FROM (SELECT table2.t1_id AS t1_id, " + "count(table2.id) AS count FROM table2 " + "GROUP BY table2.t1_id) AS anon_1 JOIN table1 ON anon_1.t1_id = table1.id" + ) + + def test_select_mapped_to_select_explicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + self.assert_compile( + sess.query(subq.c.count, T1.id).select_from(T1).join(subq, subq.c.t1_id==T1.id), + "SELECT anon_1.count AS anon_1_count, table1.id AS table1_id " + "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, " + "count(table2.id) AS count FROM table2 GROUP BY table2.t1_id) " + "AS anon_1 ON anon_1.t1_id = table1.id" + ) + + def test_select_mapped_to_select_implicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + assert_raises_message( + sa_exc.InvalidRequestError, + r"Can't construct a join from ", + sess.query(subq.c.count, T1.id).join, subq, subq.c.t1_id==T1.id, + ) + + def test_mapped_select_to_mapped_implicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + # this query is wrong, but verifying behavior stays the same + # (or improves, like an error message) + self.assert_compile( + sess.query(T1.id, subq.c.count).join(T1, subq.c.t1_id==T1.id), + "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count FROM " + "(SELECT table2.t1_id AS t1_id, count(table2.id) AS count FROM " + "table2 GROUP BY table2.t1_id) AS anon_1, table1 JOIN table1 " + "ON anon_1.t1_id = table1.id" + ) + + def test_mapped_select_to_mapped_explicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + self.assert_compile( + sess.query(T1.id, subq.c.count).select_from(subq).join(T1, subq.c.t1_id==T1.id), + "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " + "FROM (SELECT table2.t1_id AS t1_id, count(table2.id) AS count " + "FROM table2 GROUP BY table2.t1_id) AS anon_1 JOIN table1 " + "ON anon_1.t1_id = table1.id" + ) + + def test_mapped_select_to_select_explicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + self.assert_compile( + sess.query(T1.id, subq.c.count).select_from(T1).join(subq, subq.c.t1_id==T1.id), + "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " + "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, count(table2.id) AS count " + "FROM table2 GROUP BY table2.t1_id) AS anon_1 " + "ON anon_1.t1_id = table1.id" + ) + + def test_mapped_select_to_select_implicit_left(self): + T1, T2 = self.classes.T1, self.classes.T2 + + sess = Session() + subq = sess.query(T2.t1_id, func.count(T2.id).label('count')).\ + group_by(T2.t1_id).subquery() + + self.assert_compile( + sess.query(T1.id, subq.c.count).join(subq, subq.c.t1_id==T1.id), + "SELECT table1.id AS table1_id, anon_1.count AS anon_1_count " + "FROM table1 JOIN (SELECT table2.t1_id AS t1_id, count(table2.id) AS count " + "FROM table2 GROUP BY table2.t1_id) AS anon_1 " + "ON anon_1.t1_id = table1.id" + ) class MultiplePathTest(fixtures.MappedTest, AssertsCompiledSQL): @classmethod