]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
so here is kind of the idea. but it doesn't work like it used to.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Apr 2010 21:49:58 +0000 (17:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Apr 2010 21:49:58 +0000 (17:49 -0400)
so I think I want to try to build a smarter "find everything without a dependency"
system that is more inline with how this is running now anyway - i.e.
go through the whole list, find nodes with no dependencies.  maybe the
original topological.sort() can do that, not sure.

lib/sqlalchemy/orm/unitofwork.py

index 9f888f8672408d6025f248e301e8ea9f3e6995f9..bb3bb4fb271fe13f155b12f229152787f0cac733 100644 (file)
@@ -227,10 +227,16 @@ class UOWTransaction(object):
             (head, children) = topological.organize_as_tree(self.dependencies, sort)
             stack = [(head, children)]
             
+            head.execute(self)
             while stack:
                 node, children = stack.pop()
-                node.execute(self)
-                stack += children
+                if children:
+                    related = set([n[0] for n in children])
+                    while related:
+                        n = related.pop()
+                        n.execute_aggregate(self, related)
+                
+                    stack += children
         else:
             for rec in sort:
                 rec.execute(self)
@@ -273,6 +279,9 @@ class PostSortRec(object):
             uow.postsort_actions[key] = ret = object.__new__(cls)
             return ret
     
+    def execute_aggregate(self, uow, recs):
+        self.execute(uow)
+        
     def __repr__(self):
         return "%s(%s)" % (
             self.__class__.__name__,
@@ -408,6 +417,17 @@ class SaveUpdateState(PostSortRec):
             uow
         )
 
+    def execute_aggregate(self, uow, recs):
+        cls_ = self.__class__
+        # TODO: have 'mapper' be present on SaveUpdateState already
+        mapper = self.state.manager.mapper.base_mapper
+        
+        our_recs = [r for r in recs 
+                        if r.__class__ is cls_ and 
+                        r.state.manager.mapper.base_mapper is mapper]
+        recs.difference_update(our_recs)
+        mapper._save_obj([self.state] + [r.state for r in our_recs], uow)
+
     def __repr__(self):
         return "%s(%s)" % (
             self.__class__.__name__,