from . import loading
-def bulk_insert(mapper, mappings, uowtransaction):
+def _bulk_insert(mapper, mappings, session_transaction, isstates):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
- if uowtransaction.session.connection_callable:
+ if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
"not supported in bulk_insert()")
- connection = uowtransaction.transaction.connection(base_mapper)
+ if isstates:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ connection = session_transaction.connection(base_mapper)
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper):
continue
state, state_dict, params, mp,
conn, value_params, has_all_pks,
has_all_defaults in _collect_insert_commands(table, (
- (None, mapping, super_mapper, connection)
- for mapping in mappings)
+ (None, mapping, mapper, connection)
+ for mapping in mappings),
+ bulk=True
)
)
- _emit_insert_statements(base_mapper, uowtransaction,
+ _emit_insert_statements(base_mapper, None,
cached_connections,
super_mapper, table, records,
bookkeeping=False)
-def bulk_update(mapper, mappings, uowtransaction):
+def _bulk_update(mapper, mappings, session_transaction, isstates):
base_mapper = mapper.base_mapper
cached_connections = _cached_connection_dict(base_mapper)
- if uowtransaction.session.connection_callable:
+ def _changed_dict(mapper, state):
+ return dict(
+ (k, v)
+ for k, v in state.dict.items() if k in state.committed_state or k
+ in mapper._primary_key_propkeys
+ )
+
+ if isstates:
+ mappings = [_changed_dict(mapper, state) for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
"not supported in bulk_update()")
- connection = uowtransaction.transaction.connection(base_mapper)
+ connection = session_transaction.connection(base_mapper)
value_params = {}
+
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper):
continue
- label_pks = super_mapper._pks_by_table[table]
- if mapper.version_id_col is not None:
- label_pks = label_pks.union([mapper.version_id_col])
-
- to_translate = dict(
- (propkey, col._label if col in label_pks else col.key)
- for propkey, col in super_mapper._propkey_to_col[table].items()
+ records = (
+ (None, None, params, super_mapper, connection, value_params)
+ for
+ params in _collect_bulk_update_commands(mapper, table, mappings)
)
- records = []
- for mapping in mappings:
- params = dict(
- (to_translate[k], v) for k, v in mapping.items()
- )
-
- if mapper.version_id_generator is not False and \
- mapper.version_id_col is not None and \
- mapper.version_id_col.key not in params:
- params[mapper.version_id_col.key] = \
- mapper.version_id_generator(
- params[mapper.version_id_col._label])
-
- records.append(
- (None, None, params, super_mapper, connection, value_params)
- )
-
- _emit_update_statements(base_mapper, uowtransaction,
+ _emit_update_statements(base_mapper, None,
cached_connections,
super_mapper, table, records,
bookkeeping=False)
col = propkey_to_col[propkey]
if value is None:
continue
- elif isinstance(value, sql.ClauseElement):
+ elif not bulk and isinstance(value, sql.ClauseElement):
value_params[col.key] = value
else:
params[col.key] = value
state, state_dict, params, mapper,
connection, value_params)
+def _collect_bulk_update_commands(mapper, table, mappings):
+ label_pks = mapper._pks_by_table[table]
+ if mapper.version_id_col is not None:
+ label_pks = label_pks.union([mapper.version_id_col])
+
+ to_translate = dict(
+ (propkey, col.key if col not in label_pks else col._label)
+ for propkey, col in mapper._propkey_to_col[table].items()
+ )
+
+ for mapping in mappings:
+ params = dict(
+ (to_translate[k], mapping[k]) for k in to_translate
+ if k in mapping and k not in mapper._primary_key_propkeys
+ )
+
+ if not params:
+ continue
+
+ try:
+ params.update(
+ (to_translate[k], mapping[k]) for k in
+ mapper._primary_key_propkeys.intersection(to_translate)
+ )
+ except KeyError as ke:
+ raise orm_exc.FlushError(
+ "Can't update table using NULL for primary "
+ "key attribute: %s" % ke)
+
+ if mapper.version_id_generator is not False and \
+ mapper.version_id_col is not None and \
+ mapper.version_id_col.key not in params:
+ params[mapper.version_id_col.key] = \
+ mapper.version_id_generator(
+ params[mapper.version_id_col._label])
+
+ yield params
+
def _collect_post_update_commands(base_mapper, uowtransaction, table,
states_to_update, post_update_cols):
_none_set, state_str, instance_str
)
import itertools
+from . import persistence
from .unitofwork import UOWTransaction
from . import state as statelib
import sys
(attributes.instance_state(obj) for obj in objects),
lambda state: (state.mapper, state.key is not None)
):
- if isupdate:
- self.bulk_update_mappings(mapper, (s.dict for s in states))
- else:
- self.bulk_insert_mappings(mapper, (s.dict for s in states))
+ self._bulk_save_mappings(mapper, states, isupdate, True)
def bulk_insert_mappings(self, mapper, mappings):
- self._bulk_save_mappings(mapper, mappings, False)
+ self._bulk_save_mappings(mapper, mappings, False, False)
def bulk_update_mappings(self, mapper, mappings):
- self._bulk_save_mappings(mapper, mappings, True)
+ self._bulk_save_mappings(mapper, mappings, True, False)
- def _bulk_save_mappings(self, mapper, mappings, isupdate):
+ def _bulk_save_mappings(self, mapper, mappings, isupdate, isstates):
mapper = _class_to_mapper(mapper)
self._flushing = True
- flush_context = UOWTransaction(self)
- flush_context.transaction = transaction = self.begin(
+ transaction = self.begin(
subtransactions=True)
try:
if isupdate:
- self.dispatch.before_bulk_update(
- self, flush_context, mapper, mappings)
- flush_context.bulk_update(mapper, mappings)
- self.dispatch.after_bulk_update(
- self, flush_context, mapper, mappings)
+ persistence._bulk_update(
+ mapper, mappings, transaction, isstates)
else:
- self.dispatch.before_bulk_insert(
- self, flush_context, mapper, mappings)
- flush_context.bulk_insert(mapper, mappings)
- self.dispatch.after_bulk_insert(
- self, flush_context, mapper, mappings)
+ persistence._bulk_insert(
+ mapper, mappings, transaction, isstates)
transaction.commit()
except: