]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
a new batching algorithm for the topological sort
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Apr 2006 23:47:26 +0000 (23:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Apr 2006 23:47:26 +0000 (23:47 +0000)
lib/sqlalchemy/mapping/topological.py
test/dependency.py

index 779faab2d4b5e9bae93e1365e4c9ec2c2b7420f7..495eec8cecd26c7dbdc8ea154431a04a0ae373b2 100644 (file)
@@ -57,6 +57,16 @@ class QueueDependencySorter(object):
             return "%s  (idself=%s)" % (str(self.item), repr(id(self)))
         def __repr__(self):
             return self.describe()
+        def is_dependent(self, child):
+            if self.cycles is not None:
+                for c in self.cycles:
+                    if c.dependencies.has_key(child):
+                        return True
+            if child.cycles is not None:
+                for c in child.cycles:
+                    if self.dependencies.has_key(c):
+                        return True
+            return self.dependencies.has_key(child)
             
     def __init__(self, tuples, allitems):
         self.tuples = tuples
@@ -138,21 +148,46 @@ class QueueDependencySorter(object):
                 if len(childnode.edges) == 0:
                     queue.append(childnode)
 
-        #print repr(output)
-        head = None
-        node = None
-        # put the sorted list into a "tree".  this is not much of a 
-        # "tree" at the moment as its more of a linked list.  it would be nice
-        # to group non-dependent nodes into sibling nodes, which allows better batching
-        # of SQL statements, but this algorithm has proved tricky
-        for o in output:
-            if head is None:
-                head = o
-            else:
-                node.children.append(o)
-            node = o
-        return head
+        return self._create_batched_tree(output)
+        
 
+    def _create_batched_tree(self, nodes):
+        """given a list of nodes from a topological sort, organizes the nodes into a tree structure,
+        with as many non-dependent nodes set as silbings to each other as possible."""
+        def sort(index=None, l=None):
+            if index is None:
+                index = 0
+            
+            if index >= len(nodes):
+                return None
+            
+            node = nodes[index]
+            l2 = []
+            sort(index + 1, l2)
+            for n in l2:
+                if l is None or search_dep(node, n):
+                    node.children.append(n)
+                else:
+                    l.append(n)
+            if l is not None:
+                l.append(node)
+            return node
+            
+        def search_dep(parent, child):
+            if child is None:
+                return False
+            elif parent.is_dependent(child):
+                return True
+            else:
+                for c in child.children:
+                    x = search_dep(parent, c)
+                    if x is True:
+                        return True
+                else:
+                    return False
+        return sort()
+        
+        
     def _add_edge(self, edges, edge):
         (parentnode, childnode) = edge
         edges[parentnode][childnode] = True
index 0aede4c7e5555f901c7b598d4d3965d3927dfd30..81165dc6d03d4e15168980b269bddf80be1f9382 100644 (file)
@@ -24,10 +24,10 @@ class DependencySortTest(PersistTest):
 
         print "\n" + str(head)
         def findnode(t, n, parent=False):
-            if n.item is t[0]:
+            if n.item is t[0] or (n.cycles is not None and t[0] in [c.item for c in n.cycles]):
                 parent=True
             elif n.item is t[1]:
-                if not parent and t[0] not in [c.item for c in n.cycles]:
+                if not parent and (n.cycles is None or t[0] not in [c.item for c in n.cycles]):
                     self.assert_(False, "Node " + str(t[1]) + " not a child of " +str(t[0]))
                 else:
                     return
@@ -148,6 +148,7 @@ class DependencySortTest(PersistTest):
         self._assert_sort(tuples, allitems)
 
     def testcircular(self):
+        #print "TESTCIRCULAR"
         node1 = thingy('node1')
         node2 = thingy('node2')
         node3 = thingy('node3')
@@ -162,6 +163,7 @@ class DependencySortTest(PersistTest):
             (node4, node1)
         ]
         self._assert_sort(tuples, [node1,node2,node3,node4,node5], allow_all_cycles=True)
+        #print "TESTCIRCULAR DONE"
         
 
 if __name__ == "__main__":