]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dev
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2014 22:30:14 +0000 (18:30 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2014 22:30:14 +0000 (18:30 -0400)
lib/sqlalchemy/orm/persistence.py

index 8d3e90cf4567a53fe39751f5cbc24b4d2a4fdcc9..f9e7eda285abb80c654460da55f6ff38b4ff9c13 100644 (file)
@@ -34,26 +34,18 @@ def bulk_insert(mapper, mappings, uowtransaction):
             "not supported in bulk_insert()")
 
     connection = uowtransaction.transaction.connection(base_mapper)
-    value_params = {}
     for table, sub_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(sub_mapper):
             continue
 
-        has_version_generator = mapper.version_id_generator is not False and \
-            mapper.version_id_col is not None
-
         records = []
-        for mapping in mappings:
-            params = dict(
-                (col.key, mapping[propkey])
-                for col, propkey in mapper._col_to_propkey[table]
-                if propkey in mapping
-            )
-
-            if has_version_generator:
-                params[mapper.version_id_col.key] = \
-                    mapper.version_id_generator(None)
-
+        for (
+            state, state_dict, params, mapper,
+            connection, value_params, has_all_pks,
+            has_all_defaults) in _collect_insert_commands(table, (
+                (None, mapping, sub_mapper, connection)
+                for mapping in mappings)
+        ):
             records.append(
                 (None, None, params, sub_mapper,
                     connection, value_params, True, True)
@@ -82,13 +74,13 @@ def bulk_update(mapper, mappings, uowtransaction):
         if not mapper.isa(sub_mapper):
             continue
 
-        label_pks = mapper._pks_by_table[table]
+        label_pks = sub_mapper._pks_by_table[table]
         if mapper.version_id_col is not None:
             label_pks = label_pks.union([mapper.version_id_col])
 
         to_translate = dict(
             (propkey, col._label if col in label_pks else col.key)
-            for col, propkey in mapper._col_to_propkey[table]
+            for propkey, col in sub_mapper._propkey_to_col[table].items()
         )
 
         records = []
@@ -350,7 +342,7 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
         yield state, dict_, mapper, bool(state.key), connection
 
 
-def _collect_insert_commands(table, states_to_insert):
+def _collect_insert_commands(table, states_to_insert, bulk=False):
     """Identify sets of values to use in INSERT statements for a
     list of states.
 
@@ -374,17 +366,20 @@ def _collect_insert_commands(table, states_to_insert):
             else:
                 params[col.key] = value
 
-        for colkey in mapper._insert_cols_as_none[table].\
-                difference(params).difference(value_params):
-            params[colkey] = None
+        if not bulk:
+            for colkey in mapper._insert_cols_as_none[table].\
+                    difference(params).difference(value_params):
+                params[colkey] = None
 
-        has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
+            has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
 
-        if mapper.base_mapper.eager_defaults:
-            has_all_defaults = mapper._server_default_cols[table].\
-                issubset(params)
+            if mapper.base_mapper.eager_defaults:
+                has_all_defaults = mapper._server_default_cols[table].\
+                    issubset(params)
+            else:
+                has_all_defaults = True
         else:
-            has_all_defaults = True
+            has_all_defaults = has_all_pks = True
 
         if mapper.version_id_generator is not False \
                 and mapper.version_id_col is not None and \