]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
saves basic one 2 many
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Sep 2005 01:18:59 +0000 (01:18 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Sep 2005 01:18:59 +0000 (01:18 +0000)
lib/sqlalchemy/mapper.py
test/mapper.py

index 62a1bb8567b0ba96e35f67aa7debdcee1412f23c..d7e3b9c34da45a4be4523f220cfde77fd6ae8e1a 100644 (file)
@@ -423,6 +423,7 @@ class PropertyLoader(MapperProperty):
     def init(self, key, parent, root):
         self.key = key
         self.mapper.init(root)
+        self.parenttable = parent.selectable
         if self.secondary is not None:
             if self.secondaryjoin is None:
                 self.secondaryjoin = match_primaries(self.target, self.secondary)
@@ -440,20 +441,29 @@ class PropertyLoader(MapperProperty):
         # if a mapping table exists, determine the two foreign key columns 
         # in the mapping table, set the two values, and insert that row, for
         # each row in the list
-  #      for child in getattr(obj, self.key):
-  #          setter = ForeignKeySetter(obj, child)
-  #          self.primaryjoin.accept_visitor(setter)
-  #          self.mapper.save(child)
-        pass
+        setter = ForeignKeySetter(self.parenttable, self.target, obj)
+        for child in getattr(obj, self.key):
+            setter.child = child
+            self.primaryjoin.accept_visitor(setter)
+            self.mapper.save(child)
+        #pass
 
     def delete(self):
         self.mapper.delete()
 
-#class ForeignKeySetter(ClauseVisitor):
- #   def visit_binary(self, binary):
-  #      if binary.operator == '==':
-  #          if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
-  #              setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname))
+class ForeignKeySetter(sql.ClauseVisitor):
+    def __init__(self, primarytable, secondarytable, obj):
+        self.child = None
+        self.obj = obj
+        self.primarytable = primarytable
+        self.secondarytable = secondarytable
+
+    def visit_binary(self, binary):
+        if binary.operator == '=':
+            if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
+                setattr(self.child, binary.left.key, getattr(self.obj, binary.right.key))
+            elif binary.right.table == self.primarytable and binary.left.table == self.secondarytable:
+                setattr(self.child, binary.right.key, getattr(self.obj, binary.left.key))
 
 class LazyLoader(PropertyLoader):
 
index e8ed519f6fc5fe25cd22f750a13b4c70176da1d9..14e3f760df3b0297830a3b12ffcef8444fec017b 100644 (file)
@@ -231,7 +231,6 @@ class EagerTest(PersistTest):
         print repr(l)
 
 class SaveTest(PersistTest):
-        
 
     def testsave(self):
         # save two users
@@ -265,6 +264,8 @@ class SaveTest(PersistTest):
         self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2')
 
     def testsavemultitable(self):
+        """tests a save of an object where each instance spans two tables. also tests
+        redefinition of the keynames for the column properties."""
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
         m = mapper(User, usersaddresses, table = users, echo = True, properties = dict(email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id)))
         u = User()
@@ -280,10 +281,30 @@ class SaveTest(PersistTest):
         u.email = 'lala@hey.com'
         u.user_name = 'imnew'
         m.save(u)
-        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall()
-        self.assert_(usertable[0].row == (10, 'imnew'))
-        addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall()
-        self.assert_(addresstable[0].row == (4, 10, 'lala@hey.com'))
+        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.user_id)).execute()).fetchall()
+        self.assert_(usertable[0].row == (u.user_id, 'imnew'))
+        addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(u.address_id)).execute()).fetchall()
+        self.assert_(addresstable[0].row == (u.address_id, u.user_id, 'lala@hey.com'))
+
+    def testsaveonetomany(self):
+        m = mapper(User, users, properties = dict(
+            addresses = relation(Address, addresses, lazy = True)
+        ), echo = True)
+        u = User()
+        u.user_name = 'one2manytester'
+        u.addresses = []
+        a = Address()
+        a.email_address = 'one2many@test.org'
+        u.addresses.append(a)
+        a2 = Address()
+        a2.email_address = 'lala@test.org'
+        u.addresses.append(a2)
+        m.save(u)
+        usertable = engine.ResultProxy(users.select(users.c.user_id.in_(u.user_id)).execute()).fetchall()
+        self.assert_(usertable[0].row == (u.user_id, 'one2manytester'))
+        addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(a.address_id, a2.address_id)).execute()).fetchall()
+        self.assert_(addresstable[0].row == (a.address_id, u.user_id, 'one2many@test.org'))
+        self.assert_(addresstable[1].row == (a2.address_id, u.user_id, 'lala@test.org'))
 
 if __name__ == "__main__":
     unittest.main()