From 591f2e4ed2d455cb2c5b9ece43d79fde4b109510 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 14 Aug 2014 19:47:23 -0400 Subject: [PATCH] - change to be represented as two very fast bulk_insert() and bulk_update() methods --- doc/build/faq.rst | 37 ++--- lib/sqlalchemy/orm/events.py | 9 +- lib/sqlalchemy/orm/persistence.py | 255 ++++++++++++++++++++---------- lib/sqlalchemy/orm/session.py | 57 +++---- lib/sqlalchemy/orm/state.py | 15 -- lib/sqlalchemy/orm/unitofwork.py | 22 +-- 6 files changed, 223 insertions(+), 172 deletions(-) diff --git a/doc/build/faq.rst b/doc/build/faq.rst index b777f908fa..487f5b953a 100644 --- a/doc/build/faq.rst +++ b/doc/build/faq.rst @@ -907,12 +907,11 @@ methods of inserting rows, going from the most automated to the least. With cPython 2.7, runtimes observed:: classics-MacBook-Pro:sqlalchemy classic$ python test.py - SQLAlchemy ORM: Total time for 100000 records 12.4703581333 secs - SQLAlchemy ORM pk given: Total time for 100000 records 7.32723999023 secs - SQLAlchemy ORM bulk_save_objects(): Total time for 100000 records 3.43464708328 secs - SQLAlchemy ORM bulk_save_mappings(): Total time for 100000 records 2.37040805817 secs - SQLAlchemy Core: Total time for 100000 records 0.495043992996 secs - sqlite3: Total time for 100000 records 0.508063077927 sec + SQLAlchemy ORM: Total time for 100000 records 12.0471920967 secs + SQLAlchemy ORM pk given: Total time for 100000 records 7.06283402443 secs + SQLAlchemy ORM bulk_save_objects(): Total time for 100000 records 0.856323003769 secs + SQLAlchemy Core: Total time for 100000 records 0.485800027847 secs + sqlite3: Total time for 100000 records 0.487842082977 sec We can reduce the time by a factor of three using recent versions of `Pypy `_:: @@ -980,15 +979,16 @@ Script:: " records " + str(time.time() - t0) + " secs") - def test_sqlalchemy_orm_bulk_save(n=100000): + def test_sqlalchemy_orm_bulk_insert(n=100000): init_sqlalchemy() t0 = time.time() n1 = n while n1 > 0: n1 = n1 - 10000 - DBSession.bulk_save_objects( + DBSession.bulk_insert_mappings( + Customer, [ - Customer(name="NAME " + str(i)) + dict(name="NAME " + str(i)) for i in xrange(min(10000, n1)) ] ) @@ -998,22 +998,6 @@ Script:: " records " + str(time.time() - t0) + " secs") - def test_sqlalchemy_orm_bulk_save_mappings(n=100000): - init_sqlalchemy() - t0 = time.time() - DBSession.bulk_save_mappings( - Customer, - [ - dict(name="NAME " + str(i)) - for i in xrange(n) - ] - ) - DBSession.commit() - print( - "SQLAlchemy ORM bulk_save_mappings(): Total time for " + str(n) + - " records " + str(time.time() - t0) + " secs") - - def test_sqlalchemy_core(n=100000): init_sqlalchemy() t0 = time.time() @@ -1052,8 +1036,7 @@ Script:: if __name__ == '__main__': test_sqlalchemy_orm(100000) test_sqlalchemy_orm_pk_given(100000) - test_sqlalchemy_orm_bulk_save(100000) - test_sqlalchemy_orm_bulk_save_mappings(100000) + test_sqlalchemy_orm_bulk_insert(100000) test_sqlalchemy_core(100000) test_sqlite3(100000) diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 097726c625..37ea3071bc 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1453,13 +1453,16 @@ class SessionEvents(event.Events): """ - def before_bulk_save(self, session, flush_context, objects): + def before_bulk_insert(self, session, flush_context, mapper, mappings): """""" - def after_bulk_save(self, session, flush_context, objects): + def after_bulk_insert(self, session, flush_context, mapper, mappings): """""" - def after_bulk_save_postexec(self, session, flush_context, objects): + def before_bulk_update(self, session, flush_context, mapper, mappings): + """""" + + def after_bulk_update(self, session, flush_context, mapper, mappings): """""" def after_begin(self, session, transaction, connection): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 64c8440c4d..a8d4bd695d 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,9 +23,104 @@ from ..sql import expression from . import loading +def bulk_insert(mapper, mappings, uowtransaction): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + if uowtransaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_insert()") + + connection = uowtransaction.transaction.connection(base_mapper) + + for table, sub_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(sub_mapper): + continue + + to_translate = dict( + (mapper._columntoproperty[col].key, col.key) + for col in mapper._cols_by_table[table] + ) + has_version_generator = mapper.version_id_generator is not False and \ + mapper.version_id_col is not None + multiparams = [] + for mapping in mappings: + params = dict( + (k, mapping.get(v)) for k, v in to_translate.items() + ) + if has_version_generator: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator(None) + multiparams.append(params) + + statement = base_mapper._memo(('insert', table), table.insert) + cached_connections[connection].execute(statement, multiparams) + + +def bulk_update(mapper, mappings, uowtransaction): + base_mapper = mapper.base_mapper + + cached_connections = _cached_connection_dict(base_mapper) + + if uowtransaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_update()") + + connection = uowtransaction.transaction.connection(base_mapper) + + for table, sub_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(sub_mapper): + continue + + needs_version_id = sub_mapper.version_id_col is not None and \ + table.c.contains_column(sub_mapper.version_id_col) + + def update_stmt(): + return _update_stmt_for_mapper(sub_mapper, table, needs_version_id) + + statement = base_mapper._memo(('update', table), update_stmt) + + pks = mapper._pks_by_table[table] + to_translate = dict( + (mapper._columntoproperty[col].key, col._label + if col in pks else col.key) + for col in mapper._cols_by_table[table] + ) + + for colnames, sub_mappings in groupby( + mappings, + lambda mapping: sorted(tuple(mapping.keys()))): + + multiparams = [] + for mapping in sub_mappings: + params = dict( + (to_translate[k], v) for k, v in mapping.items() + ) + multiparams.append(params) + + c = cached_connections[connection].execute(statement, multiparams) + + rows = c.rowcount + + if connection.dialect.supports_sane_rowcount: + if rows != len(multiparams): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(multiparams), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description, + stacklevel=12) + + def save_obj( - base_mapper, states, uowtransaction, single=False, - bookkeeping=True): + base_mapper, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -45,14 +140,13 @@ def save_obj( states_to_insert, states_to_update = _organize_states_for_save( base_mapper, states, - uowtransaction, bookkeeping) + uowtransaction) cached_connections = _cached_connection_dict(base_mapper) for table, mapper in base_mapper._sorted_tables.items(): insert = _collect_insert_commands(base_mapper, uowtransaction, - table, states_to_insert, - bookkeeping) + table, states_to_insert) update = _collect_update_commands(base_mapper, uowtransaction, table, states_to_update) @@ -65,12 +159,11 @@ def save_obj( if insert: _emit_insert_statements(base_mapper, uowtransaction, cached_connections, - mapper, table, insert, - bookkeeping) + mapper, table, insert) - _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update, - bookkeeping) + _finalize_insert_update_commands( + base_mapper, uowtransaction, + states_to_insert, states_to_update) def post_update(base_mapper, states, uowtransaction, post_update_cols): @@ -126,8 +219,7 @@ def delete_obj(base_mapper, states, uowtransaction): mapper.dispatch.after_delete(mapper, connection, state) -def _organize_states_for_save( - base_mapper, states, uowtransaction, bookkeeping): +def _organize_states_for_save(base_mapper, states, uowtransaction): """Make an initial pass across a set of states for INSERT or UPDATE. @@ -149,8 +241,7 @@ def _organize_states_for_save( has_identity = bool(state.key) - if bookkeeping: - instance_key = state.key or mapper._identity_key_from_state(state) + instance_key = state.key or mapper._identity_key_from_state(state) row_switch = None @@ -167,7 +258,7 @@ def _organize_states_for_save( # no instance_key attached to it), and another instance # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if bookkeeping and not has_identity and \ + if not has_identity and \ instance_key in uowtransaction.session.identity_map: instance = \ uowtransaction.session.identity_map[instance_key] @@ -239,7 +330,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): def _collect_insert_commands(base_mapper, uowtransaction, table, - states_to_insert, bookkeeping): + states_to_insert): """Identify sets of values to use in INSERT statements for a list of states. @@ -270,7 +361,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, prop = mapper._columntoproperty[col] value = state_dict.get(prop.key, None) - if bookkeeping and value is None: + if value is None: if col in pks: has_all_pks = False elif col.default is None and \ @@ -481,6 +572,28 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, return delete +def _update_stmt_for_mapper(mapper, table, needs_version_id): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + if mapper.base_mapper.eager_defaults: + stmt = stmt.return_defaults() + elif mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt + + def _emit_update_statements(base_mapper, uowtransaction, cached_connections, mapper, table, update): """Emit UPDATE statements corresponding to value lists collected @@ -490,25 +603,7 @@ def _emit_update_statements(base_mapper, uowtransaction, table.c.contains_column(mapper.version_id_col) def update_stmt(): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append( - mapper.version_id_col == sql.bindparam( - mapper.version_id_col._label, - type_=mapper.version_id_col.type)) - - stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - - return stmt + return _update_stmt_for_mapper(mapper, table, needs_version_id) statement = base_mapper._memo(('update', table), update_stmt) @@ -572,8 +667,7 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert, - bookkeeping): + cached_connections, mapper, table, insert): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -599,20 +693,19 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - if bookkeeping: - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: @@ -768,8 +861,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, def _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update, - bookkeeping): + states_to_insert, states_to_update): """finalize state on states that have been inserted or updated, including calling after_insert/after_update events. @@ -778,34 +870,33 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, row_switch in states_to_insert + \ states_to_update: - if bookkeeping: - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [p.key for p in mapper._readonly_props - if p.expire_on_flush or p.key not in state.dict] - ) - if readonly: - state._expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, load - # all expired cols. Else if we have a version_id_col, make sure - # it isn't expired. - toload_now = [] - - if base_mapper.eager_defaults: - toload_now.extend(state._unloaded_non_object) - elif mapper.version_id_col is not None and \ - mapper.version_id_generator is False: - prop = mapper._columntoproperty[mapper.version_id_col] - if prop.key in state.unloaded: - toload_now.extend([prop.key]) - - if toload_now: - state.key = base_mapper._identity_key_from_state(state) - loading.load_on_ident( - uowtransaction.session.query(base_mapper), - state.key, refresh_state=state, - only_load_props=toload_now) + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [p.key for p in mapper._readonly_props + if p.expire_on_flush or p.key not in state.dict] + ) + if readonly: + state._expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + if base_mapper.eager_defaults: + toload_now.extend(state._unloaded_non_object) + elif mapper.version_id_col is not None and \ + mapper.version_id_generator is False: + prop = mapper._columntoproperty[mapper.version_id_col] + if prop.key in state.unloaded: + toload_now.extend([prop.key]) + + if toload_now: + state.key = base_mapper._identity_key_from_state(state) + loading.load_on_ident( + uowtransaction.session.query(base_mapper), + state.key, refresh_state=state, + only_load_props=toload_now) # call after_XXX extensions if not has_identity: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 546355611a..3199a4332d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -20,6 +20,7 @@ from .base import ( _class_to_mapper, _state_mapper, object_state, _none_set, state_str, instance_str ) +import itertools from .unitofwork import UOWTransaction from . import state as statelib import sys @@ -482,7 +483,8 @@ class Session(_SessionClassMethods): '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested', 'close', 'commit', 'connection', 'delete', 'execute', 'expire', 'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind', - 'is_modified', 'bulk_save_objects', 'bulk_save_mappings', + 'is_modified', 'bulk_save_objects', 'bulk_insert_mappings', + 'bulk_update_mappings', 'merge', 'query', 'refresh', 'rollback', 'scalar') @@ -2034,42 +2036,41 @@ class Session(_SessionClassMethods): transaction.rollback(_capture_exception=True) def bulk_save_objects(self, objects): - self._bulk_save((attributes.instance_state(obj) for obj in objects)) + for (mapper, isupdate), states in itertools.groupby( + (attributes.instance_state(obj) for obj in objects), + lambda state: (state.mapper, state.key is not None) + ): + if isupdate: + self.bulk_update_mappings(mapper, (s.dict for s in states)) + else: + self.bulk_insert_mappings(mapper, (s.dict for s in states)) - def bulk_save_mappings(self, mapper, mappings): - mapper = class_mapper(mapper) + def bulk_insert_mappings(self, mapper, mappings): + self._bulk_save_mappings(mapper, mappings, False) - self._bulk_save(( - statelib.MappingState(mapper, mapping) - for mapping in mappings) - ) + def bulk_update_mappings(self, mapper, mappings): + self._bulk_save_mappings(mapper, mappings, True) - def _bulk_save(self, states): + def _bulk_save_mappings(self, mapper, mappings, isupdate): + mapper = _class_to_mapper(mapper) self._flushing = True flush_context = UOWTransaction(self) - if self.dispatch.before_bulk_save: - self.dispatch.before_bulk_save( - self, flush_context, states) - flush_context.transaction = transaction = self.begin( subtransactions=True) try: - self._warn_on_events = True - try: - flush_context.bulk_save(states) - finally: - self._warn_on_events = False - - self.dispatch.after_bulk_save( - self, flush_context, states - ) - - flush_context.finalize_flush_changes() - - self.dispatch.after_bulk_save_postexec( - self, flush_context, states) - + if isupdate: + self.dispatch.before_bulk_update( + self, flush_context, mapper, mappings) + flush_context.bulk_update(mapper, mappings) + self.dispatch.after_bulk_update( + self, flush_context, mapper, mappings) + else: + self.dispatch.before_bulk_insert( + self, flush_context, mapper, mappings) + flush_context.bulk_insert(mapper, mappings) + self.dispatch.after_bulk_insert( + self, flush_context, mapper, mappings) transaction.commit() except: diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index e941bc1a47..fe8ccd222b 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -580,21 +580,6 @@ class InstanceState(interfaces.InspectionAttr): state._strong_obj = None -class MappingState(InstanceState): - committed_state = {} - callables = {} - - def __init__(self, mapper, mapping): - self.class_ = mapper.class_ - self.manager = mapper.class_manager - self.modified = True - self._dict = mapping - - @property - def dict(self): - return self._dict - - class AttributeState(object): """Provide an inspection interface corresponding to a particular attribute on a particular mapped object. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index bc8a0f5565..b3a1519c52 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -394,23 +394,11 @@ class UOWTransaction(object): if other: self.session._register_newly_persistent(other) - def bulk_save(self, states): - for (base_mapper, in_session), states_ in itertools.groupby( - states, - lambda state: - ( - state.mapper.base_mapper, - state.key is self.session.hash_key - )): - - persistence.save_obj( - base_mapper, list(states_), self, bookkeeping=in_session) - - if in_session: - self.states.update( - (state, (False, False)) - for state in states_ - ) + def bulk_insert(self, mapper, mappings): + persistence.bulk_insert(mapper, mappings, self) + + def bulk_update(self, mapper, mappings): + persistence.bulk_update(mapper, mappings, self) class IterateMappersMixin(object): -- 2.47.3