]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- move out checks for table in mapper._pks_by_table
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2014 21:02:52 +0000 (17:02 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Aug 2014 21:02:52 +0000 (17:02 -0400)
lib/sqlalchemy/orm/persistence.py

index f17b1d79cd8294da3005a8933e19c8a956712cf6..c949e47764ad8ba09f04d1649401b3923c996578 100644 (file)
@@ -63,16 +63,18 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
         if table not in mapper._pks_by_table:
             continue
         insert = (
-            (state, state_dict, mapper, connection)
-            for state, state_dict, mapper, connection, has_identity,
+            (state, state_dict, sub_mapper, connection)
+            for state, state_dict, sub_mapper, connection, has_identity,
             row_switch in states_to_insert
+            if table in sub_mapper._pks_by_table
         )
         insert = _collect_insert_commands(table, insert)
 
         update = (
-            (state, state_dict, mapper, connection, row_switch)
-            for state, state_dict, mapper, connection, has_identity,
+            (state, state_dict, sub_mapper, connection, row_switch)
+            for state, state_dict, sub_mapper, connection, has_identity,
             row_switch in states_to_update
+            if table in sub_mapper._pks_by_table
         )
         update = _collect_update_commands(uowtransaction, table, update)
 
@@ -108,8 +110,16 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
     for table, mapper in base_mapper._sorted_tables.items():
         if table not in mapper._pks_by_table:
             continue
+
+        update = (
+            (state, state_dict, sub_mapper, connection)
+            for
+            state, state_dict, sub_mapper, connection in states_to_update
+            if table in sub_mapper._pks_by_table
+        )
+
         update = _collect_post_update_commands(base_mapper, uowtransaction,
-                                               table, states_to_update,
+                                               table, update,
                                                post_update_cols)
 
         _emit_post_update_statements(base_mapper, uowtransaction,
@@ -139,8 +149,15 @@ def delete_obj(base_mapper, states, uowtransaction):
         if table not in mapper._pks_by_table:
             continue
 
+        delete = (
+            (state, state_dict, sub_mapper, connection)
+            for state, state_dict, sub_mapper, has_identity, connection
+            in states_to_delete if table in sub_mapper._pks_by_table
+            and has_identity
+        )
+
         delete = _collect_delete_commands(base_mapper, uowtransaction,
-                                          table, states_to_delete)
+                                          table, delete)
 
         _emit_delete_statements(base_mapper, uowtransaction,
                                 cached_connections, mapper, table, delete)
@@ -248,8 +265,7 @@ def _collect_insert_commands(table, states_to_insert):
     """
     for state, state_dict, mapper, connection in states_to_insert:
 
-        if table not in mapper._pks_by_table:
-            continue
+        # assert table in mapper._pks_by_table
 
         params = {}
         value_params = {}
@@ -309,8 +325,8 @@ def _collect_update_commands(uowtransaction, table, states_to_update):
     """
 
     for state, state_dict, mapper, connection, row_switch in states_to_update:
-        if table not in mapper._pks_by_table:
-            continue
+
+        # assert table in mapper._pks_by_table
 
         pks = mapper._pks_by_table[table]
 
@@ -394,8 +410,9 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
     """
 
     for state, state_dict, mapper, connection in states_to_update:
-        if table not in mapper._pks_by_table:
-            continue
+
+        # assert table in mapper._pks_by_table
+
         pks = mapper._pks_by_table[table]
         params = {}
         hasdata = False
@@ -425,10 +442,9 @@ def _collect_delete_commands(base_mapper, uowtransaction, table,
     """Identify values to use in DELETE statements for a list of
     states to be deleted."""
 
-    for state, state_dict, mapper, has_identity, connection \
-            in states_to_delete:
-        if not has_identity or table not in mapper._pks_by_table:
-            continue
+    for state, state_dict, mapper, connection in states_to_delete:
+
+        # assert table in mapper._pks_by_table
 
         params = {}
         for col in mapper._pks_by_table[table]: