]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
further refactoring of topological sort for clarity
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Nov 2006 03:03:55 +0000 (03:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Nov 2006 03:03:55 +0000 (03:03 +0000)
lib/sqlalchemy/topological.py
test/base/dependency.py

index cff5eadee05a2ca522eb587d66aec6a445eda123..25e57632bc9445b1f5de0cd7ccb47a75a61dd99e 100644 (file)
@@ -31,103 +31,145 @@ have been of this nature - very tricky to reproduce and track down, particularly
 realized this characteristic of the algorithm.
 """
 import string, StringIO
-from sets import *
 from sqlalchemy import util
 from sqlalchemy.exceptions import *
 
+class _Node(object):
+    """represents each item in the sort.  While the topological sort
+    produces a straight ordered list of items, _Node ultimately stores a tree-structure
+    of those items which are organized so that non-dependent nodes are siblings."""
+    def __init__(self, item):
+        self.item = item
+        self.dependencies = util.Set()
+        self.children = []
+        self.cycles = None
+    def __str__(self):
+        return self.safestr()
+    def safestr(self, indent=0):
+        return (' ' * indent * 2) + \
+            str(self.item) + \
+            (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \
+            "\n" + \
+            string.join([n.safestr(indent + 1) for n in self.children], '')
+    def __repr__(self):
+        return "%s" % (str(self.item))
+    def is_dependent(self, child):
+        if self.cycles is not None:
+            for c in self.cycles:
+                if child in c.dependencies:
+                    return True
+        if child.cycles is not None:
+            for c in child.cycles:
+                if c in self.dependencies:
+                    return True
+        return child in self.dependencies
+
+class _EdgeCollection(object):
+    """a collection of directed edges."""
+    def __init__(self):
+        self.parent_to_children = {}
+        self.child_to_parents = {}
+    def add(self, edge):
+        """add an edge to this collection."""
+        (parentnode, childnode) = edge
+        if not self.parent_to_children.has_key(parentnode):
+            self.parent_to_children[parentnode] = util.Set()
+        self.parent_to_children[parentnode].add(childnode)
+        if not self.child_to_parents.has_key(childnode):
+            self.child_to_parents[childnode] = util.Set()
+        self.child_to_parents[childnode].add(parentnode)
+        parentnode.dependencies.add(childnode)
+    def remove(self, edge):
+        """remove an edge from this collection.  return the childnode if it has no other parents"""
+        (parentnode, childnode) = edge
+        self.parent_to_children[parentnode].remove(childnode)
+        self.child_to_parents[childnode].remove(parentnode)
+        if len(self.child_to_parents[childnode]) == 0:
+            return childnode
+        else:
+            return None
+    def has_parents(self, node):
+        return self.child_to_parents.has_key(node) and len(self.child_to_parents[node]) > 0
+    def edges_by_parent(self, node):
+        if self.parent_to_children.has_key(node):
+            return [(node, child) for child in self.parent_to_children[node]]
+        else:
+            return []
+    def get_parents(self):
+        return self.parent_to_children.keys()
+    def pop_node(self, node):
+        """remove all edges where the given node is a parent.  
+        
+        returns the collection of all nodes which were children of the given node, and have
+        no further parents."""
+        children = self.parent_to_children.pop(node, None)
+        if children is not None:
+            for child in children:
+                self.child_to_parents[child].remove(node)
+                if not len(self.child_to_parents[child]):
+                    yield child
+    def __len__(self):
+        return sum([len(x) for x in self.parent_to_children.values()])
+    def __iter__(self):
+        for parent, children in self.parent_to_children.iteritems():
+            for child in children:
+                yield (parent, child)
+    def __str__(self):
+        return repr(list(self))
+        
+        
 class QueueDependencySorter(object):
     """topological sort adapted from wikipedia's article on the subject.  it creates a straight-line
     list of elements, then a second pass groups non-dependent actions together to build
     more of a tree structure with siblings."""
-    class Node:
-        """represents a node in a tree.  stores an 'item' which represents the 
-        dependent thing we are talking about.  if node 'a' is an ancestor node of 
-        node 'b', it means 'a's item is *not* dependent on that of 'b'."""
-        def __init__(self, item):
-            self.item = item
-            self.edges = {}
-            self.dependencies = {}
-            self.children = []
-            self.cycles = None
-        def __str__(self):
-            return self.safestr()
-        def safestr(self, indent=0):
-            return (' ' * indent * 2) + \
-                str(self.item) + \
-                (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \
-                "\n" + \
-                string.join([n.safestr(indent + 1) for n in self.children], '')
-                
-        def describe(self):
-            return "%s" % (str(self.item))
-        def __repr__(self):
-            return self.describe()
-        def is_dependent(self, child):
-            if self.cycles is not None:
-                for c in self.cycles:
-                    if c.dependencies.has_key(child):
-                        return True
-            if child.cycles is not None:
-                for c in child.cycles:
-                    if self.dependencies.has_key(c):
-                        return True
-            return self.dependencies.has_key(child)
             
     def __init__(self, tuples, allitems):
         self.tuples = tuples
         self.allitems = allitems
 
-    def _dump_edges(self, edges):
-        s = StringIO.StringIO()
-        for key, value in edges.iteritems():
-            for c in value.keys():
-                s.write("%s->%s\n" % (repr(key), repr(c)))
-        return s.getvalue()
-        
     def sort(self, allow_self_cycles=True, allow_all_cycles=False):
         (tuples, allitems) = (self.tuples, self.allitems)
-
         #print "\n---------------------------------\n"        
         #print repr([t for t in tuples])
         #print repr([a for a in allitems])
         #print "\n---------------------------------\n"        
 
         nodes = {}
-        edges = {}
+        edges = _EdgeCollection()
         for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
             if not nodes.has_key(item):
-                node = QueueDependencySorter.Node(item)
+                node = _Node(item)
                 nodes[item] = node
-                edges[node] = {}
         
         for t in tuples:
             if t[0] is t[1]:
                 if allow_self_cycles:
                     n = nodes[t[0]]
-                    n.cycles = Set([n])
+                    n.cycles = util.Set([n])
                     continue
                 else:
                     raise FlushError("Self-referential dependency detected " + repr(t))
             childnode = nodes[t[1]]
             parentnode = nodes[t[0]]
-            self._add_edge(edges, (parentnode, childnode))
+            edges.add((parentnode, childnode))
 
         queue = []
         for n in nodes.values():
-            if len(n.edges) == 0:
+            if not edges.has_parents(n):
                 queue.append(n)
         cycles = {}
         output = []
-        while len(edges) > 0:
+        while len(nodes) > 0:
             if len(queue) == 0:
                 # edges remain but no edgeless nodes to remove; this indicates
                 # a cycle
                 if allow_all_cycles:
+                    x = self._find_cycles(edges)
                     for cycle in self._find_cycles(edges):
                         lead = cycle[0][0]
-                        lead.cycles = Set()
+                        lead.cycles = util.Set()
                         for edge in cycle:
-                            n = self._remove_edge(edges, edge)
+                            n = edges.remove(edge)
                             lead.cycles.add(edge[0])
                             lead.cycles.add(edge[1])
                             if n is not None:
@@ -135,9 +177,9 @@ class QueueDependencySorter(object):
                         for n in lead.cycles:
                             if n is not lead:
                                 n._cyclical = True
-                                for k in list(edges[n]):
-                                    self._add_edge(edges, (lead,k))
-                                    self._remove_edge(edges, (n,k))
+                                for (n,k) in list(edges.edges_by_parent(n)):
+                                    edges.add((lead, k))
+                                    edges.remove((n,k))
                     continue
                 else:
                     # long cycles not allowed
@@ -145,14 +187,9 @@ class QueueDependencySorter(object):
             node = queue.pop()
             if not hasattr(node, '_cyclical'):
                 output.append(node)
-            nodeedges = edges.pop(node, None)
-            if nodeedges is None:
-                continue
-            for childnode in nodeedges.keys():
-                del childnode.edges[node]
-                if len(childnode.edges) == 0:
-                    queue.append(childnode)
-
+            del nodes[node.item]
+            for childnode in edges.pop_node(node):
+                queue.append(childnode)
         return self._create_batched_tree(output)
         
 
@@ -192,21 +229,6 @@ class QueueDependencySorter(object):
                     return False
         return sort()
         
-        
-    def _add_edge(self, edges, edge):
-        (parentnode, childnode) = edge
-        edges[parentnode][childnode] = True
-        parentnode.dependencies[childnode] = True
-        childnode.edges[parentnode] = True
-
-    def _remove_edge(self, edges, edge):
-        (parentnode, childnode) = edge
-        del edges[parentnode][childnode]
-        del childnode.edges[parentnode]
-        del parentnode.dependencies[childnode]
-        if len(childnode.edges) == 0:
-            return childnode
-    
     def _find_cycles(self, edges):
         involved_in_cycles = util.Set()
         cycles = {}
@@ -217,12 +239,11 @@ class QueueDependencySorter(object):
             elif node is goal:
                 return True
                 
-            for key in edges[node].keys():
+            for (n, key) in edges.edges_by_parent(node):
                 if key in cycle:
                     continue
                 cycle.append(key)
                 if traverse(key, goal, cycle):
-                    #print "adding cycle", list(cycle)
                     cycset = util.Set(cycle)
                     for x in cycle:
                         involved_in_cycles.add(x)
@@ -236,19 +257,13 @@ class QueueDependencySorter(object):
                             cycles[x] = cycset
                 cycle.pop()
                     
-        for parent in edges.keys():
+        for parent in edges.get_parents():
             traverse(parent)
 
         for cycle in dict((id(s), s) for s in cycles.values()).values():
             edgecollection = []
-            for edge in self.edge_iterator(edges):
+            for edge in edges:
                 if edge[0] in cycle and edge[1] in cycle:
                     edgecollection.append(edge)
             yield edgecollection
     
-    def edge_iterator(self, edges):
-        for key in edges.keys():
-            for value in edges[key].keys():
-                yield (key, value)
-                    
-
index b794629d2060ebc17e3d4edc21ebe7a677d84ddf..7c6578cb8d96f74aa0c6284f6440b7dfbc0280b2 100644 (file)
@@ -10,7 +10,7 @@ class DependencySorter(topological.QueueDependencySorter):pass
     
         
 class DependencySortTest(PersistTest):
-    def assert_sort(self, tuples, node):
+    def assert_sort(self, tuples, node, collection=None):
         print str(node)
         def assert_tuple(tuple, node):
             if node.cycles:
@@ -29,15 +29,20 @@ class DependencySortTest(PersistTest):
         for tuple in tuples:
             assert_tuple(list(tuple), node)
 
+        if collection is None:
+            collection = []
         items = util.Set()
         def assert_unique(node):
             for item in [n.item for n in node.cycles or [node,]]:
                 assert item not in items
                 items.add(item)
+                if item in collection:
+                    collection.remove(item)
             for c in node.children:
                 assert_unique(c)
         assert_unique(node)
-      
+        assert len(collection) == 0
+        
     def testsort(self):
         rootnode = 'root'
         node2 = 'node2'
@@ -77,7 +82,7 @@ class DependencySortTest(PersistTest):
             (node6, node2)
         ]
         head = DependencySorter(tuples, [node7]).sort()
-        self.assert_sort(tuples, head)
+        self.assert_sort(tuples, head, [node7])
 
     def testsort3(self):
         ['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords']