]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- that's it, feature is finished, needs tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Aug 2014 21:15:20 +0000 (17:15 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Aug 2014 21:15:20 +0000 (17:15 -0400)
lib/sqlalchemy/orm/persistence.py

index c2750eeb32123df39c180a2650d74b79749f8069..aa10da9f42dfb4feac1059be17f96a594e6b1712 100644 (file)
@@ -15,7 +15,7 @@ in unitofwork.py.
 """
 
 import operator
-from itertools import groupby
+from itertools import groupby, chain
 from .. import sql, util, exc as sa_exc, schema
 from . import attributes, sync, exc as orm_exc, evaluator
 from .base import state_str, _attr_as_key
@@ -86,17 +86,16 @@ def _bulk_update(mapper, mappings, session_transaction, isstates):
 
     connection = session_transaction.connection(base_mapper)
 
-    value_params = {}
-
     for table, super_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(super_mapper):
             continue
 
-        records = (
-            (None, None, params, super_mapper, connection, value_params)
-            for
-            params in _collect_bulk_update_commands(mapper, table, mappings)
-        )
+        records = _collect_update_commands(None, table, (
+            (None, mapping, mapper, connection,
+                (mapping[mapper._version_id_prop.key]
+                    if mapper._version_id_prop else None))
+            for mapping in mappings
+        ), bulk=True)
 
         _emit_update_statements(base_mapper, None,
                                 cached_connections,
@@ -158,17 +157,16 @@ def save_obj(
 
     _finalize_insert_update_commands(
         base_mapper, uowtransaction,
-        (
-            (state, state_dict, mapper, connection, False)
-            for state, state_dict, mapper, connection in states_to_insert
-        )
-    )
-    _finalize_insert_update_commands(
-        base_mapper, uowtransaction,
-        (
-            (state, state_dict, mapper, connection, True)
-            for state, state_dict, mapper, connection,
-            update_version_id in states_to_update
+        chain(
+            (
+                (state, state_dict, mapper, connection, False)
+                for state, state_dict, mapper, connection in states_to_insert
+            ),
+            (
+                (state, state_dict, mapper, connection, True)
+                for state, state_dict, mapper, connection,
+                update_version_id in states_to_update
+            )
         )
     )
 
@@ -394,7 +392,9 @@ def _collect_insert_commands(table, states_to_insert, bulk=False):
             has_all_defaults)
 
 
-def _collect_update_commands(uowtransaction, table, states_to_update):
+def _collect_update_commands(
+        uowtransaction, table, states_to_update,
+        bulk=False):
     """Identify sets of values to use in UPDATE statements for a
     list of states.
 
@@ -414,23 +414,32 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
 
         pks = mapper._pks_by_table[table]
 
-        params = {}
         value_params = {}
 
         propkey_to_col = mapper._propkey_to_col[table]
 
-        for propkey in set(propkey_to_col).intersection(state.committed_state):
-            value = state_dict[propkey]
-            col = propkey_to_col[propkey]
-
-            if not state.manager[propkey].impl.is_equal(
-                    value, state.committed_state[propkey]):
-                if isinstance(value, sql.ClauseElement):
-                    value_params[col] = value
-                else:
-                    params[col.key] = value
+        if bulk:
+            params = dict(
+                (propkey_to_col[propkey].key, state_dict[propkey])
+                for propkey in
+                set(propkey_to_col).intersection(state_dict)
+            )
+        else:
+            params = {}
+            for propkey in set(propkey_to_col).intersection(
+                    state.committed_state):
+                value = state_dict[propkey]
+                col = propkey_to_col[propkey]
+
+                if not state.manager[propkey].impl.is_equal(
+                        value, state.committed_state[propkey]):
+                    if isinstance(value, sql.ClauseElement):
+                        value_params[col] = value
+                    else:
+                        params[col.key] = value
 
-        if update_version_id is not None:
+        if update_version_id is not None and \
+                mapper.version_id_col in mapper._cols_by_table[table]:
             col = mapper.version_id_col
             params[col._label] = update_version_id
 
@@ -442,24 +451,33 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
         if not (params or value_params):
             continue
 
-        pk_params = {}
-        for col in pks:
-            propkey = mapper._columntoproperty[col].key
-            history = state.manager[propkey].impl.get_history(
-                state, state_dict, attributes.PASSIVE_OFF)
-
-            if history.added:
-                if not history.deleted or \
-                        ("pk_cascaded", state, col) in \
-                        uowtransaction.attributes:
-                    pk_params[col._label] = history.added[0]
-                    params.pop(col.key, None)
+        if bulk:
+            pk_params = dict(
+                (propkey_to_col[propkey]._label, state_dict.get(propkey))
+                for propkey in
+                set(propkey_to_col).
+                intersection(mapper._pk_keys_by_table[table])
+            )
+        else:
+            pk_params = {}
+            for col in pks:
+                propkey = mapper._columntoproperty[col].key
+
+                history = state.manager[propkey].impl.get_history(
+                    state, state_dict, attributes.PASSIVE_OFF)
+
+                if history.added:
+                    if not history.deleted or \
+                            ("pk_cascaded", state, col) in \
+                            uowtransaction.attributes:
+                        pk_params[col._label] = history.added[0]
+                        params.pop(col.key, None)
+                    else:
+                        # else, use the old value to locate the row
+                        pk_params[col._label] = history.deleted[0]
+                        params[col.key] = history.added[0]
                 else:
-                    # else, use the old value to locate the row
-                    pk_params[col._label] = history.deleted[0]
-                    params[col.key] = history.added[0]
-            else:
-                pk_params[col._label] = history.unchanged[0]
+                    pk_params[col._label] = history.unchanged[0]
 
         if params or value_params:
             if None in pk_params.values():
@@ -471,44 +489,6 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
                 state, state_dict, params, mapper,
                 connection, value_params)
 
-def _collect_bulk_update_commands(mapper, table, mappings):
-    label_pks = mapper._pks_by_table[table]
-    if mapper.version_id_col is not None:
-        label_pks = label_pks.union([mapper.version_id_col])
-
-    to_translate = dict(
-        (propkey, col.key if col not in label_pks else col._label)
-        for propkey, col in mapper._propkey_to_col[table].items()
-    )
-
-    for mapping in mappings:
-        params = dict(
-            (to_translate[k], mapping[k]) for k in to_translate
-            if k in mapping and k not in mapper._primary_key_propkeys
-        )
-
-        if not params:
-            continue
-
-        try:
-            params.update(
-                (to_translate[k], mapping[k]) for k in 
-                mapper._primary_key_propkeys.intersection(to_translate)
-            )
-        except KeyError as ke:
-            raise orm_exc.FlushError(
-                "Can't update table using NULL for primary "
-                "key attribute: %s" % ke)
-
-        if mapper.version_id_generator is not False and \
-                mapper.version_id_col is not None and \
-                mapper.version_id_col.key not in params:
-            params[mapper.version_id_col.key] = \
-                mapper.version_id_generator(
-                    params[mapper.version_id_col._label])
-
-        yield params
-
 
 def _collect_post_update_commands(base_mapper, uowtransaction, table,
                                   states_to_update, post_update_cols):
@@ -569,7 +549,7 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
                     "key value")
 
         if update_version_id is not None and \
-                table.c.contains_column(mapper.version_id_col):
+                mapper.version_id_col in mapper._cols_by_table[table]:
             params[mapper.version_id_col.key] = update_version_id
         yield params, connection
 
@@ -581,7 +561,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
     by _collect_update_commands()."""
 
     needs_version_id = mapper.version_id_col is not None and \
-        table.c.contains_column(mapper.version_id_col)
+        mapper.version_id_col in mapper._cols_by_table[table]
 
     def update_stmt():
         clause = sql.and_()
@@ -610,9 +590,9 @@ def _emit_update_statements(base_mapper, uowtransaction,
         records in groupby(
             update,
             lambda rec: (
-                rec[4],
-                tuple(sorted(rec[2])),
-                bool(rec[5]))):
+                rec[4],  # connection
+                set(rec[2]),  # set of parameter keys
+                bool(rec[5]))):  # whether or not we have "value" parameters
 
         rows = 0
         records = list(records)
@@ -692,12 +672,14 @@ def _emit_insert_statements(base_mapper, uowtransaction,
     statement = base_mapper._memo(('insert', table), table.insert)
 
     for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
-        records in groupby(insert,
-                           lambda rec: (rec[4],
-                                        tuple(sorted(rec[2].keys())),
-                                        bool(rec[5]),
-                                        rec[6], rec[7])
-                           ):
+        records in groupby(
+            insert,
+            lambda rec: (
+                rec[4],  # connection
+                set(rec[2]),  # parameter keys
+                bool(rec[5]),  # whether we have "value" parameters
+                rec[6],
+                rec[7])):
         if not bookkeeping or \
                 (
                     has_all_defaults
@@ -785,7 +767,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
     # also group them into common (connection, cols) sets
     # to support executemany().
     for key, grouper in groupby(
-        update, lambda rec: (rec[1], sorted(rec[0]))
+        update, lambda rec: (
+            rec[1],  # connection
+            set(rec[0])  # parameter keys
+        )
     ):
         connection = key[0]
         multiparams = [params for params, conn in grouper]
@@ -799,7 +784,7 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
     by _collect_delete_commands()."""
 
     need_version_id = mapper.version_id_col is not None and \
-        table.c.contains_column(mapper.version_id_col)
+        mapper.version_id_col in mapper._cols_by_table[table]
 
     def delete_stmt():
         clause = sql.and_()
@@ -821,12 +806,9 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
     statement = base_mapper._memo(('delete', table), delete_stmt)
     for connection, recs in groupby(
         delete,
-        lambda rec: rec[1]
+        lambda rec: rec[1]   # connection
     ):
-        del_objects = [
-            params
-            for params, connection in recs
-        ]
+        del_objects = [params for params, connection in recs]
 
         connection = cached_connections[connection]
 
@@ -931,7 +913,8 @@ def _postfetch(mapper, uowtransaction, table,
     postfetch_cols = result.context.postfetch_cols
     returning_cols = result.context.returning_cols
 
-    if mapper.version_id_col is not None:
+    if mapper.version_id_col is not None and \
+            mapper.version_id_col in mapper._cols_by_table[table]:
         prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
 
     if returning_cols: