From: Mike Bayer Date: Tue, 31 Jan 2012 00:52:07 +0000 (-0500) Subject: break out _save_obj(), _delete_obj(), _post_update() into a new module X-Git-Tag: rel_0_7_6~74 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b40450d0a1389edd02366f284199ecbf7d566ff1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git break out _save_obj(), _delete_obj(), _post_update() into a new module persistence.py - Mapper loses awareness of how to emit INSERT/UPDATE/DELETE, persistence.py is only used by unitofwork.py. Then break each method out into a top level with almost no logic, calling into _organize_states_for_XYZ(), _collect_XYZ_commands(), _emit_XYZ_statements(). --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index bf277664ad..42acb4928c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1875,500 +1875,6 @@ class Mapper(object): self._memoized_values[key] = value = callable_() return value - def _post_update(self, states, uowtransaction, post_update_cols): - """Issue UPDATE statements on behalf of a relationship() which - specifies post_update. - - """ - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - # if session has a connection callable, - # organize individual states with the connection - # to use for update - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - for state in _sort_states(states): - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - mapper = _state_mapper(state) - - tups.append((state, state.dict, mapper, conn)) - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper: - update = [] - - for state, state_dict, mapper, connection in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - params = {} - hasdata = False - - for col in mapper._cols_by_table[table]: - if col in pks: - params[col._label] = \ - mapper._get_state_attr_by_column( - state, - state_dict, col) - elif col in post_update_cols: - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - value = history.added[0] - params[col.key] = value - hasdata = True - if hasdata: - update.append((state, state_dict, params, mapper, - connection)) - - if update: - mapper = table_to_mapper[table] - - def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - return table.update(clause) - - statement = self._memo(('post_update', table), update_stmt) - - # execute each UPDATE in the order according to the original - # list of states to guarantee row access order, but - # also group them into common (connection, cols) sets - # to support executemany(). - for key, grouper in groupby( - update, lambda rec: (rec[4], rec[2].keys()) - ): - multiparams = [params for state, state_dict, - params, mapper, conn in grouper] - cached_connections[connection].\ - execute(statement, multiparams) - - def _save_obj(self, states, uowtransaction, single=False): - """Issue ``INSERT`` and/or ``UPDATE`` statements for a list - of objects. - - This is called within the context of a UOWTransaction during a - flush operation, given a list of states to be flushed. The - base mapper in an inheritance hierarchy handles the inserts/ - updates for all descendant mappers. - - """ - - # if batch=false, call _save_obj separately for each object - if not single and not self.batch: - for state in _sort_states(states): - self._save_obj([state], - uowtransaction, - single=True) - return - - # if session has a connection callable, - # organize individual states with the connection - # to use for insert/update - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - - for state in _sort_states(states): - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - has_identity = bool(state.key) - mapper = _state_mapper(state) - instance_key = state.key or mapper._identity_key_from_state(state) - - row_switch = None - - # call before_XXX extensions - if not has_identity: - mapper.dispatch.before_insert(mapper, conn, state) - else: - mapper.dispatch.before_update(mapper, conn, state) - - # detect if we have a "pending" instance (i.e. has - # no instance_key attached to it), and another instance - # with the same identity key already exists as persistent. - # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise orm_exc.FlushError( - "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) - - self._log_debug( - "detected row switch for identity %s. " - "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) - - # remove the "delete" flag from the existing element - uowtransaction.remove_state_actions(existing) - row_switch = existing - - tups.append( - (state, state.dict, mapper, conn, - has_identity, instance_key, row_switch) - ) - - # dictionary of connection->connection_with_cache_options. - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - table_to_mapper = self._sorted_tables - - for table in table_to_mapper: - insert = [] - update = [] - - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - - isinsert = not has_identity and not row_switch - - params = {} - value_params = {} - - if isinsert: - has_all_pks = True - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col.key] = \ - mapper.version_id_generator(None) - else: - # pull straight from the dict for - # pending objects - prop = mapper._columntoproperty[col] - value = state_dict.get(prop.key, None) - - if value is None: - if col in pks: - has_all_pks = False - elif col.default is None and \ - col.server_default is None: - params[col.key] = value - - elif isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value - - insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks)) - else: - hasdata = hasnull = False - for col in mapper._cols_by_table[table]: - if col is mapper.version_id_col: - params[col._label] = \ - mapper._get_committed_state_attr_by_column( - row_switch or state, - row_switch and row_switch.dict - or state_dict, - col) - - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE - ) - if history.added: - params[col.key] = history.added[0] - hasdata = True - else: - params[col.key] = \ - mapper.version_id_generator( - params[col._label]) - - # HACK: check for history, in case the - # history is only - # in a different table than the one - # where the version_id_col is. - for prop in mapper._columntoproperty.\ - itervalues(): - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - hasdata = True - else: - prop = mapper._columntoproperty[col] - history = attributes.get_state_history( - state, prop.key, - attributes.PASSIVE_NO_INITIALIZE) - if history.added: - if isinstance(history.added[0], - sql.ClauseElement): - value_params[col] = history.added[0] - else: - value = history.added[0] - params[col.key] = value - - if col in pks: - if history.deleted and \ - not row_switch: - # if passive_updates and sync detected - # this was a pk->pk sync, use the new - # value to locate the row, since the - # DB would already have set this - if ("pk_cascaded", state, col) in \ - uowtransaction.\ - attributes: - value = history.added[0] - params[col._label] = value - else: - # use the old value to - # locate the row - value = history.deleted[0] - params[col._label] = value - hasdata = True - else: - # row switch logic can reach us here - # remove the pk from the update params - # so the update doesn't - # attempt to include the pk in the - # update statement - del params[col.key] - value = history.added[0] - params[col._label] = value - if value is None: - hasnull = True - else: - hasdata = True - elif col in pks: - value = state.manager[prop.key].\ - impl.get(state, state_dict) - if value is None: - hasnull = True - params[col._label] = value - if hasdata: - if hasnull: - raise sa_exc.FlushError( - "Can't update table " - "using NULL for primary " - "key value") - update.append((state, state_dict, params, mapper, - connection, value_params)) - - if update: - mapper = table_to_mapper[table] - - needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) - - def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append(mapper.version_id_col ==\ - sql.bindparam(mapper.version_id_col._label, - type_=col.type)) - - return table.update(clause) - - statement = self._memo(('update', table), update_stmt) - - rows = 0 - for state, state_dict, params, mapper, \ - connection, value_params in update: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) - else: - c = cached_connections[connection].\ - execute(statement, params) - - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, - c.context.compiled_parameters[0], - value_params) - rows += c.rowcount - - if connection.dialect.supports_sane_rowcount: - if rows != len(update): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to update %d row(s); " - "%d were matched." % - (table.description, len(update), rows)) - - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) - - if insert: - statement = self._memo(('insert', table), table.insert) - - for (connection, pkeys, hasvalue, has_all_pks), \ - records in groupby(insert, - lambda rec: (rec[4], - rec[2].keys(), - bool(rec[5]), - rec[6]) - ): - if has_all_pks and not hasvalue: - records = list(records) - multiparams = [rec[2] for rec in records] - c = cached_connections[connection].\ - execute(statement, multiparams) - - for (state, state_dict, params, mapper, - conn, value_params, has_all_pks), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - c.context.prefetch_cols, - c.context.postfetch_cols, - last_inserted_params, - value_params) - - else: - for state, state_dict, params, mapper, \ - connection, value_params, \ - has_all_pks in records: - - if value_params: - result = connection.execute( - statement.values(value_params), - params) - else: - result = cached_connections[connection].\ - execute(statement, params) - - primary_key = result.context.inserted_primary_key - - if primary_key is not None: - # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): - prop = mapper._columntoproperty[col] - if state_dict.get(prop.key) is None: - # TODO: would rather say: - #state_dict[prop.key] = pk - mapper._set_state_attr_by_column( - state, - state_dict, - col, pk) - - mapper._postfetch( - uowtransaction, - table, - state, - state_dict, - result.context.prefetch_cols, - result.context.postfetch_cols, - result.context.compiled_parameters[0], - value_params) - - - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: - - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [p.key for p in mapper._readonly_props - if p.expire_on_flush or p.key not in state.dict] - ) - if readonly: - state.expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, - # refresh whatever has been expired. - if self.eager_defaults and state.unloaded: - state.key = self._identity_key_from_state(state) - uowtransaction.session.query(self)._load_on_ident( - state.key, refresh_state=state, - only_load_props=state.unloaded) - - # call after_XXX extensions - if not has_identity: - mapper.dispatch.after_insert(mapper, connection, state) - else: - mapper.dispatch.after_update(mapper, connection, state) - - def _postfetch(self, uowtransaction, table, - state, dict_, prefetch_cols, postfetch_cols, - params, value_params): - """During a flush, expire attributes in need of newly - persisted database state.""" - - if self.version_id_col is not None: - prefetch_cols = list(prefetch_cols) + [self.version_id_col] - - for c in prefetch_cols: - if c.key in params and c in self._columntoproperty: - self._set_state_attr_by_column(state, dict_, c, params[c.key]) - - if postfetch_cols: - state.expire_attributes(state.dict, - [self._columntoproperty[c].key - for c in postfetch_cols if c in - self._columntoproperty] - ) - - # synchronize newly inserted ids from one table to the next - # TODO: this still goes a little too often. would be nice to - # have definitive list of "columns that changed" here - for m, equated_pairs in self._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - self.passive_updates) - @util.memoized_property def _table_to_equated(self): """memoized map of tables to collections of columns to be @@ -2386,128 +1892,6 @@ class Mapper(object): return result - def _delete_obj(self, states, uowtransaction): - """Issue ``DELETE`` statements for a list of objects. - - This is called within the context of a UOWTransaction during a - flush operation. - - """ - if uowtransaction.session.connection_callable: - connection_callable = \ - uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(self) - connection_callable = None - - tups = [] - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - - for state in _sort_states(states): - mapper = _state_mapper(state) - - if connection_callable: - conn = connection_callable(self, state.obj()) - else: - conn = connection - - mapper.dispatch.before_delete(mapper, conn, state) - - tups.append((state, - state.dict, - _state_mapper(state), - bool(state.key), - conn)) - - table_to_mapper = self._sorted_tables - - for table in reversed(table_to_mapper.keys()): - delete = util.defaultdict(list) - for state, state_dict, mapper, has_identity, connection in tups: - if not has_identity or table not in mapper._pks_by_table: - continue - - params = {} - delete[connection].append(params) - for col in mapper._pks_by_table[table]: - params[col.key] = \ - value = \ - mapper._get_state_attr_by_column( - state, state_dict, col) - if value is None: - raise sa_exc.FlushError( - "Can't delete from table " - "using NULL for primary " - "key value") - - if mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col): - params[mapper.version_id_col.key] = \ - mapper._get_committed_state_attr_by_column( - state, state_dict, - mapper.version_id_col) - - mapper = table_to_mapper[table] - need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) - - def delete_stmt(): - clause = sql.and_() - for col in mapper._pks_by_table[table]: - clause.clauses.append( - col == sql.bindparam(col.key, type_=col.type)) - - if need_version_id: - clause.clauses.append( - mapper.version_id_col == - sql.bindparam( - mapper.version_id_col.key, - type_=mapper.version_id_col.type - ) - ) - - return table.delete(clause) - - for connection, del_objects in delete.iteritems(): - statement = self._memo(('delete', table), delete_stmt) - rows = -1 - - connection = cached_connections[connection] - - if need_version_id and \ - not connection.dialect.supports_sane_multi_rowcount: - # TODO: need test coverage for this [ticket:1761] - if connection.dialect.supports_sane_rowcount: - rows = 0 - # execute deletes individually so that versioned - # rows can be verified - for params in del_objects: - c = connection.execute(statement, params) - rows += c.rowcount - else: - util.warn( - "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." % - connection.dialect.dialect_description, - stacklevel=12) - connection.execute(statement, del_objects) - else: - c = connection.execute(statement, del_objects) - if connection.dialect.supports_sane_multi_rowcount: - rows = c.rowcount - - if rows != -1 and rows != len(del_objects): - raise orm_exc.StaleDataError( - "DELETE statement on table '%s' expected to delete %d row(s); " - "%d were matched." % - (table.description, len(del_objects), c.rowcount) - ) - - for state, state_dict, mapper, has_identity, connection in tups: - mapper.dispatch.after_delete(mapper, connection, state) def _instance_processor(self, context, path, reduced_path, adapter, polymorphic_from=None, @@ -2517,6 +1901,12 @@ class Mapper(object): """Produce a mapper level row processor callable which processes rows into mapped instances.""" + # note that this method, most of which exists in a closure + # called _instance(), resists being broken out, as + # attempts to do so tend to add significant function + # call overhead. _instance() is the most + # performance-critical section in the whole ORM. + pk_cols = self.primary_key if polymorphic_from or refresh_state: @@ -2960,13 +2350,6 @@ def _event_on_resurrect(state): state, state.dict, col, val) -def _sort_states(states): - pending = set(states) - persistent = set(s for s in pending if s.key is not None) - pending.difference_update(persistent) - return sorted(pending, key=operator.attrgetter("insert_order")) + \ - sorted(persistent, key=lambda q:q.key[1]) - class _ColumnMapping(util.py25_dict): """Error reporting helper for mapper._columntoproperty.""" diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py new file mode 100644 index 0000000000..746395919a --- /dev/null +++ b/lib/sqlalchemy/orm/persistence.py @@ -0,0 +1,780 @@ +# orm/persistence.py +# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" + +import operator +from itertools import groupby + +from sqlalchemy import sql, util, exc as sa_exc +from sqlalchemy.orm import attributes, sync, \ + exc as orm_exc + +from sqlalchemy.orm.util import _state_mapper, state_str + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_insert, states_to_update = _organize_states_for_save( + base_mapper, + states, + uowtransaction) + + cached_connections = _cached_connection_dict(base_mapper) + + for table, mapper in base_mapper._sorted_tables.iteritems(): + insert = _collect_insert_commands(base_mapper, uowtransaction, + table, states_to_insert) + + update = _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update) + + if update: + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + + if insert: + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + table, insert) + + _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update) + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + cached_connections = _cached_connection_dict(base_mapper) + + states_to_update = _organize_states_for_post_update( + base_mapper, + states, uowtransaction) + + + for table, mapper in base_mapper._sorted_tables.iteritems(): + update = _collect_post_update_commands(base_mapper, uowtransaction, + table, states_to_update, + post_update_cols) + + if update: + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + cached_connections = _cached_connection_dict(base_mapper) + + states_to_delete = _organize_states_for_delete( + base_mapper, + states, + uowtransaction) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(table_to_mapper.keys()): + delete = _collect_delete_commands(base_mapper, uowtransaction, + table, states_to_delete) + + mapper = table_to_mapper[table] + + _emit_delete_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, delete) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + states_to_insert = [] + states_to_update = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + has_identity = bool(state.key) + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: + instance = \ + uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, + state_str(existing))) + + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if not has_identity and not row_switch: + states_to_insert.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + else: + states_to_update.append( + (state, dict_, mapper, connection, + has_identity, instance_key, row_switch) + ) + + return states_to_insert, states_to_update + +def _organize_states_for_post_update(base_mapper, states, + uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return list(_connections_for_states(base_mapper, uowtransaction, + states)) + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + states_to_delete = [] + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, + states): + + mapper.dispatch.before_delete(mapper, connection, state) + + states_to_delete.append((state, dict_, mapper, + bool(state.key), connection)) + return states_to_delete + +def _collect_insert_commands(base_mapper, uowtransaction, table, + states_to_insert): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + insert = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + has_all_pks = True + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col.key] = mapper.version_id_generator(None) + else: + # pull straight from the dict for + # pending objects + prop = mapper._columntoproperty[col] + value = state_dict.get(prop.key, None) + + if value is None: + if col in pks: + has_all_pks = False + elif col.default is None and \ + col.server_default is None: + params[col.key] = value + + elif isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value + + insert.append((state, state_dict, params, mapper, + connection, value_params, has_all_pks)) + return insert + +def _collect_update_commands(base_mapper, uowtransaction, + table, states_to_update): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + update = [] + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_update: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + params = {} + value_params = {} + + hasdata = hasnull = False + for col in mapper._cols_by_table[table]: + if col is mapper.version_id_col: + params[col._label] = \ + mapper._get_committed_state_attr_by_column( + row_switch or state, + row_switch and row_switch.dict + or state_dict, + col) + + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + params[col.key] = history.added[0] + hasdata = True + else: + params[col.key] = mapper.version_id_generator( + params[col._label]) + + # HACK: check for history, in case the + # history is only + # in a different table than the one + # where the version_id_col is. + for prop in mapper._columntoproperty.itervalues(): + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + hasdata = True + else: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + if isinstance(history.added[0], + sql.ClauseElement): + value_params[col] = history.added[0] + else: + value = history.added[0] + params[col.key] = value + + if col in pks: + if history.deleted and \ + not row_switch: + # if passive_updates and sync detected + # this was a pk->pk sync, use the new + # value to locate the row, since the + # DB would already have set this + if ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + value = history.added[0] + params[col._label] = value + else: + # use the old value to + # locate the row + value = history.deleted[0] + params[col._label] = value + hasdata = True + else: + # row switch logic can reach us here + # remove the pk from the update params + # so the update doesn't + # attempt to include the pk in the + # update statement + del params[col.key] + value = history.added[0] + params[col._label] = value + if value is None: + hasnull = True + else: + hasdata = True + elif col in pks: + value = state.manager[prop.key].impl.get( + state, state_dict) + if value is None: + hasnull = True + params[col._label] = value + if hasdata: + if hasnull: + raise sa_exc.FlushError( + "Can't update table " + "using NULL for primary " + "key value") + update.append((state, state_dict, params, mapper, + connection, value_params)) + return update + + +def _collect_post_update_commands(base_mapper, uowtransaction, table, + states_to_update, post_update_cols): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + update = [] + for state, state_dict, mapper, connection in states_to_update: + if table not in mapper._pks_by_table: + continue + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = \ + mapper._get_state_attr_by_column( + state, + state_dict, col) + elif col in post_update_cols: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, + attributes.PASSIVE_NO_INITIALIZE) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + update.append((state, state_dict, params, mapper, + connection)) + return update + +def _collect_delete_commands(base_mapper, uowtransaction, table, + states_to_delete): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + delete = util.defaultdict(list) + + for state, state_dict, mapper, has_identity, connection \ + in states_to_delete: + if not has_identity or table not in mapper._pks_by_table: + continue + + params = {} + delete[connection].append(params) + for col in mapper._pks_by_table[table]: + params[col.key] = \ + value = \ + mapper._get_state_attr_by_column( + state, state_dict, col) + if value is None: + raise sa_exc.FlushError( + "Can't delete from table " + "using NULL for primary " + "key value") + + if mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col): + params[mapper.version_id_col.key] = \ + mapper._get_committed_state_attr_by_column( + state, state_dict, + mapper.version_id_col) + return delete + + +def _emit_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append(mapper.version_id_col ==\ + sql.bindparam(mapper.version_id_col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('update', table), update_stmt) + + rows = 0 + for state, state_dict, params, mapper, \ + connection, value_params in update: + + if value_params: + c = connection.execute( + statement.values(value_params), + params) + else: + c = cached_connections[connection].\ + execute(statement, params) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + c.context.compiled_parameters[0], + value_params) + rows += c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(update): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(update), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + +def _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, table, insert): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + statement = base_mapper._memo(('insert', table), table.insert) + + for (connection, pkeys, hasvalue, has_all_pks), \ + records in groupby(insert, + lambda rec: (rec[4], + rec[2].keys(), + bool(rec[5]), + rec[6]) + ): + if has_all_pks and not hasvalue: + records = list(records) + multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ + execute(statement, multiparams) + + for (state, state_dict, params, mapper, + conn, value_params, has_all_pks), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c.context.prefetch_cols, + c.context.postfetch_cols, + last_inserted_params, + value_params) + + else: + for state, state_dict, params, mapper, \ + connection, value_params, \ + has_all_pks in records: + + if value_params: + result = connection.execute( + statement.values(value_params), + params) + else: + result = cached_connections[connection].\ + execute(statement, params) + + primary_key = result.context.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for pk, col in zip(primary_key, + mapper._pks_by_table[table]): + prop = mapper._columntoproperty[col] + if state_dict.get(prop.key) is None: + # TODO: would rather say: + #state_dict[prop.key] = pk + mapper._set_state_attr_by_column( + state, + state_dict, + col, pk) + + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + result.context.prefetch_cols, + result.context.postfetch_cols, + result.context.compiled_parameters[0], + value_params) + + + +def _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, mapper, table, update): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + return table.update(clause) + + statement = base_mapper._memo(('post_update', table), update_stmt) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, grouper in groupby( + update, lambda rec: (rec[4], rec[2].keys()) + ): + multiparams = [params for state, state_dict, + params, mapper, conn in grouper] + cached_connections[conn].\ + execute(statement, multiparams) + + +def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, + mapper, table, delete): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = mapper.version_id_col is not None and \ + table.c.contains_column(mapper.version_id_col) + + def delete_stmt(): + clause = sql.and_() + for col in mapper._pks_by_table[table]: + clause.clauses.append( + col == sql.bindparam(col.key, type_=col.type)) + + if need_version_id: + clause.clauses.append( + mapper.version_id_col == + sql.bindparam( + mapper.version_id_col.key, + type_=mapper.version_id_col.type + ) + ) + + return table.delete(clause) + + for connection, del_objects in delete.iteritems(): + statement = base_mapper._memo(('delete', table), delete_stmt) + rows = -1 + + connection = cached_connections[connection] + + if need_version_id and \ + not connection.dialect.supports_sane_multi_rowcount: + # TODO: need test coverage for this [ticket:1761] + if connection.dialect.supports_sane_rowcount: + rows = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute(statement, params) + rows += c.rowcount + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." % + connection.dialect.dialect_description, + stacklevel=12) + connection.execute(statement, del_objects) + else: + c = connection.execute(statement, del_objects) + if connection.dialect.supports_sane_multi_rowcount: + rows = c.rowcount + + if rows != -1 and rows != len(del_objects): + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched." % + (table.description, len(del_objects), c.rowcount) + ) + +def _finalize_insert_update_commands(base_mapper, uowtransaction, + states_to_insert, states_to_update): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in states_to_insert + \ + states_to_update: + + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state.expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, + # refresh whatever has been expired. + if base_mapper.eager_defaults and state.unloaded: + state.key = base_mapper._identity_key_from_state(state) + uowtransaction.session.query(base_mapper)._load_on_ident( + state.key, refresh_state=state, + only_load_props=state.unloaded) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + +def _postfetch(mapper, uowtransaction, table, + state, dict_, prefetch_cols, postfetch_cols, + params, value_params): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + if mapper.version_id_col is not None: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + mapper._set_state_attr_by_column(state, dict_, c, params[c.key]) + + if postfetch_cols: + state.expire_attributes(state.dict, + [mapper._columntoproperty[c].key + for c in postfetch_cols if c in + mapper._columntoproperty] + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + mapper.passive_updates) + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = \ + uowtransaction.session.connection_callable + else: + connection = uowtransaction.transaction.connection( + base_mapper) + connection_callable = None + + for state in _sort_states(states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + + mapper = _state_mapper(state) + + yield state, state.dict, mapper, connection + +def _cached_connection_dict(base_mapper): + # dictionary of connection->connection_with_cache_options. + return util.PopulateDict( + lambda conn:conn.execution_options( + compiled_cache=base_mapper._compiled_cache + )) + +def _sort_states(states): + pending = set(states) + persistent = set(s for s in pending if s.key is not None) + pending.difference_update(persistent) + return sorted(pending, key=operator.attrgetter("insert_order")) + \ + sorted(persistent, key=lambda q:q.key[1]) + + diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index b016e81a0e..a20e871e4f 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -6,6 +6,7 @@ """private module containing functions used for copying data between instances based on join conditions. + """ from sqlalchemy.orm import exc, util as mapperutil, attributes diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 3cd0f15eb1..8fc5f139d3 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -14,7 +14,7 @@ organizes them in order of dependency, and executes. from sqlalchemy import util, event from sqlalchemy.util import topological -from sqlalchemy.orm import attributes, interfaces +from sqlalchemy.orm import attributes, interfaces, persistence from sqlalchemy.orm import util as mapperutil session = util.importlater("sqlalchemy.orm", "session") @@ -462,7 +462,7 @@ class IssuePostUpdate(PostSortRec): states, cols = uow.post_update_states[self.mapper] states = [s for s in states if uow.states[s][0] == self.isdelete] - self.mapper._post_update(states, uow, cols) + persistence.post_update(self.mapper, states, uow, cols) class SaveUpdateAll(PostSortRec): def __init__(self, uow, mapper): @@ -470,7 +470,7 @@ class SaveUpdateAll(PostSortRec): assert mapper is mapper.base_mapper def execute(self, uow): - self.mapper._save_obj( + persistence.save_obj(self.mapper, uow.states_for_mapper_hierarchy(self.mapper, False, False), uow ) @@ -493,7 +493,7 @@ class DeleteAll(PostSortRec): assert mapper is mapper.base_mapper def execute(self, uow): - self.mapper._delete_obj( + persistence.delete_obj(self.mapper, uow.states_for_mapper_hierarchy(self.mapper, True, False), uow ) @@ -551,7 +551,7 @@ class SaveUpdateState(PostSortRec): if r.__class__ is cls_ and r.mapper is mapper] recs.difference_update(our_recs) - mapper._save_obj( + persistence.save_obj(mapper, [self.state] + [r.state for r in our_recs], uow) @@ -575,7 +575,7 @@ class DeleteState(PostSortRec): r.mapper is mapper] recs.difference_update(our_recs) states = [self.state] + [r.state for r in our_recs] - mapper._delete_obj( + persistence.delete_obj(mapper, [s for s in states if uow.states[s][0]], uow) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 469bb626a5..e2ae823220 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import mapper, relationship, backref, \ validates, aliased, defer, deferred, synonym, attributes, \ column_property, composite, dynamic_loader, \ comparable_property, Session -from sqlalchemy.orm.mapper import _sort_states +from sqlalchemy.orm.persistence import _sort_states from test.lib.testing import eq_, AssertsCompiledSQL from test.lib import fixtures from test.orm import _fixtures