From: Mike Bayer Date: Sun, 8 Oct 2006 19:30:00 +0000 (+0000) Subject: - mapper.save_obj() now functions across all mappers in its polymorphic X-Git-Tag: rel_0_3_0~73 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ef77cfa61b6a894202495d460b055de6fea9eed6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - mapper.save_obj() now functions across all mappers in its polymorphic series, UOWTask calls mapper appropriately in this manner - polymorphic mappers (i.e. using inheritance) now produces INSERT statements in order of tables across all inherited classes [ticket:321] --- diff --git a/CHANGES b/CHANGES index 06268a1f4b..927f9c0959 100644 --- a/CHANGES +++ b/CHANGES @@ -98,6 +98,9 @@ - more rearrangements of unit-of-work commit scheme to better allow dependencies within circular flushes to work properly...updated task traversal/logging implementation + - polymorphic mappers (i.e. using inheritance) now produces INSERT + statements in order of tables across all inherited classes + [ticket:321] - added an automatic "row switch" feature to mapping, which will detect a pending instance/deleted instance pair with the same identity key and convert the INSERT/DELETE to a single UPDATE diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index b60d9c0342..f8e9bca072 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -517,6 +517,21 @@ class Mapper(object): m = m.inherits return m is self + def iterate_to_root(self): + m = self + while m is not None: + yield m + m = m.inherits + + def polymorphic_iterator(self): + m = self.base_mapper() + def iterate(m): + yield m + for mapper in m._inheriting_mappers: + for x in iterate(mapper): + yield x + return iterate(m) + def accept_mapper_option(self, option): option.process_mapper(self) @@ -702,9 +717,11 @@ class Mapper(object): self.columntoproperty[column][0].setattr(obj, value) def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False): - """called by a UnitOfWork object to save objects, which involves either an INSERT or - an UPDATE statement for each table used by this mapper, for each element of the - list.""" + """save a list of objects. + + this method is called within a unit of work flush() process. It saves objects that are mapped not just + by this mapper, but inherited mappers as well, so that insert ordering of polymorphic objects is maintained.""" + self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched")) # if batch=false, call save_obj separately for each object @@ -718,37 +735,30 @@ class Mapper(object): if not postupdate: for obj in objects: if not has_identity(obj): - self.extension.before_insert(self, connection, obj) + for mapper in object_mapper(obj).iterate_to_root(): + mapper.extension.before_insert(mapper, connection, obj) else: - self.extension.before_update(self, connection, obj) + for mapper in object_mapper(obj).iterate_to_root(): + mapper.extension.before_update(mapper, connection, obj) inserted_objects = util.Set() updated_objects = util.Set() - for table in self.tables.sort(reverse=False): - #print "SAVE_OBJ table ", self.class_.__name__, table.name - # looping through our set of tables, which are all "real" tables, as opposed - # to our main table which might be a select statement or something non-writeable - - # the loop structure is tables on the outer loop, objects on the inner loop. - # this allows us to bundle inserts/updates on the same table together...although currently - # they are separate execs via execute(), not executemany() + + table_to_mapper = {} + for mapper in self.polymorphic_iterator(): + for t in mapper.tables: + table_to_mapper[t] = mapper - if not self._has_pks(table): - #print "NO PKS ?", str(table) - # if we dont have a full set of primary keys for this table, we cant really - # do any CRUD with it, so skip. this occurs if we are mapping against a query - # that joins on other tables so its not really an error condition. - continue - + for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False): # two lists to store parameters for each table/object pair located insert = [] update = [] - # we have our own idea of the primary key columns - # for this table, in the case that the user - # specified custom primary key cols. for obj in objects: - instance_key = self.instance_key(obj) + mapper = object_mapper(obj) + if table not in mapper.tables or not mapper._has_pks(table): + continue + instance_key = mapper.instance_key(obj) self.__log_debug("save_obj() instance %s identity %s" % (mapperutil.instance_str(obj), str(instance_key))) # detect if we have a "pending" instance (i.e. has no instance_key attached to it), @@ -766,31 +776,31 @@ class Mapper(object): params = {} hasdata = False for col in table.columns: - if col is self.version_id_col: + if col is mapper.version_id_col: if not isinsert: - params[col._label] = self._getattrbycolumn(obj, col) + params[col._label] = mapper._getattrbycolumn(obj, col) params[col.key] = params[col._label] + 1 else: params[col.key] = 1 - elif col in self.pks_by_table[table]: + elif col in mapper.pks_by_table[table]: # column is a primary key ? if not isinsert: # doing an UPDATE? put primary key values as "WHERE" parameters # matching the bindparam we are creating below, i.e. "_" - params[col._label] = self._getattrbycolumn(obj, col) + params[col._label] = mapper._getattrbycolumn(obj, col) else: # doing an INSERT, primary key col ? # if the primary key values are not populated, # leave them out of the INSERT altogether, since PostGres doesn't want # them to be present for SERIAL to take effect. A SQLEngine that uses # explicit sequences will put them back in if they are needed - value = self._getattrbycolumn(obj, col) + value = mapper._getattrbycolumn(obj, col) if value is not None: params[col.key] = value - elif self.polymorphic_on is not None and self.polymorphic_on.shares_lineage(col): + elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col): if isinsert: - self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (self.polymorphic_identity, col.key)) - value = self.polymorphic_identity + self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key)) + value = mapper.polymorphic_identity if col.default is None or value is not None: params[col.key] = value else: @@ -805,7 +815,7 @@ class Mapper(object): params[col.key] = self._getattrbycolumn(obj, col) hasdata = True continue - prop = self._getpropbycolumn(col, False) + prop = mapper._getpropbycolumn(col, False) if prop is None: continue history = prop.get_history(obj, passive=True) @@ -821,7 +831,7 @@ class Mapper(object): # default. if its None and theres no default, we still might # not want to put it in the col list but SQLIte doesnt seem to like that # if theres no columns at all - value = self._getattrbycolumn(obj, col, False) + value = mapper._getattrbycolumn(obj, col, False) if value is NO_ATTRIBUTE: continue if col.default is None or value is not None: @@ -834,18 +844,19 @@ class Mapper(object): update.append((obj, params)) else: insert.append((obj, params)) - + + mapper = table_to_mapper[table] if len(update): clause = sql.and_() - for col in self.pks_by_table[table]: + for col in mapper.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, type=col.type)) - if self.version_id_col is not None: - clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label, type=col.type)) + if mapper.version_id_col is not None: + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type)) statement = table.update(clause) rows = 0 supports_sane_rowcount = True def comparator(a, b): - for col in self.pks_by_table[table]: + for col in mapper.pks_by_table[table]: x = cmp(a[1][col._label],b[1][col._label]) if x != 0: return x @@ -854,7 +865,7 @@ class Mapper(object): for rec in update: (obj, params) = rec c = connection.execute(statement, params) - self._postfetch(connection, table, obj, c, c.last_updated_params()) + mapper._postfetch(connection, table, obj, c, c.last_updated_params()) updated_objects.add(obj) rows += c.cursor.rowcount @@ -873,11 +884,11 @@ class Mapper(object): primary_key = c.last_inserted_ids() if primary_key is not None: i = 0 - for col in self.pks_by_table[table]: - if self._getattrbycolumn(obj, col) is None and len(primary_key) > i: - self._setattrbycolumn(obj, col, primary_key[i]) + for col in mapper.pks_by_table[table]: + if mapper._getattrbycolumn(obj, col) is None and len(primary_key) > i: + mapper._setattrbycolumn(obj, col, primary_key[i]) i+=1 - self._postfetch(connection, table, obj, c, c.last_inserted_params()) + mapper._postfetch(connection, table, obj, c, c.last_inserted_params()) # synchronize newly inserted ids from one table to the next def sync(mapper): @@ -886,12 +897,16 @@ class Mapper(object): sync(inherit) if mapper._synchronizer is not None: mapper._synchronizer.execute(obj, obj) - sync(self) + sync(mapper) inserted_objects.add(obj) if not postupdate: - [self.extension.after_insert(self, connection, obj) for obj in inserted_objects] - [self.extension.after_update(self, connection, obj) for obj in updated_objects] + for obj in inserted_objects: + for mapper in object_mapper(obj).iterate_to_root(): + mapper.extension.after_insert(mapper, connection, obj) + for obj in updated_objects: + for mapper in object_mapper(obj).iterate_to_root(): + mapper.extension.after_update(mapper, connection, obj) def _postfetch(self, connection, table, obj, resultproxy, params): """after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 2750b34d1c..2beb6a78eb 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -627,8 +627,7 @@ class UOWTask(object): pass def _save_objects(self, trans): - for task in self.polymorphic_tasks(): - task.mapper.save_obj(task.tosave_objects, trans) + self.mapper.save_obj(self.polymorphic_tosave_objects, trans) def _delete_objects(self, trans): for task in self.polymorphic_tasks(): task.mapper.delete_obj(task.todelete_objects, trans) @@ -701,6 +700,7 @@ class UOWTask(object): todelete_elements = property(lambda self:[rec for rec in self.get_elements(polymorphic=False) if rec.isdelete]) tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is False]) todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is True]) + polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True) if rec.obj is not None and not rec.listonly and rec.isdelete is False]) def _sort_circular_dependencies(self, trans, cycles): """for a single task, creates a hierarchical tree of "subtasks" which associate diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 4015fd2442..94caade68b 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -6,8 +6,18 @@ import sqlalchemy.util as util class TableCollection(object): - def __init__(self): - self.tables = [] + def __init__(self, tables=None): + self.tables = tables or [] + def __len__(self): + return len(self.tables) + def __getitem__(self, i): + return self.tables[i] + def __iter__(self): + return iter(self.tables) + def __contains__(self, obj): + return obj in self.tables + def __add__(self, obj): + return self.tables + list(obj) def add(self, table): self.tables.append(table) if hasattr(self, '_sorted'): @@ -29,10 +39,11 @@ class TableCollection(object): import sqlalchemy.orm.topological tuples = [] class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(self, fkey): + def visit_foreign_key(_self, fkey): parent_table = fkey.column.table - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) + if parent_table in self: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) vis = TVisitor() for table in self.tables: table.accept_schema_visitor(vis) @@ -57,16 +68,6 @@ class TableFinder(TableCollection, sql.ClauseVisitor): table.accept_visitor(self) def visit_table(self, table): self.tables.append(table) - def __len__(self): - return len(self.tables) - def __getitem__(self, i): - return self.tables[i] - def __iter__(self): - return iter(self.tables) - def __contains__(self, obj): - return obj in self.tables - def __add__(self, obj): - return self.tables + list(obj) def visit_column(self, column): if self.check_columns: column.table.accept_visitor(self) diff --git a/test/orm/polymorph.py b/test/orm/polymorph.py index 410af94d81..ce8f99320a 100644 --- a/test/orm/polymorph.py +++ b/test/orm/polymorph.py @@ -225,6 +225,38 @@ class MultipleTableTest(testbase.PersistTest): session.delete(c) session.flush() + def test_insert_order(self): + person_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + 'person':people.select(people.c.type=='person'), + }, None, 'pjoin') + + person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person') + + mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + mapper(Company, companies, properties={ + 'employees': relation(Person, private=True, backref='company', order_by=person_join.c.person_id) + }) + + session = create_session() + c = Company(name='company1') + c.employees.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss')) + c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', name='dilbert')) + c.employees.append(Person(status='HHH', name='joesmith')) + c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', name='wally')) + c.employees.append(Manager(status='ABA', manager_name='manager2', name='jsmith')) + session.save(c) + session.flush() + session.clear() + c = session.query(Company).get(c.company_id) + for e in c.employees: + print e, e._instance_key, e.company + + assert [e.get_name() for e in c.employees] == ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith'] + if __name__ == "__main__": testbase.main()