From: Mike Bayer Date: Sun, 4 Apr 2010 01:42:41 +0000 (-0400) Subject: cleanup and callcount reduction in mapper._save_obj, _delete_obj. X-Git-Tag: rel_0_6_0~64^2~1 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=dfab13e9ae147a7112a23db5abc4e4b7addc00fa;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git cleanup and callcount reduction in mapper._save_obj, _delete_obj. includes an untested fix for [ticket:1761] --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8f0f2128bd..f83275b360 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1224,11 +1224,13 @@ class Mapper(object): try: if item_type == 'property': prop = iterator.next() - visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None)) + visitables.append((prop.cascade_iterator(type_, parent_state, + visited_instances, halt_on), 'mapper', None)) elif item_type == 'mapper': instance, instance_mapper, corresponding_state = iterator.next() yield (instance, instance_mapper) - visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state)) + visitables.append((instance_mapper._props.itervalues(), + 'property', corresponding_state)) except StopIteration: visitables.pop() @@ -1263,55 +1265,46 @@ class Mapper(object): # if batch=false, call _save_obj separately for each object if not single and not self.batch: for state in _sort_states(states): - self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) + self._save_obj([state], + uowtransaction, + postupdate=postupdate, + post_update_cols=post_update_cols, + single=True) return - + # if session has a connection callable, - # organize individual states with the connection to use for insert/update - tups = [] + # organize individual states with the connection + # to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: - connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - for state in _sort_states(states): - m = _state_mapper(state) - tups.append( - ( - state, - m, - connection_callable(self, state.obj()), - _state_has_identity(state), - state.key or m._identity_key_from_state(state) - ) - ) + connection_callable = \ + uowtransaction.mapper_flush_opts['connection_callable'] else: connection = uowtransaction.transaction.connection(self) - for state in _sort_states(states): - m = _state_mapper(state) - tups.append( - ( - state, - m, - connection, - _state_has_identity(state), - state.key or m._identity_key_from_state(state) - ) - ) + connection_callable = None - if not postupdate: - # call before_XXX extensions - for state, mapper, connection, has_identity, instance_key in tups: + tups = [] + for state in _sort_states(states): + conn = connection_callable and \ + connection_callable(self, state.obj()) or \ + connection + + has_identity = _state_has_identity(state) + mapper = _state_mapper(state) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + if not postupdate: + # call before_XXX extensions if not has_identity: if 'before_insert' in mapper.extension: - mapper.extension.before_insert(mapper, connection, state.obj()) + mapper.extension.before_insert(mapper, conn, state.obj()) else: if 'before_update' in mapper.extension: - mapper.extension.before_update(mapper, connection, state.obj()) + mapper.extension.before_update(mapper, conn, state.obj()) - row_switches = {} - if not postupdate: - for state, mapper, connection, has_identity, instance_key in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. + # and another instance with the same identity key already exists as persistent. + # convert to an UPDATE if so. if not has_identity and instance_key in uowtransaction.session.identity_map: instance = uowtransaction.session.identity_map[instance_key] existing = attributes.instance_state(instance) @@ -1320,28 +1313,42 @@ class Mapper(object): "New instance %s with identity key %s conflicts " "with persistent instance %s" % (state_str(state), instance_key, state_str(existing))) - + self._log_debug( - "detected row switch for identity %s. will update %s, remove %s from " - "transaction", instance_key, state_str(state), state_str(existing)) - + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + # remove the "delete" flag from the existing element uowtransaction.set_row_switch(existing) - row_switches[state] = existing - - table_to_mapper = self._sorted_tables + row_switch = existing + + tups.append( + (state, + mapper, + conn, + has_identity, + instance_key, + row_switch) + ) - for table in table_to_mapper.iterkeys(): + table_to_mapper = self._sorted_tables + + for table in table_to_mapper: insert = [] update = [] - for state, mapper, connection, has_identity, instance_key in tups: + for state, mapper, connection, has_identity, \ + instance_key, row_switch in tups: if table not in mapper._pks_by_table: continue pks = mapper._pks_by_table[table] - isinsert = not has_identity and not postupdate and state not in row_switches + isinsert = not has_identity and \ + not postupdate and \ + not row_switch params = {} value_params = {} @@ -1371,23 +1378,36 @@ class Mapper(object): value_params[col] = value else: params[col.key] = value - insert.append((state, params, mapper, connection, value_params)) + insert.append((state, params, mapper, + connection, value_params)) else: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: - params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col) - params[col.key] = mapper.version_id_generator(params[col._label]) + params[col._label] = \ + mapper._get_state_attr_by_column( + row_switch or state, + col) + params[col.key] = \ + mapper.version_id_generator(params[col._label]) + + # HACK: check for history, in case the history is only + # in a different table than the one where the version_id_col + # is. for prop in mapper._columntoproperty.itervalues(): - history = attributes.get_state_history(state, prop.key, passive=True) + history = attributes.get_state_history( + state, prop.key, passive=True) if history.added: hasdata = True elif mapper.polymorphic_on is not None and \ - mapper.polymorphic_on.shares_lineage(col) and col not in pks: + mapper.polymorphic_on.shares_lineage(col) and \ + col not in pks: pass else: - if post_update_cols is not None and col not in post_update_cols: + if post_update_cols is not None and \ + col not in post_update_cols: if col in pks: - params[col._label] = mapper._get_state_attr_by_column(state, col) + params[col._label] = \ + mapper._get_state_attr_by_column(state, col) continue prop = mapper._columntoproperty[col] @@ -1424,27 +1444,32 @@ class Mapper(object): elif col in pks: params[col._label] = mapper._get_state_attr_by_column(state, col) if hasdata: - update.append((state, params, mapper, connection, value_params)) + update.append((state, params, mapper, + connection, value_params)) if update: mapper = table_to_mapper[table] clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) + clause.clauses.append( + col == + sql.bindparam(col._label, type_=col.type) + ) - if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): - + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + if needs_version_id: clause.clauses.append(mapper.version_id_col ==\ sql.bindparam(mapper.version_id_col._label, type_=col.type)) statement = table.update(clause) - + rows = 0 for state, params, mapper, connection, value_params in update: c = connection.execute(statement.values(value_params), params) - mapper._postfetch(uowtransaction, connection, table, + mapper._postfetch(uowtransaction, table, state, c, c.last_updated_params(), value_params) rows += c.rowcount @@ -1452,13 +1477,15 @@ class Mapper(object): if connection.dialect.supports_sane_rowcount: if rows != len(update): raise orm_exc.ConcurrentModificationError( - "Updated rowcount %d does not match number of objects updated %d" % + "Updated rowcount %d does not match number " + "of objects updated %d" % (rows, len(update))) - - elif mapper.version_id_col is not None: + + elif needs_version_id: util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % c.dialect.dialect_description, - stacklevel=12) + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) if insert: statement = table.insert() @@ -1473,12 +1500,12 @@ class Mapper(object): len(primary_key) > i: mapper._set_state_attr_by_column(state, col, primary_key[i]) - mapper._postfetch(uowtransaction, connection, table, + mapper._postfetch(uowtransaction, table, state, c, c.last_inserted_params(), value_params) - if not postupdate: - for state, mapper, connection, has_identity, instance_key in tups: + for state, mapper, connection, has_identity, \ + instance_key, row_switch in tups: # expire readonly attributes readonly = state.unmodified.intersection( @@ -1488,8 +1515,8 @@ class Mapper(object): if readonly: _expire_state(state, state.dict, readonly) - # if specified, eagerly refresh whatever has - # been expired. + # if eager_defaults option is enabled, + # refresh whatever has been expired. if self.eager_defaults and state.unloaded: state.key = self._identity_key_from_state(state) uowtransaction.session.query(self)._get( @@ -1504,7 +1531,7 @@ class Mapper(object): if 'after_update' in mapper.extension: mapper.extension.after_update(mapper, connection, state.obj()) - def _postfetch(self, uowtransaction, connection, table, + def _postfetch(self, uowtransaction, table, state, resultproxy, params, value_params): """Expire attributes in need of newly persisted database state.""" @@ -1523,23 +1550,37 @@ class Mapper(object): if c.key in params and c in self._columntoproperty: self._set_state_attr_by_column(state, c, params[c.key]) - deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] - - if deferred_props: - _expire_state(state, state.dict, deferred_props) + if postfetch_cols: + _expire_state(state, state.dict, + [self._columntoproperty[c].key + for c in postfetch_cols] + ) # synchronize newly inserted ids from one table to the next # TODO: this still goes a little too often. would be nice to # have definitive list of "columns that changed" here - cols = set(table.c) - for m in self.iterate_to_root(): - if m._inherits_equated_pairs and \ - cols.intersection([l for l, r in m._inherits_equated_pairs]): - sync.populate(state, m, state, m, - m._inherits_equated_pairs, - uowtransaction, - self.passive_updates) - + for m, equated_pairs in self._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + self.passive_updates) + + @util.memoized_property + def _table_to_equated(self): + """memoized map of tables to collections of columns to be + synchronized upwards to the base mapper.""" + + result = util.defaultdict(list) + + for table in self._sorted_tables: + cols = set(table.c) + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and \ + cols.intersection([l for l, r in m._inherits_equated_pairs]): + result[table].append((m, m._inherits_equated_pairs)) + + return result + def _delete_obj(self, states, uowtransaction): """Issue ``DELETE`` statements for a list of objects. @@ -1548,50 +1589,95 @@ class Mapper(object): """ if 'connection_callable' in uowtransaction.mapper_flush_opts: - connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)] + connection_callable = \ + uowtransaction.mapper_flush_opts['connection_callable'] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)] - - for state, mapper, connection in tups: + connection_callable = None + + tups = [] + for state in _sort_states(states): + mapper = _state_mapper(state) + + conn = connection_callable and \ + connection_callable(self, state.obj()) or \ + connection + if 'before_delete' in mapper.extension: - mapper.extension.before_delete(mapper, connection, state.obj()) + mapper.extension.before_delete(mapper, conn, state.obj()) + + tups.append((state, + _state_mapper(state), + _state_has_identity(state), + conn)) table_to_mapper = self._sorted_tables for table in reversed(table_to_mapper.keys()): - delete = {} - for state, mapper, connection in tups: - if table not in mapper._pks_by_table: + delete = util.defaultdict(list) + for state, mapper, has_identity, connection in tups: + if not has_identity or table not in mapper._pks_by_table: continue params = {} - if not _state_has_identity(state): - continue - else: - delete.setdefault(connection, []).append(params) + delete[connection].append(params) for col in mapper._pks_by_table[table]: params[col.key] = mapper._get_state_attr_by_column(state, col) - if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_state_attr_by_column(state, mapper.version_id_col) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) - if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + if need_version_id: clause.clauses.append( mapper.version_id_col == - sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + statement = table.delete(clause) - c = connection.execute(statement, del_objects) - if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): - raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match " - "number of objects deleted %d" % (c.rowcount, len(del_objects))) + rows = -1 + + if need_version_id and \ + not connection.dialect.supports_sane_multi_rowcount: + # TODO: need test coverage for this [ticket:1761] + if connection.dialect.supports_sane_rowcount: + rows = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows += c.rowcount + else: + util.warn("Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + c = connection.execute(statement, del_objects) + if connection.dialect.supports_sane_multi_rowcount: + rows = c.rowcount + + if rows != -1 and rows != len(del_objects): + raise orm_exc.ConcurrentModificationError( + "Deleted rowcount %d does not match " + "number of objects deleted %d" % + (c.rowcount, len(del_objects)) + ) - for state, mapper, connection in tups: + for state, mapper, has_identity, connection in tups: if 'after_delete' in mapper.extension: mapper.extension.after_delete(mapper, connection, state.obj())