]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- proof of concept
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Aug 2014 19:38:30 +0000 (15:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 19:52:35 +0000 (15:52 -0400)
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py

index aa99673badaa7ef8d0427d069b4c846943a9acfe..097726c625e4caeefd5fc928cfb3287fa54a7475 100644 (file)
@@ -1453,6 +1453,15 @@ class SessionEvents(event.Events):
 
         """
 
+    def before_bulk_save(self, session, flush_context, objects):
+        """"""
+
+    def after_bulk_save(self, session, flush_context, objects):
+        """"""
+
+    def after_bulk_save_postexec(self, session, flush_context, objects):
+        """"""
+
     def after_begin(self, session, transaction, connection):
         """Execute after a transaction is begun on a connection
 
index 9d39c39b07fba2baea0029d2180f57d6f24d56e4..511a324bee2e7540e25725877ba5ebbd4f12cc61 100644 (file)
@@ -23,7 +23,9 @@ from ..sql import expression
 from . import loading
 
 
-def save_obj(base_mapper, states, uowtransaction, single=False):
+def save_obj(
+    base_mapper, states, uowtransaction, single=False,
+        bookkeeping=True):
     """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
     of objects.
 
@@ -43,13 +45,14 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
     states_to_insert, states_to_update = _organize_states_for_save(
         base_mapper,
         states,
-        uowtransaction)
+        uowtransaction, bookkeeping)
 
     cached_connections = _cached_connection_dict(base_mapper)
 
     for table, mapper in base_mapper._sorted_tables.items():
         insert = _collect_insert_commands(base_mapper, uowtransaction,
-                                          table, states_to_insert)
+                                          table, states_to_insert,
+                                          bookkeeping)
 
         update = _collect_update_commands(base_mapper, uowtransaction,
                                           table, states_to_update)
@@ -65,7 +68,8 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
                                     mapper, table, insert)
 
     _finalize_insert_update_commands(base_mapper, uowtransaction,
-                                     states_to_insert, states_to_update)
+                                     states_to_insert, states_to_update,
+                                     bookkeeping)
 
 
 def post_update(base_mapper, states, uowtransaction, post_update_cols):
@@ -121,7 +125,8 @@ def delete_obj(base_mapper, states, uowtransaction):
         mapper.dispatch.after_delete(mapper, connection, state)
 
 
-def _organize_states_for_save(base_mapper, states, uowtransaction):
+def _organize_states_for_save(
+        base_mapper, states, uowtransaction, bookkeeping):
     """Make an initial pass across a set of states for INSERT or
     UPDATE.
 
@@ -158,7 +163,7 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
         # no instance_key attached to it), and another instance
         # with the same identity key already exists as persistent.
         # convert to an UPDATE if so.
-        if not has_identity and \
+        if bookkeeping and not has_identity and \
                 instance_key in uowtransaction.session.identity_map:
             instance = \
                 uowtransaction.session.identity_map[instance_key]
@@ -230,7 +235,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
 
 
 def _collect_insert_commands(base_mapper, uowtransaction, table,
-                             states_to_insert):
+                             states_to_insert, bookkeeping):
     """Identify sets of values to use in INSERT statements for a
     list of states.
 
@@ -261,12 +266,12 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
                 value = state_dict.get(prop.key, None)
 
                 if value is None:
-                    if col in pks:
+                    if bookkeeping and col in pks:
                         has_all_pks = False
                     elif col.default is None and \
                             col.server_default is None:
                         params[col.key] = value
-                    elif col.server_default is not None and \
+                    elif bookkeeping and col.server_default is not None and \
                             mapper.base_mapper.eager_defaults:
                         has_all_defaults = False
 
@@ -756,7 +761,8 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
 
 
 def _finalize_insert_update_commands(base_mapper, uowtransaction,
-                                     states_to_insert, states_to_update):
+                                     states_to_insert, states_to_update,
+                                     bookkeeping):
     """finalize state on states that have been inserted or updated,
     including calling after_insert/after_update events.
 
@@ -765,33 +771,34 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
             instance_key, row_switch in states_to_insert + \
             states_to_update:
 
