]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- initial stab at using executemany() for inserts in the ORM when possible
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Dec 2010 02:38:46 +0000 (21:38 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Dec 2010 02:38:46 +0000 (21:38 -0500)
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/mapper.py
test/sql/test_query.py

index dad075b3454d8dae3c7b2337a8789778782e0a21..4a00ebda2167566f7344434017d0e81ccae86a64 100644 (file)
@@ -629,24 +629,6 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
-    def last_inserted_params(self):
-        """Return a dictionary of the full parameter dictionary for the last
-        compiled INSERT statement.
-
-        Includes any ColumnDefaults or Sequences that were pre-executed.
-        """
-
-        raise NotImplementedError()
-
-    def last_updated_params(self):
-        """Return a dictionary of the full parameter dictionary for the last
-        compiled UPDATE statement.
-
-        Includes any ColumnDefaults that were pre-executed.
-        """
-
-        raise NotImplementedError()
-
     def lastrow_has_defaults(self):
         """Return True if the last INSERT or UPDATE row contained
         inlined or database-side defaults.
@@ -2466,13 +2448,6 @@ class ResultProxy(object):
         did not explicitly specify returning().
 
         """
-        if not self.context.isinsert:
-            raise exc.InvalidRequestError(
-                        "Statement is not an insert() expression construct.")
-        elif self.context._is_explicit_returning:
-            raise exc.InvalidRequestError(
-                        "Can't call inserted_primary_key when returning() "
-                        "is used.")
             
         return self.context._inserted_primary_key
 
@@ -2481,15 +2456,15 @@ class ResultProxy(object):
         """Return the primary key for the row just inserted."""
         
         return self.inserted_primary_key
-        
+    
     def last_updated_params(self):
         """Return ``last_updated_params()`` from the underlying
         ExecutionContext.
 
         See ExecutionContext for details.
         """
-
-        return self.context.last_updated_params()
+        
+        return self.context.last_updated_params
 
     def last_inserted_params(self):
         """Return ``last_inserted_params()`` from the underlying
@@ -2498,7 +2473,7 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
 
-        return self.context.last_inserted_params()
+        return self.context.last_inserted_params
 
     def lastrow_has_defaults(self):
         """Return ``lastrow_has_defaults()`` from the underlying
index 0717a8fef3e541192bc9ecf8c1524a8c38fd6215..63b9e44b37c1c84318823f7338919a95ff8e0d94 100644 (file)
@@ -635,13 +635,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                 ipk.append(row[c])
         
         self._inserted_primary_key = ipk
-
-    def last_inserted_params(self):
-        return self._last_inserted_params
-
-    def last_updated_params(self):
-        return self._last_updated_params
-
+    
     def lastrow_has_defaults(self):
         return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
 
@@ -714,7 +708,32 @@ class DefaultExecutionContext(base.ExecutionContext):
             return None
         else:
             return self._exec_default(column.onupdate)
-
+    
+    @util.memoized_property
+    def _inserted_primary_key(self):
+
+        if not self.isinsert:
+            raise exc.InvalidRequestError(
+                        "Statement is not an insert() expression construct.")
+        elif self._is_explicit_returning:
+            raise exc.InvalidRequestError(
+                        "Can't call inserted_primary_key when returning() "
+                        "is used.")
+        
+        
+        # lazyily evaluate inserted_primary_key for executemany.
+        # for execute(), its already in __dict__.
+        if self.executemany:
+            return [
+                    [compiled_parameters.get(c.key, None) 
+                            for c in self.compiled.\
+                                                statement.table.primary_key
+                    ] for compiled_parameters in self.compiled_parameters
+            ]
+        else:
+            # _inserted_primary_key should be calced here
+            assert False
+    
     def __process_defaults(self):
         """Generate default values for compiled insert/update statements,
         and generate inserted_primary_key collection.
@@ -746,6 +765,11 @@ class DefaultExecutionContext(base.ExecutionContext):
                             param[c.key] = val
                 del self.current_parameters
 
+            if self.isinsert:
+                self.last_inserted_params = self.compiled_parameters
+            else:
+                self.last_updated_params = self.compiled_parameters
+
         else:
             self.current_parameters = compiled_parameters = \
                                         self.compiled_parameters[0]
@@ -759,18 +783,20 @@ class DefaultExecutionContext(base.ExecutionContext):
                 if val is not None:
                     compiled_parameters[c.key] = val
             del self.current_parameters
-
-            if self.isinsert:
+            
+            if self.isinsert and not self._is_explicit_returning:
                 self._inserted_primary_key = [
-                                    compiled_parameters.get(c.key, None) 
-                                    for c in self.compiled.\
+                                self.compiled_parameters[0].get(c.key, None) 
+                                        for c in self.compiled.\
                                                 statement.table.primary_key
-                                        ]
-                self._last_inserted_params = compiled_parameters
+                                ]
+
+            if self.isinsert:
+                self.last_inserted_params = compiled_parameters
             else:
-                self._last_updated_params = compiled_parameters
+                self.last_updated_params = compiled_parameters
 
-            self.postfetch_cols = self.compiled.postfetch
-            self.prefetch_cols = self.compiled.prefetch
+        self.postfetch_cols = self.compiled.postfetch
+        self.prefetch_cols = self.compiled.prefetch
             
 DefaultDialect.execution_ctx_cls = DefaultExecutionContext
index a4662770eb6a3a4bd307c6bb92c208e814ecb38c..f6a5516d99e7426f6638ed79291836599bdca512 100644 (file)
@@ -1653,7 +1653,8 @@ class Mapper(object):
                 params = {}
                 value_params = {}
                 hasdata = False
-
+                has_all_pks = True
+                
                 if isinsert:
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
@@ -1673,13 +1674,15 @@ class Mapper(object):
                                      col not in pks:
                                      
                                      params[col.key] = value
+                                elif col in pks:
+                                    has_all_pks = False
                             elif isinstance(value, sql.ClauseElement):
                                 value_params[col] = value
                             else:
                                 params[col.key] = value
 
                     insert.append((state, state_dict, params, mapper, 
-                                    connection, value_params))
+                                    connection, value_params, has_all_pks))
                 else:
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
@@ -1793,6 +1796,7 @@ class Mapper(object):
                 statement = self._memo(('update', table), update_stmt)
                 
                 rows = 0
+                postfetch = []
                 for state, state_dict, params, mapper, \
                             connection, value_params in update:
                     
@@ -1803,13 +1807,18 @@ class Mapper(object):
                     else:
                         c = cached_connections[connection].\
                                             execute(statement, params)
-                        
-                    mapper._postfetch(uowtransaction, table, 
-                                        state, state_dict, c, 
-                                        c.last_updated_params(), value_params)
-
+                    
+                    postfetch.append((mapper, state, state_dict,
+                            c.prefetch_cols(), c.postfetch_cols(),
+                            c.last_updated_params(), value_params))
                     rows += c.rowcount
 
+                for mapper, pf in groupby(
+                    postfetch, lambda rec: rec[0]
+                ):
+                    mapper._postfetch(uowtransaction, table, pf)
+
+
                 if connection.dialect.supports_sane_rowcount:
                     if rows != len(update):
                         raise orm_exc.StaleDataError(
@@ -1825,38 +1834,61 @@ class Mapper(object):
                     
             if insert:
                 statement = self._memo(('insert', table), table.insert)
+                postfetch = []
 
-                for state, state_dict, params, mapper, \
-                            connection, value_params in insert:
-
-                    if value_params:
-                        c = connection.execute(
-                                            statement.values(value_params),
-                                            params)
-                    else:
+                for (connection, pkeys, hasvalue, has_all_pks), records in groupby(
+                    insert, lambda rec: (rec[4], rec[2].keys(), bool(rec[5]), rec[6])
+                ):
+                    if has_all_pks and not hasvalue:
+                        records = list(records)
+                        multiparams = [params for state, state_dict, 
+                                            params, mapper, conn, value_params, 
+                                            has_all_pks in records]
                         c = cached_connections[connection].\
-                                            execute(statement, params)
+                                            execute(statement, multiparams)
+                        
+                        for (state, state_dict, params, mapper, conn, value_params, has_all_pks), \
+                                last_inserted_params in zip(records, c.context.compiled_parameters):
+                            postfetch.append((mapper, state, state_dict,
+                                    c.prefetch_cols(), c.postfetch_cols(),
+                                    last_inserted_params, {}))
+                            
+                    else:
+                        for state, state_dict, params, mapper, \
+                                    connection, value_params, has_all_pks in records:
+
+                            if value_params:
+                                c = connection.execute(
+                                                    statement.values(value_params),
+                                                    params)
+                            else:
+                                c = cached_connections[connection].\
+                                                    execute(statement, params)
                     
-                    primary_key = c.inserted_primary_key
-
-                    if primary_key is not None:
-                        # set primary key attributes
-                        for pk, col in zip(primary_key, 
-                                        mapper._pks_by_table[table]):
-                            # TODO: make sure this inlined code is OK
-                            # with composites
-                            prop = mapper._columntoproperty[col]
-                            if state_dict.get(prop.key) is None:
-                                # TODO: would rather say:
-                                #state_dict[prop.key] = pk
-                                mapper._set_state_attr_by_column(state, 
-                                                                state_dict, 
-                                                                col, pk)
-                                
-                    mapper._postfetch(uowtransaction, table, 
-                                        state, state_dict, c,
-                                        c.last_inserted_params(),
-                                        value_params)
+                            primary_key = c.inserted_primary_key
+
+                            if primary_key is not None:
+                                # set primary key attributes
+                                for pk, col in zip(primary_key, 
+                                                mapper._pks_by_table[table]):
+                                    # TODO: make sure this inlined code is OK
+                                    # with composites
+                                    prop = mapper._columntoproperty[col]
+                                    if state_dict.get(prop.key) is None:
+                                        # TODO: would rather say:
+                                        #state_dict[prop.key] = pk
+                                        mapper._set_state_attr_by_column(state, 
+                                                                        state_dict, 
+                                                                        col, pk)
+
+                            postfetch.append((mapper, state, state_dict,
+                                    c.prefetch_cols(), c.postfetch_cols(),
+                                    c.last_inserted_params(), value_params))
+
+                for mapper, pf in groupby(
+                    postfetch, lambda rec: rec[0]
+                ):
+                    mapper._postfetch(uowtransaction, table, pf)
 
         for state, state_dict, mapper, connection, has_identity, \
                         instance_key, row_switch in tups:
@@ -1883,35 +1915,36 @@ class Mapper(object):
                 mapper.dispatch.on_after_update(mapper, connection, state)
 
     def _postfetch(self, uowtransaction, table, 
-                        state, dict_, resultproxy, 
-                        params, value_params):
+                        recs):
         """During a flush, expire attributes in need of newly 
         persisted database state."""
 
-        postfetch_cols = resultproxy.postfetch_cols()
-        generated_cols = list(resultproxy.prefetch_cols())
-
-        if self.version_id_col is not None:
-            generated_cols.append(self.version_id_col)
-
-        for c in generated_cols:
-            if c.key in params and c in self._columntoproperty:
-                self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
-        if postfetch_cols:
-            sessionlib._expire_state(state, state.dict, 
-                                [self._columntoproperty[c].key 
-                                for c in postfetch_cols]
-                            )
-
-        # synchronize newly inserted ids from one table to the next
-        # TODO: this still goes a little too often.  would be nice to
-        # have definitive list of "columns that changed" here
-        for m, equated_pairs in self._table_to_equated[table]:
-            sync.populate(state, m, state, m, 
-                                            equated_pairs, 
-                                            uowtransaction,
-                                            self.passive_updates)
+        for m, state, dict_, prefetch_cols, postfetch_cols, \
+                params, value_params in recs:
+            postfetch_cols = postfetch_cols
+            generated_cols = list(prefetch_cols)
+
+            if self.version_id_col is not None:
+                generated_cols.append(self.version_id_col)
+
+            for c in generated_cols:
+                if c.key in params and c in self._columntoproperty:
+                    self._set_state_attr_by_column(state, dict_, c, params[c.key])
+
+            if postfetch_cols:
+                sessionlib._expire_state(state, state.dict, 
+                                    [self._columntoproperty[c].key 
+                                    for c in postfetch_cols]
+                                )
+
+            # synchronize newly inserted ids from one table to the next
+            # TODO: this still goes a little too often.  would be nice to
+            # have definitive list of "columns that changed" here
+            for m, equated_pairs in self._table_to_equated[table]:
+                sync.populate(state, m, state, m, 
+                                                equated_pairs, 
+                                                uowtransaction,
+                                                self.passive_updates)
             
     @util.memoized_property
     def _table_to_equated(self):
index f59b340766d05f8c79b0cced77d1a0178f1cd2ed..e14f5301e5f78a89430fec5e0f775d8494546eb1 100644 (file)
@@ -654,7 +654,24 @@ class QueryTest(TestBase):
                 getattr(result, meth),
             )
             trans.rollback()
-            
+    
+    def test_no_inserted_pk_on_non_insert(self):
+        result = testing.db.execute("select * from query_users")
+        assert_raises_message(
+            exc.InvalidRequestError,
+            r"Statement is not an insert\(\) expression construct.",
+            getattr, result, 'inserted_primary_key'
+        )
+    
+    @testing.requires.returning
+    def test_no_inserted_pk_on_returning(self):
+        result = testing.db.execute(users.insert().returning(users.c.user_id, users.c.user_name))
+        assert_raises_message(
+            exc.InvalidRequestError,
+            r"Can't call inserted_primary_key when returning\(\) is used.",
+            getattr, result, 'inserted_primary_key'
+        )
+        
     def test_fetchone_til_end(self):
         result = testing.db.execute("select * from query_users")
         eq_(result.fetchone(), None)