]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refinements
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Aug 2014 18:24:56 +0000 (14:24 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Aug 2014 18:24:56 +0000 (14:24 -0400)
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py

index 37ea3071bc3596ac28281f1c94f1bc72e0ed69e7..aa99673badaa7ef8d0427d069b4c846943a9acfe 100644 (file)
@@ -1453,18 +1453,6 @@ class SessionEvents(event.Events):
 
         """
 
-    def before_bulk_insert(self, session, flush_context, mapper, mappings):
-        """"""
-
-    def after_bulk_insert(self, session, flush_context, mapper, mappings):
-        """"""
-
-    def before_bulk_update(self, session, flush_context, mapper, mappings):
-        """"""
-
-    def after_bulk_update(self, session, flush_context, mapper, mappings):
-        """"""
-
     def after_begin(self, session, transaction, connection):
         """Execute after a transaction is begun on a connection
 
index 89c092b580cc5f5921c2e9836627b4275d9ba0e8..b98fbda42069a8255435c35f479ecfb3e02de98a 100644 (file)
@@ -2366,6 +2366,10 @@ class Mapper(InspectionAttr):
     def _primary_key_props(self):
         return [self._columntoproperty[col] for col in self.primary_key]
 
+    @_memoized_configured_property
+    def _primary_key_propkeys(self):
+        return set([prop.key for prop in self._primary_key_props])
+
     def _get_state_attr_by_column(
             self, state, dict_, column,
             passive=attributes.PASSIVE_RETURN_NEVER_SET):
index 145a7783a31d53bbe7205831c031d72ed54e3d2c..9c00089254a3cb1e98ee398d973596326120fed7 100644 (file)
@@ -23,17 +23,22 @@ from ..sql import expression
 from . import loading
 
 
-def bulk_insert(mapper, mappings, uowtransaction):
+def _bulk_insert(mapper, mappings, session_transaction, isstates):
     base_mapper = mapper.base_mapper
 
     cached_connections = _cached_connection_dict(base_mapper)
 
-    if uowtransaction.session.connection_callable:
+    if session_transaction.session.connection_callable:
         raise NotImplementedError(
             "connection_callable / per-instance sharding "
             "not supported in bulk_insert()")
 
-    connection = uowtransaction.transaction.connection(base_mapper)
+    if isstates:
+        mappings = [state.dict for state in mappings]
+    else:
+        mappings = list(mappings)
+
+    connection = session_transaction.connection(base_mapper)
     for table, super_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(super_mapper):
             continue
@@ -45,61 +50,55 @@ def bulk_insert(mapper, mappings, uowtransaction):
             state, state_dict, params, mp,
             conn, value_params, has_all_pks,
             has_all_defaults in _collect_insert_commands(table, (
-                (None, mapping, super_mapper, connection)
-                for mapping in mappings)
+                (None, mapping, mapper, connection)
+                for mapping in mappings),
+                bulk=True
             )
         )
 
-        _emit_insert_statements(base_mapper, uowtransaction,
+        _emit_insert_statements(base_mapper, None,
                                 cached_connections,
                                 super_mapper, table, records,
                                 bookkeeping=False)
 
 
-def bulk_update(mapper, mappings, uowtransaction):
+def _bulk_update(mapper, mappings, session_transaction, isstates):
     base_mapper = mapper.base_mapper
 
     cached_connections = _cached_connection_dict(base_mapper)
 
-    if uowtransaction.session.connection_callable:
+    def _changed_dict(mapper, state):
+        return dict(
+            (k, v)
+            for k, v in state.dict.items() if k in state.committed_state or k
+            in mapper._primary_key_propkeys
+        )
+
+    if isstates:
+        mappings = [_changed_dict(mapper, state) for state in mappings]
+    else:
+        mappings = list(mappings)
+
+    if session_transaction.session.connection_callable:
         raise NotImplementedError(
             "connection_callable / per-instance sharding "
             "not supported in bulk_update()")
 
-    connection = uowtransaction.transaction.connection(base_mapper)
+    connection = session_transaction.connection(base_mapper)
 
     value_params = {}
+
     for table, super_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(super_mapper):
             continue
 
-        label_pks = super_mapper._pks_by_table[table]
-        if mapper.version_id_col is not None:
-            label_pks = label_pks.union([mapper.version_id_col])
-
-        to_translate = dict(
-            (propkey, col._label if col in label_pks else col.key)
-            for propkey, col in super_mapper._propkey_to_col[table].items()
+        records = (
+            (None, None, params, super_mapper, connection, value_params)
+            for
+            params in _collect_bulk_update_commands(mapper, table, mappings)
         )
 
-        records = []
-        for mapping in mappings:
-            params = dict(
-                (to_translate[k], v) for k, v in mapping.items()
-            )
-
-            if mapper.version_id_generator is not False and \
-                    mapper.version_id_col is not None and \
-                    mapper.version_id_col.key not in params:
-                params[mapper.version_id_col.key] = \
-                    mapper.version_id_generator(
-                        params[mapper.version_id_col._label])
-
-            records.append(
-                (None, None, params, super_mapper, connection, value_params)
-            )
-
-        _emit_update_statements(base_mapper, uowtransaction,
+        _emit_update_statements(base_mapper, None,
                                 cached_connections,
                                 super_mapper, table, records,
                                 bookkeeping=False)
@@ -360,7 +359,7 @@ def _collect_insert_commands(table, states_to_insert, bulk=False):
             col = propkey_to_col[propkey]
             if value is None:
                 continue
-            elif isinstance(value, sql.ClauseElement):
+            elif not bulk and isinstance(value, sql.ClauseElement):
                 value_params[col.key] = value
             else:
                 params[col.key] = value
@@ -481,6 +480,44 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
                 state, state_dict, params, mapper,
                 connection, value_params)
 
+def _collect_bulk_update_commands(mapper, table, mappings):
+    label_pks = mapper._pks_by_table[table]
+    if mapper.version_id_col is not None:
+        label_pks = label_pks.union([mapper.version_id_col])
+
+    to_translate = dict(
+        (propkey, col.key if col not in label_pks else col._label)
+        for propkey, col in mapper._propkey_to_col[table].items()
+    )
+
+    for mapping in mappings:
+        params = dict(
+            (to_translate[k], mapping[k]) for k in to_translate
+            if k in mapping and k not in mapper._primary_key_propkeys
+        )
+
+        if not params:
+            continue
+
+        try:
+            params.update(
+                (to_translate[k], mapping[k]) for k in 
+                mapper._primary_key_propkeys.intersection(to_translate)
+            )
+        except KeyError as ke:
+            raise orm_exc.FlushError(
+                "Can't update table using NULL for primary "
+                "key attribute: %s" % ke)
+
+        if mapper.version_id_generator is not False and \
+                mapper.version_id_col is not None and \
+                mapper.version_id_col.key not in params:
+            params[mapper.version_id_col.key] = \
+                mapper.version_id_generator(
+                    params[mapper.version_id_col._label])
+
+        yield params
+
 
 def _collect_post_update_commands(base_mapper, uowtransaction, table,
                                   states_to_update, post_update_cols):
index 3199a4332d08740437195b6194d81f54ff73cc9a..968868e8437b70d510345492c65e7f8643c6e7bf 100644 (file)
@@ -21,6 +21,7 @@ from .base import (
     _none_set, state_str, instance_str
 )
 import itertools
+from . import persistence
 from .unitofwork import UOWTransaction
 from . import state as statelib
 import sys
@@ -2040,37 +2041,27 @@ class Session(_SessionClassMethods):
             (attributes.instance_state(obj) for obj in objects),
             lambda state: (state.mapper, state.key is not None)
         ):
-            if isupdate:
-                self.bulk_update_mappings(mapper, (s.dict for s in states))
-            else:
-                self.bulk_insert_mappings(mapper, (s.dict for s in states))
+            self._bulk_save_mappings(mapper, states, isupdate, True)
 
     def bulk_insert_mappings(self, mapper, mappings):
-        self._bulk_save_mappings(mapper, mappings, False)
+        self._bulk_save_mappings(mapper, mappings, False, False)
 
     def bulk_update_mappings(self, mapper, mappings):
-        self._bulk_save_mappings(mapper, mappings, True)
+        self._bulk_save_mappings(mapper, mappings, True, False)
 
-    def _bulk_save_mappings(self, mapper, mappings, isupdate):
+    def _bulk_save_mappings(self, mapper, mappings, isupdate, isstates):
         mapper = _class_to_mapper(mapper)
         self._flushing = True
-        flush_context = UOWTransaction(self)
 
-        flush_context.transaction = transaction = self.begin(
+        transaction = self.begin(
             subtransactions=True)
         try:
             if isupdate:
-                self.dispatch.before_bulk_update(
-                    self, flush_context, mapper, mappings)
-                flush_context.bulk_update(mapper, mappings)
-                self.dispatch.after_bulk_update(
-                    self, flush_context, mapper, mappings)
+                persistence._bulk_update(
+                    mapper, mappings, transaction, isstates)
             else:
-                self.dispatch.before_bulk_insert(
-                    self, flush_context, mapper, mappings)
-                flush_context.bulk_insert(mapper, mappings)
-                self.dispatch.after_bulk_insert(
-                    self, flush_context, mapper, mappings)
+                persistence._bulk_insert(
+                    mapper, mappings, transaction, isstates)
             transaction.commit()
 
         except:
index b3a1519c5299960beaeb63de5c4db3b00abb3026..05265b13fbfbcbe27b2445f67edd0d87e70b7ba5 100644 (file)
@@ -394,12 +394,6 @@ class UOWTransaction(object):
         if other:
             self.session._register_newly_persistent(other)
 
-    def bulk_insert(self, mapper, mappings):
-        persistence.bulk_insert(mapper, mappings, self)
-
-    def bulk_update(self, mapper, mappings):
-        persistence.bulk_update(mapper, mappings, self)
-
 
 class IterateMappersMixin(object):
     def _mappers(self, uow):