]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- clean up the batch insert thing
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Dec 2010 08:05:03 +0000 (03:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Dec 2010 08:05:03 +0000 (03:05 -0500)
- add a test for batch inserts
- don't need elaborate _inserted_primary_key thing
- take some cruft out of ExecutionContext, ResultProxy,
EC members can be non-underscored, have mapper just call the
EC members for now.
- simplify "connection_callable", no need for a "flush_opts"
dictionary since this point of expansion is not needed

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_unitofworkv2.py

index 4a00ebda2167566f7344434017d0e81ccae86a64..d58460fb8552978ed4def645f8aba0f4eb995dbc 100644 (file)
@@ -2448,8 +2448,16 @@ class ResultProxy(object):
         did not explicitly specify returning().
 
         """
+
+        if not self.context.isinsert:
+            raise exc.InvalidRequestError(
+                        "Statement is not an insert() expression construct.")
+        elif self.context._is_explicit_returning:
+            raise exc.InvalidRequestError(
+                        "Can't call inserted_primary_key when returning() "
+                        "is used.")
             
-        return self.context._inserted_primary_key
+        return self.context.inserted_primary_key
 
     @util.deprecated("0.6", "Use :attr:`.ResultProxy.inserted_primary_key`")
     def last_inserted_ids(self):
@@ -2458,22 +2466,24 @@ class ResultProxy(object):
         return self.inserted_primary_key
     
     def last_updated_params(self):
-        """Return ``last_updated_params()`` from the underlying
-        ExecutionContext.
-
-        See ExecutionContext for details.
-        """
+        """Return the collection of updated parameters from this
+        execution.
         
-        return self.context.last_updated_params
+        """
+        if self.context.executemany:
+            return self.context.compiled_parameters
+        else:
+            return self.context.compiled_parameters[0]
 
     def last_inserted_params(self):
-        """Return ``last_inserted_params()`` from the underlying
-        ExecutionContext.
-
-        See ExecutionContext for details.
+        """Return the collection of inserted parameters from this
+        execution.
+        
         """
-
-        return self.context.last_inserted_params
+        if self.context.executemany:
+            return self.context.compiled_parameters
+        else:
+            return self.context.compiled_parameters[0]
 
     def lastrow_has_defaults(self):
         """Return ``lastrow_has_defaults()`` from the underlying
index 63b9e44b37c1c84318823f7338919a95ff8e0d94..21603b2586a3f49e46b6a1abc990cd626466b000 100644 (file)
@@ -400,7 +400,9 @@ class DefaultExecutionContext(base.ExecutionContext):
         self.cursor = self.create_cursor()
         if self.isinsert or self.isupdate:
             self.__process_defaults()
-            
+            self.postfetch_cols = self.compiled.postfetch
+            self.prefetch_cols = self.compiled.prefetch
+                    
         processors = dict(
             (key, value) for key, value in
             ( (compiled.bind_names[bindparam],
@@ -541,7 +543,8 @@ class DefaultExecutionContext(base.ExecutionContext):
         """
 
         conn = self._connection
-        if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
+        if isinstance(stmt, unicode) and \
+            not self.dialect.supports_unicode_statements:
             stmt = stmt.encode(self.dialect.encoding)
 
         if self.dialect.positional:
@@ -614,13 +617,14 @@ class DefaultExecutionContext(base.ExecutionContext):
     
     def post_insert(self):
         if self.dialect.postfetch_lastrowid and \
-            (not len(self._inserted_primary_key) or \
-                        None in self._inserted_primary_key):
+            (not len(self.inserted_primary_key) or \
+                        None in self.inserted_primary_key):
             
             table = self.compiled.statement.table
             lastrowid = self.get_lastrowid()
-            self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
-                for c, v in zip(table.primary_key, self._inserted_primary_key)
+            self.inserted_primary_key = [
+                c is table._autoincrement_column and lastrowid or v
+                for c, v in zip(table.primary_key, self.inserted_primary_key)
             ]
             
     def _fetch_implicit_returning(self, resultproxy):
@@ -628,16 +632,17 @@ class DefaultExecutionContext(base.ExecutionContext):
         row = resultproxy.fetchone()
 
         ipk = []
-        for c, v in zip(table.primary_key, self._inserted_primary_key):
+        for c, v in zip(table.primary_key, self.inserted_primary_key):
             if v is not None:
                 ipk.append(v)
             else:
                 ipk.append(row[c])
         
-        self._inserted_primary_key = ipk
+        self.inserted_primary_key = ipk
     
     def lastrow_has_defaults(self):
-        return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
+        return (self.isinsert or self.isupdate) and \
+            bool(self.postfetch_cols)
 
     def set_input_sizes(self, translate=None, exclude_types=None):
         """Given a cursor and ClauseParameters, call the appropriate
@@ -709,31 +714,6 @@ class DefaultExecutionContext(base.ExecutionContext):
         else:
             return self._exec_default(column.onupdate)
     
-    @util.memoized_property
-    def _inserted_primary_key(self):
-
-        if not self.isinsert:
-            raise exc.InvalidRequestError(
-                        "Statement is not an insert() expression construct.")
-        elif self._is_explicit_returning:
-            raise exc.InvalidRequestError(
-                        "Can't call inserted_primary_key when returning() "
-                        "is used.")
-        
-        
-        # lazyily evaluate inserted_primary_key for executemany.
-        # for execute(), its already in __dict__.
-        if self.executemany:
-            return [
-                    [compiled_parameters.get(c.key, None) 
-                            for c in self.compiled.\
-                                                statement.table.primary_key
-                    ] for compiled_parameters in self.compiled_parameters
-            ]
-        else:
-            # _inserted_primary_key should be calced here
-            assert False
-    
     def __process_defaults(self):
         """Generate default values for compiled insert/update statements,
         and generate inserted_primary_key collection.
@@ -764,12 +744,6 @@ class DefaultExecutionContext(base.ExecutionContext):
                         if val is not None:
                             param[c.key] = val
                 del self.current_parameters
-
-            if self.isinsert:
-                self.last_inserted_params = self.compiled_parameters
-            else:
-                self.last_updated_params = self.compiled_parameters
-
         else:
             self.current_parameters = compiled_parameters = \
                                         self.compiled_parameters[0]
@@ -784,19 +758,12 @@ class DefaultExecutionContext(base.ExecutionContext):
                     compiled_parameters[c.key] = val
             del self.current_parameters
             
-            if self.isinsert and not self._is_explicit_returning:
-                self._inserted_primary_key = [
+            if self.isinsert:
+                self.inserted_primary_key = [
                                 self.compiled_parameters[0].get(c.key, None) 
                                         for c in self.compiled.\
                                                 statement.table.primary_key
                                 ]
 
-            if self.isinsert:
-                self.last_inserted_params = compiled_parameters
-            else:
-                self.last_updated_params = compiled_parameters
-
-        self.postfetch_cols = self.compiled.postfetch
-        self.prefetch_cols = self.compiled.prefetch
             
 DefaultDialect.execution_ctx_cls = DefaultExecutionContext
index 78e3f59530373734e858a2588ec9b2fc263f4ae8..e48cb9fcbbddfbcd7f8d5624ecc8b1d63ef7e918 100644 (file)
@@ -50,12 +50,12 @@ class ShardedSession(Session):
         self.id_chooser = id_chooser
         self.query_chooser = query_chooser
         self.__binds = {}
-        self._mapper_flush_opts = {'connection_callable':self.connection}
+        self.connection_callable = self.connection
         self._query_cls = ShardedQuery
         if shards is not None:
             for k in shards:
                 self.bind_shard(k, shards[k])
-        
+    
     def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
         if shard_id is None:
             shard_id = self.shard_chooser(mapper, instance)
index f6a5516d99e7426f6638ed79291836599bdca512..20242c97c0ab1a62329448b1392b36e25a35e8f5 100644 (file)
@@ -1468,9 +1468,9 @@ class Mapper(object):
         # if session has a connection callable,
         # organize individual states with the connection 
         # to use for update
-        if 'connection_callable' in uowtransaction.mapper_flush_opts:
+        if uowtransaction.session.connection_callable:
             connection_callable = \
-                    uowtransaction.mapper_flush_opts['connection_callable']
+                    uowtransaction.session.connection_callable
         else:
             connection = uowtransaction.transaction.connection(self)
             connection_callable = None
@@ -1550,15 +1550,10 @@ class Mapper(object):
         of objects.
 
         This is called within the context of a UOWTransaction during a
-        flush operation.
+        flush operation, given a list of states to be flushed.  The
+        base mapper in an inheritance hierarchy handles the inserts/
+        updates for all descendant mappers.
 
-        `_save_obj` issues SQL statements not just for instances mapped
-        directly by this mapper, but for instances mapped by all
-        inheriting mappers as well.  This is to maintain proper insert
-        ordering among a polymorphic chain of instances. Therefore
-        _save_obj is typically called only on a *base mapper*, or a
-        mapper which does not inherit from any other mapper.
-        
         """
             
         # if batch=false, call _save_obj separately for each object
@@ -1572,9 +1567,9 @@ class Mapper(object):
         # if session has a connection callable,
         # organize individual states with the connection 
         # to use for insert/update
-        if 'connection_callable' in uowtransaction.mapper_flush_opts:
+        if uowtransaction.session.connection_callable:
             connection_callable = \
-                    uowtransaction.mapper_flush_opts['connection_callable']
+                    uowtransaction.session.connection_callable
         else:
             connection = uowtransaction.transaction.connection(self)
             connection_callable = None
@@ -1592,6 +1587,7 @@ class Mapper(object):
             instance_key = state.key or mapper._identity_key_from_state(state)
 
             row_switch = None
+            
             # call before_XXX extensions
             if not has_identity:
                 mapper.dispatch.on_before_insert(mapper, conn, state)
@@ -1652,10 +1648,9 @@ class Mapper(object):
                 
                 params = {}
                 value_params = {}
-                hasdata = False
-                has_all_pks = True
                 
                 if isinsert:
+                    has_all_pks = True
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
                             params[col.key] = \
@@ -1669,13 +1664,12 @@ class Mapper(object):
                                 value = prop.get_col_value(col, value)
 
                             if value is None:
-                                if col.default is None and \
-                                     col.server_default is None and \
-                                     col not in pks:
-                                     
-                                     params[col.key] = value
-                                elif col in pks:
+                                if col in pks:
                                     has_all_pks = False
+                                elif col.default is None and \
+                                     col.server_default is None:
+                                     params[col.key] = value
+
                             elif isinstance(value, sql.ClauseElement):
                                 value_params[col] = value
                             else:
@@ -1684,6 +1678,7 @@ class Mapper(object):
                     insert.append((state, state_dict, params, mapper, 
                                     connection, value_params, has_all_pks))
                 else:
+                    hasdata = False
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
                             params[col._label] = \
@@ -1765,7 +1760,8 @@ class Mapper(object):
                                 else:
                                     hasdata = True
                             elif col in pks:
-                                value = state.manager[prop.key].impl.get(state, state_dict)
+                                value = state.manager[prop.key].\
+                                            impl.get(state, state_dict)
                                 if prop.get_col_value:
                                     value = prop.get_col_value(col, value)
                                 params[col._label] = value
@@ -1796,7 +1792,6 @@ class Mapper(object):
                 statement = self._memo(('update', table), update_stmt)
                 
                 rows = 0
-                postfetch = []
                 for state, state_dict, params, mapper, \
                             connection, value_params in update:
                     
@@ -1808,17 +1803,17 @@ class Mapper(object):
                         c = cached_connections[connection].\
                                             execute(statement, params)
                     
-                    postfetch.append((mapper, state, state_dict,
-                            c.prefetch_cols(), c.postfetch_cols(),
-                            c.last_updated_params(), value_params))
+                    mapper._postfetch(
+                            uowtransaction, 
+                            table, 
+                            state, 
+                            state_dict, 
+                            c.context.prefetch_cols, 
+                            c.context.postfetch_cols,
+                            c.context.compiled_parameters[0], 
+                            value_params)
                     rows += c.rowcount
 
-                for mapper, pf in groupby(
-                    postfetch, lambda rec: rec[0]
-                ):
-                    mapper._postfetch(uowtransaction, table, pf)
-
-
                 if connection.dialect.supports_sane_rowcount:
                     if rows != len(update):
                         raise orm_exc.StaleDataError(
@@ -1834,61 +1829,72 @@ class Mapper(object):
                     
             if insert:
                 statement = self._memo(('insert', table), table.insert)
-                postfetch = []
 
-                for (connection, pkeys, hasvalue, has_all_pks), records in groupby(
-                    insert, lambda rec: (rec[4], rec[2].keys(), bool(rec[5]), rec[6])
+                for (connection, pkeys, hasvalue, has_all_pks), \
+                    records in groupby(insert, 
+                                        lambda rec: (rec[4], 
+                                                rec[2].keys(), 
+                                                bool(rec[5]), 
+                                                rec[6])
                 ):
                     if has_all_pks and not hasvalue:
                         records = list(records)
-                        multiparams = [params for state, state_dict, 
-                                            params, mapper, conn, value_params, 
-                                            has_all_pks in records]
+                        multiparams = [rec[2] for rec in records]
                         c = cached_connections[connection].\
                                             execute(statement, multiparams)
                         
-                        for (state, state_dict, params, mapper, conn, value_params, has_all_pks), \
-                                last_inserted_params in zip(records, c.context.compiled_parameters):
-                            postfetch.append((mapper, state, state_dict,
-                                    c.prefetch_cols(), c.postfetch_cols(),
-                                    last_inserted_params, {}))
+                        for (state, state_dict, params, mapper, 
+                                conn, value_params, has_all_pks), \
+                                last_inserted_params in \
+                                zip(records, c.context.compiled_parameters):
+                            mapper._postfetch(
+                                    uowtransaction, 
+                                    table,
+                                    state, 
+                                    state_dict,
+                                    c.context.prefetch_cols,
+                                    c.context.postfetch_cols,
+                                    last_inserted_params, 
+                                    value_params)
                             
                     else:
                         for state, state_dict, params, mapper, \
-                                    connection, value_params, has_all_pks in records:
+                                    connection, value_params, \
+                                    has_all_pks in records:
 
                             if value_params:
-                                c = connection.execute(
-                                                    statement.values(value_params),
-                                                    params)
+                                result = connection.execute(
+                                            statement.values(value_params),
+                                            params)
                             else:
-                                c = cached_connections[connection].\
+                                result = cached_connections[connection].\
                                                     execute(statement, params)
                     
-                            primary_key = c.inserted_primary_key
+                            primary_key = result.context.inserted_primary_key
 
                             if primary_key is not None:
                                 # set primary key attributes
                                 for pk, col in zip(primary_key, 
                                                 mapper._pks_by_table[table]):
-                                    # TODO: make sure this inlined code is OK
-                                    # with composites
                                     prop = mapper._columntoproperty[col]
                                     if state_dict.get(prop.key) is None:
                                         # TODO: would rather say:
                                         #state_dict[prop.key] = pk
-                                        mapper._set_state_attr_by_column(state, 
-                                                                        state_dict, 
-                                                                        col, pk)
+                                        mapper._set_state_attr_by_column(
+                                                    state, 
+                                                    state_dict, 
+                                                    col, pk)
+
+                            mapper._postfetch(
+                                    uowtransaction, 
+                                    table, 
+                                    state, 
+                                    state_dict,
+                                    result.context.prefetch_cols, 
+                                    result.context.postfetch_cols,
+                                    result.context.compiled_parameters[0], 
+                                    value_params)
 
-                            postfetch.append((mapper, state, state_dict,
-                                    c.prefetch_cols(), c.postfetch_cols(),
-                                    c.last_inserted_params(), value_params))
-
-                for mapper, pf in groupby(
-                    postfetch, lambda rec: rec[0]
-                ):
-                    mapper._postfetch(uowtransaction, table, pf)
 
         for state, state_dict, mapper, connection, has_identity, \
                         instance_key, row_switch in tups:
@@ -1915,36 +1921,32 @@ class Mapper(object):
                 mapper.dispatch.on_after_update(mapper, connection, state)
 
     def _postfetch(self, uowtransaction, table, 
-                        recs):
+                    state, dict_, prefetch_cols, postfetch_cols,
+                                params, value_params):
         """During a flush, expire attributes in need of newly 
         persisted database state."""
 
-        for m, state, dict_, prefetch_cols, postfetch_cols, \
-                params, value_params in recs:
-            postfetch_cols = postfetch_cols
-            generated_cols = list(prefetch_cols)
-
-            if self.version_id_col is not None:
-                generated_cols.append(self.version_id_col)
-
-            for c in generated_cols:
-                if c.key in params and c in self._columntoproperty:
-                    self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
-            if postfetch_cols:
-                sessionlib._expire_state(state, state.dict, 
-                                    [self._columntoproperty[c].key 
-                                    for c in postfetch_cols]
-                                )
-
-            # synchronize newly inserted ids from one table to the next
-            # TODO: this still goes a little too often.  would be nice to
-            # have definitive list of "columns that changed" here
-            for m, equated_pairs in self._table_to_equated[table]:
-                sync.populate(state, m, state, m, 
-                                                equated_pairs, 
-                                                uowtransaction,
-                                                self.passive_updates)
+        if self.version_id_col is not None:
+            prefetch_cols = list(prefetch_cols) + [self.version_id_col]
+
+        for c in prefetch_cols:
+            if c.key in params and c in self._columntoproperty:
+                self._set_state_attr_by_column(state, dict_, c, params[c.key])
+
+        if postfetch_cols:
+            sessionlib._expire_state(state, state.dict, 
+                                [self._columntoproperty[c].key 
+                                for c in postfetch_cols]
+                            )
+
+        # synchronize newly inserted ids from one table to the next
+        # TODO: this still goes a little too often.  would be nice to
+        # have definitive list of "columns that changed" here
+        for m, equated_pairs in self._table_to_equated[table]:
+            sync.populate(state, m, state, m, 
+                                            equated_pairs, 
+                                            uowtransaction,
+                                            self.passive_updates)
             
     @util.memoized_property
     def _table_to_equated(self):
@@ -1970,9 +1972,9 @@ class Mapper(object):
         flush operation.
 
         """
-        if 'connection_callable' in uowtransaction.mapper_flush_opts:
+        if uowtransaction.session.connection_callable:
             connection_callable = \
-                    uowtransaction.mapper_flush_opts['connection_callable']
+                    uowtransaction.session.connection_callable
         else:
             connection = uowtransaction.transaction.connection(self)
             connection_callable = None
index 3517eab2b3c0f2760c14e3dc678d630236d08e60..30a84bf1ab49bfbb08cb276556b39369939a6bdf 100644 (file)
@@ -511,7 +511,6 @@ class Session(object):
         self._enable_transaction_accounting = _enable_transaction_accounting
         self.twophase = twophase
         self._query_cls = query_cls
-        self._mapper_flush_opts = {}
         
         if extension:
             for ext in util.to_list(extension):
@@ -530,6 +529,8 @@ class Session(object):
 
     dispatch = event.dispatcher(SessionEvents)
 
+    connection_callable = None
+    
     def begin(self, subtransactions=False, nested=False):
         """Begin a transaction on this Session.
 
index 875ce634ba855397eb7005023af67a82abc9b0eb..d9d64fe391f7ce77fee71d80e1769d56b8e63169 100644 (file)
@@ -76,7 +76,6 @@ class UOWEventHandler(interfaces.AttributeExtension):
 class UOWTransaction(object):
     def __init__(self, session):
         self.session = session
-        self.mapper_flush_opts = session._mapper_flush_opts
 
         # dictionary used by external actors to 
         # store arbitrary state information.
@@ -316,7 +315,7 @@ class UOWTransaction(object):
                                     postsort_actions):
                 rec.execute(self)
             
-
+            
     def finalize_flush_changes(self):
         """mark processed objects as clean / deleted after a successful flush().
 
index 73a884e0c1db7630c00158019c259d11b6faa7b3..766addc05b958f30c4413d6ecc684e3c1140be50 100644 (file)
@@ -4,7 +4,8 @@ from test.lib.schema import Table, Column
 from sqlalchemy import Integer, String, ForeignKey, func
 from test.orm import _fixtures, _base
 from sqlalchemy.orm import mapper, relationship, backref, \
-                            create_session, unitofwork, attributes
+                            create_session, unitofwork, attributes,\
+                            Session
 from test.lib.assertsql import AllOf, CompiledSQL
 
 from test.orm._fixtures import keywords, addresses, Base, Keyword,  \
@@ -776,5 +777,72 @@ class RowswitchAccountingTest(_base.MappedTest):
 
         sess.flush()
 
+class BatchInsertsTest(_base.MappedTest, testing.AssertsExecutionResults):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('t', metadata,
+            Column('id', Integer, primary_key=True, 
+                        test_needs_autoincrement=True),
+            Column('data', String(50)),
+            Column('def_', String(50), server_default='def1')
+        )
 
+    @testing.resolve_artifact_names
+    def test_batch_interaction(self):
+        """test batching groups same-structured, primary 
+        key present statements together.
+        
+        """
+        class T(Base):
+            pass
+        mapper(T, t)
+        sess = Session()
+        sess.add_all([
+            T(data='t1'),
+            T(data='t2'),
+            T(id=3, data='t3'),
+            T(id=4, data='t4'),
+            T(id=5, data='t5'),
+            T(id=6, data=func.lower('t6')),
+            T(id=7, data='t7'),
+            T(id=8, data='t8'),
+            T(id=9, data='t9', def_='def2'),
+            T(id=10, data='t10', def_='def3'),
+            T(id=11, data='t11'),
+        ])
+        self.assert_sql_execution(
+            testing.db,
+            sess.flush,
+            CompiledSQL(
+                "INSERT INTO t (data) VALUES (:data)",
+                {'data': 't1'}
+            ),
+            CompiledSQL(
+                "INSERT INTO t (data) VALUES (:data)",
+                {'data': 't2'}
+            ),
+            CompiledSQL(
+                "INSERT INTO t (id, data) VALUES (:id, :data)",
+                [{'data': 't3', 'id': 3}, 
+                    {'data': 't4', 'id': 4}, 
+                    {'data': 't5', 'id': 5}]
+            ),
+            CompiledSQL(
+                "INSERT INTO t (id, data) VALUES (:id, lower(:lower_1))",
+                {'lower_1': 't6', 'id': 6}
+            ),
+            CompiledSQL(
+                "INSERT INTO t (id, data) VALUES (:id, :data)",
+                [{'data': 't7', 'id': 7}, {'data': 't8', 'id': 8}]
+            ),
+            CompiledSQL(
+                "INSERT INTO t (id, data, def_) VALUES (:id, :data, :def_)",
+                [{'data': 't9', 'id': 9, 'def_':'def2'}, 
+                {'data': 't10', 'id': 10, 'def_':'def3'}]
+            ),
+            CompiledSQL(
+                "INSERT INTO t (id, data) VALUES (:id, :data)",
+                {'data': 't11', 'id': 11}
+            ),
+        )