m = m.inherits
return m is self
+ def iterate_to_root(self):
+ m = self
+ while m is not None:
+ yield m
+ m = m.inherits
+
+ def polymorphic_iterator(self):
+ m = self.base_mapper()
+ def iterate(m):
+ yield m
+ for mapper in m._inheriting_mappers:
+ for x in iterate(mapper):
+ yield x
+ return iterate(m)
+
def accept_mapper_option(self, option):
option.process_mapper(self)
self.columntoproperty[column][0].setattr(obj, value)
def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
- """called by a UnitOfWork object to save objects, which involves either an INSERT or
- an UPDATE statement for each table used by this mapper, for each element of the
- list."""
+ """save a list of objects.
+
+ this method is called within a unit of work flush() process. It saves objects that are mapped not just
+ by this mapper, but inherited mappers as well, so that insert ordering of polymorphic objects is maintained."""
+
self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched"))
# if batch=false, call save_obj separately for each object
if not postupdate:
for obj in objects:
if not has_identity(obj):
- self.extension.before_insert(self, connection, obj)
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.before_insert(mapper, connection, obj)
else:
- self.extension.before_update(self, connection, obj)
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.before_update(mapper, connection, obj)
inserted_objects = util.Set()
updated_objects = util.Set()
- for table in self.tables.sort(reverse=False):
- #print "SAVE_OBJ table ", self.class_.__name__, table.name
- # looping through our set of tables, which are all "real" tables, as opposed
- # to our main table which might be a select statement or something non-writeable
-
- # the loop structure is tables on the outer loop, objects on the inner loop.
- # this allows us to bundle inserts/updates on the same table together...although currently
- # they are separate execs via execute(), not executemany()
+
+ table_to_mapper = {}
+ for mapper in self.polymorphic_iterator():
+ for t in mapper.tables:
+ table_to_mapper[t] = mapper
- if not self._has_pks(table):
- #print "NO PKS ?", str(table)
- # if we dont have a full set of primary keys for this table, we cant really
- # do any CRUD with it, so skip. this occurs if we are mapping against a query
- # that joins on other tables so its not really an error condition.
- continue
-
+ for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False):
# two lists to store parameters for each table/object pair located
insert = []
update = []
- # we have our own idea of the primary key columns
- # for this table, in the case that the user
- # specified custom primary key cols.
for obj in objects:
- instance_key = self.instance_key(obj)
+ mapper = object_mapper(obj)
+ if table not in mapper.tables or not mapper._has_pks(table):
+ continue
+ instance_key = mapper.instance_key(obj)
self.__log_debug("save_obj() instance %s identity %s" % (mapperutil.instance_str(obj), str(instance_key)))
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
params = {}
hasdata = False
for col in table.columns:
- if col is self.version_id_col:
+ if col is mapper.version_id_col:
if not isinsert:
- params[col._label] = self._getattrbycolumn(obj, col)
+ params[col._label] = mapper._getattrbycolumn(obj, col)
params[col.key] = params[col._label] + 1
else:
params[col.key] = 1
- elif col in self.pks_by_table[table]:
+ elif col in mapper.pks_by_table[table]:
# column is a primary key ?
if not isinsert:
# doing an UPDATE? put primary key values as "WHERE" parameters
# matching the bindparam we are creating below, i.e. "<tablename>_<colname>"
- params[col._label] = self._getattrbycolumn(obj, col)
+ params[col._label] = mapper._getattrbycolumn(obj, col)
else:
# doing an INSERT, primary key col ?
# if the primary key values are not populated,
# leave them out of the INSERT altogether, since PostGres doesn't want
# them to be present for SERIAL to take effect. A SQLEngine that uses
# explicit sequences will put them back in if they are needed
- value = self._getattrbycolumn(obj, col)
+ value = mapper._getattrbycolumn(obj, col)
if value is not None:
params[col.key] = value
- elif self.polymorphic_on is not None and self.polymorphic_on.shares_lineage(col):
+ elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
if isinsert:
- self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (self.polymorphic_identity, col.key))
- value = self.polymorphic_identity
+ self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
+ value = mapper.polymorphic_identity
if col.default is None or value is not None:
params[col.key] = value
else:
params[col.key] = self._getattrbycolumn(obj, col)
hasdata = True
continue
- prop = self._getpropbycolumn(col, False)
+ prop = mapper._getpropbycolumn(col, False)
if prop is None:
continue
history = prop.get_history(obj, passive=True)
# default. if its None and theres no default, we still might
# not want to put it in the col list but SQLIte doesnt seem to like that
# if theres no columns at all
- value = self._getattrbycolumn(obj, col, False)
+ value = mapper._getattrbycolumn(obj, col, False)
if value is NO_ATTRIBUTE:
continue
if col.default is None or value is not None:
update.append((obj, params))
else:
insert.append((obj, params))
-
+
+ mapper = table_to_mapper[table]
if len(update):
clause = sql.and_()
- for col in self.pks_by_table[table]:
+ for col in mapper.pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label, type=col.type))
- if self.version_id_col is not None:
- clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label, type=col.type))
+ if mapper.version_id_col is not None:
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
def comparator(a, b):
- for col in self.pks_by_table[table]:
+ for col in mapper.pks_by_table[table]:
x = cmp(a[1][col._label],b[1][col._label])
if x != 0:
return x
for rec in update:
(obj, params) = rec
c = connection.execute(statement, params)
- self._postfetch(connection, table, obj, c, c.last_updated_params())
+ mapper._postfetch(connection, table, obj, c, c.last_updated_params())
updated_objects.add(obj)
rows += c.cursor.rowcount
primary_key = c.last_inserted_ids()
if primary_key is not None:
i = 0
- for col in self.pks_by_table[table]:
- if self._getattrbycolumn(obj, col) is None and len(primary_key) > i:
- self._setattrbycolumn(obj, col, primary_key[i])
+ for col in mapper.pks_by_table[table]:
+ if mapper._getattrbycolumn(obj, col) is None and len(primary_key) > i:
+ mapper._setattrbycolumn(obj, col, primary_key[i])
i+=1
- self._postfetch(connection, table, obj, c, c.last_inserted_params())
+ mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
# synchronize newly inserted ids from one table to the next
def sync(mapper):
sync(inherit)
if mapper._synchronizer is not None:
mapper._synchronizer.execute(obj, obj)
- sync(self)
+ sync(mapper)
inserted_objects.add(obj)
if not postupdate:
- [self.extension.after_insert(self, connection, obj) for obj in inserted_objects]
- [self.extension.after_update(self, connection, obj) for obj in updated_objects]
+ for obj in inserted_objects:
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.after_insert(mapper, connection, obj)
+ for obj in updated_objects:
+ for mapper in object_mapper(obj).iterate_to_root():
+ mapper.extension.after_update(mapper, connection, obj)
def _postfetch(self, connection, table, obj, resultproxy, params):
"""after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side
class TableCollection(object):
- def __init__(self):
- self.tables = []
+ def __init__(self, tables=None):
+ self.tables = tables or []
+ def __len__(self):
+ return len(self.tables)
+ def __getitem__(self, i):
+ return self.tables[i]
+ def __iter__(self):
+ return iter(self.tables)
+ def __contains__(self, obj):
+ return obj in self.tables
+ def __add__(self, obj):
+ return self.tables + list(obj)
def add(self, table):
self.tables.append(table)
if hasattr(self, '_sorted'):
import sqlalchemy.orm.topological
tuples = []
class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(self, fkey):
+ def visit_foreign_key(_self, fkey):
parent_table = fkey.column.table
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
+ if parent_table in self:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
vis = TVisitor()
for table in self.tables:
table.accept_schema_visitor(vis)
table.accept_visitor(self)
def visit_table(self, table):
self.tables.append(table)
- def __len__(self):
- return len(self.tables)
- def __getitem__(self, i):
- return self.tables[i]
- def __iter__(self):
- return iter(self.tables)
- def __contains__(self, obj):
- return obj in self.tables
- def __add__(self, obj):
- return self.tables + list(obj)
def visit_column(self, column):
if self.check_columns:
column.table.accept_visitor(self)
session.delete(c)
session.flush()
+ def test_insert_order(self):
+ person_join = polymorphic_union(
+ {
+ 'engineer':people.join(engineers),
+ 'manager':people.join(managers),
+ 'person':people.select(people.c.type=='person'),
+ }, None, 'pjoin')
+
+ person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person')
+
+ mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
+ mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
+ mapper(Company, companies, properties={
+ 'employees': relation(Person, private=True, backref='company', order_by=person_join.c.person_id)
+ })
+
+ session = create_session()
+ c = Company(name='company1')
+ c.employees.append(Manager(status='AAB', manager_name='manager1', name='pointy haired boss'))
+ c.employees.append(Engineer(status='BBA', engineer_name='engineer1', primary_language='java', name='dilbert'))
+ c.employees.append(Person(status='HHH', name='joesmith'))
+ c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', name='wally'))
+ c.employees.append(Manager(status='ABA', manager_name='manager2', name='jsmith'))
+ session.save(c)
+ session.flush()
+ session.clear()
+ c = session.query(Company).get(c.company_id)
+ for e in c.employees:
+ print e, e._instance_key, e.company
+
+ assert [e.get_name() for e in c.employees] == ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']
+
if __name__ == "__main__":
testbase.main()