From 53deb98918263cd0a89d1a0aeb73f7010d8907bf Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 15 Mar 2009 03:02:42 +0000 Subject: [PATCH] - Query.join() can now construct multiple FROM clauses, if needed. Such as, query(A, B).join(A.x).join(B.y) might say SELECT A.*, B.* FROM A JOIN X, B JOIN Y. Eager loading can also tack its joins onto those multiple FROM clauses. [ticket:1337] --- CHANGES | 6 ++++ lib/sqlalchemy/orm/query.py | 53 +++++++++++++++++--------------- lib/sqlalchemy/orm/strategies.py | 22 ++++++++----- lib/sqlalchemy/sql/util.py | 32 ++++++++++++++----- test/orm/eager_relations.py | 44 +++++++++++++++++++++++++- test/orm/inheritance/query.py | 19 ++++++++++++ 6 files changed, 136 insertions(+), 40 deletions(-) diff --git a/CHANGES b/CHANGES index 35af1e41b1..503031d411 100644 --- a/CHANGES +++ b/CHANGES @@ -19,6 +19,12 @@ CHANGES union(query1, query2), select([foo]).select_from(query), etc. + - Query.join() can now construct multiple FROM clauses, if + needed. Such as, query(A, B).join(A.x).join(B.y) + might say SELECT A.*, B.* FROM A JOIN X, B JOIN Y. + Eager loading can also tack its joins onto those + multiple FROM clauses. [ticket:1337] + - Fixed bug where column_prefix wasn't being checked before not mapping an attribute that already had class-level name present. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a8be43f5a1..9dba9edad4 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -83,7 +83,7 @@ class Query(object): self._current_path = () self._only_load_props = None self._refresh_state = None - self._from_obj = None + self._from_obj = () self._polymorphic_adapters = {} self._filter_aliases = None self._from_obj_alias = None @@ -135,11 +135,11 @@ class Query(object): if isinstance(from_obj, expression._SelectBaseMixin): from_obj = from_obj.alias() - self._from_obj = from_obj + self._from_obj = (from_obj,) equivs = self.__all_equivs() if isinstance(from_obj, expression.Alias): - self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs) + self._from_obj_alias = sql_util.ColumnAdapter(from_obj, equivs) def _get_polymorphic_adapter(self, entity, selectable): self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) @@ -258,12 +258,6 @@ class Query(object): self._entities = [entity] + self._entities[1:] return entity - def __mapper_zero_from_obj(self): - if self._from_obj: - return self._from_obj - else: - return self._entity_zero().selectable - def __all_equivs(self): equivs = {} for ent in self._mapper_entities: @@ -276,7 +270,8 @@ class Query(object): self._group_by: raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth) - self._statement = self._criterion = self._from_obj = None + self._from_obj = () + self._statement = self._criterion = None self._order_by = self._group_by = self._distinct = False self.__joined_tables = {} @@ -899,10 +894,8 @@ class Query(object): if not from_joinpoint: self.__reset_joinpoint() - # join from our from_obj. This is - # None unless select_from()/from_self() has been called. - clause = self._from_obj - + clause = replace_clause_index = None + # after the method completes, # the query's joinpoint will be set to this. right_entity = None @@ -963,10 +956,13 @@ class Query(object): elif not left_entity: left_entity = self._joinpoint_zero() - # if no initial left-hand clause is set, extract - # this from the left_entity or as a last - # resort from the onclause argument, if it's - # a PropComparator. + if not clause and self._from_obj: + mp, left_selectable, is_aliased_class = _entity_info(left_entity) + + replace_clause_index, clause = sql_util.find_join_source(self._from_obj, left_selectable) + if not clause: + clause = left_selectable + if not clause: for ent in self._entities: if ent.corresponds_to(left_entity): @@ -1067,7 +1063,7 @@ class Query(object): # construct the onclause. join_to_left = not is_aliased_class or \ onclause is prop or \ - clause is self._from_obj and self._from_obj_alias + self._from_obj_alias and clause is self._from_obj[0] # create the join clause = orm_join(clause, right_entity, onclause, isouter=outerjoin, join_to_left=join_to_left) @@ -1086,9 +1082,12 @@ class Query(object): ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns) ) - # loop finished. we're selecting from - # our final clause now - self._from_obj = clause + if replace_clause_index is not None: + l = list(self._from_obj) + l[replace_clause_index] = clause + self._from_obj = tuple(l) + else: + self._from_obj = self._from_obj + (clause,) # future joins with from_joinpoint=True join from our established right_entity. self._joinpoint = right_entity @@ -1115,7 +1114,13 @@ class Query(object): `from_obj` is a single table or selectable. """ + if isinstance(from_obj, (tuple, list)): + # from_obj is actually a list again as of 0.5.3. so this restriction here + # is somewhat artificial, but is still in place since select_from() implies aliasing all further + # criterion against what's placed here, and its less complex to only + # keep track of a single aliased FROM element being selected against. This could in theory be opened + # up again to more complexity. util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.") from_obj = from_obj[-1] if not isinstance(from_obj, expression.FromClause): @@ -1474,7 +1479,7 @@ class Query(object): entity.setup_context(self, context) if context.from_clause: - from_obj = [context.from_clause] + from_obj = list(context.from_clause) else: from_obj = context.froms @@ -1722,7 +1727,7 @@ class Query(object): eager_joins = context.eager_joins.values() if context.from_clause: - froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used + froms = list(context.from_clause) # "load from explicit FROMs" mode, i.e. when select_from() or join() is used else: froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index f46ffc44d8..dc64b283d6 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -660,18 +660,26 @@ class EagerLoader(AbstractRelationLoader): if entity in context.eager_joins: entity_key, default_towrap = entity, entity.selectable - elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable): - # if no from_clause, or a from_clause we can't join to, or a subquery is going to be generated, + + elif should_nest_selectable or not context.from_clause: + # if no from_clause, or a subquery is going to be generated, # store eager joins per _MappedEntity; Query._compile_context will # add them as separate selectables to the select(), or splice them together # after the subquery is generated entity_key, default_towrap = entity, entity.selectable else: - # otherwise, create a single eager join from the from clause. - # Query._compile_context will adapt as needed and append to the - # FROM clause of the select(). - entity_key, default_towrap = None, context.from_clause - + index, clause = sql_util.find_join_source(context.from_clause, entity.selectable) + if clause: + # join to an existing FROM clause on the query. + # key it to its list index in the eager_joins dict. + # Query._compile_context will adapt as needed and append to the + # FROM clause of the select(). + entity_key, default_towrap = index, clause + else: + # if no from_clause to join to, + # store eager joins per _MappedEntity + entity_key, default_towrap = entity, entity.selectable + towrap = context.eager_joins.setdefault(entity_key, default_towrap) # create AliasedClauses object to build up the eager query. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index b8ceabb741..a8de5c6352 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -21,15 +21,31 @@ def sort_tables(tables): visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) return topological.sort(tuples, tables) -def search(clause, target): - if not clause: - return False - for elem in visitors.iterate(clause, {'column_collections':False}): - if elem is target: - return True +def find_join_source(clauses, join_to): + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns + None, None if no match is found. + + e.g.:: + + clause1 = table1.join(table2) + clause2 = table4.join(table5) + + join_to = table2.join(table3) + + find_join_source([clause1, clause2], join_to) == clause1 + + """ + + selectables = list(expression._from_objects(join_to)) + for i, f in enumerate(clauses): + for s in selectables: + if f.is_derived_from(s): + return i, f else: - return False - + return None, None + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False): """locate Table objects within the given expression.""" diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index be1d0a955d..1876eca258 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -4,7 +4,7 @@ import testenv; testenv.configure_for_tests() from testlib import sa, testing from sqlalchemy.orm import eagerload, deferred, undefer from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func -from testlib.sa.orm import mapper, relation, create_session, lazyload +from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased from testlib.testing import eq_ from testlib.assertsql import CompiledSQL from orm import _base, _fixtures @@ -1219,6 +1219,48 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) self.assert_sql_count(testing.db, go, 1) + @testing.exclude('sqlite', '>', (0, 0, 0), "sqlite flat out blows it on the multiple JOINs") + @testing.resolve_artifact_names + def test_two_entities_with_joins(self): + sess = create_session() + + # two FROM clauses where there's a join on each one + def go(): + u1 = aliased(User) + o1 = aliased(Order) + eq_( + [ + ( + User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), + Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]), + User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), + Order(description=u'order 3', isopen=1, items=[Item(description=u'item 3'), Item(description=u'item 4'), Item(description=u'item 5')]) + ), + + ( + User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), + Order(description=u'order 2', isopen=0, items=[Item(description=u'item 1'), Item(description=u'item 2'), Item(description=u'item 3')]), + User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), + Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')]) + ), + + ( + User(addresses=[Address(email_address=u'fred@fred.com')], name=u'fred'), + Order(description=u'order 4', isopen=1, items=[Item(description=u'item 1'), Item(description=u'item 5')]), + User(addresses=[Address(email_address=u'jack@bean.com')], name=u'jack'), + Order(address_id=None, description=u'order 5', isopen=0, items=[Item(description=u'item 5')]) + ), + ], + sess.query(User, Order, u1, o1).\ + join((Order, User.orders)).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).\ + join((o1, u1.orders)).options(eagerload(u1.addresses), eagerload(o1.items)).filter(u1.id==7).\ + filter(Order.id