from . import loading
+def bulk_insert(mapper, mappings, uowtransaction):
+ base_mapper = mapper.base_mapper
+
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ if uowtransaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_insert()")
+
+ connection = uowtransaction.transaction.connection(base_mapper)
+
+ for table, sub_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(sub_mapper):
+ continue
+
+ to_translate = dict(
+ (mapper._columntoproperty[col].key, col.key)
+ for col in mapper._cols_by_table[table]
+ )
+ has_version_generator = mapper.version_id_generator is not False and \
+ mapper.version_id_col is not None
+ multiparams = []
+ for mapping in mappings:
+ params = dict(
+ (k, mapping.get(v)) for k, v in to_translate.items()
+ )
+ if has_version_generator:
+ params[mapper.version_id_col.key] = \
+ mapper.version_id_generator(None)
+ multiparams.append(params)
+
+ statement = base_mapper._memo(('insert', table), table.insert)
+ cached_connections[connection].execute(statement, multiparams)
+
+
+def bulk_update(mapper, mappings, uowtransaction):
+ base_mapper = mapper.base_mapper
+
+ cached_connections = _cached_connection_dict(base_mapper)
+
+ if uowtransaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_update()")
+
+ connection = uowtransaction.transaction.connection(base_mapper)
+
+ for table, sub_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(sub_mapper):
+ continue
+
+ needs_version_id = sub_mapper.version_id_col is not None and \
+ table.c.contains_column(sub_mapper.version_id_col)
+
+ def update_stmt():
+ return _update_stmt_for_mapper(sub_mapper, table, needs_version_id)
+
+ statement = base_mapper._memo(('update', table), update_stmt)
+
+ pks = mapper._pks_by_table[table]
+ to_translate = dict(
+ (mapper._columntoproperty[col].key, col._label
+ if col in pks else col.key)
+ for col in mapper._cols_by_table[table]
+ )
+
+ for colnames, sub_mappings in groupby(
+ mappings,
+ lambda mapping: sorted(tuple(mapping.keys()))):
+
+ multiparams = []
+ for mapping in sub_mappings:
+ params = dict(
+ (to_translate[k], v) for k, v in mapping.items()
+ )
+ multiparams.append(params)
+
+ c = cached_connections[connection].execute(statement, multiparams)
+
+ rows = c.rowcount
+
+ if connection.dialect.supports_sane_rowcount:
+ if rows != len(multiparams):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(multiparams), rows))
+
+ elif needs_version_id:
+ util.warn("Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
+
+
def save_obj(
- base_mapper, states, uowtransaction, single=False,
- bookkeeping=True):
+ base_mapper, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
states_to_insert, states_to_update = _organize_states_for_save(
base_mapper,
states,
- uowtransaction, bookkeeping)
+ uowtransaction)
cached_connections = _cached_connection_dict(base_mapper)
for table, mapper in base_mapper._sorted_tables.items():
insert = _collect_insert_commands(base_mapper, uowtransaction,
- table, states_to_insert,
- bookkeeping)
+ table, states_to_insert)
update = _collect_update_commands(base_mapper, uowtransaction,
table, states_to_update)
if insert:
_emit_insert_statements(base_mapper, uowtransaction,
cached_connections,
- mapper, table, insert,
- bookkeeping)
+ mapper, table, insert)
- _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update,
- bookkeeping)
+ _finalize_insert_update_commands(
+ base_mapper, uowtransaction,
+ states_to_insert, states_to_update)
def post_update(base_mapper, states, uowtransaction, post_update_cols):
mapper.dispatch.after_delete(mapper, connection, state)
-def _organize_states_for_save(
- base_mapper, states, uowtransaction, bookkeeping):
+def _organize_states_for_save(base_mapper, states, uowtransaction):
"""Make an initial pass across a set of states for INSERT or
UPDATE.
has_identity = bool(state.key)
- if bookkeeping:
- instance_key = state.key or mapper._identity_key_from_state(state)
+ instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
# no instance_key attached to it), and another instance
# with the same identity key already exists as persistent.
# convert to an UPDATE if so.
- if bookkeeping and not has_identity and \
+ if not has_identity and \
instance_key in uowtransaction.session.identity_map:
instance = \
uowtransaction.session.identity_map[instance_key]
def _collect_insert_commands(base_mapper, uowtransaction, table,
- states_to_insert, bookkeeping):
+ states_to_insert):
"""Identify sets of values to use in INSERT statements for a
list of states.
prop = mapper._columntoproperty[col]
value = state_dict.get(prop.key, None)
- if bookkeeping and value is None:
+ if value is None:
if col in pks:
has_all_pks = False
elif col.default is None and \
return delete
+def _update_stmt_for_mapper(mapper, table, needs_version_id):
+ clause = sql.and_()
+
+ for col in mapper._pks_by_table[table]:
+ clause.clauses.append(col == sql.bindparam(col._label,
+ type_=col.type))
+
+ if needs_version_id:
+ clause.clauses.append(
+ mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type))
+
+ stmt = table.update(clause)
+ if mapper.base_mapper.eager_defaults:
+ stmt = stmt.return_defaults()
+ elif mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
+
+
def _emit_update_statements(base_mapper, uowtransaction,
cached_connections, mapper, table, update):
"""Emit UPDATE statements corresponding to value lists collected
table.c.contains_column(mapper.version_id_col)
def update_stmt():
- clause = sql.and_()
-
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
-
- if needs_version_id:
- clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
- mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
-
- stmt = table.update(clause)
- if mapper.base_mapper.eager_defaults:
- stmt = stmt.return_defaults()
- elif mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
-
- return stmt
+ return _update_stmt_for_mapper(mapper, table, needs_version_id)
statement = base_mapper._memo(('update', table), update_stmt)
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert,
- bookkeeping):
+ cached_connections, mapper, table, insert):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
c = cached_connections[connection].\
execute(statement, multiparams)
- if bookkeeping:
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params)
+ for (state, state_dict, params, mapper_rec,
+ conn, value_params, has_all_pks, has_all_defaults), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params)
else:
if not has_all_defaults and base_mapper.eager_defaults:
def _finalize_insert_update_commands(base_mapper, uowtransaction,
- states_to_insert, states_to_update,
- bookkeeping):
+ states_to_insert, states_to_update):
"""finalize state on states that have been inserted or updated,
including calling after_insert/after_update events.
row_switch in states_to_insert + \
states_to_update:
- if bookkeeping:
- if mapper._readonly_props:
- readonly = state.unmodified_intersection(
- [p.key for p in mapper._readonly_props
- if p.expire_on_flush or p.key not in state.dict]
- )
- if readonly:
- state._expire_attributes(state.dict, readonly)
-
- # if eager_defaults option is enabled, load
- # all expired cols. Else if we have a version_id_col, make sure
- # it isn't expired.
- toload_now = []
-
- if base_mapper.eager_defaults:
- toload_now.extend(state._unloaded_non_object)
- elif mapper.version_id_col is not None and \
- mapper.version_id_generator is False:
- prop = mapper._columntoproperty[mapper.version_id_col]
- if prop.key in state.unloaded:
- toload_now.extend([prop.key])
-
- if toload_now:
- state.key = base_mapper._identity_key_from_state(state)
- loading.load_on_ident(
- uowtransaction.session.query(base_mapper),
- state.key, refresh_state=state,
- only_load_props=toload_now)
+ if mapper._readonly_props:
+ readonly = state.unmodified_intersection(
+ [p.key for p in mapper._readonly_props
+ if p.expire_on_flush or p.key not in state.dict]
+ )
+ if readonly:
+ state._expire_attributes(state.dict, readonly)
+
+ # if eager_defaults option is enabled, load
+ # all expired cols. Else if we have a version_id_col, make sure
+ # it isn't expired.
+ toload_now = []
+
+ if base_mapper.eager_defaults:
+ toload_now.extend(state._unloaded_non_object)
+ elif mapper.version_id_col is not None and \
+ mapper.version_id_generator is False:
+ prop = mapper._columntoproperty[mapper.version_id_col]
+ if prop.key in state.unloaded:
+ toload_now.extend([prop.key])
+
+ if toload_now:
+ state.key = base_mapper._identity_key_from_state(state)
+ loading.load_on_ident(
+ uowtransaction.session.query(base_mapper),
+ state.key, refresh_state=state,
+ only_load_props=toload_now)
# call after_XXX extensions
if not has_identity:
_class_to_mapper, _state_mapper, object_state,
_none_set, state_str, instance_str
)
+import itertools
from .unitofwork import UOWTransaction
from . import state as statelib
import sys
'__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested',
'close', 'commit', 'connection', 'delete', 'execute', 'expire',
'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind',
- 'is_modified', 'bulk_save_objects', 'bulk_save_mappings',
+ 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings',
+ 'bulk_update_mappings',
'merge', 'query', 'refresh', 'rollback',
'scalar')
transaction.rollback(_capture_exception=True)
def bulk_save_objects(self, objects):
- self._bulk_save((attributes.instance_state(obj) for obj in objects))
+ for (mapper, isupdate), states in itertools.groupby(
+ (attributes.instance_state(obj) for obj in objects),
+ lambda state: (state.mapper, state.key is not None)
+ ):
+ if isupdate:
+ self.bulk_update_mappings(mapper, (s.dict for s in states))
+ else:
+ self.bulk_insert_mappings(mapper, (s.dict for s in states))
- def bulk_save_mappings(self, mapper, mappings):
- mapper = class_mapper(mapper)
+ def bulk_insert_mappings(self, mapper, mappings):
+ self._bulk_save_mappings(mapper, mappings, False)
- self._bulk_save((
- statelib.MappingState(mapper, mapping)
- for mapping in mappings)
- )
+ def bulk_update_mappings(self, mapper, mappings):
+ self._bulk_save_mappings(mapper, mappings, True)
- def _bulk_save(self, states):
+ def _bulk_save_mappings(self, mapper, mappings, isupdate):
+ mapper = _class_to_mapper(mapper)
self._flushing = True
flush_context = UOWTransaction(self)
- if self.dispatch.before_bulk_save:
- self.dispatch.before_bulk_save(
- self, flush_context, states)
-
flush_context.transaction = transaction = self.begin(
subtransactions=True)
try:
- self._warn_on_events = True
- try:
- flush_context.bulk_save(states)
- finally:
- self._warn_on_events = False
-
- self.dispatch.after_bulk_save(
- self, flush_context, states
- )
-
- flush_context.finalize_flush_changes()
-
- self.dispatch.after_bulk_save_postexec(
- self, flush_context, states)
-
+ if isupdate:
+ self.dispatch.before_bulk_update(
+ self, flush_context, mapper, mappings)
+ flush_context.bulk_update(mapper, mappings)
+ self.dispatch.after_bulk_update(
+ self, flush_context, mapper, mappings)
+ else:
+ self.dispatch.before_bulk_insert(
+ self, flush_context, mapper, mappings)
+ flush_context.bulk_insert(mapper, mappings)
+ self.dispatch.after_bulk_insert(
+ self, flush_context, mapper, mappings)
transaction.commit()
except: