From: Mike Bayer Date: Fri, 18 Jun 2010 23:12:52 +0000 (-0400) Subject: - moved "post update" functionality from _save_obj() into X-Git-Tag: rel_0_6_2~30 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fc7674bcc8b18b3e218c248d34d6755b7e14383d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - moved "post update" functionality from _save_obj() into its own method, which also groups updates into executemanys. [ticket:1831] --- diff --git a/CHANGES b/CHANGES index 15c3a43599..a1013546ee 100644 --- a/CHANGES +++ b/CHANGES @@ -19,7 +19,9 @@ CHANGES calls, each affecting different foreign key columns of the same row, are executed in a single UPDATE statement, rather than one UPDATE - statement per column per row. + statement per column per row. Multiple row + updates are also batched into executemany()s as + possible, while maintaining consistent row ordering. - Query.statement, Query.subquery(), etc. now transfer the values of bind parameters, i.e. those specified diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index f35a4d76ca..935cbc35d7 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -17,7 +17,7 @@ available in :class:`~sqlalchemy.orm.`. import types import weakref import operator -from itertools import chain +from itertools import chain, groupby deque = __import__('collections').deque from sqlalchemy import sql, util, log, exc as sa_exc @@ -1450,8 +1450,96 @@ class Mapper(object): self._memoized_values[key] = value = callable_() return value - def _save_obj(self, states, uowtransaction, postupdate=False, - post_update_cols=None, single=False): + def _post_update(self, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + cached_connections = util.PopulateDict( + lambda conn:conn.execution_options( + compiled_cache=self._compiled_cache + )) + + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if 'connection_callable' in uowtransaction.mapper_flush_opts: + connection_callable = \ + uowtransaction.mapper_flush_opts['connection_callable'] + else: + connection = uowtransaction.transaction.connection(self) + connection_callable = None + + tups = [] + for state in _sort_states(states): + if connection_callable: + conn = connection_callable(self, state.obj()) + else: + conn = connection + + mapper = _state_mapper(state) + + tups.append((state, state.dict, mapper, conn)) + + table_to_mapper = self._sorted_tables + + for table in table_to_mapper: + update = [] + + for state, state_dict, mapper, connection in tups: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = \ + mapper._get_state_attr_by_column( + state, + state_dict, col) + elif col in post_update_cols: + prop = mapper._columntoproperty[col] + history = attributes.get_state_history( + state, prop.key, passive=True) + if history.added: + params[col.key] = \ + prop.get_col_value(col, + history.added[0]) + hasdata = True + if hasdata: + update.append((state, state_dict, params, mapper, + connection)) + + if update: + mapper = table_to_mapper[table] + + def update_stmt(): + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + return table.update(clause) + + statement = self._memo(('post_update', table), update_stmt) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, grouper in groupby( + update, lambda rec: (rec[4], rec[2].keys()) + ): + multiparams = [params for state, state_dict, + params, mapper, conn in grouper] + cached_connections[connection].\ + execute(statement, multiparams) + + def _save_obj(self, states, uowtransaction, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1466,21 +1554,15 @@ class Mapper(object): mapper which does not inherit from any other mapper. """ + # if batch=false, call _save_obj separately for each object if not single and not self.batch: for state in _sort_states(states): self._save_obj([state], uowtransaction, - postupdate=postupdate, - post_update_cols=post_update_cols, single=True) return - cached_connections = util.PopulateDict( - lambda conn:conn.execution_options( - compiled_cache=self._compiled_cache - )) - # if session has a connection callable, # organize individual states with the connection # to use for insert/update @@ -1504,53 +1586,53 @@ class Mapper(object): instance_key = state.key or mapper._identity_key_from_state(state) row_switch = None - if not postupdate: - # call before_XXX extensions - if not has_identity: - if 'before_insert' in mapper.extension: - mapper.extension.before_insert( - mapper, conn, state.obj()) - else: - if 'before_update' in mapper.extension: - mapper.extension.before_update( - mapper, conn, state.obj()) - - # detect if we have a "pending" instance (i.e. has - # no instance_key attached to it), and another instance - # with the same identity key already exists as persistent. - # convert to an UPDATE if so. - if not has_identity and \ - instance_key in uowtransaction.session.identity_map: - instance = \ - uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise orm_exc.FlushError( - "New instance %s with identity key %s conflicts " - "with persistent instance %s" % - (state_str(state), instance_key, - state_str(existing))) - - self._log_debug( - "detected row switch for identity %s. " - "will update %s, remove %s from " - "transaction", instance_key, - state_str(state), state_str(existing)) - - # remove the "delete" flag from the existing element - uowtransaction.remove_state_actions(existing) - row_switch = existing + # call before_XXX extensions + if not has_identity: + if 'before_insert' in mapper.extension: + mapper.extension.before_insert( + mapper, conn, state.obj()) + else: + if 'before_update' in mapper.extension: + mapper.extension.before_update( + mapper, conn, state.obj()) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: + instance = \ + uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise orm_exc.FlushError( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" % + (state_str(state), instance_key, + state_str(existing))) + + self._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", instance_key, + state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing tups.append( - (state, - state.dict, - mapper, - conn, - has_identity, - instance_key, - row_switch) + (state, state.dict, mapper, conn, + has_identity, instance_key, row_switch) ) + # dictionary of connection->connection_with_cache_options. + cached_connections = util.PopulateDict( + lambda conn:conn.execution_options( + compiled_cache=self._compiled_cache + )) + table_to_mapper = self._sorted_tables for table in table_to_mapper: @@ -1565,7 +1647,6 @@ class Mapper(object): pks = mapper._pks_by_table[table] isinsert = not has_identity and \ - not postupdate and \ not row_switch params = {} @@ -1630,14 +1711,6 @@ class Mapper(object): col not in pks: pass else: - if post_update_cols is not None and \ - col not in post_update_cols: - if col in pks: - params[col._label] = mapper.\ - _get_state_attr_by_column( - state, state_dict, col) - continue - prop = mapper._columntoproperty[col] history = attributes.get_state_history( state, prop.key, passive=True) @@ -1691,7 +1764,6 @@ class Mapper(object): connection, value_params)) if update: - mapper = table_to_mapper[table] needs_version_id = mapper.version_id_col is not None and \ @@ -1701,10 +1773,8 @@ class Mapper(object): clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append( - col == sql.bindparam(col._label, - type_=col.type) - ) + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) if needs_version_id: clause.clauses.append(mapper.version_id_col ==\ @@ -1777,40 +1847,40 @@ class Mapper(object): c.last_inserted_params(), value_params) - if not postupdate: - for state, state_dict, mapper, connection, has_identity, \ - instance_key, row_switch in tups: + for state, state_dict, mapper, connection, has_identity, \ + instance_key, row_switch in tups: - # expire readonly attributes - readonly = state.unmodified.intersection( - p.key for p in mapper._readonly_props - ) - - if readonly: - _expire_state(state, state.dict, readonly) - - # if eager_defaults option is enabled, - # refresh whatever has been expired. - if self.eager_defaults and state.unloaded: - state.key = self._identity_key_from_state(state) - uowtransaction.session.query(self)._get( - state.key, refresh_state=state, - only_load_props=state.unloaded) - - # call after_XXX extensions - if not has_identity: - if 'after_insert' in mapper.extension: - mapper.extension.after_insert( - mapper, connection, state.obj()) - else: - if 'after_update' in mapper.extension: - mapper.extension.after_update( - mapper, connection, state.obj()) + # expire readonly attributes + readonly = state.unmodified.intersection( + p.key for p in mapper._readonly_props + ) + + if readonly: + _expire_state(state, state.dict, readonly) + + # if eager_defaults option is enabled, + # refresh whatever has been expired. + if self.eager_defaults and state.unloaded: + state.key = self._identity_key_from_state(state) + uowtransaction.session.query(self)._get( + state.key, refresh_state=state, + only_load_props=state.unloaded) + + # call after_XXX extensions + if not has_identity: + if 'after_insert' in mapper.extension: + mapper.extension.after_insert( + mapper, connection, state.obj()) + else: + if 'after_update' in mapper.extension: + mapper.extension.after_update( + mapper, connection, state.obj()) def _postfetch(self, uowtransaction, table, state, dict_, resultproxy, params, value_params): - """Expire attributes in need of newly persisted database state.""" + """During a flush, expire attributes in need of newly + persisted database state.""" postfetch_cols = resultproxy.postfetch_cols() generated_cols = list(resultproxy.prefetch_cols()) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 2a6f439bca..e108919243 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -430,9 +430,7 @@ class IssuePostUpdate(PostSortRec): states, cols = uow.post_update_states[self.mapper] states = [s for s in states if uow.states[s][0] == self.isdelete] - self.mapper._save_obj(states, uow, \ - postupdate=True, \ - post_update_cols=cols) + self.mapper._post_update(states, uow, cols) class SaveUpdateAll(PostSortRec): def __init__(self, uow, mapper): diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index 896f7f07d5..26765244a7 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -713,44 +713,30 @@ class OneToManyManyToOneTest(_base.MappedTest): "VALUES (:favorite_ball_id, :data)", lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}), - AllOf( - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b2.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b3.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b4.id}) + CompiledSQL("UPDATE ball SET person_id=:person_id " + "WHERE ball.id = :ball_id", + lambda ctx:[ + {'person_id':p.id,'ball_id':b.id}, + {'person_id':p.id,'ball_id':b2.id}, + {'person_id':p.id,'ball_id':b3.id}, + {'person_id':p.id,'ball_id':b4.id} + ] ), + ) sess.delete(p) self.assert_sql_execution(testing.db, sess.flush, - AllOf(CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b.id}), - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b2.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b3.id}), - - CompiledSQL("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id': None, 'ball_id': b4.id})), - + "WHERE ball.id = :ball_id", + lambda ctx:[ + {'person_id': None, 'ball_id': b.id}, + {'person_id': None, 'ball_id': b2.id}, + {'person_id': None, 'ball_id': b3.id}, + {'person_id': None, 'ball_id': b4.id} + ] + ), CompiledSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id':p.id}]), @@ -874,16 +860,16 @@ class SelfReferentialPostUpdateTest(_base.MappedTest): ) session.delete(root) + self.assert_sql_execution( testing.db, session.flush, - AllOf( - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx:{'next_sibling_id':None, 'node_id':about.id}), - CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " - "WHERE node.id = :node_id", - lambda ctx:{'node_id':stories.id, 'next_sibling_id':None}) + CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " + "WHERE node.id = :node_id", + lambda ctx: [ + {'node_id': about.id, 'next_sibling_id': None}, + {'node_id': stories.id, 'next_sibling_id': None} + ] ), AllOf( CompiledSQL("DELETE FROM node WHERE node.id = :id",