From 99cb7b09378703a4158d302888bd41734130f2da Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 3 Sep 2005 15:22:11 +0000 Subject: [PATCH] manytomany save --- lib/sqlalchemy/mapper.py | 86 +++++++++++++++++++++++++--------------- lib/sqlalchemy/util.py | 2 + test/mapper.py | 28 ++++++++++++- 3 files changed, 83 insertions(+), 33 deletions(-) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 40968412d3..d117553f30 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -231,6 +231,9 @@ class Mapper(object): def foo(): for table in self.tables: params = {} + # TODO: prepare the insert() and update() - (1) within the code or + # (2) as a real prepared statement, just once, and put them somewhere for + # an external loop to grab onto them for primary_key in table.primary_keys: if self._getattrbycolumn(obj, primary_key) is None: statement = table.insert() @@ -430,37 +433,45 @@ class PropertyLoader(MapperProperty): self.primaryjoin = match_primaries(parent.selectable, self.target) def save(self, obj, traverse, refetch): - # if a mapping table does not exist, save a row for all objects - # in our list normally, setting their primary keys - # else, determine the foreign key column in our table, set it to the parent - # of all child objects before saving - # if a mapping table exists, determine the two foreign key columns - # in the mapping table, set the two values, and insert that row, for - # each row in the list - if self.secondary is None: - setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, obj) - childlist = getattr(obj, self.key) - if not isinstance(childlist, util.HistoryArraySet): - childlist = util.HistoryArraySet(childlist) - clean_setattr(obj, self.key, childlist) - for child in childlist.added_items(): - setter.child = child - self.primaryjoin.accept_visitor(setter) - child.dirty = True - for child in childlist.deleted_items(): - setter.child = child - setter.clearkeys = True - self.primaryjoin.accept_visitor(setter) - child.dirty = True - self.mapper.save(child) - for child in childlist: - self.mapper.save(child) - # TODO: if transaction fails state is invalid - # use unit of work ? - childlist.clear_history() - else: - raise "TODO" + # saves child objects + + # TODO: put association table inserts/deletes into one batch + #if self.secondary is not None: + # secondary_delete = self.secondary.delete(sql.and_([c == bindparam(c.key) for c in setter.secondary.c])) + + setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj) + childlist = getattr(obj, self.key) + if not isinstance(childlist, util.HistoryArraySet): + childlist = util.HistoryArraySet(childlist) + clean_setattr(obj, self.key, childlist) + for child in childlist.deleted_items(): + setter.child = child + setter.clearkeys = True + self.primaryjoin.accept_visitor(setter) + child.dirty = True + self.mapper.save(child) + if self.secondary is not None: + self.secondaryjoin.accept_visitor(setter) + # TODO: prepare this above + statement = self.secondary.delete(sql.and_(*[c == setter.associationrow[c.key] for c in self.secondary.c])) + statement.echo = self.mapper.echo + statement.execute() + for child in childlist.added_items(): + setter.child = child + self.primaryjoin.accept_visitor(setter) + child.dirty = True self.mapper.save(child) + if self.secondary is not None: + self.secondaryjoin.accept_visitor(setter) + # TODO: prepare this above + statement = self.secondary.insert() + statement.echo = self.mapper.echo + statement.execute(**setter.associationrow) + for child in childlist.unchanged_items(): + self.mapper.save(child) + # TODO: if transaction fails state is invalid + # use unit of work ? + childlist.clear_history() def delete(self): @@ -615,17 +626,20 @@ class TableFinder(sql.ClauseVisitor): self.tables.append(table) class ForeignKeySetter(sql.ClauseVisitor): - def __init__(self, parentmapper, childmapper, primarytable, secondarytable, obj): + def __init__(self, parentmapper, childmapper, primarytable, secondarytable, associationtable, obj): self.parentmapper = parentmapper self.childmapper = childmapper self.primarytable = primarytable self.secondarytable = secondarytable + self.associationtable = associationtable self.obj = obj + self.associationrow = {} self.clearkeys = False self.child = None def visit_binary(self, binary): if binary.operator == '=': + # TODO: this code is silly if binary.left.table == self.primarytable and binary.right.table == self.secondarytable: if self.clearkeys: self.childmapper._setattrbycolumn(self.child, binary.right, None) @@ -636,7 +650,15 @@ class ForeignKeySetter(sql.ClauseVisitor): self.childmapper._setattrbycolumn(self.child, binary.left, None) else: self.childmapper._setattrbycolumn(self.child, binary.left, self.parentmapper._getattrbycolumn(self.obj, binary.right)) - + elif binary.right.table == self.associationtable and binary.left.table == self.primarytable: + self.associationrow[binary.right.key] = self.parentmapper._getattrbycolumn(self.obj, binary.left) + elif binary.left.table == self.associationtable and binary.right.table == self.primarytable: + self.associationrow[binary.left.key] = self.parentmapper._getattrbycolumn(self.obj, binary.right) + elif binary.right.table == self.associationtable and binary.left.table == self.secondarytable: + self.associationrow[binary.right.key] = self.childmapper._getattrbycolumn(self.child, binary.left) + elif binary.left.table == self.associationtable and binary.right.table == self.secondarytable: + self.associationrow[binary.left.key] = self.childmapper._getattrbycolumn(self.child, binary.right) + class LazyIzer(sql.ClauseVisitor): """converts an expression which refers to a table column into an expression refers to a Bind Param, i.e. a specific value. diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 1c86bf8c30..dd6d030457 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -157,6 +157,8 @@ class HistoryArraySet(UserList.UserList): return [key for key, value in self.records.iteritems() if value is True] def deleted_items(self): return [key for key, value in self.records.iteritems() if value is False] + def unchanged_items(self): + return [key for key, value in self.records.iteritems() if value is None] def append_nohistory(self, item): if not self.records.has_key(item): self.records[item] = None diff --git a/test/mapper.py b/test/mapper.py index ccd2a9fe40..7f129cf945 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -375,6 +375,32 @@ class SaveTest(PersistTest): addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(a.address_id, a2.address_id)).execute()).fetchall() self.assert_(addresstable[0].row == (a.address_id, u.user_id, 'one2many@test.org')) self.assert_(addresstable[1].row == (a2.address_id, None, 'lala@test.org')) - + + def testmanytomany(self): + items = orderitems + + m = mapper(Item, items, properties = dict( + keywords = relation(Keyword, keywords, itemkeywords, lazy = False), + ), echo = True) + + keywordmapper = mapper(Keyword, keywords) + + item = Item() + item.item_name = 'item1' + item.keywords = [] + k = Keyword() + k.name = 'purple' + item.keywords.append(k) + klist = keywordmapper.select(keywords.c.name.in_('blue', 'big', 'round')) + for k in klist: + item.keywords.append(k) + m.save(item) + print repr(m.select(items.c.item_id == item.item_id)) + + del item.keywords[2] + del item.keywords[2] + m.save(item) + print repr(m.select(items.c.item_id == item.item_id)) + if __name__ == "__main__": unittest.main() -- 2.47.2