]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- change to be represented as two very fast bulk_insert() and bulk_update() methods
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Aug 2014 23:47:23 +0000 (19:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 19:53:38 +0000 (15:53 -0400)
doc/build/faq.rst
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/unitofwork.py

index b777f908fadf8d32469013b86f4753e19bb4669b..487f5b953a5695c655ba8fd0543c49a352285534 100644 (file)
@@ -907,12 +907,11 @@ methods of inserting rows, going from the most automated to the least.
 With cPython 2.7, runtimes observed::
 
     classics-MacBook-Pro:sqlalchemy classic$ python test.py
-    SQLAlchemy ORM: Total time for 100000 records 12.4703581333 secs
-    SQLAlchemy ORM pk given: Total time for 100000 records 7.32723999023 secs
-    SQLAlchemy ORM bulk_save_objects(): Total time for 100000 records 3.43464708328 secs
-    SQLAlchemy ORM bulk_save_mappings(): Total time for 100000 records 2.37040805817 secs
-    SQLAlchemy Core: Total time for 100000 records 0.495043992996 secs
-    sqlite3: Total time for 100000 records 0.508063077927 sec
+    SQLAlchemy ORM: Total time for 100000 records 12.0471920967 secs
+    SQLAlchemy ORM pk given: Total time for 100000 records 7.06283402443 secs
+    SQLAlchemy ORM bulk_save_objects(): Total time for 100000 records 0.856323003769 secs
+    SQLAlchemy Core: Total time for 100000 records 0.485800027847 secs
+    sqlite3: Total time for 100000 records 0.487842082977 sec
 
 We can reduce the time by a factor of three using recent versions of `Pypy <http://pypy.org/>`_::
 
@@ -980,15 +979,16 @@ Script::
             " records " + str(time.time() - t0) + " secs")
 
 
-    def test_sqlalchemy_orm_bulk_save(n=100000):
+    def test_sqlalchemy_orm_bulk_insert(n=100000):
         init_sqlalchemy()
         t0 = time.time()
         n1 = n
         while n1 > 0:
             n1 = n1 - 10000
-            DBSession.bulk_save_objects(
+            DBSession.bulk_insert_mappings(
+                Customer,
                 [
-                    Customer(name="NAME " + str(i))
+                    dict(name="NAME " + str(i))
                     for i in xrange(min(10000, n1))
                 ]
             )
@@ -998,22 +998,6 @@ Script::
             " records " + str(time.time() - t0) + " secs")
 
 
-    def test_sqlalchemy_orm_bulk_save_mappings(n=100000):
-        init_sqlalchemy()
-        t0 = time.time()
-        DBSession.bulk_save_mappings(
-            Customer,
-            [
-                dict(name="NAME " + str(i))
-                for i in xrange(n)
-            ]
-        )
-        DBSession.commit()
-        print(
-            "SQLAlchemy ORM bulk_save_mappings(): Total time for " + str(n) +
-            " records " + str(time.time() - t0) + " secs")
-
-
     def test_sqlalchemy_core(n=100000):
         init_sqlalchemy()
         t0 = time.time()
@@ -1052,8 +1036,7 @@ Script::
     if __name__ == '__main__':
         test_sqlalchemy_orm(100000)
         test_sqlalchemy_orm_pk_given(100000)
-        test_sqlalchemy_orm_bulk_save(100000)
-        test_sqlalchemy_orm_bulk_save_mappings(100000)
+        test_sqlalchemy_orm_bulk_insert(100000)
         test_sqlalchemy_core(100000)
         test_sqlite3(100000)
 
