]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- factor out determination of current version id out of
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Aug 2014 18:24:45 +0000 (14:24 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Aug 2014 18:24:45 +0000 (14:24 -0400)
_collect_update_commands and _collect_delete_commands

lib/sqlalchemy/orm/persistence.py

index 37b696d0f9b8dae0f2557c130a6fea71f600a01f..511a9cef0fb0194b0053751915a4b631f3395c57 100644 (file)
@@ -45,38 +45,26 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
     cached_connections = _cached_connection_dict(base_mapper)
 
     for (state, dict_, mapper, connection,
-            has_identity, row_switch) in _organize_states_for_save(
+            has_identity,
+            row_switch, update_version_id) in _organize_states_for_save(
             base_mapper, states, uowtransaction
     ):
         if has_identity or row_switch:
             states_to_update.append(
-                (state, dict_, mapper, connection,
-                    has_identity, row_switch)
+                (state, dict_, mapper, connection, update_version_id)
             )
         else:
             states_to_insert.append(
-                (state, dict_, mapper, connection,
-                    has_identity, row_switch)
+                (state, dict_, mapper, connection)
             )
 
     for table, mapper in base_mapper._sorted_tables.items():
         if table not in mapper._pks_by_table:
             continue
-        insert = (
-            (state, state_dict, sub_mapper, connection)
-            for state, state_dict, sub_mapper, connection, has_identity,
-            row_switch in states_to_insert
-            if table in sub_mapper._pks_by_table
-        )
-        insert = _collect_insert_commands(table, insert)
+        insert = _collect_insert_commands(table, states_to_insert)
 
-        update = (
-            (state, state_dict, sub_mapper, connection, row_switch)
-            for state, state_dict, sub_mapper, connection, has_identity,
-            row_switch in states_to_update
-            if table in sub_mapper._pks_by_table
-        )
-        update = _collect_update_commands(uowtransaction, table, update)
+        update = _collect_update_commands(
+            uowtransaction, table, states_to_update)
 
         _emit_update_statements(base_mapper, uowtransaction,
                                 cached_connections,
@@ -89,9 +77,16 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
     _finalize_insert_update_commands(
         base_mapper, uowtransaction,
         (
-            (state, state_dict, mapper, connection, has_identity)
-            for state, state_dict, mapper, connection, has_identity,
-            row_switch in states_to_insert + states_to_update
+            (state, state_dict, mapper, connection, False)
+            for state, state_dict, mapper, connection in states_to_insert
+        )
+    )
+    _finalize_insert_update_commands(
+        base_mapper, uowtransaction,
+        (
+            (state, state_dict, mapper, connection, True)
+            for state, state_dict, mapper, connection,
+            update_version_id in states_to_update
         )
     )
 
@@ -149,21 +144,14 @@ def delete_obj(base_mapper, states, uowtransaction):
         if table not in mapper._pks_by_table:
             continue
 
-        delete = (
-            (state, state_dict, sub_mapper, connection)
-            for state, state_dict, sub_mapper, has_identity, connection
-            in states_to_delete if table in sub_mapper._pks_by_table
-            and has_identity
-        )
-
         delete = _collect_delete_commands(base_mapper, uowtransaction,
-                                          table, delete)
+                                          table, states_to_delete)
 
         _emit_delete_statements(base_mapper, uowtransaction,
                                 cached_connections, mapper, table, delete)
 
-    for state, state_dict, mapper, has_identity, connection \
-            in states_to_delete:
+    for state, state_dict, mapper, connection, \
+            update_version_id in states_to_delete:
         mapper.dispatch.after_delete(mapper, connection, state)
 
 
@@ -187,7 +175,7 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
 
         instance_key = state.key or mapper._identity_key_from_state(state)
 
-        row_switch = None
+        row_switch = update_version_id = None
 
         # call before_XXX extensions
         if not has_identity:
@@ -224,8 +212,14 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
             uowtransaction.remove_state_actions(existing)
             row_switch = existing
 
+        if (has_identity or row_switch) and mapper.version_id_col is not None:
+            update_version_id = mapper._get_committed_state_attr_by_column(
+                row_switch if row_switch else state,
+                row_switch.dict if row_switch else dict_,
+                mapper.version_id_col)
+
         yield (state, dict_, mapper, connection,
-               has_identity, row_switch)
+               has_identity, row_switch, update_version_id)
 
 
 def _organize_states_for_post_update(base_mapper, states,
@@ -255,7 +249,16 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
 
         mapper.dispatch.before_delete(mapper, connection, state)
 
-        yield state, dict_, mapper, bool(state.key), connection
+        if mapper.version_id_col is not None:
+            update_version_id = \
+                mapper._get_committed_state_attr_by_column(
+                    state, dict_,
+                    mapper.version_id_col)
+        else:
+            update_version_id = None
+
+        yield (
+            state, dict_, mapper, connection, update_version_id)
 
 
 def _collect_insert_commands(table, states_to_insert):
@@ -264,8 +267,8 @@ def _collect_insert_commands(table, states_to_insert):
 
     """
     for state, state_dict, mapper, connection in states_to_insert:
-
-        # assert table in mapper._pks_by_table
+        if table not in mapper._pks_by_table:
+            continue
 
         params = {}
         value_params = {}
@@ -318,9 +321,11 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
 
     """
 
-    for state, state_dict, mapper, connection, row_switch in states_to_update:
+    for state, state_dict, mapper, connection, \
+            update_version_id in states_to_update:
 
-        # assert table in mapper._pks_by_table
+        if table not in mapper._pks_by_table:
+            continue
 
         pks = mapper._pks_by_table[table]
 
@@ -340,17 +345,13 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
                 else:
                     params[col.key] = value
 
-        if mapper.version_id_col is not None:
+        if update_version_id is not None:
             col = mapper.version_id_col
-            params[col._label] = \
-                mapper._get_committed_state_attr_by_column(
-                    row_switch if row_switch else state,
-                    row_switch.dict if row_switch else state_dict,
-                    col)
+            params[col._label] = update_version_id
 
             if col.key not in params and \
                     mapper.version_id_generator is not False:
-                val = mapper.version_id_generator(params[col._label])
+                val = mapper.version_id_generator(update_version_id)
                 params[col.key] = val
 
         if not (params or value_params):
@@ -364,7 +365,8 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
 
             if history.added:
                 if not history.deleted or \
-                    ("pk_cascaded", state, col) in uowtransaction.attributes:
+                        ("pk_cascaded", state, col) in \
+                        uowtransaction.attributes:
                     pk_params[col._label] = history.added[0]
                     params.pop(col.key, None)
                 else:
@@ -374,7 +376,6 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
             else:
                 pk_params[col._label] = history.unchanged[0]
 
-
         if params or value_params:
             if None in pk_params.values():
                 raise orm_exc.FlushError(
@@ -426,9 +427,11 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
     """Identify values to use in DELETE statements for a list of
     states to be deleted."""
 
-    for state, state_dict, mapper, connection in states_to_delete:
+    for state, state_dict, mapper, connection, \
+            update_version_id in states_to_delete:
 
-        # assert table in mapper._pks_by_table
+        if table not in mapper._pks_by_table:
+            continue
 
         params = {}
         for col in mapper._pks_by_table[table]:
@@ -442,12 +445,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
                     "using NULL for primary "
                     "key value")
 
-        if mapper.version_id_col is not None and \
+        if update_version_id is not None and \
                 table.c.contains_column(mapper.version_id_col):
-            params[mapper.version_id_col.key] = \
-                mapper._get_committed_state_attr_by_column(
-                    state, state_dict,
-                    mapper.version_id_col)
+            params[mapper.version_id_col.key] = update_version_id
         yield params, connection