]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
cleanup and callcount reduction in mapper._save_obj, _delete_obj.
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Apr 2010 01:42:41 +0000 (21:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Apr 2010 01:42:41 +0000 (21:42 -0400)
includes an untested fix for [ticket:1761]

lib/sqlalchemy/orm/mapper.py

index 8f0f2128bd86ea0a4abdac6ca5d13440fdc5be2d..f83275b360341df9886e6e2a32f570d8594c9b21 100644 (file)
@@ -1224,11 +1224,13 @@ class Mapper(object):
             try:
                 if item_type == 'property':
                     prop = iterator.next()
-                    visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None))
+                    visitables.append((prop.cascade_iterator(type_, parent_state, 
+                                            visited_instances, halt_on), 'mapper', None))
                 elif item_type == 'mapper':
                     instance, instance_mapper, corresponding_state  = iterator.next()
                     yield (instance, instance_mapper)
-                    visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state))
+                    visitables.append((instance_mapper._props.itervalues(), 
+                                            'property', corresponding_state))
             except StopIteration:
                 visitables.pop()
 
@@ -1263,55 +1265,46 @@ class Mapper(object):
         # if batch=false, call _save_obj separately for each object
         if not single and not self.batch:
             for state in _sort_states(states):
-                self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
+                self._save_obj([state], 
+                                uowtransaction, 
+                                postupdate=postupdate, 
+                                post_update_cols=post_update_cols, 
+                                single=True)
             return
-
+        
         # if session has a connection callable,
-        # organize individual states with the connection to use for insert/update
-        tups = []
+        # organize individual states with the connection 
+        # to use for insert/update
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
-            connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            for state in _sort_states(states):
-                m = _state_mapper(state)
-                tups.append(
-                    (
-                        state, 
-                        m, 
-                        connection_callable(self, state.obj()), 
-                        _state_has_identity(state), 
-                        state.key or m._identity_key_from_state(state)
-                    )
-                )
+            connection_callable = \
+                    uowtransaction.mapper_flush_opts['connection_callable']
         else:
             connection = uowtransaction.transaction.connection(self)
-            for state in _sort_states(states):
-                m = _state_mapper(state)
-                tups.append(
-                    (
-                        state, 
-                        m, 
-                        connection,
-                        _state_has_identity(state), 
-                        state.key or m._identity_key_from_state(state)
-                    )
-                )
+            connection_callable = None
 
-        if not postupdate:
-            # call before_XXX extensions
-            for state, mapper, connection, has_identity, instance_key in tups:
+        tups = []
+        for state in _sort_states(states):
+            conn = connection_callable and \
+                connection_callable(self, state.obj()) or \
+                connection
+
+            has_identity = _state_has_identity(state)
+            mapper = _state_mapper(state)
+            instance_key = state.key or mapper._identity_key_from_state(state)
+
+            row_switch = None
+            if not postupdate:
+                # call before_XXX extensions
                 if not has_identity:
                     if 'before_insert' in mapper.extension:
-                        mapper.extension.before_insert(mapper, connection, state.obj())
+                        mapper.extension.before_insert(mapper, conn, state.obj())
                 else:
                     if 'before_update' in mapper.extension:
-                        mapper.extension.before_update(mapper, connection, state.obj())
+                        mapper.extension.before_update(mapper, conn, state.obj())
 
-        row_switches = {}
-        if not postupdate:
-            for state, mapper, connection, has_identity, instance_key in tups:
                 # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
-                # and another instance with the same identity key already exists as persistent.  convert to an
-                # UPDATE if so.
+                # and another instance with the same identity key already exists as persistent. 
+                # convert to an UPDATE if so.
                 if not has_identity and instance_key in uowtransaction.session.identity_map:
                     instance = uowtransaction.session.identity_map[instance_key]
                     existing = attributes.instance_state(instance)
@@ -1320,28 +1313,42 @@ class Mapper(object):
                             "New instance %s with identity key %s conflicts "
                             "with persistent instance %s" % 
                             (state_str(state), instance_key, state_str(existing)))
-                            
+
                     self._log_debug(
-                        "detected row switch for identity %s.  will update %s, remove %s from "
-                        "transaction", instance_key, state_str(state), state_str(existing))
-                            
+                        "detected row switch for identity %s.  "
+                        "will update %s, remove %s from "
+                        "transaction", instance_key, 
+                        state_str(state), state_str(existing))
+
                     # remove the "delete" flag from the existing element
                     uowtransaction.set_row_switch(existing)
-                    row_switches[state] = existing
-        
-        table_to_mapper = self._sorted_tables
+                    row_switch = existing
+
+            tups.append(
+                (state,
+                    mapper,
+                    conn,
+                    has_identity,
+                    instance_key,
+                    row_switch)
+            )
 
-        for table in table_to_mapper.iterkeys():
+        table_to_mapper = self._sorted_tables
+        
+        for table in table_to_mapper:
             insert = []
             update = []
 
-            for state, mapper, connection, has_identity, instance_key in tups:
+            for state, mapper, connection, has_identity, \
+                            instance_key, row_switch in tups:
                 if table not in mapper._pks_by_table:
                     continue
                     
                 pks = mapper._pks_by_table[table]
                 
-                isinsert = not has_identity and not postupdate and state not in row_switches
+                isinsert = not has_identity and \
+                                not postupdate and \
+                                not row_switch
                 
                 params = {}
                 value_params = {}
@@ -1371,23 +1378,36 @@ class Mapper(object):
                                     value_params[col] = value
                                 else:
                                     params[col.key] = value
-                    insert.append((state, params, mapper, connection, value_params))
+                    insert.append((state, params, mapper, 
+                                    connection, value_params))
                 else:
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
-                            params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col)
-                            params[col.key] = mapper.version_id_generator(params[col._label])
+                            params[col._label] = \
+                                        mapper._get_state_attr_by_column(
+                                                    row_switch or state, 
+                                                    col)
+                            params[col.key] = \
+                                    mapper.version_id_generator(params[col._label])
+
+                            # HACK: check for history, in case the history is only
+                            # in a different table than the one where the version_id_col
+                            # is.
                             for prop in mapper._columntoproperty.itervalues():
-                                history = attributes.get_state_history(state, prop.key, passive=True)
+                                history = attributes.get_state_history(
+                                                state, prop.key, passive=True)
                                 if history.added:
                                     hasdata = True
                         elif mapper.polymorphic_on is not None and \
-                                mapper.polymorphic_on.shares_lineage(col) and col not in pks:
+                                mapper.polymorphic_on.shares_lineage(col) and \
+                                col not in pks:
                             pass
                         else:
-                            if post_update_cols is not None and col not in post_update_cols:
+                            if post_update_cols is not None and \
+                                col not in post_update_cols:
                                 if col in pks:
-                                    params[col._label] = mapper._get_state_attr_by_column(state, col)
+                                    params[col._label] = \
+                                                mapper._get_state_attr_by_column(state, col)
                                 continue
 
                             prop = mapper._columntoproperty[col]
@@ -1424,27 +1444,32 @@ class Mapper(object):
                             elif col in pks:
                                 params[col._label] = mapper._get_state_attr_by_column(state, col)
                     if hasdata:
-                        update.append((state, params, mapper, connection, value_params))
+                        update.append((state, params, mapper, 
+                                        connection, value_params))
 
             if update:
                 mapper = table_to_mapper[table]
                 clause = sql.and_()
                 
                 for col in mapper._pks_by_table[table]:
-                    clause.clauses.append(col == sql.bindparam(col._label, type_=col.type))
+                    clause.clauses.append(
+                                        col == 
+                                        sql.bindparam(col._label, type_=col.type)
+                                    )
 
-                if mapper.version_id_col is not None and \
-                            table.c.contains_column(mapper.version_id_col):
-                            
+                needs_version_id = mapper.version_id_col is not None and \
+                            table.c.contains_column(mapper.version_id_col)
+
+                if needs_version_id:
                     clause.clauses.append(mapper.version_id_col ==\
                             sql.bindparam(mapper.version_id_col._label, type_=col.type))
 
                 statement = table.update(clause)
-                
+
                 rows = 0
                 for state, params, mapper, connection, value_params in update:
                     c = connection.execute(statement.values(value_params), params)
-                    mapper._postfetch(uowtransaction, connection, table, 
+                    mapper._postfetch(uowtransaction, table, 
                                         state, c, c.last_updated_params(), value_params)
 
                     rows += c.rowcount
@@ -1452,13 +1477,15 @@ class Mapper(object):
                 if connection.dialect.supports_sane_rowcount:
                     if rows != len(update):
                         raise orm_exc.ConcurrentModificationError(
-                                "Updated rowcount %d does not match number of objects updated %d" %
+                                "Updated rowcount %d does not match number "
+                                "of objects updated %d" %
                                 (rows, len(update)))
-                        
-                elif mapper.version_id_col is not None:
+
+                elif needs_version_id:
                     util.warn("Dialect %s does not support updated rowcount "
-                            "- versioning cannot be verified." % c.dialect.dialect_description,
-                              stacklevel=12)
+                            "- versioning cannot be verified." % 
+                            c.dialect.dialect_description,
+                            stacklevel=12)
                     
             if insert:
                 statement = table.insert()
@@ -1473,12 +1500,12 @@ class Mapper(object):
                                                                 len(primary_key) > i:
                                 mapper._set_state_attr_by_column(state, col, primary_key[i])
                                 