index 097726c625e4caeefd5fc928cfb3287fa54a7475..37ea3071bc3596ac28281f1c94f1bc72e0ed69e7 100644 (file)
@@ -1453,13 +1453,16 @@ class SessionEvents(event.Events):
 
         """
 
-    def before_bulk_save(self, session, flush_context, objects):
+    def before_bulk_insert(self, session, flush_context, mapper, mappings):
         """"""
 
-    def after_bulk_save(self, session, flush_context, objects):
+    def after_bulk_insert(self, session, flush_context, mapper, mappings):
         """"""
 
-    def after_bulk_save_postexec(self, session, flush_context, objects):
+    def before_bulk_update(self, session, flush_context, mapper, mappings):
+        """"""
+
+    def after_bulk_update(self, session, flush_context, mapper, mappings):
         """"""
 
     def after_begin(self, session, transaction, connection):
index 64c8440c4d6b60207e968efa997f22572fa47eef..a8d4bd695db240994043caa52d2fdf00f09b3c64 100644 (file)
@@ -23,9 +23,104 @@ from ..sql import expression
 from . import loading
 
 
+def bulk_insert(mapper, mappings, uowtransaction):
+    base_mapper = mapper.base_mapper
+
+    cached_connections = _cached_connection_dict(base_mapper)
+
+    if uowtransaction.session.connection_callable:
+        raise NotImplementedError(
+            "connection_callable / per-instance sharding "
+            "not supported in bulk_insert()")
+
+    connection = uowtransaction.transaction.connection(base_mapper)
+
+    for table, sub_mapper in base_mapper._sorted_tables.items():
+        if not mapper.isa(sub_mapper):
+            continue
+
+        to_translate = dict(
+            (mapper._columntoproperty[col].key, col.key)
+            for col in mapper._cols_by_table[table]
+        )
+        has_version_generator = mapper.version_id_generator is not False and \
+            mapper.version_id_col is not None
+        multiparams = []
+        for mapping in mappings:
+            params = dict(
+                (k, mapping.get(v)) for k, v in to_translate.items()
+            )
+            if has_version_generator:
+                params[mapper.version_id_col.key] = \
+                    mapper.version_id_generator(None)
+            multiparams.append(params)
+
+        statement = base_mapper._memo(('insert', table), table.insert)
+        cached_connections[connection].execute(statement, multiparams)
+
+
+def bulk_update(mapper, mappings, uowtransaction):
+    base_mapper = mapper.base_mapper
+
+    cached_connections = _cached_connection_dict(base_mapper)
+
+    if uowtransaction.session.connection_callable:
+        raise NotImplementedError(
+            "connection_callable / per-instance sharding "
+            "not supported in bulk_update()")
+
+    connection = uowtransaction.transaction.connection(base_mapper)
+
+    for table, sub_mapper in base_mapper._sorted_tables.items():
+        if not mapper.isa(sub_mapper):
+            continue
+
+        needs_version_id = sub_mapper.version_id_col is not None and \
+            table.c.contains_column(sub_mapper.version_id_col)
+
+        def update_stmt():
+            return _update_stmt_for_mapper(sub_mapper, table, needs_version_id)
+
+        statement = base_mapper._memo(('update', table), update_stmt)
+
+        pks = mapper._pks_by_table[table]
+        to_translate = dict(
+            (mapper._columntoproperty[col].key, col._label
+                if col in pks else col.key)
+            for col in mapper._cols_by_table[table]
+        )
+
+        for colnames, sub_mappings in groupby(
+                mappings,
+                lambda mapping: sorted(tuple(mapping.keys()))):
+
+            multiparams = []
+            for mapping in sub_mappings:
+                params = dict(
+                    (to_translate[k], v) for k, v in mapping.items()
+                )
+                multiparams.append(params)
+
+            c = cached_connections[connection].execute(statement, multiparams)
+
+            rows = c.rowcount
+
+            if connection.dialect.supports_sane_rowcount:
+                if rows != len(multiparams):
+                    raise orm_exc.StaleDataError(
+                        "UPDATE statement on table '%s' expected to "
+                        "update %d row(s); %d were matched." %
+                        (table.description, len(multiparams), rows))
+
+            elif needs_version_id:
+                util.warn("Dialect %s does not support updated rowcount "
+                          "- versioning cannot be verified." %
+                          c.dialect.dialect_description,
+                          stacklevel=12)
+
+
 def save_obj(
-    base_mapper, states, uowtransaction, single=False,
-        bookkeeping=True):
+        base_mapper, states, uowtransaction, single=False):
     """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
     of objects.
 
@@ -45,14 +140,13 @@ def save_obj(
     states_to_insert, states_to_update = _organize_states_for_save(
         base_mapper,
         states,
-        uowtransaction, bookkeeping)
+        uowtransaction)
 
     cached_connections = _cached_connection_dict(base_mapper)
 
     for table, mapper in base_mapper._sorted_tables.items():
         insert = _collect_insert_commands(base_mapper, uowtransaction,
-                                          table, states_to_insert,
-                                          bookkeeping)
+                                          table, states_to_insert)
 
         update = _collect_update_commands(base_mapper, uowtransaction,
                                           table, states_to_update)
@@ -65,12 +159,11 @@ def save_obj(
         if insert:
             _emit_insert_statements(base_mapper, uowtransaction,
                                     cached_connections,
-                                    mapper, table, insert,
-                                    bookkeeping)
+                                    mapper, table, insert)
 
-    _finalize_insert_update_commands(base_mapper, uowtransaction,
-                                     states_to_insert, states_to_update,
-                                     bookkeeping)
+    _finalize_insert_update_commands(
+        base_mapper, uowtransaction,
+        states_to_insert, states_to_update)
 
 
 def post_update(base_mapper, states, uowtransaction, post_update_cols):
@@ -126,8 +219,7 @@ def delete_obj(base_mapper, states, uowtransaction):
         mapper.dispatch.after_delete(mapper, connection, state)
 
 
-def _organize_states_for_save(
-        base_mapper, states, uowtransaction, bookkeeping):
+def _organize_states_for_save(base_mapper, states, uowtransaction):
     """Make an initial pass across a set of states for INSERT or
     UPDATE.
 
@@ -149,8 +241,7 @@ def _organize_states_for_save(
 
         has_identity = bool(state.key)
 
-        if bookkeeping:
-            instance_key = state.key or mapper._identity_key_from_state(state)
+        instance_key = state.key or mapper._identity_key_from_state(state)
 
         row_switch = None
 
@@ -167,7 +258,7 @@ def _organize_states_for_save(
         # no instance_key attached to it), and another instance
         # with the same identity key already exists as persistent.
         # convert to an UPDATE if so.
-        if bookkeeping and not has_identity and \
+        if not has_identity and \
                 instance_key in uowtransaction.session.identity_map:
             instance = \
                 uowtransaction.session.identity_map[instance_key]
@@ -239,7 +330,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
 
 
 def _collect_insert_commands(base_mapper, uowtransaction, table,
-                             states_to_insert, bookkeeping):
+                             states_to_insert):
     """Identify sets of values to use in INSERT statements for a
     list of states.
 
@@ -270,7 +361,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
                 prop = mapper._columntoproperty[col]
                 value = state_dict.get(prop.key, None)
 
-                if bookkeeping and value is None:
+                if value is None:
                     if col in pks:
                         has_all_pks = False
                     elif col.default is None and \
@@ -481,6 +572,28 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
     return delete
 
 
+def _update_stmt_for_mapper(mapper, table, needs_version_id):
+    clause = sql.and_()
+
+    for col in mapper._pks_by_table[table]:
+        clause.clauses.append(col == sql.bindparam(col._label,
+                                                   type_=col.type))
+
+    if needs_version_id:
+        clause.clauses.append(
+            mapper.version_id_col == sql.bindparam(
+                mapper.version_id_col._label,
+                type_=mapper.version_id_col.type))
+
+    stmt = table.update(clause)
+    if mapper.base_mapper.eager_defaults:
+        stmt = stmt.return_defaults()
+    elif mapper.version_id_col is not None:
+        stmt = stmt.return_defaults(mapper.version_id_col)
+
+    return stmt
+
+
 def _emit_update_statements(base_mapper, uowtransaction,
                             cached_connections, mapper, table, update):
     """Emit UPDATE statements corresponding to value lists collected
@@ -490,25 +603,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
         table.c.contains_column(mapper.version_id_col)
 
     def update_stmt():
-        clause = sql.and_()
-
-        for col in mapper._pks_by_table[table]:
-            clause.clauses.append(col == sql.bindparam(col._label,
-                                                       type_=col.type))
-
-        if needs_version_id:
-            clause.clauses.append(
-                mapper.version_id_col == sql.bindparam(
-                    mapper.version_id_col._label,
-                    type_=mapper.version_id_col.type))
-
-        stmt = table.update(clause)
-        if mapper.base_mapper.eager_defaults:
-            stmt = stmt.return_defaults()
-        elif mapper.version_id_col is not None:
-            stmt = stmt.return_defaults(mapper.version_id_col)
-
-        return stmt
+        return _update_stmt_for_mapper(mapper, table, needs_version_id)
 
     statement = base_mapper._memo(('update', table), update_stmt)
 
@@ -572,8 +667,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
 
 
 def _emit_insert_statements(base_mapper, uowtransaction,
-                            cached_connections, mapper, table, insert,
-                            bookkeeping):
+                            cached_connections, mapper, table, insert):
     """Emit INSERT statements corresponding to value lists collected
     by _collect_insert_commands()."""
 
@@ -599,20 +693,19 @@ def _emit_insert_statements(base_mapper, uowtransaction,
             c = cached_connections[connection].\
                 execute(statement, multiparams)
 
-            if bookkeeping:
-                for (state, state_dict, params, mapper_rec,
-                        conn, value_params, has_all_pks, has_all_defaults), \
-                        last_inserted_params in \
-                        zip(records, c.context.compiled_parameters):
-                    _postfetch(
-                        mapper_rec,
-                        uowtransaction,
-                        table,
-                        state,
-                        state_dict,
-                        c,
-                        last_inserted_params,
-                        value_params)
+            for (state, state_dict, params, mapper_rec,
+                    conn, value_params, has_all_pks, has_all_defaults), \
+                    last_inserted_params in \
+                    zip(records, c.context.compiled_parameters):
+                _postfetch(
+                    mapper_rec,
+                    uowtransaction,
+                    table,
+                    state,
+                    state_dict,
+                    c,
+                    last_inserted_params,
+                    value_params)
 
         else:
             if not has_all_defaults and base_mapper.eager_defaults:
@@ -768,8 +861,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
 
 
 def _finalize_insert_update_commands(base_mapper, uowtransaction,
-                                     states_to_insert, states_to_update,
-                                     bookkeeping):
+                                     states_to_insert, states_to_update):
     """finalize state on states that have been inserted or updated,
     including calling after_insert/after_update events.
 
