From: Mike Bayer Date: Fri, 19 Sep 2008 00:04:38 +0000 (+0000) Subject: un-stupified insert/update/delete sorting X-Git-Tag: rel_0_5rc2~40 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2c2ecbae867801c66b57770d5f7501bd4c0c3474;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git un-stupified insert/update/delete sorting --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 17fea7854f..296c019a7f 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -781,7 +781,8 @@ class InstanceState(object): key = None runid = None expired_attributes = EMPTY_SET - + insert_order = None + def __init__(self, obj, manager): self.class_ = obj.__class__ self.manager = manager @@ -797,6 +798,10 @@ class InstanceState(object): def dispose(self): del self.session_id + @property + def sort_key(self): + return self.key and self.key[1] or self.insert_order + def check_modified(self): if self.modified: return True diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 21cbe3f2b8..52b90d22a7 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1027,7 +1027,7 @@ class Mapper(object): def _get_committed_state_attr_by_column(self, state, column, passive=False): return self._get_col_to_prop(column).getcommitted(state, column, passive=passive) - + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1047,9 +1047,7 @@ class Mapper(object): # if batch=false, call _save_obj separately for each object if not single and not self.batch: - def comparator(a, b): - return cmp(getattr(a, 'insert_order', 0), getattr(b, 'insert_order', 0)) - for state in sorted(states, comparator): + for state in _sort_states(states): self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return @@ -1057,10 +1055,10 @@ class Mapper(object): # organize individual states with the connection to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in states] + tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in states] + tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)] if not postupdate: # call before_XXX extensions @@ -1185,20 +1183,11 @@ class Mapper(object): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type)) statement = table.update(clause) - pks = mapper._pks_by_table[table] - def comparator(a, b): - for col in pks: - x = cmp(a[1][col._label], b[1][col._label]) - if x != 0: - return x - return 0 - update.sort(comparator) - rows = 0 for rec in update: (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) - mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params) + mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params) # testlib.pragma exempt:__hash__ updated_objects.add((state, connection)) @@ -1209,9 +1198,6 @@ class Mapper(object): if insert: statement = table.insert() - def comparator(a, b): - return cmp(a[0].insert_order, b[0].insert_order) - insert.sort(comparator) for rec in insert: (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) @@ -1222,7 +1208,7 @@ class Mapper(object): for i, col in enumerate(mapper._pks_by_table[table]): if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i: mapper._set_state_attr_by_column(state, col, primary_key[i]) - mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params) + mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this performs some unnecessary attribute transfers @@ -1263,7 +1249,7 @@ class Mapper(object): if 'after_update' in mapper.extension.methods: mapper.extension.after_update(mapper, connection, state.obj()) - def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): + def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): """For a given Table that has just been inserted/updated, mark as 'expired' those attributes which correspond to columns that are marked as 'postfetch', and populate attributes which @@ -1303,10 +1289,10 @@ class Mapper(object): if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in states] + tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection) for state in states] + tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)] for state, mapper, connection in tups: if 'before_delete' in mapper.extension.methods: @@ -1335,13 +1321,6 @@ class Mapper(object): for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] - def comparator(a, b): - for col in mapper._pks_by_table[table]: - x = cmp(a[col.key], b[col.key]) - if x != 0: - return x - return 0 - del_objects.sort(comparator) clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) @@ -1694,6 +1673,8 @@ def _event_on_init_failure(state, instance, args, kwargs): instrumenting_mapper, instrumenting_mapper.class_, state.manager.events.original_init, instance, args, kwargs) +def _sort_states(states): + return sorted(states, lambda a, b:cmp(a.sort_key, b.sort_key)) def _load_scalar_attributes(state, attribute_names): mapper = _state_mapper(state) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 66009be01c..ad987430a5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1039,9 +1039,6 @@ class Session(object): self.identity_map.remove(state) state.key = instance_key - if hasattr(state, 'insert_order'): - delattr(state, 'insert_order') - obj = state.obj() # prevent against last minute dereferences of the object # TODO: identify a code path where state.obj() is None diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py index a46f90563a..9ae7073b9f 100644 --- a/lib/sqlalchemy/orm/uowdumper.py +++ b/lib/sqlalchemy/orm/uowdumper.py @@ -45,25 +45,7 @@ class UOWDumper(unitofwork.UOWExecutor): def save_objects(self, trans, task): - # sort elements to be inserted by insert order - def comparator(a, b): - if a.state is None: - x = None - elif not hasattr(a.state, 'insert_order'): - x = None - else: - x = a.state.insert_order - if b.state is None: - y = None - elif not hasattr(b.state, 'insert_order'): - y = None - else: - y = b.state.insert_order - return cmp(x, y) - - l = list(task.polymorphic_tosave_elements) - l.sort(comparator) - for rec in l: + for rec in sorted(task.polymorphic_tosave_elements, lambda a, b:cmp(a.state.sort_key, b.state.sort_key)): if rec.listonly: continue self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n")