]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refine this enough so that _collect_insert_commands() seems
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 22:22:08 +0000 (18:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 22:22:08 +0000 (18:22 -0400)
to be more than twice as fast now (.039 vs. .091); bulk_insert()
and bulk_update() do their own collection but now both call into
_emit_insert_statements() / _emit_update_statements(); the approach
seems to have no impact on insert speed, still .85 for the
insert test

lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py

index 06ec2bf144c64ed50ddc4a117d36808824cefa35..fc15769cd48d8c8fa0159ea798bf67f09cb0dd28 100644 (file)
@@ -1892,6 +1892,41 @@ class Mapper(InspectionAttr):
 
     """
 
+    @_memoized_configured_property
+    def _col_to_propkey(self):
+        return dict(
+            (
+                table,
+                [
+                    (col, self._columntoproperty[col].key)
+                    for col in columns
+                ]
+            )
+            for table, columns in self._cols_by_table.items()
+        )
+
+    @_memoized_configured_property
+    def _pk_keys_by_table(self):
+        return dict(
+            (
+                table,
+                frozenset([col.key for col in pks])
+            )
+            for table, pks in self._pks_by_table.items()
+        )
+
+    @_memoized_configured_property
+    def _server_default_cols(self):
+        return dict(
+            (
+                table,
+                frozenset([
+                    col for col in columns
+                    if col.server_default is not None])
+            )
+            for table, columns in self._cols_by_table.items()
+        )
+
     @property
     def selectable(self):
         """The :func:`.select` construct this :class:`.Mapper` selects from
index a8d4bd695db240994043caa52d2fdf00f09b3c64..782d94dc8a07072c620b40225816176d790c1308 100644 (file)
@@ -34,29 +34,35 @@ def bulk_insert(mapper, mappings, uowtransaction):
             "not supported in bulk_insert()")
 
     connection = uowtransaction.transaction.connection(base_mapper)
-
+    value_params = {}
     for table, sub_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(sub_mapper):
             continue
 
-        to_translate = dict(
-            (mapper._columntoproperty[col].key, col.key)
-            for col in mapper._cols_by_table[table]
-        )
         has_version_generator = mapper.version_id_generator is not False and \
             mapper.version_id_col is not None
-        multiparams = []
+
+        records = []
         for mapping in mappings:
             params = dict(
-                (k, mapping.get(v)) for k, v in to_translate.items()
+                (col.key, mapping[propkey])
+                for col, propkey in mapper._col_to_propkey[table]
+                if propkey in mapping
             )
+
             if has_version_generator:
                 params[mapper.version_id_col.key] = \
                     mapper.version_id_generator(None)
-            multiparams.append(params)
 
-        statement = base_mapper._memo(('insert', table), table.insert)
-        cached_connections[connection].execute(statement, multiparams)
+            records.append(
+                (None, None, params, sub_mapper,
+                    connection, value_params, True, True)
+            )
+
+        _emit_insert_statements(base_mapper, uowtransaction,
+                                cached_connections,
+                                mapper, table, records,
+                                bookkeeping=False)
 
 
 def bulk_update(mapper, mappings, uowtransaction):
@@ -71,52 +77,41 @@ def bulk_update(mapper, mappings, uowtransaction):
 
     connection = uowtransaction.transaction.connection(base_mapper)
 
+    value_params = {}
     for table, sub_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(sub_mapper):
             continue
 
-        needs_version_id = sub_mapper.version_id_col is not None and \
-            table.c.contains_column(sub_mapper.version_id_col)
-
-        def update_stmt():
-            return _update_stmt_for_mapper(sub_mapper, table, needs_version_id)
-
-        statement = base_mapper._memo(('update', table), update_stmt)
+        label_pks = mapper._pks_by_table[table]
+        if mapper.version_id_col is not None:
+            label_pks = label_pks.union([mapper.version_id_col])
 
-        pks = mapper._pks_by_table[table]
         to_translate = dict(
-            (mapper._columntoproperty[col].key, col._label
-                if col in pks else col.key)
-            for col in mapper._cols_by_table[table]
+            (propkey, col._label if col in label_pks else col.key)
+            for col, propkey in mapper._col_to_propkey[table]
         )
 
-        for colnames, sub_mappings in groupby(
-                mappings,
-                lambda mapping: sorted(tuple(mapping.keys()))):
-
-            multiparams = []
-            for mapping in sub_mappings:
-                params = dict(
-                    (to_translate[k], v) for k, v in mapping.items()
-                )
-                multiparams.append(params)
-
-            c = cached_connections[connection].execute(statement, multiparams)
+        records = []
+        for mapping in mappings:
+            params = dict(
+                (to_translate[k], v) for k, v in mapping.items()
+            )
 
