From: Mike Bayer Date: Tue, 19 Aug 2014 18:24:56 +0000 (-0400) Subject: - refinements X-Git-Tag: rel_1_0_0b1~213 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=91959122e0a12943e5ff9399024c65ad4d7489e1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - refinements --- diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 37ea3071bc..aa99673bad 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1453,18 +1453,6 @@ class SessionEvents(event.Events): """ - def before_bulk_insert(self, session, flush_context, mapper, mappings): - """""" - - def after_bulk_insert(self, session, flush_context, mapper, mappings): - """""" - - def before_bulk_update(self, session, flush_context, mapper, mappings): - """""" - - def after_bulk_update(self, session, flush_context, mapper, mappings): - """""" - def after_begin(self, session, transaction, connection): """Execute after a transaction is begun on a connection diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 89c092b580..b98fbda420 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2366,6 +2366,10 @@ class Mapper(InspectionAttr): def _primary_key_props(self): return [self._columntoproperty[col] for col in self.primary_key] + @_memoized_configured_property + def _primary_key_propkeys(self): + return set([prop.key for prop in self._primary_key_props]) + def _get_state_attr_by_column( self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NEVER_SET): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 145a7783a3..9c00089254 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -23,17 +23,22 @@ from ..sql import expression from . import loading -def bulk_insert(mapper, mappings, uowtransaction): +def _bulk_insert(mapper, mappings, session_transaction, isstates): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) - if uowtransaction.session.connection_callable: + if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_insert()") - connection = uowtransaction.transaction.connection(base_mapper) + if isstates: + mappings = [state.dict for state in mappings] + else: + mappings = list(mappings) + + connection = session_transaction.connection(base_mapper) for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue @@ -45,61 +50,55 @@ def bulk_insert(mapper, mappings, uowtransaction): state, state_dict, params, mp, conn, value_params, has_all_pks, has_all_defaults in _collect_insert_commands(table, ( - (None, mapping, super_mapper, connection) - for mapping in mappings) + (None, mapping, mapper, connection) + for mapping in mappings), + bulk=True ) ) - _emit_insert_statements(base_mapper, uowtransaction, + _emit_insert_statements(base_mapper, None, cached_connections, super_mapper, table, records, bookkeeping=False) -def bulk_update(mapper, mappings, uowtransaction): +def _bulk_update(mapper, mappings, session_transaction, isstates): base_mapper = mapper.base_mapper cached_connections = _cached_connection_dict(base_mapper) - if uowtransaction.session.connection_callable: + def _changed_dict(mapper, state): + return dict( + (k, v) + for k, v in state.dict.items() if k in state.committed_state or k + in mapper._primary_key_propkeys + ) + + if isstates: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = list(mappings) + + if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " "not supported in bulk_update()") - connection = uowtransaction.transaction.connection(base_mapper) + connection = session_transaction.connection(base_mapper) value_params = {} + for table, super_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(super_mapper): continue - label_pks = super_mapper._pks_by_table[table] - if mapper.version_id_col is not None: - label_pks = label_pks.union([mapper.version_id_col]) - - to_translate = dict( - (propkey, col._label if col in label_pks else col.key) - for propkey, col in super_mapper._propkey_to_col[table].items() + records = ( + (None, None, params, super_mapper, connection, value_params) + for + params in _collect_bulk_update_commands(mapper, table, mappings) ) - records = [] - for mapping in mappings: - params = dict( - (to_translate[k], v) for k, v in mapping.items() - ) - - if mapper.version_id_generator is not False and \ - mapper.version_id_col is not None and \ - mapper.version_id_col.key not in params: - params[mapper.version_id_col.key] = \ - mapper.version_id_generator( - params[mapper.version_id_col._label]) - - records.append( - (None, None, params, super_mapper, connection, value_params) - ) - - _emit_update_statements(base_mapper, uowtransaction, + _emit_update_statements(base_mapper, None, cached_connections, super_mapper, table, records, bookkeeping=False) @@ -360,7 +359,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False): col = propkey_to_col[propkey] if value is None: continue - elif isinstance(value, sql.ClauseElement): + elif not bulk and isinstance(value, sql.ClauseElement): value_params[col.key] = value else: params[col.key] = value @@ -481,6 +480,44 @@ def _collect_update_commands(uowtransaction, table, states_to_update): state, state_dict, params, mapper, connection, value_params) +def _collect_bulk_update_commands(mapper, table, mappings): + label_pks = mapper._pks_by_table[table] + if mapper.version_id_col is not None: + label_pks = label_pks.union([mapper.version_id_col]) + + to_translate = dict( + (propkey, col.key if col not in label_pks else col._label) + for propkey, col in mapper._propkey_to_col[table].items() + ) + + for mapping in mappings: + params = dict( + (to_translate[k], mapping[k]) for k in to_translate + if k in mapping and k not in mapper._primary_key_propkeys + ) + + if not params: + continue + + try: + params.update( + (to_translate[k], mapping[k]) for k in + mapper._primary_key_propkeys.intersection(to_translate) + ) + except KeyError as ke: + raise orm_exc.FlushError( + "Can't update table using NULL for primary " + "key attribute: %s" % ke) + + if mapper.version_id_generator is not False and \ + mapper.version_id_col is not None and \ + mapper.version_id_col.key not in params: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator( + params[mapper.version_id_col._label]) + + yield params + def _collect_post_update_commands(base_mapper, uowtransaction, table, states_to_update, post_update_cols): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 3199a4332d..968868e843 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -21,6 +21,7 @@ from .base import ( _none_set, state_str, instance_str ) import itertools +from . import persistence from .unitofwork import UOWTransaction from . import state as statelib import sys @@ -2040,37 +2041,27 @@ class Session(_SessionClassMethods): (attributes.instance_state(obj) for obj in objects), lambda state: (state.mapper, state.key is not None) ): - if isupdate: - self.bulk_update_mappings(mapper, (s.dict for s in states)) - else: - self.bulk_insert_mappings(mapper, (s.dict for s in states)) + self._bulk_save_mappings(mapper, states, isupdate, True) def bulk_insert_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, False) + self._bulk_save_mappings(mapper, mappings, False, False) def bulk_update_mappings(self, mapper, mappings): - self._bulk_save_mappings(mapper, mappings, True) + self._bulk_save_mappings(mapper, mappings, True, False) - def _bulk_save_mappings(self, mapper, mappings, isupdate): + def _bulk_save_mappings(self, mapper, mappings, isupdate, isstates): mapper = _class_to_mapper(mapper) self._flushing = True - flush_context = UOWTransaction(self) - flush_context.transaction = transaction = self.begin( + transaction = self.begin( subtransactions=True) try: if isupdate: - self.dispatch.before_bulk_update( - self, flush_context, mapper, mappings) - flush_context.bulk_update(mapper, mappings) - self.dispatch.after_bulk_update( - self, flush_context, mapper, mappings) + persistence._bulk_update( + mapper, mappings, transaction, isstates) else: - self.dispatch.before_bulk_insert( - self, flush_context, mapper, mappings) - flush_context.bulk_insert(mapper, mappings) - self.dispatch.after_bulk_insert( - self, flush_context, mapper, mappings) + persistence._bulk_insert( + mapper, mappings, transaction, isstates) transaction.commit() except: diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index b3a1519c52..05265b13fb 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -394,12 +394,6 @@ class UOWTransaction(object): if other: self.session._register_newly_persistent(other) - def bulk_insert(self, mapper, mappings): - persistence.bulk_insert(mapper, mappings, self) - - def bulk_update(self, mapper, mappings): - persistence.bulk_update(mapper, mappings, self) - class IterateMappersMixin(object): def _mappers(self, uow):