]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
working on representing longer circular relationships
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jan 2006 06:18:08 +0000 (06:18 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jan 2006 06:18:08 +0000 (06:18 +0000)
lib/sqlalchemy/mapping/topological.py
test/dependency.py

index 70aa2105de9535be0359f5866fef6686f0d93941..92db01c237f91688060b44f4ce14688bd3a5ad8e 100644 (file)
@@ -46,10 +46,11 @@ class QueueDependencySorter(object):
             self.edges = {}
             self.dependencies = {}
             self.children = []
+            self.cycles = []
         def __str__(self):
             return self.safestr()
         def safestr(self, indent=0):
-            return (' ' * indent) + "%s  (idself=%s)" % (str(self.item), repr(id(self))) + "\n" + string.join([n.safestr(indent + 1) for n in self.children], '')
+            return (' ' * indent) + "%s  (idself=%s)" % (str(self.item), repr(id(self))) + repr(self.cycles) + "\n" + string.join([n.safestr(indent + 1) for n in self.children], '')
         def describe(self):
             return "%s  (idself=%s)" % (str(self.item), repr(id(self)))
         def __repr__(self):
@@ -59,7 +60,7 @@ class QueueDependencySorter(object):
         self.tuples = tuples
         self.allitems = allitems
         
-    def sort(self):
+    def sort(self, allow_self_cycles=True, allow_all_cycles=False):
         (tuples, allitems) = (self.tuples, self.allitems)
 
         #print "\n---------------------------------\n"        
@@ -77,8 +78,13 @@ class QueueDependencySorter(object):
         
         for t in tuples:
             if t[0] is t[1]:
-                nodes[t[0]].circular = True
-                continue
+                if allow_self_cycles:
+                    n = nodes[t[0]]
+                    n.circular = True
+                    n.cycles.append(n)
+                    continue
+                else:
+                    raise "Self-referential dependency detected " + repr(t)
             childnode = nodes[t[1]]
             parentnode = nodes[t[0]]
             edges[parentnode][childnode] = True
@@ -90,10 +96,34 @@ class QueueDependencySorter(object):
             if len(n.edges) == 0:
                 queue.append(n)
         
+        cycles = {}
         output = []
         while len(edges) > 0:
             if len(queue) == 0:
-                raise "Circular dependency detected " + repr(edges) + repr(queue)
+                # edges remain but no edgeless nodes to remove; this indicates
+                # a cycle
+                if allow_all_cycles:
+                    # for each cycle, throw all the nodes involved in that 
+                    # cycle into a list attached to one of those nodes, then 
+                    # add just that node to the output.
+                    for parentnode in edges.keys():
+                        d = edges[parentnode]
+                        for childnode in d.keys():
+                            if cycles.has_key(parentnode):
+                                cycles[parentnode].append(childnode)
+                                cycles[childnode] = cycles[parentnode]
+                            elif cycles.has_key(childnode):
+                                cycles[childnode].append(parentnode)
+                                cycles[parentnode] = cycles[childnode]
+                            else:
+                                cycles[parentnode] = parentnode.cycles
+                                parentnode.cycles.append(childnode)
+                                cycles[childnode] = parentnode.cycles
+                                output.append(parentnode)
+                    break
+                else:
+                    # long cycles not allowed
+                    raise "Circular dependency detected " + repr(edges) + repr(queue)
             node = queue.pop()
             output.append(node)
             nodeedges = edges.pop(node, None)
index 159a1864a847b2785991d8a1a59a612dbb58eefb..c2e5db164e5e75d85798d74e494eecfa5ac3c5b7 100644 (file)
@@ -115,6 +115,22 @@ class DependencySortTest(PersistTest):
         ]
         head = DependencySorter(tuples, allitems).sort()
         print "\n" + str(head)
+
+    def testcircular(self):
+        node1 = thingy('node1')
+        node2 = thingy('node2')
+        node3 = thingy('node3')
+        node4 = thingy('node4')
+        node5 = thingy('node5')
+        tuples = [
+            (node1, node2),
+            (node2, node3),
+            (node3, node1),
+            (node4, node5),
+            (node5, node4)
+        ]
+        head = DependencySorter(tuples, []).sort()
+        print "\n" + str(head)
         
 
 if __name__ == "__main__":