From: Mike Bayer Date: Fri, 15 Aug 2014 22:22:08 +0000 (-0400) Subject: - refine this enough so that _collect_insert_commands() seems X-Git-Tag: rel_1_0_0b1~217 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8773307257550e86801217f2b77d47047718807a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - refine this enough so that _collect_insert_commands() seems to be more than twice as fast now (.039 vs. .091); bulk_insert() and bulk_update() do their own collection but now both call into _emit_insert_statements() / _emit_update_statements(); the approach seems to have no impact on insert speed, still .85 for the insert test --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 06ec2bf144..fc15769cd4 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1892,6 +1892,41 @@ class Mapper(InspectionAttr): """ + @_memoized_configured_property + def _col_to_propkey(self): + return dict( + ( + table, + [ + (col, self._columntoproperty[col].key) + for col in columns + ] + ) + for table, columns in self._cols_by_table.items() + ) + + @_memoized_configured_property + def _pk_keys_by_table(self): + return dict( + ( + table, + frozenset([col.key for col in pks]) + ) + for table, pks in self._pks_by_table.items() + ) + + @_memoized_configured_property + def _server_default_cols(self): + return dict( + ( + table, + frozenset([ + col for col in columns + if col.server_default is not None]) + ) + for table, columns in self._cols_by_table.items() + ) + @property def selectable(self): """The :func:`.select` construct this :class:`.Mapper` selects from diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index a8d4bd695d..782d94dc8a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -34,29 +34,35 @@ def bulk_insert(mapper, mappings, uowtransaction): "not supported in bulk_insert()") connection = uowtransaction.transaction.connection(base_mapper) - + value_params = {} for table, sub_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(sub_mapper): continue - to_translate = dict( - (mapper._columntoproperty[col].key, col.key) - for col in mapper._cols_by_table[table] - ) has_version_generator = mapper.version_id_generator is not False and \ mapper.version_id_col is not None - multiparams = [] + + records = [] for mapping in mappings: params = dict( - (k, mapping.get(v)) for k, v in to_translate.items() + (col.key, mapping[propkey]) + for col, propkey in mapper._col_to_propkey[table] + if propkey in mapping ) + if has_version_generator: params[mapper.version_id_col.key] = \ mapper.version_id_generator(None) - multiparams.append(params) - statement = base_mapper._memo(('insert', table), table.insert) - cached_connections[connection].execute(statement, multiparams) + records.append( + (None, None, params, sub_mapper, + connection, value_params, True, True) + ) + + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, records, + bookkeeping=False) def bulk_update(mapper, mappings, uowtransaction): @@ -71,52 +77,41 @@ def bulk_update(mapper, mappings, uowtransaction): connection = uowtransaction.transaction.connection(base_mapper) + value_params = {} for table, sub_mapper in base_mapper._sorted_tables.items(): if not mapper.isa(sub_mapper): continue - needs_version_id = sub_mapper.version_id_col is not None and \ - table.c.contains_column(sub_mapper.version_id_col) - - def update_stmt(): - return _update_stmt_for_mapper(sub_mapper, table, needs_version_id) - - statement = base_mapper._memo(('update', table), update_stmt) + label_pks = mapper._pks_by_table[table] + if mapper.version_id_col is not None: + label_pks = label_pks.union([mapper.version_id_col]) - pks = mapper._pks_by_table[table] to_translate = dict( - (mapper._columntoproperty[col].key, col._label - if col in pks else col.key) - for col in mapper._cols_by_table[table] + (propkey, col._label if col in label_pks else col.key) + for col, propkey in mapper._col_to_propkey[table] ) - for colnames, sub_mappings in groupby( - mappings, - lambda mapping: sorted(tuple(mapping.keys()))): - - multiparams = [] - for mapping in sub_mappings: - params = dict( - (to_translate[k], v) for k, v in mapping.items() - ) - multiparams.append(params) - - c = cached_connections[connection].execute(statement, multiparams) + records = [] + for mapping in mappings: + params = dict( + (to_translate[k], v) for k, v in mapping.items() + ) - rows = c.rowcount + if mapper.version_id_generator is not False and \ + mapper.version_id_col is not None and \ + mapper.version_id_col.key not in params: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator( + params[mapper.version_id_col._label]) - if connection.dialect.supports_sane_rowcount: - if rows != len(multiparams): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." % - (table.description, len(multiparams), rows)) + records.append( + (None, None, params, sub_mapper, connection, value_params) + ) - elif needs_version_id: - util.warn("Dialect %s does not support updated rowcount " - "- versioning cannot be verified." % - c.dialect.dialect_description, - stacklevel=12) + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, records, + bookkeeping=False) def save_obj( @@ -342,39 +337,36 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, if table not in mapper._pks_by_table: continue - pks = mapper._pks_by_table[table] - params = {} value_params = {} - - has_all_pks = True - has_all_defaults = True - has_version_id_generator = mapper.version_id_generator is not False \ - and mapper.version_id_col is not None - for col in mapper._cols_by_table[table]: - if has_version_id_generator and col is mapper.version_id_col: - val = mapper.version_id_generator(None) - params[col.key] = val + for col, propkey in mapper._col_to_propkey[table]: + if propkey in state_dict: + value = state_dict[propkey] + if isinstance(value, sql.ClauseElement): + value_params[col.key] = value + elif value is not None or ( + not col.primary_key and + not col.server_default and + not col.default): + params[col.key] = value else: - # pull straight from the dict for - # pending objects - prop = mapper._columntoproperty[col] - value = state_dict.get(prop.key, None) + if not col.server_default \ + and not col.default and not col.primary_key: + params[col.key] = None - if value is None: - if col in pks: - has_all_pks = False - elif col.default is None and \ - col.server_default is None: - params[col.key] = value - elif col.server_default is not None and \ - mapper.base_mapper.eager_defaults: - has_all_defaults = False + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - elif isinstance(value, sql.ClauseElement): - value_params[col] = value - else: - params[col.key] = value + if base_mapper.eager_defaults: + has_all_defaults = mapper._server_default_cols[table].\ + issubset(params) + else: + has_all_defaults = True + + if mapper.version_id_generator is not False \ + and mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + params[mapper.version_id_col.key] = \ + mapper.version_id_generator(None) insert.append((state, state_dict, params, mapper, connection, value_params, has_all_pks, @@ -572,30 +564,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, return delete -def _update_stmt_for_mapper(mapper, table, needs_version_id): - clause = sql.and_() - - for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, - type_=col.type)) - - if needs_version_id: - clause.clauses.append( - mapper.version_id_col == sql.bindparam( - mapper.version_id_col._label, - type_=mapper.version_id_col.type)) - - stmt = table.update(clause) - if mapper.base_mapper.eager_defaults: - stmt = stmt.return_defaults() - elif mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - - return stmt - - def _emit_update_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, update): + cached_connections, mapper, table, update, + bookkeeping=True): """Emit UPDATE statements corresponding to value lists collected by _collect_update_commands().""" @@ -603,7 +574,25 @@ def _emit_update_statements(base_mapper, uowtransaction, table.c.contains_column(mapper.version_id_col) def update_stmt(): - return _update_stmt_for_mapper(mapper, table, needs_version_id) + clause = sql.and_() + + for col in mapper._pks_by_table[table]: + clause.clauses.append(col == sql.bindparam(col._label, + type_=col.type)) + + if needs_version_id: + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + if mapper.base_mapper.eager_defaults: + stmt = stmt.return_defaults() + elif mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt statement = base_mapper._memo(('update', table), update_stmt) @@ -624,15 +613,16 @@ def _emit_update_statements(base_mapper, uowtransaction, c = connection.execute( statement.values(value_params), params) - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) rows += c.rowcount else: multiparams = [rec[2] for rec in records] @@ -640,17 +630,18 @@ def _emit_update_statements(base_mapper, uowtransaction, execute(statement, multiparams) rows += c.rowcount - for state, state_dict, params, mapper, \ - connection, value_params in records: - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params) + if bookkeeping: + for state, state_dict, params, mapper, \ + connection, value_params in records: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params) if connection.dialect.supports_sane_rowcount: if rows != len(records): @@ -667,7 +658,8 @@ def _emit_update_statements(base_mapper, uowtransaction, def _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, mapper, table, insert): + cached_connections, mapper, table, insert, + bookkeeping=True): """Emit INSERT statements corresponding to value lists collected by _collect_insert_commands().""" @@ -676,11 +668,11 @@ def _emit_insert_statements(base_mapper, uowtransaction, for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \ records in groupby(insert, lambda rec: (rec[4], - list(rec[2].keys()), + tuple(sorted(rec[2].keys())), bool(rec[5]), rec[6], rec[7]) ): - if \ + if not bookkeeping or \ ( has_all_defaults or not base_mapper.eager_defaults @@ -693,19 +685,20 @@ def _emit_insert_statements(base_mapper, uowtransaction, c = cached_connections[connection].\ execute(statement, multiparams) - for (state, state_dict, params, mapper_rec, - conn, value_params, has_all_pks, has_all_defaults), \ - last_inserted_params in \ - zip(records, c.context.compiled_parameters): - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - last_inserted_params, - value_params) + if bookkeeping: + for (state, state_dict, params, mapper_rec, + conn, value_params, has_all_pks, has_all_defaults), \ + last_inserted_params in \ + zip(records, c.context.compiled_parameters): + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + last_inserted_params, + value_params) else: if not has_all_defaults and base_mapper.eager_defaults: