From: Mike Bayer Date: Sat, 23 Jun 2012 16:39:46 +0000 (-0400) Subject: - simplify setup_entity and related calls X-Git-Tag: rel_0_8_0b1~358 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=df62f4501ee1ec37113477eb6a97068cc07faf5d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - simplify setup_entity and related calls - break _compile_context() into three methods --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 987a77ba96..156dd6128f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -129,7 +129,8 @@ class Query(object): for entity in ent.entities: if entity not in d: ext_info = _extended_entity_info(entity) - if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic: + if not ext_info.is_aliased_class and \ + ext_info.mapper.with_polymorphic: if ext_info.mapper.mapped_table not in \ self._polymorphic_adapters: self._mapper_loads_polymorphically_with(ext_info.mapper, @@ -144,10 +145,11 @@ class Query(object): else: aliased_adapter = None - d[entity] = (ext_info.mapper, aliased_adapter, ext_info.selectable, - ext_info.is_aliased_class, ext_info.with_polymorphic_mappers, - ext_info.with_polymorphic_discriminator) - ent.setup_entity(entity, *d[entity]) + d[entity] = ( + ext_info, + aliased_adapter + ) + ent.setup_entity(*d[entity]) def _mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: @@ -2536,7 +2538,8 @@ class Query(object): to count, to skip the usage of a subquery or otherwise control of the FROM clause, or to use other aggregate functions, - use :attr:`~sqlalchemy.sql.expression.func` expressions in conjunction + use :attr:`~sqlalchemy.sql.expression.func` + expressions in conjunction with :meth:`~.Session.query`, i.e.:: from sqlalchemy import func @@ -2605,7 +2608,8 @@ class Query(object): """ #TODO: cascades need handling. - delete_op = persistence.BulkDelete.factory(self, synchronize_session) + delete_op = persistence.BulkDelete.factory( + self, synchronize_session) delete_op.exec_() return delete_op.rowcount @@ -2661,30 +2665,34 @@ class Query(object): # fk assignments #TODO: cascades need handling. - update_op = persistence.BulkUpdate.factory(self, synchronize_session, values) + update_op = persistence.BulkUpdate.factory( + self, synchronize_session, values) update_op.exec_() return update_op.rowcount + _lockmode_lookup = { + 'read': 'read', + 'read_nowait': 'read_nowait', + 'update': True, + 'update_nowait': 'nowait', + None: False + } + def _compile_context(self, labels=True): context = QueryContext(self) if context.statement is not None: return context + context.labels = labels + if self._lockmode: try: - for_update = {'read': 'read', - 'read_nowait': 'read_nowait', - 'update': True, - 'update_nowait': 'nowait', - None: False}[self._lockmode] + context.for_update = self._lockmode_lookup[self._lockmode] except KeyError: raise sa_exc.ArgumentError( - "Unknown lockmode %r" % self._lockmode) - else: - for_update = False - + "Unknown lockmode %r" % self._lockmode) for entity in self._entities: entity.setup_context(self, context) @@ -2697,11 +2705,11 @@ class Query(object): if context.from_clause: # "load from explicit FROMs" mode, # i.e. when select_from() or join() is used - froms = list(context.from_clause) + context.froms = list(context.from_clause) else: # "load from discrete FROMs" mode, # i.e. when each _MappedEntity has its own FROM - froms = context.froms + context.froms = context.froms if self._enable_single_crit: self._adjust_for_single_inheritance(context) @@ -2718,128 +2726,133 @@ class Query(object): "SELECT from.") if context.multi_row_eager_loaders and self._should_nest_selectable: - # for eager joins present and LIMIT/OFFSET/DISTINCT, - # wrap the query inside a select, - # then append eager joins onto that - - if context.order_by: - order_by_col_expr = list( - chain(*[ - sql_util.unwrap_order_by(o) - for o in context.order_by - ]) - ) - else: - context.order_by = None - order_by_col_expr = [] + context.statement = self._compound_eager_statement(context) + else: + context.statement = self._simple_statement(context) + return context - inner = sql.select( - context.primary_columns + order_by_col_expr, - context.whereclause, - from_obj=froms, - use_labels=labels, - # TODO: this order_by is only needed if - # LIMIT/OFFSET is present in self._select_args, - # else the application on the outside is enough - order_by=context.order_by, - **self._select_args - ) + def _compound_eager_statement(self, context): + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if context.order_by: + order_by_col_expr = list( + chain(*[ + sql_util.unwrap_order_by(o) + for o in context.order_by + ]) + ) + else: + context.order_by = None + order_by_col_expr = [] + + inner = sql.select( + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=context.froms, + use_labels=context.labels, + # TODO: this order_by is only needed if + # LIMIT/OFFSET is present in self._select_args, + # else the application on the outside is enough + order_by=context.order_by, + **self._select_args + ) - for hint in self._with_hints: - inner = inner.with_hint(*hint) + for hint in self._with_hints: + inner = inner.with_hint(*hint) - if self._correlate: - inner = inner.correlate(*self._correlate) + if self._correlate: + inner = inner.correlate(*self._correlate) - inner = inner.alias() + inner = inner.alias() - equivs = self.__all_equivs() + equivs = self.__all_equivs() - context.adapter = sql_util.ColumnAdapter(inner, equivs) + context.adapter = sql_util.ColumnAdapter(inner, equivs) - statement = sql.select( - [inner] + context.secondary_columns, - for_update=for_update, - use_labels=labels) + statement = sql.select( + [inner] + context.secondary_columns, + for_update=context.for_update, + use_labels=context.labels) - from_clause = inner - for eager_join in eager_joins: - # EagerLoader places a 'stop_on' attribute on the join, - # giving us a marker as to where the "splice point" of - # the join should be - from_clause = sql_util.splice_joins( - from_clause, - eager_join, eager_join.stop_on) + from_clause = inner + for eager_join in context.eager_joins.values(): + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of + # the join should be + from_clause = sql_util.splice_joins( + from_clause, + eager_join, eager_join.stop_on) - statement.append_from(from_clause) + statement.append_from(from_clause) - if context.order_by: - statement.append_order_by( - *context.adapter.copy_and_process( - context.order_by - ) + if context.order_by: + statement.append_order_by( + *context.adapter.copy_and_process( + context.order_by ) + ) - statement.append_order_by(*context.eager_order_by) - else: - if not context.order_by: - context.order_by = None - - if self._distinct and context.order_by: - order_by_col_expr = list( - chain(*[ - sql_util.unwrap_order_by(o) - for o in context.order_by - ]) - ) - context.primary_columns += order_by_col_expr - - froms += tuple(context.eager_joins.values()) - - statement = sql.select( - context.primary_columns + - context.secondary_columns, - context.whereclause, - from_obj=froms, - use_labels=labels, - for_update=for_update, - order_by=context.order_by, - **self._select_args - ) + statement.append_order_by(*context.eager_order_by) + return statement + + def _simple_statement(self, context): + if not context.order_by: + context.order_by = None + + if self._distinct and context.order_by: + order_by_col_expr = list( + chain(*[ + sql_util.unwrap_order_by(o) + for o in context.order_by + ]) + ) + context.primary_columns += order_by_col_expr + + context.froms += tuple(context.eager_joins.values()) - for hint in self._with_hints: - statement = statement.with_hint(*hint) + statement = sql.select( + context.primary_columns + + context.secondary_columns, + context.whereclause, + from_obj=context.froms, + use_labels=context.labels, + for_update=context.for_update, + order_by=context.order_by, + **self._select_args + ) - if self._correlate: - statement = statement.correlate(*self._correlate) + for hint in self._with_hints: + statement = statement.with_hint(*hint) - if context.eager_order_by: - statement.append_order_by(*context.eager_order_by) + if self._correlate: + statement = statement.correlate(*self._correlate) - context.statement = statement + if context.eager_order_by: + statement.append_order_by(*context.eager_order_by) + return statement - return context def _adjust_for_single_inheritance(self, context): """Apply single-table-inheritance filtering. - - For all distinct single-table-inheritance mappers represented in the - columns clause of this query, add criterion to the WHERE clause of the - given QueryContext such that only the appropriate subtypes are - selected from the total results. - + + For all distinct single-table-inheritance mappers represented in + the columns clause of this query, add criterion to the WHERE + clause of the given QueryContext such that only the appropriate + subtypes are selected from the total results. + """ - for entity, (mapper, adapter, s, i, w, d) in \ - self._mapper_adapter_map.iteritems(): - if entity in self._join_entities: + for (ext_info, adapter) in self._mapper_adapter_map.values(): + if ext_info.entity in self._join_entities: continue - single_crit = mapper._single_table_criterion + single_crit = ext_info.mapper._single_table_criterion if single_crit is not None: if adapter: single_crit = adapter.traverse(single_crit) single_crit = self._adapt_clause(single_crit, False, False) - context.whereclause = sql.and_( - context.whereclause, single_crit) + context.whereclause = sql.and_(context.whereclause, + single_crit) def __str__(self): return str(self._compile_context().statement) @@ -2873,21 +2886,19 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.expr = entity - def setup_entity(self, entity, mapper, aliased_adapter, - from_obj, is_aliased_class, - with_polymorphic, - with_polymorphic_discriminator): - self.mapper = mapper + def setup_entity(self, ext_info, aliased_adapter): + self.mapper = ext_info.mapper self.aliased_adapter = aliased_adapter - self.selectable = from_obj - self.is_aliased_class = is_aliased_class - self._with_polymorphic = with_polymorphic - self._polymorphic_discriminator = with_polymorphic_discriminator - if is_aliased_class: - self.entity_zero = entity + self.selectable = ext_info.selectable + self.is_aliased_class = ext_info.is_aliased_class + self._with_polymorphic = ext_info.with_polymorphic_mappers + self._polymorphic_discriminator = \ + ext_info.with_polymorphic_discriminator + if ext_info.is_aliased_class: + self.entity_zero = ext_info.entity self._label_name = self.entity_zero._sa_label_name else: - self.entity_zero = mapper + self.entity_zero = self.mapper self._label_name = self.mapper.class_.__name__ self.path = self.entity_zero._sa_path_registry @@ -3132,12 +3143,10 @@ class _ColumnEntity(_QueryEntity): c.entity_zero = self.entity_zero c.entities = self.entities - def setup_entity(self, entity, mapper, adapter, from_obj, - is_aliased_class, with_polymorphic, - with_polymorphic_discriminator): + def setup_entity(self, ext_info, aliased_adapter): if 'selectable' not in self.__dict__: - self.selectable = from_obj - self.froms.add(from_obj) + self.selectable = ext_info.selectable + self.froms.add(ext_info.selectable) def corresponds_to(self, entity): if self.entity_zero is None: @@ -3178,6 +3187,7 @@ class QueryContext(object): multi_row_eager_loaders = False adapter = None froms = () + for_update = False def __init__(self, query):