-        if mapper._readonly_props:
-            readonly = state.unmodified_intersection(
-                [p.key for p in mapper._readonly_props
-                    if p.expire_on_flush or p.key not in state.dict]
-            )
-            if readonly:
-                state._expire_attributes(state.dict, readonly)
-
-        # if eager_defaults option is enabled, load
-        # all expired cols.  Else if we have a version_id_col, make sure
-        # it isn't expired.
-        toload_now = []
-
-        if base_mapper.eager_defaults:
-            toload_now.extend(state._unloaded_non_object)
-        elif mapper.version_id_col is not None and \
-                mapper.version_id_generator is False:
-            prop = mapper._columntoproperty[mapper.version_id_col]
-            if prop.key in state.unloaded:
-                toload_now.extend([prop.key])
-
-        if toload_now:
-            state.key = base_mapper._identity_key_from_state(state)
-            loading.load_on_ident(
-                uowtransaction.session.query(base_mapper),
-                state.key, refresh_state=state,
-                only_load_props=toload_now)
+        if bookkeeping:
+            if mapper._readonly_props:
+                readonly = state.unmodified_intersection(
+                    [p.key for p in mapper._readonly_props
+                        if p.expire_on_flush or p.key not in state.dict]
+                )
+                if readonly:
+                    state._expire_attributes(state.dict, readonly)
+
+            # if eager_defaults option is enabled, load
+            # all expired cols.  Else if we have a version_id_col, make sure
+            # it isn't expired.
+            toload_now = []
+
+            if base_mapper.eager_defaults:
+                toload_now.extend(state._unloaded_non_object)
+            elif mapper.version_id_col is not None and \
+                    mapper.version_id_generator is False:
+                prop = mapper._columntoproperty[mapper.version_id_col]
+                if prop.key in state.unloaded:
+                    toload_now.extend([prop.key])
+
+            if toload_now:
+                state.key = base_mapper._identity_key_from_state(state)
+                loading.load_on_ident(
+                    uowtransaction.session.query(base_mapper),
+                    state.key, refresh_state=state,
+                    only_load_props=toload_now)
 
         # call after_XXX extensions
         if not has_identity:
index 036045dba238527d1d37c0a7990a4fa03c916f5f..2455c803a7ec5785563be43c3550b1f826cc1dbb 100644 (file)
@@ -2033,6 +2033,40 @@ class Session(_SessionClassMethods):
             with util.safe_reraise():
                 transaction.rollback(_capture_exception=True)
 
+    def bulk_save(self, objects):
+        self._flushing = True
+        flush_context = UOWTransaction(self)
+
+        if self.dispatch.before_bulk_save:
+            self.dispatch.before_bulk_save(
+                self, flush_context, objects)
+
+        flush_context.transaction = transaction = self.begin(
+            subtransactions=True)
+        try:
+            self._warn_on_events = True
+            try:
+                flush_context.bulk_save(objects)
+            finally:
+                self._warn_on_events = False
+
+            self.dispatch.after_bulk_save(
+                self, flush_context, objects
+            )
+
+            flush_context.finalize_flush_changes()
+
+            self.dispatch.after_bulk_save_postexec(
+                self, flush_context, objects)
+
+            transaction.commit()
+
+        except:
+            with util.safe_reraise():
+                transaction.rollback(_capture_exception=True)
+        finally:
+            self._flushing = False
+
     def is_modified(self, instance, include_collections=True,
                     passive=True):
         """Return ``True`` if the given instance has locally
index 71e61827b1bf0f1f2a9370bdb685e5153ad2c957..8df24e95ac3540016e162b44c89d947bd70c287d 100644 (file)
@@ -16,6 +16,7 @@ organizes them in order of dependency, and executes.
 from .. import util, event
 from ..util import topological
 from . import attributes, persistence, util as orm_util
+import itertools
 
 
 def track_cascade_events(descriptor, prop):
@@ -379,14 +380,37 @@ class UOWTransaction(object):
         execute() method has succeeded and the transaction has been committed.
 
         """
+        if not self.states:
+            return
+
         states = set(self.states)
         isdel = set(
             s for (s, (isdelete, listonly)) in self.states.items()
             if isdelete
         )
         other = states.difference(isdel)
-        self.session._remove_newly_deleted(isdel)
-        self.session._register_newly_persistent(other)
+        if isdel:
+            self.session._remove_newly_deleted(isdel)
+        if other:
+            self.session._register_newly_persistent(other)
+
+    def bulk_save(self, objects):
+        for (base_mapper, in_session), states in itertools.groupby(
+                (attributes.instance_state(obj) for obj in objects),
+                lambda state:
+                (
+                    state.mapper.base_mapper,
+                    state.key is self.session.hash_key
+                )):
+
+            persistence.save_obj(
+                base_mapper, list(states), self, bookkeeping=in_session)
+
+            if in_session:
+                self.states.update(
+                    (state, (False, False))
+                    for state in states
+                )
 
 
 class IterateMappersMixin(object):