]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
manytomany save
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Sep 2005 15:22:11 +0000 (15:22 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Sep 2005 15:22:11 +0000 (15:22 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/util.py
test/mapper.py

index 40968412d3a387bfd6bd27cd03feab27eb8c1140..d117553f307d969bf24af2109d83db6f079edad3 100644 (file)
@@ -231,6 +231,9 @@ class Mapper(object):
             def foo():
                 for table in self.tables:
                     params = {}
+                    # TODO: prepare the insert() and update() - (1) within the code or
+                    # (2) as a real prepared statement, just once, and put them somewhere for 
+                    # an external loop to grab onto them
                     for primary_key in table.primary_keys:
                         if self._getattrbycolumn(obj, primary_key) is None:
                             statement = table.insert()
@@ -430,37 +433,45 @@ class PropertyLoader(MapperProperty):
                 self.primaryjoin = match_primaries(parent.selectable, self.target)
 
     def save(self, obj, traverse, refetch):
-        # if a mapping table does not exist, save a row for all objects
-        # in our list normally, setting their primary keys
-        # else, determine the foreign key column in our table, set it to the parent
-        # of all child objects before saving
-        # if a mapping table exists, determine the two foreign key columns 
-        # in the mapping table, set the two values, and insert that row, for
-        # each row in the list
-        if self.secondary is None:
-            setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, obj)
-            childlist = getattr(obj, self.key)
-            if not isinstance(childlist, util.HistoryArraySet):
-                childlist = util.HistoryArraySet(childlist)
-                clean_setattr(obj, self.key, childlist)
-            for child in childlist.added_items():
-                setter.child = child
-                self.primaryjoin.accept_visitor(setter)
-                child.dirty = True
-            for child in childlist.deleted_items():
-                setter.child = child
-                setter.clearkeys = True
-                self.primaryjoin.accept_visitor(setter)
-                child.dirty = True
-                self.mapper.save(child)
-            for child in childlist:
-                self.mapper.save(child)
-            # TODO: if transaction fails state is invalid
-            # use unit of work ?
-            childlist.clear_history()
-        else:
-            raise "TODO"
+        # saves child objects
+        
+        # TODO: put association table inserts/deletes into one batch
+        #if self.secondary is not None:
+         #   secondary_delete = self.secondary.delete(sql.and_([c == bindparam(c.key) for c in setter.secondary.c]))
+            
+        setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj)
+        childlist = getattr(obj, self.key)
+        if not isinstance(childlist, util.HistoryArraySet):
+            childlist = util.HistoryArraySet(childlist)
+            clean_setattr(obj, self.key, childlist)
+        for child in childlist.deleted_items():
+            setter.child = child
+            setter.clearkeys = True
+            self.primaryjoin.accept_visitor(setter)
+            child.dirty = True
+            self.mapper.save(child)
+            if self.secondary is not None:
+                self.secondaryjoin.accept_visitor(setter)
+                # TODO: prepare this above
+                statement = self.secondary.delete(sql.and_(*[c == setter.associationrow[c.key] for c in self.secondary.c]))
+                statement.echo = self.mapper.echo
+                statement.execute()
+        for child in childlist.added_items():
+            setter.child = child
+            self.primaryjoin.accept_visitor(setter)
+            child.dirty = True
             self.mapper.save(child)
+            if self.secondary is not None:
+                self.secondaryjoin.accept_visitor(setter)
+                # TODO: prepare this above
+                statement = self.secondary.insert()
+                statement.echo = self.mapper.echo
+                statement.execute(**setter.associationrow)
+        for child in childlist.unchanged_items():
+            self.mapper.save(child)
+        # TODO: if transaction fails state is invalid
+        # use unit of work ?
+        childlist.clear_history()
             
             
     def delete(self):
@@ -615,17 +626,20 @@ class TableFinder(sql.ClauseVisitor):
         self.tables.append(table)
 
 class ForeignKeySetter(sql.ClauseVisitor):
-    def __init__(self, parentmapper, childmapper, primarytable, secondarytable, obj):
+    def __init__(self, parentmapper, childmapper, primarytable, secondarytable, associationtable, obj):
         self.parentmapper = parentmapper
         self.childmapper = childmapper
         self.primarytable = primarytable
         self.secondarytable = secondarytable
+        self.associationtable = associationtable
         self.obj = obj
+        self.associationrow = {}
         self.clearkeys = False
         self.child = None
 
     def visit_binary(self, binary):
         if binary.operator == '=':
+            # TODO: this code is silly
             if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
                 if self.clearkeys:
                     self.childmapper._setattrbycolumn(self.child, binary.right, None)
@@ -636,7 +650,15 @@ class ForeignKeySetter(sql.ClauseVisitor):
                     self.childmapper._setattrbycolumn(self.child, binary.left, None)
                 else:
                     self.childmapper._setattrbycolumn(self.child, binary.left, self.parentmapper._getattrbycolumn(self.obj, binary.right))
-
+            elif binary.right.table == self.associationtable and binary.left.table == self.primarytable:
+                self.associationrow[binary.right.key] = self.parentmapper._getattrbycolumn(self.obj, binary.left)
+            elif binary.left.table == self.associationtable and binary.right.table == self.primarytable:
+                self.associationrow[binary.left.key] = self.parentmapper._getattrbycolumn(self.obj, binary.right)
+            elif binary.right.table == self.associationtable and binary.left.table == self.secondarytable:
+                self.associationrow[binary.right.key] = self.childmapper._getattrbycolumn(self.child, binary.left)
+            elif binary.left.table == self.associationtable and binary.right.table == self.secondarytable:
+                self.associationrow[binary.left.key] = self.childmapper._getattrbycolumn(self.child, binary.right)
+                
 class LazyIzer(sql.ClauseVisitor):
     """converts an expression which refers to a table column into an
     expression refers to a Bind Param, i.e. a specific value.  
index 1c86bf8c30848455785231d4ed700c12203995ee..dd6d030457914578aa7a916f5b982ad4369d8769 100644 (file)
@@ -157,6 +157,8 @@ class HistoryArraySet(UserList.UserList):
         return [key for key, value in self.records.iteritems() if value is True]
     def deleted_items(self):
         return [key for key, value in self.records.iteritems() if value is False]
+    def unchanged_items(self):
+        return [key for key, value in self.records.iteritems() if value is None]
     def append_nohistory(self, item):
         if not self.records.has_key(item):
             self.records[item] = None
index ccd2a9fe40da7e8257d19c527194215b9d4265a9..7f129cf945fde189913d37461b02249fbadbb3b5 100644 (file)
@@ -375,6 +375,32 @@ class SaveTest(PersistTest):
         addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(a.address_id, a2.address_id)).execute()).fetchall()
         self.assert_(addresstable[0].row == (a.address_id, u.user_id, 'one2many@test.org'))
         self.assert_(addresstable[1].row == (a2.address_id, None, 'lala@test.org'))
-        
+
+    def testmanytomany(self):
+        items = orderitems
+
+        m = mapper(Item, items, properties = dict(
+                keywords = relation(Keyword, keywords, itemkeywords, lazy = False),
+            ), echo = True)
+
+        keywordmapper = mapper(Keyword, keywords)
+
+        item = Item()
+        item.item_name = 'item1'
+        item.keywords = []
+        k = Keyword()
+        k.name = 'purple'
+        item.keywords.append(k)
+        klist = keywordmapper.select(keywords.c.name.in_('blue', 'big', 'round'))
+        for k in klist:
+            item.keywords.append(k)
+        m.save(item)
+        print repr(m.select(items.c.item_id == item.item_id))
+
+        del item.keywords[2]
+        del item.keywords[2]
+        m.save(item)
+        print repr(m.select(items.c.item_id == item.item_id))
+
 if __name__ == "__main__":
     unittest.main()