@@ -778,34 +870,33 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
             row_switch in states_to_insert + \
             states_to_update:
 
-        if bookkeeping:
-            if mapper._readonly_props:
-                readonly = state.unmodified_intersection(
-                    [p.key for p in mapper._readonly_props
-                        if p.expire_on_flush or p.key not in state.dict]
-                )
-                if readonly:
-                    state._expire_attributes(state.dict, readonly)
-
-            # if eager_defaults option is enabled, load
-            # all expired cols.  Else if we have a version_id_col, make sure
-            # it isn't expired.
-            toload_now = []
-
-            if base_mapper.eager_defaults:
-                toload_now.extend(state._unloaded_non_object)
-            elif mapper.version_id_col is not None and \
-                    mapper.version_id_generator is False:
-                prop = mapper._columntoproperty[mapper.version_id_col]
-                if prop.key in state.unloaded:
-                    toload_now.extend([prop.key])
-
-            if toload_now:
-                state.key = base_mapper._identity_key_from_state(state)
-                loading.load_on_ident(
-                    uowtransaction.session.query(base_mapper),
-                    state.key, refresh_state=state,
-                    only_load_props=toload_now)
+        if mapper._readonly_props:
+            readonly = state.unmodified_intersection(
+                [p.key for p in mapper._readonly_props
+                    if p.expire_on_flush or p.key not in state.dict]
+            )
+            if readonly:
+                state._expire_attributes(state.dict, readonly)
+
+        # if eager_defaults option is enabled, load
+        # all expired cols.  Else if we have a version_id_col, make sure
+        # it isn't expired.
+        toload_now = []
+
+        if base_mapper.eager_defaults:
+            toload_now.extend(state._unloaded_non_object)
+        elif mapper.version_id_col is not None and \
+                mapper.version_id_generator is False:
+            prop = mapper._columntoproperty[mapper.version_id_col]
+            if prop.key in state.unloaded:
+                toload_now.extend([prop.key])
+
+        if toload_now:
+            state.key = base_mapper._identity_key_from_state(state)
+            loading.load_on_ident(
+                uowtransaction.session.query(base_mapper),
+                state.key, refresh_state=state,
+                only_load_props=toload_now)
 
         # call after_XXX extensions
         if not has_identity:
index 546355611a581bca526ae853f68f5cb9d99e2d78..3199a4332d08740437195b6194d81f54ff73cc9a 100644 (file)
@@ -20,6 +20,7 @@ from .base import (
     _class_to_mapper, _state_mapper, object_state,
     _none_set, state_str, instance_str
 )
+import itertools
 from .unitofwork import UOWTransaction
 from . import state as statelib
 import sys
@@ -482,7 +483,8 @@ class Session(_SessionClassMethods):
         '__contains__', '__iter__', 'add', 'add_all', 'begin', 'begin_nested',
         'close', 'commit', 'connection', 'delete', 'execute', 'expire',
         'expire_all', 'expunge', 'expunge_all', 'flush', 'get_bind',