-                    mapper._postfetch(uowtransaction, connection, table, 
+                    mapper._postfetch(uowtransaction, table, 
                                         state, c, c.last_inserted_params(), value_params)
 
-                        
         if not postupdate:
-            for state, mapper, connection, has_identity, instance_key in tups:
+            for state, mapper, connection, has_identity, \
+                            instance_key, row_switch in tups:
 
                 # expire readonly attributes
                 readonly = state.unmodified.intersection(
@@ -1488,8 +1515,8 @@ class Mapper(object):
                 if readonly:
                     _expire_state(state, state.dict, readonly)
 
-                # if specified, eagerly refresh whatever has
-                # been expired.
+                # if eager_defaults option is enabled,
+                # refresh whatever has been expired.
                 if self.eager_defaults and state.unloaded:
                     state.key = self._identity_key_from_state(state)
                     uowtransaction.session.query(self)._get(
@@ -1504,7 +1531,7 @@ class Mapper(object):
                     if 'after_update' in mapper.extension:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
-    def _postfetch(self, uowtransaction, connection, table, 
+    def _postfetch(self, uowtransaction, table, 
                                 state, resultproxy, params, value_params):
         """Expire attributes in need of newly persisted database state."""
 
@@ -1523,23 +1550,37 @@ class Mapper(object):
             if c.key in params and c in self._columntoproperty:
                 self._set_state_attr_by_column(state, c, params[c.key])
 
-        deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
-
-        if deferred_props:
-            _expire_state(state, state.dict, deferred_props)
+        if postfetch_cols:
+            _expire_state(state, state.dict, 
+                                [self._columntoproperty[c].key 
+                                for c in postfetch_cols]
+                            )
 
         # synchronize newly inserted ids from one table to the next
         # TODO: this still goes a little too often.  would be nice to
         # have definitive list of "columns that changed" here
-        cols = set(table.c)
-        for m in self.iterate_to_root():
-            if m._inherits_equated_pairs and \
-                        cols.intersection([l for l, r in m._inherits_equated_pairs]):
-                sync.populate(state, m, state, m, 
-                                                m._inherits_equated_pairs, 
-                                                uowtransaction,
-                                                self.passive_updates)
-
+        for m, equated_pairs in self._table_to_equated[table]:
+            sync.populate(state, m, state, m, 
+                                            equated_pairs, 
+                                            uowtransaction,
+                                            self.passive_updates)
+            
+    @util.memoized_property
+    def _table_to_equated(self):
+        """memoized map of tables to collections of columns to be 
+        synchronized upwards to the base mapper."""
+        
+        result = util.defaultdict(list)
+        
+        for table in self._sorted_tables:
+            cols = set(table.c)
+            for m in self.iterate_to_root():
+                if m._inherits_equated_pairs and \
+                            cols.intersection([l for l, r in m._inherits_equated_pairs]):
+                    result[table].append((m, m._inherits_equated_pairs))
+        
+        return result
+        
     def _delete_obj(self, states, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
 
@@ -1548,50 +1589,95 @@ class Mapper(object):
 
         """
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
-            connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)]
+            connection_callable = \
+                    uowtransaction.mapper_flush_opts['connection_callable']
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
-
-        for state, mapper, connection in tups:
+            connection_callable = None
+        
+        tups = []
+        for state in _sort_states(states):
+            mapper = _state_mapper(state)
+        
+            conn = connection_callable and \
+                connection_callable(self, state.obj()) or \
+                connection
+            
             if 'before_delete' in mapper.extension:
-                mapper.extension.before_delete(mapper, connection, state.obj())
+                mapper.extension.before_delete(mapper, conn, state.obj())
+            
+            tups.append((state, 
+                    _state_mapper(state), 
+                    _state_has_identity(state),
+                    conn))
 
         table_to_mapper = self._sorted_tables
 
         for table in reversed(table_to_mapper.keys()):
-            delete = {}
-            for state, mapper, connection in tups:
-                if table not in mapper._pks_by_table:
+            delete = util.defaultdict(list)
+            for state, mapper, has_identity, connection in tups:
+                if not has_identity or table not in mapper._pks_by_table:
                     continue
 
                 params = {}
-                if not _state_has_identity(state):
-                    continue
-                else:
-                    delete.setdefault(connection, []).append(params)
+                delete[connection].append(params)
                 for col in mapper._pks_by_table[table]:
                     params[col.key] = mapper._get_state_attr_by_column(state, col)
-                if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
-                    params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col)
+                if mapper.version_id_col is not None and \
+                            table.c.contains_column(mapper.version_id_col):
+                    params[mapper.version_id_col.key] = \
+                                mapper._get_state_attr_by_column(state, mapper.version_id_col)
 
             for connection, del_objects in delete.iteritems():
                 mapper = table_to_mapper[table]
                 clause = sql.and_()
                 for col in mapper._pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col.key, type_=col.type))
-                if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
+
+                need_version_id = mapper.version_id_col is not None and \
+                    table.c.contains_column(mapper.version_id_col)
+
+                if need_version_id:
                     clause.clauses.append(
                         mapper.version_id_col == 
-                        sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type))
+                        sql.bindparam(
+                                mapper.version_id_col.key, 
+                                type_=mapper.version_id_col.type
+                        )
+                    )
+
                 statement = table.delete(clause)
-                c = connection.execute(statement, del_objects)
-                if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
-                    raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match "
-                            "number of objects deleted %d" % (c.rowcount, len(del_objects)))
+                rows = -1
+
+                if need_version_id and \
+                        not connection.dialect.supports_sane_multi_rowcount:
+                    # TODO: need test coverage for this [ticket:1761]
+                    if connection.dialect.supports_sane_rowcount:
+                        rows = 0
+                        # execute deletes individually so that versioned
+                        # rows can be verified
+                        for params in del_objects:
+                            c = connection.execute(statement, params)
+                            rows += c.rowcount
+                    else:
+                        util.warn("Dialect %s does not support deleted rowcount "
+                                "- versioning cannot be verified." % 
+                                c.dialect.dialect_description,
+                                stacklevel=12)
+                        connection.execute(statement, del_objects)
+                else:
+                    c = connection.execute(statement, del_objects)
+                    if connection.dialect.supports_sane_multi_rowcount:
+                        rows = c.rowcount
+
+                if rows != -1 and rows != len(del_objects):
+                    raise orm_exc.ConcurrentModificationError(
+                        "Deleted rowcount %d does not match "
+                        "number of objects deleted %d" % 
+                        (c.rowcount, len(del_objects))
+                    )
 
-        for state, mapper, connection in tups:
+        for state, mapper, has_identity, connection in tups:
             if 'after_delete' in mapper.extension:
                 mapper.extension.after_delete(mapper, connection, state.obj())