]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- mapper.save_obj() now functions across all mappers in its polymorphic
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 19:30:00 +0000 (19:30 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Oct 2006 19:30:00 +0000 (19:30 +0000)
series, UOWTask calls mapper appropriately in this manner
- polymorphic mappers (i.e. using inheritance) now produces INSERT
statements in order of tables across all inherited classes
[ticket:321]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/sql_util.py
test/orm/polymorph.py

diff --git a/CHANGES b/CHANGES
index 06268a1f4b68579d829343df54d787a79351a1ea..927f9c0959c93a68f2ed41abed85c4c61b3b73a5 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -98,6 +98,9 @@
     - more rearrangements of unit-of-work commit scheme to better allow
     dependencies within circular flushes to work properly...updated
     task traversal/logging implementation
+    - polymorphic mappers (i.e. using inheritance) now produces INSERT
+    statements in order of tables across all inherited classes
+    [ticket:321]
     - added an automatic "row switch" feature to mapping, which will
     detect a pending instance/deleted instance pair with the same 
     identity key and convert the INSERT/DELETE to a single UPDATE
index b60d9c03421a44e9fe1f33909dc6982dbe1eaf8d..f8e9bca072b979ddda5a7c62391cecab1d84e9de 100644 (file)
@@ -517,6 +517,21 @@ class Mapper(object):
             m = m.inherits
         return m is self
 
+    def iterate_to_root(self):
+        m = self
+        while m is not None:
+            yield m
+            m = m.inherits
+    
+    def polymorphic_iterator(self):
+        m = self.base_mapper()
+        def iterate(m):
+            yield m
+            for mapper in m._inheriting_mappers:
+                for x in iterate(mapper):
+                    yield x
+        return iterate(m)
+                
     def accept_mapper_option(self, option):
         option.process_mapper(self)
         
@@ -702,9 +717,11 @@ class Mapper(object):
         self.columntoproperty[column][0].setattr(obj, value)
     
     def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
-        """called by a UnitOfWork object to save objects, which involves either an INSERT or
-        an UPDATE statement for each table used by this mapper, for each element of the
-        list."""
+        """save a list of objects.
+        
+        this method is called within a unit of work flush() process.  It saves objects that are mapped not just
+        by this mapper, but inherited mappers as well, so that insert ordering of polymorphic objects is maintained."""
+        
         self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched"))
         
         # if batch=false, call save_obj separately for each object
@@ -718,37 +735,30 @@ class Mapper(object):
         if not postupdate:
             for obj in objects:
                 if not has_identity(obj):
-                    self.extension.before_insert(self, connection, obj)
+                    for mapper in object_mapper(obj).iterate_to_root():
+                        mapper.extension.before_insert(mapper, connection, obj)
                 else:
-                    self.extension.before_update(self, connection, obj)
+                    for mapper in object_mapper(obj).iterate_to_root():
+                        mapper.extension.before_update(mapper, connection, obj)
 
         inserted_objects = util.Set()
         updated_objects = util.Set()
-        for table in self.tables.sort(reverse=False):
-            #print "SAVE_OBJ table ", self.class_.__name__, table.name
-            # looping through our set of tables, which are all "real" tables, as opposed
-            # to our main table which might be a select statement or something non-writeable
-            
-            # the loop structure is tables on the outer loop, objects on the inner loop.
-            # this allows us to bundle inserts/updates on the same table together...although currently
-            # they are separate execs via execute(), not executemany()
+        
+        table_to_mapper = {}
+        for mapper in self.polymorphic_iterator():
+            for t in mapper.tables:
+                table_to_mapper[t] = mapper
             
-            if not self._has_pks(table):
-                #print "NO PKS ?", str(table)
-                # if we dont have a full set of primary keys for this table, we cant really
-                # do any CRUD with it, so skip.  this occurs if we are mapping against a query
-                # that joins on other tables so its not really an error condition.
-                continue
-
+        for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False):
             # two lists to store parameters for each table/object pair located
             insert = []
             update = []
             
-            # we have our own idea of the primary key columns 
-            # for this table, in the case that the user
-            # specified custom primary key cols.
             for obj in objects:
-                instance_key = self.instance_key(obj)
+                mapper = object_mapper(obj)
+                if table not in mapper.tables or not mapper._has_pks(table):
+                    continue
+                instance_key = mapper.instance_key(obj)
                 self.__log_debug("save_obj() instance %s identity %s" % (mapperutil.instance_str(obj), str(instance_key)))
 
                 # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
@@ -766,31 +776,31 @@ class Mapper(object):
                 params = {}
                 hasdata = False
                 for col in table.columns:
-                    if col is self.version_id_col:
+                    if col is mapper.version_id_col:
                         if not isinsert:
-                            params[col._label] = self._getattrbycolumn(obj, col)
+                            params[col._label] = mapper._getattrbycolumn(obj, col)
                             params[col.key] = params[col._label] + 1
                         else:
                             params[col.key] = 1
-                    elif col in self.pks_by_table[table]:
+                    elif col in mapper.pks_by_table[table]:
                         # column is a primary key ?
                         if not isinsert:
                             # doing an UPDATE?  put primary key values as "WHERE" parameters
                             # matching the bindparam we are creating below, i.e. "<tablename>_<colname>"
-                            params[col._label] = self._getattrbycolumn(obj, col)
+                            params[col._label] = mapper._getattrbycolumn(obj, col)
                         else:
                             # doing an INSERT, primary key col ? 
                             # if the primary key values are not populated,
                             # leave them out of the INSERT altogether, since PostGres doesn't want
                             # them to be present for SERIAL to take effect.  A SQLEngine that uses
                             # explicit sequences will put them back in if they are needed
-                            value = self._getattrbycolumn(obj, col)
+                            value = mapper._getattrbycolumn(obj, col)
                             if value is not None:
                                 params[col.key] = value
-                    elif self.polymorphic_on is not None and self.polymorphic_on.shares_lineage(col):
+                    elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
                         if isinsert:
-                            self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (self.polymorphic_identity, col.key))
-                            value = self.polymorphic_identity
+                            self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
+                            value = mapper.polymorphic_identity
                             if col.default is None or value is not None:
                                 params[col.key] = value
                     else:
@@ -805,7 +815,7 @@ class Mapper(object):
                                 params[col.key] = self._getattrbycolumn(obj, col)
                                 hasdata = True
                                 continue
-                            prop = self._getpropbycolumn(col, False)
+                            prop = mapper._getpropbycolumn(col, False)
                             if prop is None:
                                 continue
                             history = prop.get_history(obj, passive=True)
@@ -821,7 +831,7 @@ class Mapper(object):
                             # default.  if its None and theres no default, we still might
                             # not want to put it in the col list but SQLIte doesnt seem to like that
                             # if theres no columns at all
-                            value = self._getattrbycolumn(obj, col, False)
+                            value = mapper._getattrbycolumn(obj, col, False)
                             if value is NO_ATTRIBUTE:
                                 continue
                             if col.default is None or value is not None:
@@ -834,18 +844,19 @@ class Mapper(object):
                         update.append((obj, params))
                 else:
                     insert.append((obj, params))
-                    
+
+            mapper = table_to_mapper[table]
             if len(update):
                 clause = sql.and_()
-                for col in self.pks_by_table[table]:
+                for col in mapper.pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col._label, type=col.type))
-                if self.version_id_col is not None:
-                    clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label, type=col.type))
+                if mapper.version_id_col is not None:
+                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type))
                 statement = table.update(clause)
                 rows = 0
                 supports_sane_rowcount = True
                 def comparator(a, b):
-                    for col in self.pks_by_table[table]:
+                    for col in mapper.pks_by_table[table]:
                         x = cmp(a[1][col._label],b[1][col._label])
                         if x != 0:
                             return x
@@ -854,7 +865,7 @@ class Mapper(object):
                 for rec in update:
                     (obj, params) = rec
                     c = connection.execute(statement, params)
-                    self._postfetch(connection, table, obj, c, c.last_updated_params())
+                    mapper._postfetch(connection, table, obj, c, c.last_updated_params())
 
                     updated_objects.add(obj)
                     rows += c.cursor.rowcount
@@ -873,11 +884,11 @@ class Mapper(object):
                     primary_key = c.last_inserted_ids()
                     if primary_key is not None:
                         i = 0
-                        for col in self.pks_by_table[table]:
-                            if self._getattrbycolumn(obj, col) is None and len(primary_key) > i:
-                                self._setattrbycolumn(obj, col, primary_key[i])
+                        for col in mapper.pks_by_table[table]:
+                            if mapper._getattrbycolumn(obj, col) is None and len(primary_key) > i:
+                                mapper._setattrbycolumn(obj, col, primary_key[i])
                             i+=1
-                    self._postfetch(connection, table, obj, c, c.last_inserted_params())
+                    mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
                     
                     # synchronize newly inserted ids from one table to the next
                     def sync(mapper):
@@ -886,12 +897,16 @@ class Mapper(object):
                             sync(inherit)
                         if mapper._synchronizer is not None:
                             mapper._synchronizer.execute(obj, obj)
-                    sync(self)
+                    sync(mapper)
                     
                     inserted_objects.add(obj)
         if not postupdate:
