]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved "post update" functionality from _save_obj() into
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Jun 2010 23:12:52 +0000 (19:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Jun 2010 23:12:52 +0000 (19:12 -0400)
its own method, which also groups updates into executemanys.
[ticket:1831]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_cycles.py

diff --git a/CHANGES b/CHANGES
index 15c3a435999a7e54457eda5f4e237de94d3475cf..a1013546ee0ae29dd7f1a9762b1000012ea2e311 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -19,7 +19,9 @@ CHANGES
     calls, each affecting different foreign key
     columns of the same row, are executed in a single
     UPDATE statement, rather than one UPDATE 
-    statement per column per row.
+    statement per column per row.   Multiple row
+    updates are also batched into executemany()s as 
+    possible, while maintaining consistent row ordering.
     
   - Query.statement, Query.subquery(), etc. now transfer
     the values of bind parameters, i.e. those specified
index f35a4d76cae9964abd07327cdc43c2b74a0e692d..935cbc35d79539a417ceb6efdc7081ab1ad6c1f7 100644 (file)
@@ -17,7 +17,7 @@ available in :class:`~sqlalchemy.orm.`.
 import types
 import weakref
 import operator
-from itertools import chain
+from itertools import chain, groupby
 deque = __import__('collections').deque
 
 from sqlalchemy import sql, util, log, exc as sa_exc
@@ -1450,8 +1450,96 @@ class Mapper(object):
             self._memoized_values[key] = value = callable_()
             return value
     
-    def _save_obj(self, states, uowtransaction, postupdate=False, 
-                                post_update_cols=None, single=False):
+    def _post_update(self, states, uowtransaction, post_update_cols):
+        """Issue UPDATE statements on behalf of a relationship() which
+        specifies post_update.
+        
+        """
+        cached_connections = util.PopulateDict(
+            lambda conn:conn.execution_options(
+            compiled_cache=self._compiled_cache
+        ))
+
+        # if session has a connection callable,
+        # organize individual states with the connection 
+        # to use for update
+        if 'connection_callable' in uowtransaction.mapper_flush_opts:
+            connection_callable = \
+                    uowtransaction.mapper_flush_opts['connection_callable']
+        else:
+            connection = uowtransaction.transaction.connection(self)
+            connection_callable = None
+
+        tups = []
+        for state in _sort_states(states):
+            if connection_callable:
+                conn = connection_callable(self, state.obj())
+            else:
+                conn = connection
+
+            mapper = _state_mapper(state)
+                
+            tups.append((state, state.dict, mapper, conn))
+
+        table_to_mapper = self._sorted_tables
+
+        for table in table_to_mapper:
+            update = []
+
+            for state, state_dict, mapper, connection in tups:
+                if table not in mapper._pks_by_table:
+                    continue
+                    
+                pks = mapper._pks_by_table[table]
+                params = {}
+                hasdata = False
+
+                for col in mapper._cols_by_table[table]:
+                    if col in pks:
+                        params[col._label] = \
+                                mapper._get_state_attr_by_column(
+                                                state,
+                                                state_dict, col)
+                    elif col in post_update_cols:
+                        prop = mapper._columntoproperty[col]
+                        history = attributes.get_state_history(
+                                        state, prop.key, passive=True)
+                        if history.added:
+                            params[col.key] = \
+                                prop.get_col_value(col,
+                                                    history.added[0])
+                            hasdata = True
+                if hasdata:
+                    update.append((state, state_dict, params, mapper, 
+                                    connection))
+        
+            if update:
+                mapper = table_to_mapper[table]
+
+                def update_stmt():
+                    clause = sql.and_()
+
+                    for col in mapper._pks_by_table[table]:
+                        clause.clauses.append(col == sql.bindparam(col._label,
+                                                        type_=col.type))
+
+                    return table.update(clause)
+
+                statement = self._memo(('post_update', table), update_stmt)
+
+                # execute each UPDATE in the order according to the original
+                # list of states to guarantee row access order, but
+                # also group them into common (connection, cols) sets 
+                # to support executemany().
+                for key, grouper in groupby(
+                    update, lambda rec: (rec[4], rec[2].keys())
+                ):
+                    multiparams = [params for state, state_dict, 
+                                            params, mapper, conn in grouper]
+                    cached_connections[connection].\
+                                        execute(statement, multiparams)
+                
+    def _save_obj(self, states, uowtransaction, single=False):
         """Issue ``INSERT`` and/or ``UPDATE`` statements for a list 
         of objects.
 
@@ -1466,21 +1554,15 @@ class Mapper(object):
         mapper which does not inherit from any other mapper.
         
         """
+            
         # if batch=false, call _save_obj separately for each object
         if not single and not self.batch:
             for state in _sort_states(states):
                 self._save_obj([state], 
                                 uowtransaction, 
-                                postupdate=postupdate, 
-                                post_update_cols=post_update_cols, 
                                 single=True)
             return
 
-        cached_connections = util.PopulateDict(
-            lambda conn:conn.execution_options(
-            compiled_cache=self._compiled_cache
-        ))
-
         # if session has a connection callable,
         # organize individual states with the connection 
         # to use for insert/update
@@ -1504,53 +1586,53 @@ class Mapper(object):
             instance_key = state.key or mapper._identity_key_from_state(state)
 
             row_switch = None
-            if not postupdate:
-                # call before_XXX extensions
-                if not has_identity:
-                    if 'before_insert' in mapper.extension:
-                        mapper.extension.before_insert(
-                                            mapper, conn, state.obj())
-                else:
-                    if 'before_update' in mapper.extension:
-                        mapper.extension.before_update(
-                                            mapper, conn, state.obj())
-
-                # detect if we have a "pending" instance (i.e. has 
-                # no instance_key attached to it), and another instance 
-                # with the same identity key already exists as persistent. 
-                # convert to an UPDATE if so.
-                if not has_identity and \
-                    instance_key in uowtransaction.session.identity_map:
-                    instance = \
-                        uowtransaction.session.identity_map[instance_key]
-                    existing = attributes.instance_state(instance)
-                    if not uowtransaction.is_deleted(existing):
-                        raise orm_exc.FlushError(
-                            "New instance %s with identity key %s conflicts "
-                            "with persistent instance %s" % 
-                            (state_str(state), instance_key,
-                             state_str(existing)))
-
-                    self._log_debug(
-                        "detected row switch for identity %s.  "
-                        "will update %s, remove %s from "
-                        "transaction", instance_key, 
-                        state_str(state), state_str(existing))
-
-                    # remove the "delete" flag from the existing element
-                    uowtransaction.remove_state_actions(existing)
-                    row_switch = existing
+            # call before_XXX extensions
+            if not has_identity:
+                if 'before_insert' in mapper.extension:
+                    mapper.extension.before_insert(
+                                        mapper, conn, state.obj())
+            else:
+                if 'before_update' in mapper.extension:
+                    mapper.extension.before_update(
+                                        mapper, conn, state.obj())
+
+            # detect if we have a "pending" instance (i.e. has 
+            # no instance_key attached to it), and another instance 
+            # with the same identity key already exists as persistent. 
+            # convert to an UPDATE if so.
+            if not has_identity and \
+                instance_key in uowtransaction.session.identity_map:
+                instance = \
+                    uowtransaction.session.identity_map[instance_key]
+                existing = attributes.instance_state(instance)
+                if not uowtransaction.is_deleted(existing):
+                    raise orm_exc.FlushError(
+                        "New instance %s with identity key %s conflicts "
+                        "with persistent instance %s" % 
+                        (state_str(state), instance_key,
+                         state_str(existing)))
+
+                self._log_debug(
+                    "detected row switch for identity %s.  "
+                    "will update %s, remove %s from "
+                    "transaction", instance_key, 
+                    state_str(state), state_str(existing))
+
+                # remove the "delete" flag from the existing element
+                uowtransaction.remove_state_actions(existing)
+                row_switch = existing
 
             tups.append(
-                (state,
-                    state.dict,
-                    mapper,
-                    conn,
-                    has_identity,
-                    instance_key,
-                    row_switch)
+                (state, state.dict, mapper, conn, 
+                has_identity, instance_key, row_switch)
             )
 
+        # dictionary of connection->connection_with_cache_options.
+        cached_connections = util.PopulateDict(
+            lambda conn:conn.execution_options(
+            compiled_cache=self._compiled_cache
+        ))
+
         table_to_mapper = self._sorted_tables
 
         for table in table_to_mapper:
@@ -1565,7 +1647,6 @@ class Mapper(object):
                 pks = mapper._pks_by_table[table]
                 
                 isinsert = not has_identity and \
-                                not postupdate and \
                                 not row_switch
                 
                 params = {}
@@ -1630,14 +1711,6 @@ class Mapper(object):
                                 col not in pks:
                             pass
                         else:
-                            if post_update_cols is not None and \
-                                col not in post_update_cols:
-                                if col in pks:
-                                    params[col._label] = mapper.\
-                                                _get_state_attr_by_column(
-                                                    state, state_dict, col)
-                                continue
-
                             prop = mapper._columntoproperty[col]
                             history = attributes.get_state_history(
                                             state, prop.key, passive=True)
@@ -1691,7 +1764,6 @@ class Mapper(object):
                                         connection, value_params))
 
             if update:
-                
                 mapper = table_to_mapper[table]
 
                 needs_version_id = mapper.version_id_col is not None and \
@@ -1701,10 +1773,8 @@ class Mapper(object):
                     clause = sql.and_()
                 
                     for col in mapper._pks_by_table[table]:
-                        clause.clauses.append(
-                                            col ==  sql.bindparam(col._label,
-                                                        type_=col.type)
-                                        )
+                        clause.clauses.append(col == sql.bindparam(col._label,
+                                                        type_=col.type))
 
                     if needs_version_id:
                         clause.clauses.append(mapper.version_id_col ==\
@@ -1777,40 +1847,40 @@ class Mapper(object):
                                         c.last_inserted_params(),
                                         value_params)
 
-        if not postupdate:
-            for state, state_dict, mapper, connection, has_identity, \
-                            instance_key, row_switch in tups:
+        for state, state_dict, mapper, connection, has_identity, \
+                        instance_key, row_switch in tups:
 
-                # expire readonly attributes
-                readonly = state.unmodified.intersection(
-                    p.key for p in mapper._readonly_props
-                )
-                
-                if readonly:
-                    _expire_state(state, state.dict, readonly)
-
-                # if eager_defaults option is enabled,
-                # refresh whatever has been expired.
-                if self.eager_defaults and state.unloaded:
-                    state.key = self._identity_key_from_state(state)
-                    uowtransaction.session.query(self)._get(
-                        state.key, refresh_state=state,
-                        only_load_props=state.unloaded)
-
-                # call after_XXX extensions
-                if not has_identity:
-                    if 'after_insert' in mapper.extension:
-                        mapper.extension.after_insert(
-                                            mapper, connection, state.obj())
-                else:
-                    if 'after_update' in mapper.extension:
-                        mapper.extension.after_update(
-                                            mapper, connection, state.obj())
+            # expire readonly attributes
+            readonly = state.unmodified.intersection(
+                p.key for p in mapper._readonly_props
+            )
+            
+            if readonly:
+                _expire_state(state, state.dict, readonly)
+
+            # if eager_defaults option is enabled,
+            # refresh whatever has been expired.
+            if self.eager_defaults and state.unloaded:
+                state.key = self._identity_key_from_state(state)
+                uowtransaction.session.query(self)._get(
+                    state.key, refresh_state=state,
+                    only_load_props=state.unloaded)
+
+            # call after_XXX extensions
+            if not has_identity:
+                if 'after_insert' in mapper.extension:
+                    mapper.extension.after_insert(
+                                        mapper, connection, state.obj())
+            else:
+                if 'after_update' in mapper.extension:
+                    mapper.extension.after_update(
+                                        mapper, connection, state.obj())
 
     def _postfetch(self, uowtransaction, table, 
                         state, dict_, resultproxy, 
                         params, value_params):
