]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more adjustments to #321
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 20:24:30 +0000 (20:24 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 20:24:30 +0000 (20:24 +0000)
lib/sqlalchemy/orm/mapper.py

index 977a438e14c735ecfd0560b901ba6251d1d28f9a..84ff75a795a8670cff64bd50cc71f694e247ae7f 100644 (file)
@@ -759,10 +759,11 @@ class Mapper(object):
         updated_objects = util.Set()
         
         table_to_mapper = {}
+        tables = util.Set()
         for mapper in self.polymorphic_iterator():
             for t in mapper.tables:
-                table_to_mapper[t] = mapper
-            
+                table_to_mapper.setdefault(t, mapper)
+
         for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False):
             # two lists to store parameters for each table/object pair located
             insert = []
@@ -844,12 +845,12 @@ class Mapper(object):
                     if hasdata:
                         # if none of the attributes changed, dont even
                         # add the row to be updated.
-                        update.append((obj, params))
+                        update.append((obj, params, mapper))
                 else:
-                    insert.append((obj, params))
+                    insert.append((obj, params, mapper))
 
-            mapper = table_to_mapper[table]
             if len(update):
+                mapper = table_to_mapper[table]
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col._label, type=col.type))
@@ -866,7 +867,7 @@ class Mapper(object):
                     return 0
                 update.sort(comparator)
                 for rec in update:
-                    (obj, params) = rec
+                    (obj, params, mapper) = rec
                     c = connection.execute(statement, params)
                     mapper._postfetch(connection, table, obj, c, c.last_updated_params())
 
@@ -882,7 +883,7 @@ class Mapper(object):
                     return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
                 insert.sort(comparator)
                 for rec in insert:
-                    (obj, params) = rec
+                    (obj, params, mapper) = rec
                     c = connection.execute(statement, params)
                     primary_key = c.last_inserted_ids()
                     if primary_key is not None: