From 5cdb942791b9aeb63d02680c712d1afc104606b0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 9 Dec 2007 23:27:04 +0000 Subject: [PATCH] - Query.select_from() now replaces all existing FROM criterion with the given argument; the previous behavior of constructing a list of FROM clauses was generally not useful as is required filter() calls to create join criterion, and new tables introduced within filter() already add themselves to the FROM clause. The new behavior allows not just joins from the main table, but select statements as well. Filter criterion, order bys, eager load clauses will be "aliased" against the given statement. --- CHANGES | 9 ++ lib/sqlalchemy/orm/query.py | 153 +++++++++++++++++++++---------- lib/sqlalchemy/orm/strategies.py | 19 +--- lib/sqlalchemy/sql/compiler.py | 2 +- lib/sqlalchemy/sql/expression.py | 16 +++- test/orm/query.py | 130 +++++++++++++++++++++++++- test/sql/select.py | 15 +++ 7 files changed, 272 insertions(+), 72 deletions(-) diff --git a/CHANGES b/CHANGES index aa0c33c4b0..e0df814694 100644 --- a/CHANGES +++ b/CHANGES @@ -66,6 +66,15 @@ CHANGES database's ON UPDATE CASCADE (required for DB's like Postgres) or issued directly by the ORM in the form of UPDATE statements, by setting the flag "passive_cascades=False". + + - Query.select_from() now replaces all existing FROM criterion with + the given argument; the previous behavior of constructing a list + of FROM clauses was generally not useful as is required + filter() calls to create join criterion, and new tables introduced + within filter() already add themselves to the FROM clause. The + new behavior allows not just joins from the main table, but select + statements as well. Filter criterion, order bys, eager load + clauses will be "aliased" against the given statement. - added "cascade delete" behavior to "dynamic" relations just like that of regular relations. if passive_deletes flag (also just added) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 89733d5e53..dbc62a47b9 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -43,7 +43,7 @@ class Query(object): self._joinpoint = self.mapper self._aliases = None self._alias_ids = {} - self._from_obj = [self.table] + self._from_obj = self.table self._populate_existing = False self._version_check = False self._autoflush = True @@ -54,13 +54,13 @@ class Query(object): self._only_load_props = None self._refresh_instance = None - def _no_criterion(self): + def _no_criterion(self, meth): q = self._clone() - if q._criterion or q._statement or q._from_obj != [self.table]: - warnings.warn(RuntimeWarning("Query.get() being called on a Query with existing criterion; criterion is being ignored.")) + if q._criterion or q._statement or q._from_obj is not self.table: + warnings.warn(RuntimeWarning("Query.%s() being called on a Query with existing criterion; criterion is being ignored." % meth)) - q._from_obj = [self.table] + q._from_obj = self.table q._alias_ids = {} q._joinpoint = self.mapper q._statement = q._aliases = q._criterion = None @@ -322,6 +322,8 @@ class Query(object): if self._aliases is not None: criterion = self._aliases.adapt_clause(criterion) + elif self._from_obj is not self.table: + criterion = sql_util.ClauseAdapter(self._from_obj).traverse(criterion) q = self._clone() if q._criterion is not None: @@ -338,18 +340,22 @@ class Query(object): return self.filter(sql.and_(*clauses)) + def _get_joinable_tables(self): + currenttables = [self._from_obj] + def visit_join(join): + currenttables.append(join.left) + currenttables.append(join.right) + visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False}) + return currenttables + def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): if start is None: start = self._joinpoint - clause = self._from_obj[-1] + clause = self._from_obj - currenttables = [clause] - class FindJoinedTables(visitors.NoColumnVisitor): - def visit_join(self, join): - currenttables.append(join.left) - currenttables.append(join.right) - FindJoinedTables().traverse(clause) + currenttables = self._get_joinable_tables() + adapt_criterion = self.table not in currenttables mapper = start alias = self._aliases @@ -366,9 +372,15 @@ class Query(object): prop.get_join(mapper, primary=False, secondary=True), alias ) - clause = clause.join(alias.secondary, alias.primaryjoin, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin) + crit = alias.primaryjoin + if adapt_criterion: + crit = sql_util.ClauseAdapter(clause).traverse(crit) + clause = clause.join(alias.secondary, crit, isouter=outerjoin).join(alias.alias, alias.secondaryjoin, isouter=outerjoin) else: - clause = clause.join(prop.secondary, prop.get_join(mapper, primary=True, secondary=False), isouter=outerjoin) + crit = prop.get_join(mapper, primary=True, secondary=False) + if adapt_criterion: + crit = sql_util.ClauseAdapter(clause).traverse(crit) + clause = clause.join(prop.secondary, crit, isouter=outerjoin) clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False), isouter=outerjoin) else: if create_aliases: @@ -377,9 +389,15 @@ class Query(object): None, alias ) - clause = clause.join(alias.alias, alias.primaryjoin, isouter=outerjoin) + crit = alias.primaryjoin + if adapt_criterion: + crit = sql_util.ClauseAdapter(clause).traverse(crit) + clause = clause.join(alias.alias, crit, isouter=outerjoin) else: - clause = clause.join(prop.select_table, prop.get_join(mapper), isouter=outerjoin) + crit = prop.get_join(mapper) + if adapt_criterion: + crit = sql_util.ClauseAdapter(clause).traverse(crit) + clause = clause.join(prop.select_table, crit, isouter=outerjoin) elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables: # TODO: this check is not strong enough for different paths to the same endpoint which # does not use secondary tables @@ -526,7 +544,7 @@ class Query(object): def _join(self, prop, id, outerjoin, aliased, from_joinpoint): (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) q = self._clone() - q._from_obj = [clause] + q._from_obj = clause q._joinpoint = mapper q._aliases = aliases @@ -558,13 +576,23 @@ class Query(object): def select_from(self, from_obj): """Set the `from_obj` parameter of the query and return the newly - resulting ``Query``. - - `from_obj` is a list of one or more tables. + resulting ``Query``. This replaces the table which this Query selects + from with the given table. + + + `from_obj` is a single table or selectable. """ - new = self._clone() - new._from_obj = list(new._from_obj) + util.to_list(from_obj) + new = self._no_criterion('select_from') + if isinstance(from_obj, (tuple, list)): + util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.") + from_obj = from_obj[-1] + + if isinstance(from_obj, expression._SelectBaseMixin): + # alias SELECTs and unions + from_obj = from_obj.alias() + + new._from_obj = from_obj return new def __getitem__(self, item): @@ -638,7 +666,7 @@ class Query(object): if isinstance(statement, basestring): statement = sql.text(statement) - q = self._clone() + q = self._no_criterion('from_statement') q._statement = statement return q @@ -785,7 +813,7 @@ class Query(object): q = self if ident is not None: - q = q._no_criterion() + q = q._no_criterion('get') params = {} (_get_clause, _get_params) = self.select_mapper._get_clause q = q.filter(_get_clause) @@ -862,28 +890,33 @@ class Query(object): whereclause = self._criterion + from_obj = self._from_obj + currenttables = self._get_joinable_tables() + adapt_criterion = self.table not in currenttables + if whereclause is not None and (self.mapper is not self.select_mapper): # adapt the given WHERECLAUSE to adjust instances of this query's mapped # table to be that of our select_table, # which may be the "polymorphic" selectable used by our mapper. - whereclause = sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table])) + whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj])) # if extra entities, adapt the criterion to those as well for m in self._entities: if isinstance(m, type): m = mapper.class_mapper(m) if isinstance(m, mapper.Mapper): - table = m.select_table sql_util.ClauseAdapter(m.select_table).traverse(whereclause, stop_on=util.Set([m.select_table])) - from_obj = self._from_obj order_by = self._order_by if order_by is False: order_by = self.mapper.order_by if order_by is False: + order_by = [] if self.table.default_order_by() is not None: order_by = self.table.default_order_by() + if from_obj.default_order_by() is not None: + order_by = from_obj.default_order_by() try: for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] @@ -895,7 +928,7 @@ class Query(object): if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()])) - context.from_clauses = from_obj + context.from_clause = from_obj # give all the attached properties a chance to modify the query # TODO: doing this off the select_mapper. if its the polymorphic mapper, then @@ -916,7 +949,7 @@ class Query(object): if clauses is not None: m = clauses.aliased_column(m) context.secondary_columns.append(m) - + if self._eager_loaders and self._nestable(**self._select_args()): # eager loaders are present, and the SELECT has limiting criterion # produce a "wrapped" selectable. @@ -926,16 +959,19 @@ class Query(object): # locate all embedded Column clauses so they can be added to the # "inner" select statement where they'll be available to the enclosing # statement's "order by" - + cf = util.Set() - if order_by: order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] for o in order_by: cf.update(sql_util.find_columns(o)) - - s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args()) + if adapt_criterion: + context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns] + cf = [from_obj.corresponding_column(c, raiseerr=False) or c for c in cf] + + s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args()) + s3 = s2.alias() self._primary_adapter = mapperutil.create_row_adapter(s3, self.table) @@ -943,30 +979,44 @@ class Query(object): statement = sql.select([s3] + context.secondary_columns, for_update=for_update, use_labels=True) if context.eager_joins: - statement.append_from(sql_util.ClauseAdapter(s3).traverse(context.eager_joins), _copy_collection=False) - + eager_joins = sql_util.ClauseAdapter(s3).traverse(context.eager_joins) + statement.append_from(eager_joins, _copy_collection=False) + if order_by: - statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(util.to_list(order_by))) - + statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) + statement.append_order_by(*context.eager_order_by) else: + if adapt_criterion: + context.primary_columns = [from_obj.corresponding_column(c, raiseerr=False) or c for c in context.primary_columns] + self._primary_adapter = mapperutil.create_row_adapter(from_obj, self.table) + + if adapt_criterion or self._distinct: + if order_by: + order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] + + if adapt_criterion: + order_by = sql_util.ClauseAdapter(from_obj).copy_and_process(order_by) + + if self._distinct and order_by: + cf = util.Set() + for o in order_by: + cf.update(sql_util.find_columns(o)) + for c in cf: + context.primary_columns.append(c) + statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args()) + if context.eager_joins: + if adapt_criterion: + context.eager_joins = sql_util.ClauseAdapter(from_obj).traverse(context.eager_joins) statement.append_from(context.eager_joins, _copy_collection=False) if context.eager_order_by: + if adapt_criterion: + context.eager_order_by = sql_util.ClauseAdapter(from_obj).copy_and_process(context.eager_order_by) statement.append_order_by(*context.eager_order_by) - # for a DISTINCT query, you need the columns explicitly specified in order - # to use it in "order_by". ensure they are in the column criterion (particularly oid). - if self._distinct and order_by: - order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] - cf = util.Set() - for o in order_by: - cf.update(sql_util.find_columns(o)) - - [statement.append_column(c) for c in cf] - context.statement = statement return context @@ -1047,6 +1097,13 @@ class Query(object): if params is not None: q = q.params(params) return list(q) + + def _legacy_select_from(self, from_obj): + q = self._clone() + if len(from_obj) > 1: + raise exceptions.ArgumentError("Multiple-entry from_obj parameter no longer supported") + q._from_obj = from_obj[0] + return q def _legacy_select_kwargs(self, **kwargs): #pragma: no cover q = self @@ -1055,7 +1112,7 @@ class Query(object): if "group_by" in kwargs: q = q.group_by(kwargs['group_by']) if "from_obj" in kwargs: - q = q.select_from(kwargs['from_obj']) + q = q._legacy_select_from(kwargs['from_obj']) if "lockmode" in kwargs: q = q.with_lockmode(kwargs['lockmode']) if "distinct" in kwargs: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 9adf17f42a..d46271f9e2 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -484,24 +484,11 @@ class EagerLoader(AbstractRelationLoader): if context.eager_joins: towrap = context.eager_joins - elif isinstance(localparent.mapped_table, sql.Join): - towrap = localparent.mapped_table else: - # look for the mapper's selectable expressed within the current "from" criterion. - # this will locate the selectable inside of any containers it may be a part of (such - # as a join). if its inside of a join, we want to outer join on that join, not the - # selectable. - # TODO: slightly hacky way to get at all the froms - for fromclause in sql.select(from_obj=context.from_clauses).froms: - if fromclause is localparent.mapped_table: - towrap = fromclause - break - elif isinstance(fromclause, sql.Join): - if localparent.mapped_table in sql_util.find_tables(fromclause, include_aliases=True): - towrap = fromclause - break + if isinstance(context.from_clause, sql.Join): + towrap = context.from_clause else: - raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join onto, for mapped table %s" % (localparent.mapped_table)) + towrap = localparent.mapped_table # create AliasedClauses object to build up the eager query. this is cached after 1st creation. try: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3af8f97cab..30b4089d31 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -36,7 +36,7 @@ ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]') BIND_PARAMS = re.compile(r'(?