-        """Expire attributes in need of newly persisted database state."""
+        """During a flush, expire attributes in need of newly 
+        persisted database state."""
 
         postfetch_cols = resultproxy.postfetch_cols()
         generated_cols = list(resultproxy.prefetch_cols())
index 2a6f439bcabd3a6090d19f74fb76135b3f143968..e108919243dc9e53c395b5e647e950b62353a09a 100644 (file)
@@ -430,9 +430,7 @@ class IssuePostUpdate(PostSortRec):
         states, cols = uow.post_update_states[self.mapper]
         states = [s for s in states if uow.states[s][0] == self.isdelete]
         
-        self.mapper._save_obj(states, uow, \
-                            postupdate=True, \
-                            post_update_cols=cols)
+        self.mapper._post_update(states, uow, cols)
 
 class SaveUpdateAll(PostSortRec):
     def __init__(self, uow, mapper):
index 896f7f07d5242f36be468781e7c6a1800640698f..26765244a7e0ce1975ce5a7e9c25e0cb03cc94f2 100644 (file)
@@ -713,44 +713,30 @@ class OneToManyManyToOneTest(_base.MappedTest):
              "VALUES (:favorite_ball_id, :data)",
              lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}),
 
-            AllOf(
-            CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b.id}),
-
-            CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b2.id}),
-
-            CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b3.id}),
-
-            CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b4.id})
+             CompiledSQL("UPDATE ball SET person_id=:person_id "
+              "WHERE ball.id = :ball_id",
+              lambda ctx:[
+                {'person_id':p.id,'ball_id':b.id},
+                {'person_id':p.id,'ball_id':b2.id},
+                {'person_id':p.id,'ball_id':b3.id},
+                {'person_id':p.id,'ball_id':b4.id}
+                ]
             ),
