"""
import operator
-from itertools import groupby
+from itertools import groupby, chain
from .. import sql, util, exc as sa_exc, schema
from . import attributes, sync, exc as orm_exc, evaluator
from .base import state_str, _attr_as_key
connection = session_transaction.connection(base_mapper)
- value_params = {}
-
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper):
continue
- records = (
- (None, None, params, super_mapper, connection, value_params)
- for
- params in _collect_bulk_update_commands(mapper, table, mappings)
- )
+ records = _collect_update_commands(None, table, (
+ (None, mapping, mapper, connection,
+ (mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop else None))
+ for mapping in mappings
+ ), bulk=True)
_emit_update_statements(base_mapper, None,
cached_connections,
_finalize_insert_update_commands(
base_mapper, uowtransaction,
- (
- (state, state_dict, mapper, connection, False)
- for state, state_dict, mapper, connection in states_to_insert
- )
- )
- _finalize_insert_update_commands(
- base_mapper, uowtransaction,
- (
- (state, state_dict, mapper, connection, True)
- for state, state_dict, mapper, connection,
- update_version_id in states_to_update
+ chain(
+ (
+ (state, state_dict, mapper, connection, False)
+ for state, state_dict, mapper, connection in states_to_insert
+ ),
+ (
+ (state, state_dict, mapper, connection, True)
+ for state, state_dict, mapper, connection,
+ update_version_id in states_to_update
+ )
)
)
has_all_defaults)
-def _collect_update_commands(uowtransaction, table, states_to_update):
+def _collect_update_commands(
+ uowtransaction, table, states_to_update,
+ bulk=False):
"""Identify sets of values to use in UPDATE statements for a
list of states.
pks = mapper._pks_by_table[table]
- params = {}
value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
- for propkey in set(propkey_to_col).intersection(state.committed_state):
- value = state_dict[propkey]
- col = propkey_to_col[propkey]
-
- if not state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]):
- if isinstance(value, sql.ClauseElement):
- value_params[col] = value
- else:
- params[col.key] = value
+ if bulk:
+ params = dict(
+ (propkey_to_col[propkey].key, state_dict[propkey])
+ for propkey in
+ set(propkey_to_col).intersection(state_dict)
+ )
+ else:
+ params = {}
+ for propkey in set(propkey_to_col).intersection(
+ state.committed_state):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+
+ if not state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]):
+ if isinstance(value, sql.ClauseElement):
+ value_params[col] = value
+ else:
+ params[col.key] = value
- if update_version_id is not None:
+ if update_version_id is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
col = mapper.version_id_col
params[col._label] = update_version_id
if not (params or value_params):
continue
- pk_params = {}
- for col in pks:
- propkey = mapper._columntoproperty[col].key
- history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF)
-
- if history.added:
- if not history.deleted or \
- ("pk_cascaded", state, col) in \
- uowtransaction.attributes:
- pk_params[col._label] = history.added[0]
- params.pop(col.key, None)
+ if bulk:
+ pk_params = dict(
+ (propkey_to_col[propkey]._label, state_dict.get(propkey))
+ for propkey in
+ set(propkey_to_col).
+ intersection(mapper._pk_keys_by_table[table])
+ )
+ else:
+ pk_params = {}
+ for col in pks:
+ propkey = mapper._columntoproperty[col].key
+
+ history = state.manager[propkey].impl.get_history(
+ state, state_dict, attributes.PASSIVE_OFF)
+
+ if history.added:
+ if not history.deleted or \
+ ("pk_cascaded", state, col) in \
+ uowtransaction.attributes:
+ pk_params[col._label] = history.added[0]
+ params.pop(col.key, None)
+ else:
+ # else, use the old value to locate the row
+ pk_params[col._label] = history.deleted[0]
+ params[col.key] = history.added[0]
else:
- # else, use the old value to locate the row
- pk_params[col._label] = history.deleted[0]
- params[col.key] = history.added[0]
- else:
- pk_params[col._label] = history.unchanged[0]
+ pk_params[col._label] = history.unchanged[0]
if params or value_params:
if None in pk_params.values():
state, state_dict, params, mapper,
connection, value_params)
-def _collect_bulk_update_commands(mapper, table, mappings):
- label_pks = mapper._pks_by_table[table]
- if mapper.version_id_col is not None:
- label_pks = label_pks.union([mapper.version_id_col])
-
- to_translate = dict(
- (propkey, col.key if col not in label_pks else col._label)
- for propkey, col in mapper._propkey_to_col[table].items()
- )
-
- for mapping in mappings:
- params = dict(
- (to_translate[k], mapping[k]) for k in to_translate
- if k in mapping and k not in mapper._primary_key_propkeys
- )
-
- if not params:
- continue
-
- try:
- params.update(
- (to_translate[k], mapping[k]) for k in
- mapper._primary_key_propkeys.intersection(to_translate)
- )
- except KeyError as ke:
- raise orm_exc.FlushError(
- "Can't update table using NULL for primary "
- "key attribute: %s" % ke)
-
- if mapper.version_id_generator is not False and \
- mapper.version_id_col is not None and \
- mapper.version_id_col.key not in params:
- params[mapper.version_id_col.key] = \
- mapper.version_id_generator(
- params[mapper.version_id_col._label])
-
- yield params
-
def _collect_post_update_commands(base_mapper, uowtransaction, table,
states_to_update, post_update_cols):
"key value")
if update_version_id is not None and \
- table.c.contains_column(mapper.version_id_col):
+ mapper.version_id_col in mapper._cols_by_table[table]:
params[mapper.version_id_col.key] = update_version_id
yield params, connection
by _collect_update_commands()."""
needs_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
+ mapper.version_id_col in mapper._cols_by_table[table]
def update_stmt():
clause = sql.and_()
records in groupby(
update,
lambda rec: (
- rec[4],
- tuple(sorted(rec[2])),
- bool(rec[5]))):
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]))): # whether or not we have "value" parameters
rows = 0
records = list(records)
statement = base_mapper._memo(('insert', table), table.insert)
for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
- records in groupby(insert,
- lambda rec: (rec[4],
- tuple(sorted(rec[2].keys())),
- bool(rec[5]),
- rec[6], rec[7])
- ):
+ records in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7])):
if not bookkeeping or \
(
has_all_defaults
# also group them into common (connection, cols) sets
# to support executemany().
for key, grouper in groupby(
- update, lambda rec: (rec[1], sorted(rec[0]))
+ update, lambda rec: (
+ rec[1], # connection
+ set(rec[0]) # parameter keys
+ )
):
connection = key[0]
multiparams = [params for params, conn in grouper]
by _collect_delete_commands()."""
need_version_id = mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col)
+ mapper.version_id_col in mapper._cols_by_table[table]
def delete_stmt():
clause = sql.and_()
statement = base_mapper._memo(('delete', table), delete_stmt)
for connection, recs in groupby(
delete,
- lambda rec: rec[1]
+ lambda rec: rec[1] # connection
):
- del_objects = [
- params
- for params, connection in recs
- ]
+ del_objects = [params for params, connection in recs]
connection = cached_connections[connection]
postfetch_cols = result.context.postfetch_cols
returning_cols = result.context.returning_cols
- if mapper.version_id_col is not None:
+ if mapper.version_id_col is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
if returning_cols: