From: Mike Bayer Date: Fri, 2 Sep 2005 01:18:59 +0000 (+0000) Subject: saves basic one 2 many X-Git-Tag: rel_0_1_0~822 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f3681772cdb0e668ec8908406361b5541a57af9e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git saves basic one 2 many --- diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 62a1bb8567..d7e3b9c34d 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -423,6 +423,7 @@ class PropertyLoader(MapperProperty): def init(self, key, parent, root): self.key = key self.mapper.init(root) + self.parenttable = parent.selectable if self.secondary is not None: if self.secondaryjoin is None: self.secondaryjoin = match_primaries(self.target, self.secondary) @@ -440,20 +441,29 @@ class PropertyLoader(MapperProperty): # if a mapping table exists, determine the two foreign key columns # in the mapping table, set the two values, and insert that row, for # each row in the list - # for child in getattr(obj, self.key): - # setter = ForeignKeySetter(obj, child) - # self.primaryjoin.accept_visitor(setter) - # self.mapper.save(child) - pass + setter = ForeignKeySetter(self.parenttable, self.target, obj) + for child in getattr(obj, self.key): + setter.child = child + self.primaryjoin.accept_visitor(setter) + self.mapper.save(child) + #pass def delete(self): self.mapper.delete() -#class ForeignKeySetter(ClauseVisitor): - # def visit_binary(self, binary): - # if binary.operator == '==': - # if binary.left.table == self.primarytable and binary.right.table == self.secondarytable: - # setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname)) +class ForeignKeySetter(sql.ClauseVisitor): + def __init__(self, primarytable, secondarytable, obj): + self.child = None + self.obj = obj + self.primarytable = primarytable + self.secondarytable = secondarytable + + def visit_binary(self, binary): + if binary.operator == '=': + if binary.left.table == self.primarytable and binary.right.table == self.secondarytable: + setattr(self.child, binary.left.key, getattr(self.obj, binary.right.key)) + elif binary.right.table == self.primarytable and binary.left.table == self.secondarytable: + setattr(self.child, binary.right.key, getattr(self.obj, binary.left.key)) class LazyLoader(PropertyLoader): diff --git a/test/mapper.py b/test/mapper.py index e8ed519f6f..14e3f760df 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -231,7 +231,6 @@ class EagerTest(PersistTest): print repr(l) class SaveTest(PersistTest): - def testsave(self): # save two users @@ -265,6 +264,8 @@ class SaveTest(PersistTest): self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2') def testsavemultitable(self): + """tests a save of an object where each instance spans two tables. also tests + redefinition of the keynames for the column properties.""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) m = mapper(User, usersaddresses, table = users, echo = True, properties = dict(email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id))) u = User() @@ -280,10 +281,30 @@ class SaveTest(PersistTest): u.email = 'lala@hey.com' u.user_name = 'imnew' m.save(u) - usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall() - self.assert_(usertable[0].row == (10, 'imnew')) - addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall() - self.assert_(addresstable[0].row == (4, 10, 'lala@hey.com')) + usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.user_id)).execute()).fetchall() + self.assert_(usertable[0].row == (u.user_id, 'imnew')) + addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(u.address_id)).execute()).fetchall() + self.assert_(addresstable[0].row == (u.address_id, u.user_id, 'lala@hey.com')) + + def testsaveonetomany(self): + m = mapper(User, users, properties = dict( + addresses = relation(Address, addresses, lazy = True) + ), echo = True) + u = User() + u.user_name = 'one2manytester' + u.addresses = [] + a = Address() + a.email_address = 'one2many@test.org' + u.addresses.append(a) + a2 = Address() + a2.email_address = 'lala@test.org' + u.addresses.append(a2) + m.save(u) + usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.user_id)).execute()).fetchall() + self.assert_(usertable[0].row == (u.user_id, 'one2manytester')) + addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(a.address_id, a2.address_id)).execute()).fetchall() + self.assert_(addresstable[0].row == (a.address_id, u.user_id, 'one2many@test.org')) + self.assert_(addresstable[1].row == (a2.address_id, u.user_id, 'lala@test.org')) if __name__ == "__main__": unittest.main()