-        'is_modified', 'bulk_save_objects', 'bulk_save_mappings',
+        'is_modified', 'bulk_save_objects', 'bulk_insert_mappings',
+        'bulk_update_mappings',
         'merge', 'query', 'refresh', 'rollback',
         'scalar')
 
@@ -2034,42 +2036,41 @@ class Session(_SessionClassMethods):
                 transaction.rollback(_capture_exception=True)
 
     def bulk_save_objects(self, objects):
-        self._bulk_save((attributes.instance_state(obj) for obj in objects))
+        for (mapper, isupdate), states in itertools.groupby(
+            (attributes.instance_state(obj) for obj in objects),
+            lambda state: (state.mapper, state.key is not None)
+        ):
+            if isupdate:
+                self.bulk_update_mappings(mapper, (s.dict for s in states))
+            else:
+                self.bulk_insert_mappings(mapper, (s.dict for s in states))
 
-    def bulk_save_mappings(self, mapper, mappings):
-        mapper = class_mapper(mapper)
+    def bulk_insert_mappings(self, mapper, mappings):
+        self._bulk_save_mappings(mapper, mappings, False)
 
-        self._bulk_save((
-            statelib.MappingState(mapper, mapping)
-            for mapping in mappings)
-        )
+    def bulk_update_mappings(self, mapper, mappings):
+        self._bulk_save_mappings(mapper, mappings, True)
 
-    def _bulk_save(self, states):
+    def _bulk_save_mappings(self, mapper, mappings, isupdate):
+        mapper = _class_to_mapper(mapper)
         self._flushing = True
         flush_context = UOWTransaction(self)
 
-        if self.dispatch.before_bulk_save:
-            self.dispatch.before_bulk_save(
-                self, flush_context, states)
-
         flush_context.transaction = transaction = self.begin(
             subtransactions=True)
         try:
-            self._warn_on_events = True
-            try:
-                flush_context.bulk_save(states)
-            finally:
-                self._warn_on_events = False
-
-            self.dispatch.after_bulk_save(
-                self, flush_context, states
-            )
-
-            flush_context.finalize_flush_changes()
-
-            self.dispatch.after_bulk_save_postexec(
-                self, flush_context, states)
-
+            if isupdate:
+                self.dispatch.before_bulk_update(
+                    self, flush_context, mapper, mappings)
+                flush_context.bulk_update(mapper, mappings)
+                self.dispatch.after_bulk_update(
+                    self, flush_context, mapper, mappings)
+            else:
+                self.dispatch.before_bulk_insert(
+                    self, flush_context, mapper, mappings)
+                flush_context.bulk_insert(mapper, mappings)
+                self.dispatch.after_bulk_insert(
+                    self, flush_context, mapper, mappings)
             transaction.commit()
 
         except:
index e941bc1a4713ef074979e37dce43212cdd3be6a5..fe8ccd222b4eec4cc9b5a5d8698ccb8e78e9e4c9 100644 (file)
@@ -580,21 +580,6 @@ class InstanceState(interfaces.InspectionAttr):
             state._strong_obj = None
 
 
-class MappingState(InstanceState):
-    committed_state = {}
-    callables = {}
-
-    def __init__(self, mapper, mapping):
-        self.class_ = mapper.class_
-        self.manager = mapper.class_manager
-        self.modified = True
-        self._dict = mapping
-
-    @property
-    def dict(self):
-        return self._dict
-
-
 class AttributeState(object):
     """Provide an inspection interface corresponding
     to a particular attribute on a particular mapped object.
index bc8a0f5565048bb195bcf601aa10f07d93ca6dce..b3a1519c5299960beaeb63de5c4db3b00abb3026 100644 (file)
@@ -394,23 +394,11 @@ class UOWTransaction(object):
         if other:
             self.session._register_newly_persistent(other)
 
-    def bulk_save(self, states):
-        for (base_mapper, in_session), states_ in itertools.groupby(
-                states,
-                lambda state:
-                (
-                    state.mapper.base_mapper,
-                    state.key is self.session.hash_key
-                )):
-
-            persistence.save_obj(
-                base_mapper, list(states_), self, bookkeeping=in_session)
-
-            if in_session:
-                self.states.update(
-                    (state, (False, False))
-                    for state in states_
-                )
+    def bulk_insert(self, mapper, mappings):
+        persistence.bulk_insert(mapper, mappings, self)
+
+    def bulk_update(self, mapper, mappings):
+        persistence.bulk_update(mapper, mappings, self)
 
 
 class IterateMappersMixin(object):