]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- organize persistence methods in terms of generators,
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2014 20:32:48 +0000 (16:32 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2014 20:32:59 +0000 (16:32 -0400)
narrow down argument lists and generator items for each function
down to just what each function needs.   This will help for them
to be of more multipurpose use for bulk operations

lib/sqlalchemy/orm/persistence.py

index 228cfef3aab10e18afabb269ccb77b64e3576ad9..c7850ac1da2a903bd31c3b8cda2b9e26d3d61ae0 100644 (file)
@@ -40,32 +40,58 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
             save_obj(base_mapper, [state], uowtransaction, single=True)
         return
 
-    states_to_insert, states_to_update = _organize_states_for_save(
-        base_mapper,
-        states,
-        uowtransaction)
-
+    states_to_update = []
+    states_to_insert = []
     cached_connections = _cached_connection_dict(base_mapper)
 
-    for table, mapper in base_mapper._sorted_tables.items():
-        insert = _collect_insert_commands(base_mapper, uowtransaction,
-                                          table, states_to_insert)
-
-        update = _collect_update_commands(base_mapper, uowtransaction,
-                                          table, states_to_update)
-
-        if update:
-            _emit_update_statements(base_mapper, uowtransaction,
-                                    cached_connections,
-                                    mapper, table, update)
-
-        if insert:
-            _emit_insert_statements(base_mapper, uowtransaction,
-                                    cached_connections,
-                                    mapper, table, insert)
+    for (state, dict_, mapper, connection,
+            has_identity, row_switch) in _organize_states_for_save(
+            base_mapper, states, uowtransaction
+    ):
+        if has_identity or row_switch:
+            states_to_update.append(
+                (state, dict_, mapper, connection,
+                    has_identity, row_switch)
+            )
+        else:
+            states_to_insert.append(
+                (state, dict_, mapper, connection,
+                    has_identity, row_switch)
+            )
 
-    _finalize_insert_update_commands(base_mapper, uowtransaction,
-                                     states_to_insert, states_to_update)
+    for table, mapper in base_mapper._sorted_tables.items():
+        if table not in mapper._pks_by_table:
+            continue
+        insert = (
+            (state, state_dict, mapper, connection)
+            for state, state_dict, mapper, connection, has_identity,
+            row_switch in states_to_insert
+        )
+        insert = _collect_insert_commands(table, insert)
+
+        update = (
+            (state, state_dict, mapper, connection, row_switch)
+            for state, state_dict, mapper, connection, has_identity,
+            row_switch in states_to_update
+        )
+        update = _collect_update_commands(uowtransaction, table, update)
+
+        _emit_update_statements(base_mapper, uowtransaction,
+                                cached_connections,
+                                mapper, table, update)
+
+        _emit_insert_statements(base_mapper, uowtransaction,
+                                cached_connections,
+                                mapper, table, insert)
+
+    _finalize_insert_update_commands(
+        base_mapper, uowtransaction,
+        (
+            (state, state_dict, mapper, connection, has_identity)
+            for state, state_dict, mapper, connection, has_identity,
+            row_switch in states_to_insert + states_to_update
+        )
+    )
 
 
 def post_update(base_mapper, states, uowtransaction, post_update_cols):
@@ -75,19 +101,20 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
     """
     cached_connections = _cached_connection_dict(base_mapper)
 
-    states_to_update = _organize_states_for_post_update(
+    states_to_update = list(_organize_states_for_post_update(
         base_mapper,
-        states, uowtransaction)
+        states, uowtransaction))
 
     for table, mapper in base_mapper._sorted_tables.items():
+        if table not in mapper._pks_by_table:
+            continue
         update = _collect_post_update_commands(base_mapper, uowtransaction,
                                                table, states_to_update,
                                                post_update_cols)
 
-        if update:
-            _emit_post_update_statements(base_mapper, uowtransaction,
-                                         cached_connections,
-                                         mapper, table, update)
+        _emit_post_update_statements(base_mapper, uowtransaction,
+                                     cached_connections,
+                                     mapper, table, update)
 
 
 def delete_obj(base_mapper, states, uowtransaction):
@@ -100,19 +127,21 @@ def delete_obj(base_mapper, states, uowtransaction):
 
     cached_connections = _cached_connection_dict(base_mapper)
 
-    states_to_delete = _organize_states_for_delete(
+    states_to_delete = list(_organize_states_for_delete(
         base_mapper,
         states,
-        uowtransaction)
+        uowtransaction))
 
     table_to_mapper = base_mapper._sorted_tables
 
     for table in reversed(list(table_to_mapper.keys())):
+        mapper = table_to_mapper[table]
+        if table not in mapper._pks_by_table:
+            continue
+
         delete = _collect_delete_commands(base_mapper, uowtransaction,
                                           table, states_to_delete)
 
-        mapper = table_to_mapper[table]
-
         _emit_delete_statements(base_mapper, uowtransaction,
                                 cached_connections, mapper, table, delete)
 
@@ -133,9 +162,6 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
 
     """
 
-    states_to_insert = []
-    states_to_update = []
-
     for state, dict_, mapper, connection in _connections_for_states(
             base_mapper, uowtransaction,
             states):
@@ -181,18 +207,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
             uowtransaction.remove_state_actions(existing)
             row_switch = existing
 
-        if not has_identity and not row_switch:
-            states_to_insert.append(
-                (state, dict_, mapper, connection,
-                 has_identity, row_switch)
-            )
-        else:
-            states_to_update.append(
-                (state, dict_, mapper, connection,
-                 has_identity, row_switch)
-            )
-
-    return states_to_insert, states_to_update
+        yield (state, dict_, mapper, connection,
+               has_identity, row_switch)
 
 
 def _organize_states_for_post_update(base_mapper, states,
@@ -205,8 +221,7 @@ def _organize_states_for_post_update(base_mapper, states,
     the execution per state.
 
     """
-    return list(_connections_for_states(base_mapper, uowtransaction,
-                                        states))
+    return _connections_for_states(base_mapper, uowtransaction, states)
 
 
 def _organize_states_for_delete(base_mapper, states, uowtransaction):
@@ -217,28 +232,21 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
     mapper, the connection to use for the execution per state.
 
     """
-    states_to_delete = []
-
     for state, dict_, mapper, connection in _connections_for_states(
             base_mapper, uowtransaction,
             states):
 
         mapper.dispatch.before_delete(mapper, connection, state)
 
-        states_to_delete.append((state, dict_, mapper,
-                                 bool(state.key), connection))
-    return states_to_delete
+        yield state, dict_, mapper, bool(state.key), connection
 
 
-def _collect_insert_commands(base_mapper, uowtransaction, table,
-                             states_to_insert):
+def _collect_insert_commands(table, states_to_insert):
     """Identify sets of values to use in INSERT statements for a
     list of states.
 
     """
-    insert = []
-    for state, state_dict, mapper, connection, has_identity, \
-            row_switch in states_to_insert:
+    for state, state_dict, mapper, connection in states_to_insert:
 
         if table not in mapper._pks_by_table:
             continue
@@ -262,7 +270,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
 
         has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
 
-        if base_mapper.eager_defaults:
+        if mapper.base_mapper.eager_defaults:
             has_all_defaults = mapper._server_default_cols[table].\
                 issubset(params)
         else:
@@ -274,14 +282,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
             params[mapper.version_id_col.key] = \
                 mapper.version_id_generator(None)
 
-        insert.append((state, state_dict, params, mapper,
-                       connection, value_params, has_all_pks,
-                       has_all_defaults))
-    return insert
+        yield (
+            state, state_dict, params, mapper,
+            connection, value_params, has_all_pks,
+            has_all_defaults)
 
 
-def _collect_update_commands(base_mapper, uowtransaction,
-                             table, states_to_update):
+def _collect_update_commands(uowtransaction, table, states_to_update):
     """Identify sets of values to use in UPDATE statements for a
     list of states.
 
@@ -293,9 +300,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
 
     """
 
-    update = []
-    for state, state_dict, mapper, connection, has_identity, \
-            row_switch in states_to_update:
+    for state, state_dict, mapper, connection, row_switch in states_to_update:
         if table not in mapper._pks_by_table:
             continue
 
@@ -368,9 +373,9 @@ def _collect_update_commands(base_mapper, uowtransaction,
                     "Can't update table using NULL for primary "
                     "key value")
             params.update(pk_params)
-            update.append((state, state_dict, params, mapper,
-                           connection, value_params))
-    return update
+            yield (
+                state, state_dict, params, mapper,
+                connection, value_params)
 
 
 def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -380,7 +385,6 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
 
     """
 
-    update = []
     for state, state_dict, mapper, connection in states_to_update:
         if table not in mapper._pks_by_table:
             continue
@@ -405,9 +409,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
                     params[col.key] = value
                     hasdata = True
         if hasdata:
-            update.append((state, state_dict, params, mapper,
-                           connection))
-    return update
+            yield params, connection
 
 
 def _collect_delete_commands(base_mapper, uowtransaction, table,
@@ -415,15 +417,12 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
     """Identify values to use in DELETE statements for a list of
     states to be deleted."""
 
-    delete = util.defaultdict(list)
-
     for state, state_dict, mapper, has_identity, connection \
             in states_to_delete:
         if not has_identity or table not in mapper._pks_by_table:
             continue
 
         params = {}
-        delete[connection].append(params)
         for col in mapper._pks_by_table[table]:
             params[col.key] = \
                 value = \
@@ -441,7 +440,7 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
                 mapper._get_committed_state_attr_by_column(
                     state, state_dict,
                     mapper.version_id_col)
-    return delete
+        yield params, connection
 
 
 def _emit_update_statements(base_mapper, uowtransaction,
@@ -481,8 +480,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
             lambda rec: (
                 rec[4],
                 tuple(sorted(rec[2])),
-                bool(rec[5]))
-            ):
+                bool(rec[5]))):
 
         rows = 0
         records = list(records)
@@ -652,11 +650,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
     # also group them into common (connection, cols) sets
     # to support executemany().
     for key, grouper in groupby(
-        update, lambda rec: (rec[4], list(rec[2].keys()))
+        update, lambda rec: (rec[1], sorted(rec[0]))
     ):
         connection = key[0]
-        multiparams = [params for state, state_dict,
-                       params, mapper, conn in grouper]
+        multiparams = [params for params, conn in grouper]
         cached_connections[connection].\
             execute(statement, multiparams)
 
@@ -686,8 +683,15 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
 
         return table.delete(clause)
 
-    for connection, del_objects in delete.items():
-        statement = base_mapper._memo(('delete', table), delete_stmt)
+    statement = base_mapper._memo(('delete', table), delete_stmt)
+    for connection, recs in groupby(
+        delete,
+        lambda rec: rec[1]
+    ):
+        del_objects = [
+            params
+            for params, connection in recs
+        ]
 
         connection = cached_connections[connection]
 
@@ -740,15 +744,12 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
                 )
 
 
-def _finalize_insert_update_commands(base_mapper, uowtransaction,
-                                     states_to_insert, states_to_update):
+def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
     """finalize state on states that have been inserted or updated,
     including calling after_insert/after_update events.
 
     """
-    for state, state_dict, mapper, connection, has_identity, \
-            row_switch in states_to_insert + \
-            states_to_update:
+    for state, state_dict, mapper, connection, has_identity in states:
 
         if mapper._readonly_props:
             readonly = state.unmodified_intersection(