]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
really got topological going. now that we aren't putting fricking mapped objects...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2010 19:12:29 +0000 (15:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2010 19:12:29 +0000 (15:12 -0400)
it all that id() stuff can go

lib/sqlalchemy/sql/util.py
lib/sqlalchemy/topological.py
test/base/test_dependency.py

index d5575e0e73e7f546428e9208adc28f08a6f017a0..4c59f50d5ad68d2a43b26c1f07587f650c63b364 100644 (file)
@@ -15,10 +15,13 @@ def sort_tables(tables):
         parent_table = fkey.column.table
         if parent_table in tables:
             child_table = fkey.parent.table
-            tuples.append( ( parent_table, child_table ) )
+            if parent_table is not child_table:
+                tuples.append((parent_table, child_table))
 
     for table in tables:
-        visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})    
+        visitors.traverse(table, 
+                            {'schema_visitor':True}, 
+                            {'foreign_key':visit_foreign_key})
     return topological.sort(tuples, tables)
 
 def find_join_source(clauses, join_to):
index 324995889fffddd573d931bd4e4628afe795973c..15739ae8262ce82934b03c2724f49ea617af8d1d 100644 (file)
@@ -21,23 +21,8 @@ conditions.
 from sqlalchemy.exc import CircularDependencyError
 from sqlalchemy import util
 
-__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree']
+__all__ = ['sort']
 
-# TODO: obviate the need for a _Node class.
-# a straight tuple should be used.
-class _Node(tuple):
-    """Represent each item in the sort."""
-    
-    def __new__(cls, item):
-        children = []
-        t = tuple.__new__(cls, [item, children])
-        t.item = item
-        t.children = children
-        return t
-    
-    def __hash__(self):
-        return id(self)
-    
 class _EdgeCollection(object):
     """A collection of directed edges."""
 
@@ -52,20 +37,6 @@ class _EdgeCollection(object):
         self.parent_to_children[parentnode].add(childnode)
         self.child_to_parents[childnode].add(parentnode)
 
-    def remove(self, edge):
-        """Remove an edge from this collection.
-
-        Return the childnode if it has no other parents.
-        """
-
-        (parentnode, childnode) = edge
-        self.parent_to_children[parentnode].remove(childnode)
-        self.child_to_parents[childnode].remove(parentnode)
-        if not self.child_to_parents[childnode]:
-            return childnode
-        else:
-            return None
-
     def has_parents(self, node):
         return node in self.child_to_parents and bool(self.child_to_parents[node])
 
@@ -74,7 +45,12 @@ class _EdgeCollection(object):
             return [(node, child) for child in self.parent_to_children[node]]
         else:
             return []
-
+    
+    def outgoing(self, node):
+        """an iterable returning all nodes reached via node's outgoing edges"""
+        
+        return self.parent_to_children[node]
+        
     def get_parents(self):
         return self.parent_to_children.keys()
 
@@ -92,9 +68,6 @@ class _EdgeCollection(object):
                 if not self.child_to_parents[child]:
                     yield child
 
-    def __len__(self):
-        return sum(len(x) for x in self.parent_to_children.values())
-
     def __iter__(self):
         for parent, children in self.parent_to_children.iteritems():
             for child in children:
@@ -108,69 +81,57 @@ def sort(tuples, allitems):
 
     'tuples' is a list of tuples representing a partial ordering.
     """
-    nodes = {}
-    edges = _EdgeCollection()
 
-    for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]:
-        item_id = id(item)
-        if item_id not in nodes:
-            nodes[item_id] = _Node(item)
+    edges = _EdgeCollection()
+    nodes = set(allitems)
 
     for t in tuples:
-        id0, id1 = id(t[0]), id(t[1])
-        if t[0] is t[1]:
-            continue
-        childnode = nodes[id1]
-        parentnode = nodes[id0]
-        edges.add((parentnode, childnode))
+        nodes.update(t)
+        edges.add(t)
 
     queue = []
-    for n in nodes.values():
+    for n in nodes:
         if not edges.has_parents(n):
             queue.append(n)
 
     output = []
     while nodes:
         if not queue:
-            raise CircularDependencyError("Circular dependency detected " + 
-                                repr(edges) + repr(queue))
+            raise CircularDependencyError("Circular dependency detected: %r" % edges)
         node = queue.pop()
-        output.append(node.item)
-        del nodes[id(node.item)]
+        output.append(node)
+        nodes.remove(node)
         for childnode in edges.pop_node(node):
             queue.append(childnode)
     return output
 
+def find_cycles(tuples, allitems):
+    # straight from gvr with some mods
+    todo = set(allitems)
+    edges = _EdgeCollection()
 
-def _find_cycles(edges):
-    cycles = {}
-
-    def traverse(node, cycle, goal):
-        for (n, key) in edges.edges_by_parent(node):
-            if key in cycle:
-                continue
-            cycle.add(key)
-            if key is goal:
-                cycset = set(cycle)
-                for x in cycle:
-                    if x in cycles:
-                        existing_set = cycles[x]
-                        existing_set.update(cycset)
-                        for y in existing_set:
-                            cycles[y] = existing_set
-                        cycset = existing_set
-                    else:
-                        cycles[x] = cycset
+    for t in tuples:
+        todo.update(t)
+        edges.add(t)
+    
+    output = set()
+    
+    while todo:
+        node = todo.pop()
+        stack = [node]
+        while stack:
+            top = stack[-1]
+            for node in edges.outgoing(top):
+                if node in stack:
+                    cyc = stack[stack.index(node):]
+                    todo.difference_update(cyc)
+                    output.update(cyc)
+                    
+                if node in todo:
+                    stack.append(node)
+                    todo.remove(node)
+                    break
             else:
-                traverse(key, cycle, goal)
-            cycle.pop()
-
-    for parent in edges.get_parents():
-        traverse(parent, set(), parent)
-
-    unique_cycles = set(tuple(s) for s in cycles.values())
+                node = stack.pop()
+    return output
     
-    for cycle in unique_cycles:
-        edgecollection = [edge for edge in edges
-                          if edge[0] in cycle and edge[1] in cycle]
-        yield edgecollection
index 7dc55ea99e89d514949cada64f596fad1bece626..8c38a98b0badf3d9a86b933e61449e1fb32024b3 100644 (file)
@@ -1,6 +1,6 @@
 import sqlalchemy.topological as topological
 from sqlalchemy.test import TestBase
-from sqlalchemy.test.testing import assert_raises
+from sqlalchemy.test.testing import assert_raises, eq_
 from sqlalchemy import exc
 import collections
 
@@ -16,7 +16,7 @@ class DependencySortTest(TestBase):
             for n in result[i:]:
                 assert node not in deps[n]
 
-    def testsort(self):
+    def test_sort_one(self):
         rootnode = 'root'
         node2 = 'node2'
         node3 = 'node3'
@@ -38,7 +38,7 @@ class DependencySortTest(TestBase):
         ]
         self.assert_sort(tuples, topological.sort(tuples, []))
 
-    def testsort2(self):
+    def test_sort_two(self):
         node1 = 'node1'
         node2 = 'node2'
         node3 = 'node3'
@@ -55,7 +55,7 @@ class DependencySortTest(TestBase):
         ]
         self.assert_sort(tuples, topological.sort(tuples, [node7]))
 
-    def testsort4(self):
+    def test_sort_three(self):
         node1 = 'keywords'
         node2 = 'itemkeyowrds'
         node3 = 'items'
@@ -68,7 +68,7 @@ class DependencySortTest(TestBase):
         ]
         self.assert_sort(tuples, topological.sort(tuples, []))
 
-    def testcircular(self):
+    def test_raise_on_cycle_one(self):
         node1 = 'node1'
         node2 = 'node2'
         node3 = 'node3'
@@ -87,7 +87,7 @@ class DependencySortTest(TestBase):
 
         # TODO: test find_cycles
 
-    def testcircular2(self):
+    def test_raise_on_cycle_two(self):
         # this condition was arising from ticket:362
         # and was not treated properly by topological sort
         node1 = 'node1'
@@ -105,7 +105,7 @@ class DependencySortTest(TestBase):
 
         # TODO: test find_cycles
 
-    def testcircular3(self):
+    def test_raise_on_cycle_three(self):
         question, issue, providerservice, answer, provider = "Question", "Issue", "ProviderService", "Answer", "Provider"
 
         tuples = [(question, issue), (providerservice, issue), (provider, question), 
@@ -116,15 +116,14 @@ class DependencySortTest(TestBase):
         
         # TODO: test find_cycles
         
-    def testbigsort(self):
+    def test_large_sort(self):
         tuples = [(i, i + 1) for i in range(0, 1500, 2)]
         self.assert_sort(
             tuples,
             topological.sort(tuples, [])
         )
 
-
-    def testids(self):
+    def test_ticket_1380(self):
         # ticket:1380 regression: would raise a KeyError
         tuples = [(id(i), i) for i in range(3)]
         self.assert_sort(
@@ -132,5 +131,122 @@ class DependencySortTest(TestBase):
             topological.sort(tuples, [])
         )
         
+    def test_find_cycle_one(self):
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        tuples = [
+            (node1, node2),
+            (node3, node1),
+            (node2, node4),
+            (node3, node2),
+            (node2, node3)
+        ]
+
+        eq_(
+            topological.find_cycles(tuples),
+            set([node1, node2, node3])
+        )
+
+    def test_find_multiple_cycles_one(self):
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        node5 = 'node5'
+        node6 = 'node6'
+        node7 = 'node7'
+        node8 = 'node8'
+        node9 = 'node9'
+        tuples = [
+            # cycle 1
+            (node1, node2),
+            (node2, node4),
+            (node4, node1),
+
+            # cycle 2
+            (node9, node9),
+
+            # cycle 3
+            (node7, node5),
+            (node5, node7),
+
+            # cycle 4, but only if cycle 1 nodes are present
+            (node1, node6),
+            (node6, node8),
+            (node8, node4),
+
+            (node3, node1),
+            (node3, node2),
+        ]
+        
+        allnodes = set([node1, node2, node3, node4, node5, node6, node7, node8, node9])
+        eq_(
+            topological.find_cycles(tuples, allnodes),
+            set(['node8', 'node1', 'node2', 'node5', 'node4', 'node7', 'node6', 'node9'])
+        )
+
+    def test_find_multiple_cycles_two(self):
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        node5 = 'node5'
+        node6 = 'node6'
+        tuples = [
+            # cycle 1
+            (node1, node2),
+            (node2, node4),
+            (node4, node1),
+
+            # cycle 2
+            (node1, node6),
+            (node6, node2),
+            (node2, node4),
+            (node4, node1),
+        ]
+
+        allnodes = set([node1, node2, node3, node4, node5, node6])
+        eq_(
+            topological.find_cycles(tuples, allnodes),
+            set(['node1', 'node2', 'node4'])
+        )
+
+    def test_find_multiple_cycles_three(self):
+        node1 = 'node1'
+        node2 = 'node2'
+        node3 = 'node3'
+        node4 = 'node4'
+        node5 = 'node5'
+        node6 = 'node6'
+        tuples = [
+
+            # cycle 1
+            (node1, node2),
+            (node2, node1),
+
+            # cycle 2
+            (node2, node3),
+            (node3, node2),
+
+            # cycle3
+            (node2, node4),
+            (node4, node2),
+            
+            # cycle4
+            (node2, node5),
+            (node5, node6),
+            (node6, node2)
+        ]
+
+        allnodes = set([node1, node2, node3, node4, node5, node6])
+        eq_(
+            topological.find_cycles(tuples, allnodes),
+            allnodes
+        )
+        
+        
+