From: Mike Bayer Date: Wed, 20 Aug 2014 21:15:20 +0000 (-0400) Subject: - that's it, feature is finished, needs tests X-Git-Tag: rel_1_0_0b1~206 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=db70b6e79e263c137f4d282c9c600417636afa25;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - that's it, feature is finished, needs tests --- diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index c2750eeb32..aa10da9f42 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -15,7 +15,7 @@ in unitofwork.py. """ import operator -from itertools import groupby +from itertools import groupby, chain from .. import sql, util, exc as sa_exc, schema from . import attributes, sync, exc as orm_exc, evaluator from .base import state_str, _attr_as_key @@ -86,17 +86,16 @@ def _bulk_update(mapper, mappings, session_transaction, isstates): connection = session_transaction.connection(base_mapper) - value_params = {} - for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue - records = ( - (None, None, params, super_mapper, connection, value_params) - for - params in _collect_bulk_update_commands(mapper, table, mappings) - ) + records = _collect_update_commands(None, table, ( + (None, mapping, mapper, connection, + (mapping[mapper._version_id_prop.key] + if mapper._version_id_prop else None)) + for mapping in mappings + ), bulk=True) _emit_update_statements(base_mapper, None, cached_connections, @@ -158,17 +157,16 @@ def save_obj( _finalize_insert_update_commands( base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, False) - for state, state_dict, mapper, connection in states_to_insert - ) - ) - _finalize_insert_update_commands( - base_mapper, uowtransaction, - ( - (state, state_dict, mapper, connection, True) - for state, state_dict, mapper, connection, - update_version_id in states_to_update + chain( + ( + (state, state_dict, mapper, connection, False) + for state, state_dict, mapper, connection in states_to_insert + ), + ( + (state, state_dict, mapper, connection, True) + for state, state_dict, mapper, connection, + update_version_id in states_to_update + ) ) ) @@ -394,7 +392,9 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): has_all_defaults) -def _collect_update_commands(uowtransaction, table, states_to_update): +def _collect_update_commands( + uowtransaction, table, states_to_update, + bulk=False): """Identify sets of values to use in UPDATE statements for a list of states. @@ -414,23 +414,32 @@ def _collect_update_commands(uowtransaction, table, states_to_update): pks = mapper._pks_by_table[table] - params = {} value_params = {} propkey_to_col = mapper._propkey_to_col[table] - for propkey in set(propkey_to_col).intersection(state.committed_state): - value = state_dict[propkey] - col = propkey_to_col[propkey] - - if not state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey]): - if isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + if bulk: + params = dict( + (propkey_to_col[propkey].key, state_dict[propkey]) + for propkey in + set(propkey_to_col).intersection(state_dict) + ) + else: + params = {} + for propkey in set(propkey_to_col).intersection( + state.committed_state): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if not state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey]): + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value - if update_version_id is not None: + if update_version_id is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: col = mapper.version_id_col params[col._label] = update_version_id @@ -442,24 +451,33 @@ def _collect_update_commands(uowtransaction, table, states_to_update): if not (params or value_params): continue - pk_params = {} - for col in pks: - propkey = mapper._columntoproperty[col].key - history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF) - - if history.added: - if not history.deleted or \ - ("pk_cascaded", state, col) in \ - uowtransaction.attributes: - pk_params[col._label] = history.added[0] - params.pop(col.key, None) + if bulk: + pk_params = dict( + (propkey_to_col[propkey]._label, state_dict.get(propkey)) + for propkey in + set(propkey_to_col). + intersection(mapper._pk_keys_by_table[table]) + ) + else: + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF) + + if history.added: + if not history.deleted or \ + ("pk_cascaded", state, col) in \ + uowtransaction.attributes: + pk_params[col._label] = history.added[0] + params.pop(col.key, None) + else: + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + params[col.key] = history.added[0] else: - # else, use the old value to locate the row - pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] - else: - pk_params[col._label] = history.unchanged[0] + pk_params[col._label] = history.unchanged[0] if params or value_params: if None in pk_params.values(): @@ -471,44 +489,6 @@ def _collect_update_commands(uowtransaction, table, states_to_update): state, state_dict, params, mapper, connection, value_params) -def _collect_bulk_update_commands(mapper, table, mappings): - label_pks = mapper._pks_by_table[table] - if mapper.version_id_col is not None: - label_pks = label_pks.union([mapper.version_id_col]) - - to_translate = dict( - (propkey, col.key if col not in label_pks else col._label) - for propkey, col in mapper._propkey_to_col[table].items() - ) - - for mapping in mappings: - params = dict( - (to_translate[k], mapping[k]) for k in to_translate - if k in mapping and k not in mapper._primary_key_propkeys - ) - - if not params: - continue - - try: - params.update( - (to_translate[k], mapping[k]) for k in - mapper._primary_key_propkeys.intersection(to_translate) - ) - except KeyError as ke: - raise orm_exc.FlushError( - "Can't update table using NULL for primary " - "key attribute: %s" % ke) - - if mapper.version_id_generator is not False and \ - mapper.version_id_col is not None and \ - mapper.version_id_col.key not in params: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator( - params[mapper.version_id_col._label]) - - yield params - def _collect_post_update_commands(base_mapper, uowtransaction, table, states_to_update, post_update_cols): @@ -569,7 +549,7 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, "key value") if update_version_id is not None and \ - table.c.contains_column(mapper.version_id_col): + mapper.version_id_col in mapper._cols_by_table[table]: params[mapper.version_id_col.key] = update_version_id yield params, connection @@ -581,7 +561,7 @@ def _emit_update_statements(base_mapper, uowtransaction, by _collect_update_commands().""" needs_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def update_stmt(): clause = sql.and_() @@ -610,9 +590,9 @@ def _emit_update_statements(base_mapper, uowtransaction, records in groupby( update, lambda rec: ( - rec[4], - tuple(sorted(rec[2])), - bool(rec[5]))): + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]))): # whether or not we have "value" parameters rows = 0 records = list(records) @@ -692,12 +672,14 @@ def _emit_insert_statements(base_mapper, uowtransaction, statement = base_mapper._memo(('insert', table), table.insert) for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ - records in groupby(insert, - lambda rec: (rec[4], - tuple(sorted(rec[2].keys())), - bool(rec[5]), - rec[6], rec[7]) - ): + records in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7])): if not bookkeeping or \ ( has_all_defaults @@ -785,7 +767,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # also group them into common (connection, cols) sets # to support executemany(). for key, grouper in groupby( - update, lambda rec: (rec[1], sorted(rec[0])) + update, lambda rec: ( + rec[1], # connection + set(rec[0]) # parameter keys + ) ): connection = key[0] multiparams = [params for params, conn in grouper] @@ -799,7 +784,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, by _collect_delete_commands().""" need_version_id = mapper.version_id_col is not None and \ - table.c.contains_column(mapper.version_id_col) + mapper.version_id_col in mapper._cols_by_table[table] def delete_stmt(): clause = sql.and_() @@ -821,12 +806,9 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, statement = base_mapper._memo(('delete', table), delete_stmt) for connection, recs in groupby( delete, - lambda rec: rec[1] + lambda rec: rec[1] # connection ): - del_objects = [ - params - for params, connection in recs - ] + del_objects = [params for params, connection in recs] connection = cached_connections[connection] @@ -931,7 +913,8 @@ def _postfetch(mapper, uowtransaction, table, postfetch_cols = result.context.postfetch_cols returning_cols = result.context.returning_cols - if mapper.version_id_col is not None: + if mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] if returning_cols: