]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
un-stupified insert/update/delete sorting
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Sep 2008 00:04:38 +0000 (00:04 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Sep 2008 00:04:38 +0000 (00:04 +0000)
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/uowdumper.py

index 17fea7854f6a35a6df12e92f40ff527abdfabbbb..296c019a7f35b280b14079e860bf2708defd9784 100644 (file)
@@ -781,7 +781,8 @@ class InstanceState(object):
     key = None
     runid = None
     expired_attributes = EMPTY_SET
-
+    insert_order = None
+    
     def __init__(self, obj, manager):
         self.class_ = obj.__class__
         self.manager = manager
@@ -797,6 +798,10 @@ class InstanceState(object):
     def dispose(self):
         del self.session_id
 
+    @property
+    def sort_key(self):
+        return self.key and self.key[1] or self.insert_order
+        
     def check_modified(self):
         if self.modified:
             return True
index 21cbe3f2b8d1ee1dc7cc8dc5101211559433f85b..52b90d22a7adbd8a670b4971e58916148c783fd8 100644 (file)
@@ -1027,7 +1027,7 @@ class Mapper(object):
 
     def _get_committed_state_attr_by_column(self, state, column, passive=False):
         return self._get_col_to_prop(column).getcommitted(state, column, passive=passive)
-
+    
     def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
         """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
 
@@ -1047,9 +1047,7 @@ class Mapper(object):
 
         # if batch=false, call _save_obj separately for each object
         if not single and not self.batch:
-            def comparator(a, b):
-                return cmp(getattr(a, 'insert_order', 0), getattr(b, 'insert_order', 0))
-            for state in sorted(states, comparator):
+            for state in _sort_states(states):
                 self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
             return
 
@@ -1057,10 +1055,10 @@ class Mapper(object):
         # organize individual states with the connection to use for insert/update
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in states]
+            tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)]
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in states]
+            tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)]
 
         if not postupdate:
             # call before_XXX extensions
@@ -1185,20 +1183,11 @@ class Mapper(object):
                     clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type))
 
                 statement = table.update(clause)
-                pks = mapper._pks_by_table[table]
-                def comparator(a, b):
-                    for col in pks:
-                        x = cmp(a[1][col._label], b[1][col._label])
-                        if x != 0:
-                            return x
-                    return 0
-                update.sort(comparator)
-
                 rows = 0
                 for rec in update:
                     (state, params, mapper, connection, value_params) = rec
                     c = connection.execute(statement.values(value_params), params)
-                    mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
+                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
 
                     # testlib.pragma exempt:__hash__
                     updated_objects.add((state, connection))
@@ -1209,9 +1198,6 @@ class Mapper(object):
 
             if insert:
                 statement = table.insert()
-                def comparator(a, b):
-                    return cmp(a[0].insert_order, b[0].insert_order)
-                insert.sort(comparator)
                 for rec in insert:
                     (state, params, mapper, connection, value_params) = rec
                     c = connection.execute(statement.values(value_params), params)
@@ -1222,7 +1208,7 @@ class Mapper(object):
                         for i, col in enumerate(mapper._pks_by_table[table]):
                             if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i:
                                 mapper._set_state_attr_by_column(state, col, primary_key[i])
-                    mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
+                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
 
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this performs some unnecessary attribute transfers
@@ -1263,7 +1249,7 @@ class Mapper(object):
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
-    def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
+    def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
         """For a given Table that has just been inserted/updated,
         mark as 'expired' those attributes which correspond to columns
         that are marked as 'postfetch', and populate attributes which
@@ -1303,10 +1289,10 @@ class Mapper(object):
 
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in states]
+            tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)]
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, _state_mapper(state), connection) for state in states]
+            tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
 
         for state, mapper, connection in tups:
             if 'before_delete' in mapper.extension.methods:
@@ -1335,13 +1321,6 @@ class Mapper(object):
 
             for connection, del_objects in delete.iteritems():
                 mapper = table_to_mapper[table]
-                def comparator(a, b):
-                    for col in mapper._pks_by_table[table]:
-                        x = cmp(a[col.key], b[col.key])
-                        if x != 0:
-                            return x
-                    return 0
-                del_objects.sort(comparator)
                 clause = sql.and_()
                 for col in mapper._pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col.key, type_=col.type))
@@ -1694,6 +1673,8 @@ def _event_on_init_failure(state, instance, args, kwargs):
             instrumenting_mapper, instrumenting_mapper.class_,
             state.manager.events.original_init, instance, args, kwargs)
 
+def _sort_states(states):
+    return sorted(states, lambda a, b:cmp(a.sort_key, b.sort_key))
 
 def _load_scalar_attributes(state, attribute_names):
     mapper = _state_mapper(state)
index 66009be01cfc96670c9550b3097352d827890fce..ad987430a57c8f7925d02e69bcece337c80d63f0 100644 (file)
@@ -1039,9 +1039,6 @@ class Session(object):
             self.identity_map.remove(state)
             state.key = instance_key
 
-        if hasattr(state, 'insert_order'):
-            delattr(state, 'insert_order')
-
         obj = state.obj()
         # prevent against last minute dereferences of the object
         # TODO: identify a code path where state.obj() is None
index a46f90563a7e91c2c2d23d2e20541c84b630aca9..9ae7073b9fc5e7c5611de7f12e60532922edabe9 100644 (file)
@@ -45,25 +45,7 @@ class UOWDumper(unitofwork.UOWExecutor):
 
 
     def save_objects(self, trans, task):
-        # sort elements to be inserted by insert order
-        def comparator(a, b):
-            if a.state is None:
-                x = None
-            elif not hasattr(a.state, 'insert_order'):
-                x = None
-            else:
-                x = a.state.insert_order
-            if b.state is None:
-                y = None
-            elif not hasattr(b.state, 'insert_order'):
-                y = None
-            else:
-                y = b.state.insert_order
-            return cmp(x, y)
-
-        l = list(task.polymorphic_tosave_elements)
-        l.sort(comparator)
-        for rec in l:
+        for rec in sorted(task.polymorphic_tosave_elements, lambda a, b:cmp(a.state.sort_key, b.state.sort_key)):
             if rec.listonly:
                 continue
             self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec)  + "\n")