]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more fixes to topological sort with regards to cycles, fixes [ticket:365]
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Nov 2006 01:34:41 +0000 (01:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Nov 2006 01:34:41 +0000 (01:34 +0000)
lib/sqlalchemy/topological.py
test/base/dependency.py

index 9a7f2dd19214cbf83665bc1e0892c96c3b56fca6..cff5eadee05a2ca522eb587d66aec6a445eda123 100644 (file)
@@ -119,26 +119,25 @@ class QueueDependencySorter(object):
         cycles = {}
         output = []
         while len(edges) > 0:
-            #print self._dump_edges(edges)
             if len(queue) == 0:
                 # edges remain but no edgeless nodes to remove; this indicates
                 # a cycle
                 if allow_all_cycles:
-                    cycle = self._find_cycle(edges)
-                    lead = cycle[0][0]
-                    lead.cycles = Set()
-                    for edge in cycle:
-                        n = self._remove_edge(edges, edge)
-                        lead.cycles.add(edge[0])
-                        lead.cycles.add(edge[1])
-                        if n is not None:
-                            queue.append(n)
-                    for n in lead.cycles:
-                        if n is not lead:
-                            n._cyclical = True
-                    # loop through cycle
-                    # remove edges from the edge dictionary
-                    # install the cycled nodes in the "cycle" list of one of the nodes
+                    for cycle in self._find_cycles(edges):
+                        lead = cycle[0][0]
+                        lead.cycles = Set()
+                        for edge in cycle:
+                            n = self._remove_edge(edges, edge)
+                            lead.cycles.add(edge[0])
+                            lead.cycles.add(edge[1])
+                            if n is not None:
+                                queue.append(n)
+                        for n in lead.cycles:
+                            if n is not lead:
+                                n._cyclical = True
+                                for k in list(edges[n]):
+                                    self._add_edge(edges, (lead,k))
+                                    self._remove_edge(edges, (n,k))
                     continue
                 else:
                     # long cycles not allowed
@@ -207,33 +206,49 @@ class QueueDependencySorter(object):
         del parentnode.dependencies[childnode]
         if len(childnode.edges) == 0:
             return childnode
-        
-    def _find_cycle(self, edges):
-        """given a structure of edges, locates a cycle in the strucure and returns 
-        as a list of tuples representing edges involved in the cycle."""
-        seen = Set()
-        cycled_edges = []
-        def traverse(d, parent=None):
-            for key in d.keys():
-                if not edges.has_key(key):
+    
+    def _find_cycles(self, edges):
+        involved_in_cycles = util.Set()
+        cycles = {}
+        def traverse(node, goal=None, cycle=None):
+            if goal is None:
+                goal = node
+                cycle = []
+            elif node is goal:
+                return True
+                
+            for key in edges[node].keys():
+                if key in cycle:
                     continue
-                if key in seen:
-                    if parent is not None:
-                        cycled_edges.append((parent, key))
-                    return key
-                seen.add(key)
-                x = traverse(edges[key], parent=key)
-                if x is None:
-                    seen.remove(key)
-                else:
-                    if parent is not None:
-                        cycled_edges.append((parent, key))
-                    return x
-            else:
-                return None
-        s = traverse(edges)
-        if s is None:
-            return None
-        else:
-            return cycled_edges
+                cycle.append(key)
+                if traverse(key, goal, cycle):
+                    #print "adding cycle", list(cycle)
+                    cycset = util.Set(cycle)
+                    for x in cycle:
+                        involved_in_cycles.add(x)
+                        if cycles.has_key(x):
+                            existing_set = cycles[x]
+                            [existing_set.add(y) for y in cycset]
+                            for y in existing_set:
+                                cycles[y] = existing_set
+                            cycset = existing_set
+                        else:
+                            cycles[x] = cycset
+                cycle.pop()
+                    
+        for parent in edges.keys():
+            traverse(parent)
+
+        for cycle in dict((id(s), s) for s in cycles.values()).values():
+            edgecollection = []
+            for edge in self.edge_iterator(edges):
+                if edge[0] in cycle and edge[1] in cycle:
+                    edgecollection.append(edge)
+            yield edgecollection
+    
+    def edge_iterator(self, edges):
+        for key in edges.keys():
+            for value in edges[key].keys():
+                yield (key, value)
+                    
 
index 0d278ebd1c7368ef4ef3941c3b69604be08731d5..b794629d2060ebc17e3d4edc21ebe7a677d84ddf 100644 (file)
@@ -1,23 +1,17 @@
 from testbase import PersistTest
 import sqlalchemy.topological as topological
 import unittest, sys, os
-
+from sqlalchemy import util
 
 # TODO:  need assertion conditions in this suite
 
 
 class DependencySorter(topological.QueueDependencySorter):pass
     
-class thingy(object):
-    def __init__(self, name):
-        self.name = name
-    def __repr__(self):
-        return "thingy(%d, %s)" % (id(self), self.name)
-    def __str__(self):
-        return repr(self)
         
 class DependencySortTest(PersistTest):
     def assert_sort(self, tuples, node):
+        print str(node)
         def assert_tuple(tuple, node):
             if node.cycles:
                 cycles = [i.item for i in node.cycles]
@@ -34,17 +28,26 @@ class DependencySortTest(PersistTest):
         
         for tuple in tuples:
             assert_tuple(list(tuple), node)
-                    
+
+        items = util.Set()
+        def assert_unique(node):
+            for item in [n.item for n in node.cycles or [node,]]:
+                assert item not in items
+                items.add(item)
+            for c in node.children:
+                assert_unique(c)
+        assert_unique(node)
+      
     def testsort(self):
-        rootnode = thingy('root')
-        node2 = thingy('node2')
-        node3 = thingy('node3')
-        node4 = thingy('node4')
-        subnode1 = thingy('subnode1')
-        subnode2 = thingy('subnode2')
-        subnode3 = thingy('subnode3')
-        subnode4 = thingy('subnode4')
-        subsubnode1 = thingy('subsubnode1')
+        rootnode = 'root'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        subnode1 = 'subnode1'
+        subnode2 = 'subnode2'
+        subnode3 = 'subnode3'
+        subnode4 = 'subnode4'
+        subsubnode1 = 'subsubnode1'
         tuples = [
             (subnode3, subsubnode1),
             (node2, subnode1),
@@ -57,16 +60,15 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, []).sort()
         self.assert_sort(tuples, head)
-        print "\n" + str(head)
 
     def testsort2(self):
-        node1 = thingy('node1')
-        node2 = thingy('node2')
-        node3 = thingy('node3')
-        node4 = thingy('node4')
-        node5 = thingy('node5')
-        node6 = thingy('node6')
-        node7 = thingy('node7')
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        node5 = 'node5'
+        node6 = 'node6'
+        node7 = 'node7'
         tuples = [
             (node1, node2),
             (node3, node4),
@@ -76,13 +78,12 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, [node7]).sort()
         self.assert_sort(tuples, head)
-        print "\n" + str(head)
 
     def testsort3(self):
         ['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords']
-        node1 = thingy('keywords')
-        node2 = thingy('itemkeyowrds')
-        node3 = thingy('items')
+        node1 = 'keywords'
+        node2 = 'itemkeyowrds'
+        node3 = 'items'
         tuples = [
             (node1, node2),
             (node3, node2),
@@ -99,10 +100,10 @@ class DependencySortTest(PersistTest):
         print "\n" + str(head3)
 
     def testsort4(self):
-        node1 = thingy('keywords')
-        node2 = thingy('itemkeyowrds')
-        node3 = thingy('items')
-        node4 = thingy('hoho')
+        node1 = 'keywords'
+        node2 = 'itemkeyowrds'
+        node3 = 'items'
+        node4 = 'hoho'
         tuples = [
             (node1, node2),
             (node4, node1),
@@ -111,19 +112,13 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, []).sort()
         self.assert_sort(tuples, head)
-        print "\n" + str(head)
 
     def testsort5(self):
         # this one, depenending on the weather, 
-#    thingy(5780972, node4)  (idself=5781292, idparent=None)
-#     thingy(5780876, node1)  (idself=5781068, idparent=5781292)
-#      thingy(5780908, node2)  (idself=5781164, idparent=5781068)
-#       thingy(5780940, node3)  (idself=5781228, idparent=5781164)
-   
-        node1 = thingy('node1') #thingy('00B94190')
-        node2 = thingy('node2') #thingy('00B94990')
-        node3 = thingy('node3') #thingy('00B9A9B0')
-        node4 = thingy('node4') #thingy('00B4F210')
+        node1 = 'node1' #'00B94190'
+        node2 = 'node2' #'00B94990'
+        node3 = 'node3' #'00B9A9B0'
+        node4 = 'node4' #'00B4F210'
         tuples = [
             (node4, node1),
             (node1, node2),
@@ -140,14 +135,13 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, allitems).sort()
         self.assert_sort(tuples, head)
-        print "\n" + str(head)
 
     def testcircular(self):
-        node1 = thingy('node1')
-        node2 = thingy('node2')
-        node3 = thingy('node3')
-        node4 = thingy('node4')
-        node5 = thingy('node5')
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        node5 = 'node5'
         tuples = [
             (node4, node5),
             (node5, node4),
@@ -158,15 +152,14 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, []).sort(allow_all_cycles=True)
         self.assert_sort(tuples, head)
-        print "\n" + str(head)
         
     def testcircular2(self):
         # this condition was arising from ticket:362
         # and was not treated properly by topological sort
-        node1 = thingy('node1')
-        node2 = thingy('node2')
-        node3 = thingy('node3')
-        node4 = thingy('node4')
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
         tuples = [
             (node1, node2),
             (node3, node1),
@@ -176,7 +169,15 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, []).sort(allow_all_cycles=True)
         self.assert_sort(tuples, head)
-        print "\n" + str(head)
+    
+    def testcircular3(self):
+        nodes = {}
+        tuples = [('Question', 'Issue'), ('ProviderService', 'Issue'), ('Provider', 'Question'), ('Question', 'Provider'), ('ProviderService', 'Question'), ('Provider', 'ProviderService'), ('Question', 'Answer'), ('Issue', 'Question')]
+        head = DependencySorter(tuples, []).sort(allow_all_cycles=True)
+        self.assert_sort(tuples, head)
         
+            
+            
+            
 if __name__ == "__main__":
     unittest.main()