+
         )
         
         sess.delete(p)
         
         self.assert_sql_execution(testing.db, sess.flush, 
-            AllOf(CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id': None, 'ball_id': b.id}),
-
             CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id': None, 'ball_id': b2.id}),
-
-            CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id': None, 'ball_id': b3.id}),
-
-            CompiledSQL("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id': None, 'ball_id': b4.id})),
-
+                "WHERE ball.id = :ball_id",
+                lambda ctx:[
+                    {'person_id': None, 'ball_id': b.id},
+                    {'person_id': None, 'ball_id': b2.id},
+                    {'person_id': None, 'ball_id': b3.id},
+                    {'person_id': None, 'ball_id': b4.id}
+                ]
+            ),
             CompiledSQL("DELETE FROM person WHERE person.id = :id",
              lambda ctx:[{'id':p.id}]),
 
@@ -874,16 +860,16 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
         )
 
         session.delete(root)
+
         self.assert_sql_execution(
             testing.db, 
             session.flush,
-            AllOf(
-                CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
-                            "WHERE node.id = :node_id", 
-                            lambda ctx:{'next_sibling_id':None, 'node_id':about.id}),
-                CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
-                            "WHERE node.id = :node_id",
-                            lambda ctx:{'node_id':stories.id, 'next_sibling_id':None})
+            CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
+                "WHERE node.id = :node_id", 
+                lambda ctx: [
+                            {'node_id': about.id, 'next_sibling_id': None}, 
+                            {'node_id': stories.id, 'next_sibling_id': None}
+                        ]
             ),
             AllOf(
                 CompiledSQL("DELETE FROM node WHERE node.id = :id",