-            rows = c.rowcount
+            if mapper.version_id_generator is not False and \
+                    mapper.version_id_col is not None and \
+                    mapper.version_id_col.key not in params:
+                params[mapper.version_id_col.key] = \
+                    mapper.version_id_generator(
+                        params[mapper.version_id_col._label])
 
-            if connection.dialect.supports_sane_rowcount:
-                if rows != len(multiparams):
-                    raise orm_exc.StaleDataError(
-                        "UPDATE statement on table '%s' expected to "
-                        "update %d row(s); %d were matched." %
-                        (table.description, len(multiparams), rows))
+            records.append(
+                (None, None, params, sub_mapper, connection, value_params)
+            )
 
-            elif needs_version_id:
-                util.warn("Dialect %s does not support updated rowcount "
-                          "- versioning cannot be verified." %
-                          c.dialect.dialect_description,
-                          stacklevel=12)
+        _emit_update_statements(base_mapper, uowtransaction,
+                                cached_connections,
+                                mapper, table, records,
+                                bookkeeping=False)
 
 
 def save_obj(
@@ -342,39 +337,36 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
         if table not in mapper._pks_by_table:
             continue
 
-        pks = mapper._pks_by_table[table]
-
         params = {}
         value_params = {}
-
-        has_all_pks = True
-        has_all_defaults = True
-        has_version_id_generator = mapper.version_id_generator is not False \
-            and mapper.version_id_col is not None
-        for col in mapper._cols_by_table[table]:
-            if has_version_id_generator and col is mapper.version_id_col:
-                val = mapper.version_id_generator(None)
-                params[col.key] = val
+        for col, propkey in mapper._col_to_propkey[table]:
+            if propkey in state_dict:
+                value = state_dict[propkey]
+                if isinstance(value, sql.ClauseElement):
+                    value_params[col.key] = value
+                elif value is not None or (
+                        not col.primary_key and
+                        not col.server_default and
+                        not col.default):
+                    params[col.key] = value
             else:
-                # pull straight from the dict for
-                # pending objects
-                prop = mapper._columntoproperty[col]
-                value = state_dict.get(prop.key, None)
+                if not col.server_default \
+                        and not col.default and not col.primary_key:
+                    params[col.key] = None
 
-                if value is None:
-                    if col in pks:
-                        has_all_pks = False
-                    elif col.default is None and \
-                            col.server_default is None:
-                        params[col.key] = value
-                    elif col.server_default is not None and \
-                            mapper.base_mapper.eager_defaults:
-                        has_all_defaults = False
+        has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
 
-                elif isinstance(value, sql.ClauseElement):
-                    value_params[col] = value
-                else:
-                    params[col.key] = value
+        if base_mapper.eager_defaults:
+            has_all_defaults = mapper._server_default_cols[table].\
+                issubset(params)
+        else:
+            has_all_defaults = True
+
+        if mapper.version_id_generator is not False \
+                and mapper.version_id_col is not None and \
+                mapper.version_id_col in mapper._cols_by_table[table]:
+            params[mapper.version_id_col.key] = \
+                mapper.version_id_generator(None)
 
         insert.append((state, state_dict, params, mapper,
                        connection, value_params, has_all_pks,
@@ -572,30 +564,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
     return delete
 
 
-def _update_stmt_for_mapper(mapper, table, needs_version_id):
-    clause = sql.and_()
-
-    for col in mapper._pks_by_table[table]:
-        clause.clauses.append(col == sql.bindparam(col._label,
-                                                   type_=col.type))
-
-    if needs_version_id:
-        clause.clauses.append(
-            mapper.version_id_col == sql.bindparam(
-                mapper.version_id_col._label,
-                type_=mapper.version_id_col.type))
-
-    stmt = table.update(clause)
-    if mapper.base_mapper.eager_defaults:
-        stmt = stmt.return_defaults()
-    elif mapper.version_id_col is not None:
-        stmt = stmt.return_defaults(mapper.version_id_col)
-
-    return stmt
-
-
 def _emit_update_statements(base_mapper, uowtransaction,
-                            cached_connections, mapper, table, update):
+                            cached_connections, mapper, table, update,
+                            bookkeeping=True):
     """Emit UPDATE statements corresponding to value lists collected
     by _collect_update_commands()."""
 
@@ -603,7 +574,25 @@ def _emit_update_statements(base_mapper, uowtransaction,
         table.c.contains_column(mapper.version_id_col)
 
     def update_stmt():
-        return _update_stmt_for_mapper(mapper, table, needs_version_id)
+        clause = sql.and_()
+
+        for col in mapper._pks_by_table[table]:
+            clause.clauses.append(col == sql.bindparam(col._label,
+                                                       type_=col.type))
+
+        if needs_version_id:
+            clause.clauses.append(
+                mapper.version_id_col == sql.bindparam(
+                    mapper.version_id_col._label,
+                    type_=mapper.version_id_col.type))
+
+        stmt = table.update(clause)
+        if mapper.base_mapper.eager_defaults:
+            stmt = stmt.return_defaults()
+        elif mapper.version_id_col is not None:
+            stmt = stmt.return_defaults(mapper.version_id_col)
+
+        return stmt
 
     statement = base_mapper._memo(('update', table), update_stmt)
 
@@ -624,15 +613,16 @@ def _emit_update_statements(base_mapper, uowtransaction,
                 c = connection.execute(
                     statement.values(value_params),
                     params)
-                _postfetch(
-                    mapper,
-                    uowtransaction,
-                    table,
-                    state,
-                    state_dict,
-                    c,
-                    c.context.compiled_parameters[0],
-                    value_params)
+                if bookkeeping:
+                    _postfetch(
+                        mapper,
+                        uowtransaction,
+                        table,
+                        state,
+                        state_dict,
+                        c,
+                        c.context.compiled_parameters[0],
+                        value_params)
                 rows += c.rowcount
         else:
             multiparams = [rec[2] for rec in records]
@@ -640,17 +630,18 @@ def _emit_update_statements(base_mapper, uowtransaction,
                 execute(statement, multiparams)
 
             rows += c.rowcount
-            for state, state_dict, params, mapper, \
-                    connection, value_params in records:
-                _postfetch(
-                    mapper,
-                    uowtransaction,
-                    table,
-                    state,
-                    state_dict,
-                    c,
-                    c.context.compiled_parameters[0],
-                    value_params)
+            if bookkeeping:
+                for state, state_dict, params, mapper, \
+                        connection, value_params in records:
+                    _postfetch(
+                        mapper,
+                        uowtransaction,
+                        table,
+                        state,
+                        state_dict,
+                        c,
+                        c.context.compiled_parameters[0],
+                        value_params)
 
         if connection.dialect.supports_sane_rowcount:
             if rows != len(records):
@@ -667,7 +658,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
 
 
 def _emit_insert_statements(base_mapper, uowtransaction,
-                            cached_connections, mapper, table, insert):
+                            cached_connections, mapper, table, insert,
+                            bookkeeping=True):
     """Emit INSERT statements corresponding to value lists collected
     by _collect_insert_commands()."""
 
@@ -676,11 +668,11 @@ def _emit_insert_statements(base_mapper, uowtransaction,
     for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
         records in groupby(insert,
                            lambda rec: (rec[4],
-                                        list(rec[2].keys()),
+                                        tuple(sorted(rec[2].keys())),
                                         bool(rec[5]),
                                         rec[6], rec[7])
                            ):
-        if \
+        if not bookkeeping or \
                 (
                     has_all_defaults
                     or not base_mapper.eager_defaults
@@ -693,19 +685,20 @@ def _emit_insert_statements(base_mapper, uowtransaction,
             c = cached_connections[connection].\
                 execute(statement, multiparams)
 
-            for (state, state_dict, params, mapper_rec,
-                    conn, value_params, has_all_pks, has_all_defaults), \
-                    last_inserted_params in \
-                    zip(records, c.context.compiled_parameters):
-                _postfetch(
-                    mapper_rec,
-                    uowtransaction,
-                    table,
-                    state,
-                    state_dict,
-                    c,
-                    last_inserted_params,
-                    value_params)
+            if bookkeeping:
+                for (state, state_dict, params, mapper_rec,
+                        conn, value_params, has_all_pks, has_all_defaults), \
+                        last_inserted_params in \
+                        zip(records, c.context.compiled_parameters):
+                    _postfetch(
+                        mapper_rec,
+                        uowtransaction,
+                        table,
+                        state,
+                        state_dict,
+                        c,
+                        last_inserted_params,
+                        value_params)
 
         else:
             if not has_all_defaults and base_mapper.eager_defaults: