From: Mike Bayer Date: Sat, 11 Dec 2010 02:38:46 +0000 (-0500) Subject: - initial stab at using executemany() for inserts in the ORM when possible X-Git-Tag: rel_0_7b1~180^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=66e5de30f2e01593182058091075780b41411a78;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - initial stab at using executemany() for inserts in the ORM when possible --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dad075b345..4a00ebda21 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -629,24 +629,6 @@ class ExecutionContext(object): raise NotImplementedError() - def last_inserted_params(self): - """Return a dictionary of the full parameter dictionary for the last - compiled INSERT statement. - - Includes any ColumnDefaults or Sequences that were pre-executed. - """ - - raise NotImplementedError() - - def last_updated_params(self): - """Return a dictionary of the full parameter dictionary for the last - compiled UPDATE statement. - - Includes any ColumnDefaults that were pre-executed. - """ - - raise NotImplementedError() - def lastrow_has_defaults(self): """Return True if the last INSERT or UPDATE row contained inlined or database-side defaults. @@ -2466,13 +2448,6 @@ class ResultProxy(object): did not explicitly specify returning(). """ - if not self.context.isinsert: - raise exc.InvalidRequestError( - "Statement is not an insert() expression construct.") - elif self.context._is_explicit_returning: - raise exc.InvalidRequestError( - "Can't call inserted_primary_key when returning() " - "is used.") return self.context._inserted_primary_key @@ -2481,15 +2456,15 @@ class ResultProxy(object): """Return the primary key for the row just inserted.""" return self.inserted_primary_key - + def last_updated_params(self): """Return ``last_updated_params()`` from the underlying ExecutionContext. See ExecutionContext for details. """ - - return self.context.last_updated_params() + + return self.context.last_updated_params def last_inserted_params(self): """Return ``last_inserted_params()`` from the underlying @@ -2498,7 +2473,7 @@ class ResultProxy(object): See ExecutionContext for details. """ - return self.context.last_inserted_params() + return self.context.last_inserted_params def lastrow_has_defaults(self): """Return ``lastrow_has_defaults()`` from the underlying diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 0717a8fef3..63b9e44b37 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -635,13 +635,7 @@ class DefaultExecutionContext(base.ExecutionContext): ipk.append(row[c]) self._inserted_primary_key = ipk - - def last_inserted_params(self): - return self._last_inserted_params - - def last_updated_params(self): - return self._last_updated_params - + def lastrow_has_defaults(self): return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) @@ -714,7 +708,32 @@ class DefaultExecutionContext(base.ExecutionContext): return None else: return self._exec_default(column.onupdate) - + + @util.memoized_property + def _inserted_primary_key(self): + + if not self.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() expression construct.") + elif self._is_explicit_returning: + raise exc.InvalidRequestError( + "Can't call inserted_primary_key when returning() " + "is used.") + + + # lazyily evaluate inserted_primary_key for executemany. + # for execute(), its already in __dict__. + if self.executemany: + return [ + [compiled_parameters.get(c.key, None) + for c in self.compiled.\ + statement.table.primary_key + ] for compiled_parameters in self.compiled_parameters + ] + else: + # _inserted_primary_key should be calced here + assert False + def __process_defaults(self): """Generate default values for compiled insert/update statements, and generate inserted_primary_key collection. @@ -746,6 +765,11 @@ class DefaultExecutionContext(base.ExecutionContext): param[c.key] = val del self.current_parameters + if self.isinsert: + self.last_inserted_params = self.compiled_parameters + else: + self.last_updated_params = self.compiled_parameters + else: self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] @@ -759,18 +783,20 @@ class DefaultExecutionContext(base.ExecutionContext): if val is not None: compiled_parameters[c.key] = val del self.current_parameters - - if self.isinsert: + + if self.isinsert and not self._is_explicit_returning: self._inserted_primary_key = [ - compiled_parameters.get(c.key, None) - for c in self.compiled.\ + self.compiled_parameters[0].get(c.key, None) + for c in self.compiled.\ statement.table.primary_key - ] - self._last_inserted_params = compiled_parameters + ] + + if self.isinsert: + self.last_inserted_params = compiled_parameters else: - self._last_updated_params = compiled_parameters + self.last_updated_params = compiled_parameters - self.postfetch_cols = self.compiled.postfetch - self.prefetch_cols = self.compiled.prefetch + self.postfetch_cols = self.compiled.postfetch + self.prefetch_cols = self.compiled.prefetch DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a4662770eb..f6a5516d99 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1653,7 +1653,8 @@ class Mapper(object): params = {} value_params = {} hasdata = False - + has_all_pks = True + if isinsert: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: @@ -1673,13 +1674,15 @@ class Mapper(object): col not in pks: params[col.key] = value + elif col in pks: + has_all_pks = False elif isinstance(value, sql.ClauseElement): value_params[col] = value else: params[col.key] = value insert.append((state, state_dict, params, mapper, - connection, value_params)) + connection, value_params, has_all_pks)) else: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: @@ -1793,6 +1796,7 @@ class Mapper(object): statement = self._memo(('update', table), update_stmt) rows = 0 + postfetch = [] for state, state_dict, params, mapper, \ connection, value_params in update: @@ -1803,13 +1807,18 @@ class Mapper(object): else: c = cached_connections[connection].\ execute(statement, params) - - mapper._postfetch(uowtransaction, table, - state, state_dict, c, - c.last_updated_params(), value_params) - + + postfetch.append((mapper, state, state_dict, + c.prefetch_cols(), c.postfetch_cols(), + c.last_updated_params(), value_params)) rows += c.rowcount + for mapper, pf in groupby( + postfetch, lambda rec: rec[0] + ): + mapper._postfetch(uowtransaction, table, pf) + + if connection.dialect.supports_sane_rowcount: if rows != len(update): raise orm_exc.StaleDataError( @@ -1825,38 +1834,61 @@ class Mapper(object): if insert: statement = self._memo(('insert', table), table.insert) + postfetch = [] - for state, state_dict, params, mapper, \ - connection, value_params in insert: - - if value_params: - c = connection.execute( - statement.values(value_params), - params) - else: + for (connection, pkeys, hasvalue, has_all_pks), records in groupby( + insert, lambda rec: (rec[4], rec[2].keys(), bool(rec[5]), rec[6]) + ): + if has_all_pks and not hasvalue: + records = list(records) + multiparams = [params for state, state_dict, + params, mapper, conn, value_params, + has_all_pks in records] c = cached_connections[connection].\ - execute(statement, params) + execute(statement, multiparams) + + for (state, state_dict, params, mapper, conn, value_params, has_all_pks), \ + last_inserted_params in zip(records, c.context.compiled_parameters): + postfetch.append((mapper, state, state_dict, + c.prefetch_cols(), c.postfetch_cols(), + last_inserted_params, {})) + + else: + for state, state_dict, params, mapper, \ + connection, value_params, has_all_pks in records: + + if value_params: + c = connection.execute( + statement.values(value_params), + params) + else: + c = cached_connections[connection].\ + execute(statement, params) - primary_key = c.inserted_primary_key - - if primary_key is not None: - # set primary key attributes - for pk, col in zip(primary_key, - mapper._pks_by_table[table]): - # TODO: make sure this inlined code is OK - # with composites - prop = mapper._columntoproperty[col] - if state_dict.get(prop.key) is None: - # TODO: would rather say: - #state_dict[prop.key] = pk - mapper._set_state_attr_by_column(state, - state_dict, - col, pk) - - mapper._postfetch(uowtransaction, table, - state, state_dict, c, - c.last_inserted_params(), - value_params) + primary_key = c.inserted_primary_key + + if primary_key is not None: + # set primary key attributes + for pk, col in zip(primary_key, + mapper._pks_by_table[table]): + # TODO: make sure this inlined code is OK + # with composites + prop = mapper._columntoproperty[col] + if state_dict.get(prop.key) is None: + # TODO: would rather say: + #state_dict[prop.key] = pk + mapper._set_state_attr_by_column(state, + state_dict, + col, pk) + + postfetch.append((mapper, state, state_dict, + c.prefetch_cols(), c.postfetch_cols(), + c.last_inserted_params(), value_params)) + + for mapper, pf in groupby( + postfetch, lambda rec: rec[0] + ): + mapper._postfetch(uowtransaction, table, pf) for state, state_dict, mapper, connection, has_identity, \ instance_key, row_switch in tups: @@ -1883,35 +1915,36 @@ class Mapper(object): mapper.dispatch.on_after_update(mapper, connection, state) def _postfetch(self, uowtransaction, table, - state, dict_, resultproxy, - params, value_params): + recs): """During a flush, expire attributes in need of newly persisted database state.""" - postfetch_cols = resultproxy.postfetch_cols() - generated_cols = list(resultproxy.prefetch_cols()) - - if self.version_id_col is not None: - generated_cols.append(self.version_id_col) - - for c in generated_cols: - if c.key in params and c in self._columntoproperty: - self._set_state_attr_by_column(state, dict_, c, params[c.key]) - - if postfetch_cols: - sessionlib._expire_state(state, state.dict, - [self._columntoproperty[c].key - for c in postfetch_cols] - ) - - # synchronize newly inserted ids from one table to the next - # TODO: this still goes a little too often. would be nice to - # have definitive list of "columns that changed" here - for m, equated_pairs in self._table_to_equated[table]: - sync.populate(state, m, state, m, - equated_pairs, - uowtransaction, - self.passive_updates) + for m, state, dict_, prefetch_cols, postfetch_cols, \ + params, value_params in recs: + postfetch_cols = postfetch_cols + generated_cols = list(prefetch_cols) + + if self.version_id_col is not None: + generated_cols.append(self.version_id_col) + + for c in generated_cols: + if c.key in params and c in self._columntoproperty: + self._set_state_attr_by_column(state, dict_, c, params[c.key]) + + if postfetch_cols: + sessionlib._expire_state(state, state.dict, + [self._columntoproperty[c].key + for c in postfetch_cols] + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in self._table_to_equated[table]: + sync.populate(state, m, state, m, + equated_pairs, + uowtransaction, + self.passive_updates) @util.memoized_property def _table_to_equated(self): diff --git a/test/sql/test_query.py b/test/sql/test_query.py index f59b340766..e14f5301e5 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -654,7 +654,24 @@ class QueryTest(TestBase): getattr(result, meth), ) trans.rollback() - + + def test_no_inserted_pk_on_non_insert(self): + result = testing.db.execute("select * from query_users") + assert_raises_message( + exc.InvalidRequestError, + r"Statement is not an insert\(\) expression construct.", + getattr, result, 'inserted_primary_key' + ) + + @testing.requires.returning + def test_no_inserted_pk_on_returning(self): + result = testing.db.execute(users.insert().returning(users.c.user_id, users.c.user_name)) + assert_raises_message( + exc.InvalidRequestError, + r"Can't call inserted_primary_key when returning\(\) is used.", + getattr, result, 'inserted_primary_key' + ) + def test_fetchone_til_end(self): result = testing.db.execute("select * from query_users") eq_(result.fetchone(), None)