-            [self.extension.after_insert(self, connection, obj) for obj in inserted_objects]
-            [self.extension.after_update(self, connection, obj) for obj in updated_objects]
+            for obj in inserted_objects:
+                for mapper in object_mapper(obj).iterate_to_root():
+                    mapper.extension.after_insert(mapper, connection, obj)
+            for obj in updated_objects:
+                for mapper in object_mapper(obj).iterate_to_root():
+                    mapper.extension.after_update(mapper, connection, obj)
 
     def _postfetch(self, connection, table, obj, resultproxy, params):
         """after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side
index 2750b34d1c680f2345be39b96db7a59ec8fdda88..2beb6a78eb68b7c2180a4e22a0b2167eba1300d5 100644 (file)
@@ -627,8 +627,7 @@ class UOWTask(object):
             pass
 
     def _save_objects(self, trans):
-        for task in self.polymorphic_tasks():
-            task.mapper.save_obj(task.tosave_objects, trans)
+        self.mapper.save_obj(self.polymorphic_tosave_objects, trans)
     def _delete_objects(self, trans):
         for task in self.polymorphic_tasks():
             task.mapper.delete_obj(task.todelete_objects, trans)
@@ -701,6 +700,7 @@ class UOWTask(object):
     todelete_elements = property(lambda self:[rec for rec in self.get_elements(polymorphic=False) if rec.isdelete])
     tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
     todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is True])
+    polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
         
     def _sort_circular_dependencies(self, trans, cycles):
         """for a single task, creates a hierarchical tree of "subtasks" which associate
index 4015fd244238e11ef1e3724418f2ff2b9505c670..94caade68bd74666f7d308d48340f02060a177f1 100644 (file)
@@ -6,8 +6,18 @@ import sqlalchemy.util as util
 
 
 class TableCollection(object):
-    def __init__(self):
-        self.tables = []
+    def __init__(self, tables=None):
+        self.tables = tables or []
+    def __len__(self):
+        return len(self.tables)
+    def __getitem__(self, i):
+        return self.tables[i]
+    def __iter__(self):
+        return iter(self.tables)
+    def __contains__(self, obj):
+        return obj in self.tables
+    def __add__(self, obj):
+        return self.tables + list(obj)
     def add(self, table):
         self.tables.append(table)
         if hasattr(self, '_sorted'):
@@ -29,10 +39,11 @@ class TableCollection(object):
         import sqlalchemy.orm.topological
         tuples = []
         class TVisitor(schema.SchemaVisitor):
-            def visit_foreign_key(self, fkey):
+            def visit_foreign_key(_self, fkey):
                 parent_table = fkey.column.table
-                child_table = fkey.parent.table
-                tuples.append( ( parent_table, child_table ) )
+                if parent_table in self:
+                    child_table = fkey.parent.table
+                    tuples.append( ( parent_table, child_table ) )
         vis = TVisitor()        
         for table in self.tables:
             table.accept_schema_visitor(vis)
@@ -57,16 +68,6 @@ class TableFinder(TableCollection, sql.ClauseVisitor):
             table.accept_visitor(self)
     def visit_table(self, table):
         self.tables.append(table)
-    def __len__(self):
-        return len(self.tables)
-    def __getitem__(self, i):
-        return self.tables[i]
-    def __iter__(self):
-        return iter(self.tables)
-    def __contains__(self, obj):
-        return obj in self.tables
-    def __add__(self, obj):
-        return self.tables + list(obj)
     def visit_column(self, column):
         if self.check_columns:
             column.table.accept_visitor(self)
index 410af94d8180a0205f2000c0a24c12e71291423d..ce8f99320af2028d923c5acb2f8e8659c58f4096 100644 (file)
@@ -225,6 +225,38 @@ class MultipleTableTest(testbase.PersistTest):
         session.delete(c)
         session.flush()
 
+    def test_insert_order(self):
+        person_join = polymorphic_union(
+            {
+                'engineer':people.join(engineers),
+                'manager':people.join(managers),
+                'person':people.select(people.c.type=='person'),
+            }, None, 'pjoin')
+
+        person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person')
+
+        mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
+        mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
+        mapper(Company, companies, properties={
+            'employees': relation(Person, private=True, backref='company', order_by=person_join.c.person_id)
+        })
+
+        session = create_session()
+        c = Company(name='company1')
+        c.employees.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss'))
+        c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', name='dilbert'))
+        c.employees.append(Person(status='HHH', name='joesmith'))
+        c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', name='wally'))
+        c.employees.append(Manager(status='ABA', manager_name='manager2', name='jsmith'))
+        session.save(c)
+        session.flush()
+        session.clear()
+        c = session.query(Company).get(c.company_id)
+        for e in c.employees:
+            print e, e._instance_key, e.company
+        
+        assert [e.get_name() for e in c.employees] == ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']
+
 if __name__ == "__main__":    
     testbase.main()