From: Mike Bayer Date: Sat, 11 Nov 2006 03:03:55 +0000 (+0000) Subject: further refactoring of topological sort for clarity X-Git-Tag: rel_0_3_1~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2d4e0d27dcbf9e7ab5718b203812c54c61ec3a40;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git further refactoring of topological sort for clarity --- diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index cff5eadee0..25e57632bc 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -31,103 +31,145 @@ have been of this nature - very tricky to reproduce and track down, particularly realized this characteristic of the algorithm. """ import string, StringIO -from sets import * from sqlalchemy import util from sqlalchemy.exceptions import * +class _Node(object): + """represents each item in the sort. While the topological sort + produces a straight ordered list of items, _Node ultimately stores a tree-structure + of those items which are organized so that non-dependent nodes are siblings.""" + def __init__(self, item): + self.item = item + self.dependencies = util.Set() + self.children = [] + self.cycles = None + def __str__(self): + return self.safestr() + def safestr(self, indent=0): + return (' ' * indent * 2) + \ + str(self.item) + \ + (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \ + "\n" + \ + string.join([n.safestr(indent + 1) for n in self.children], '') + def __repr__(self): + return "%s" % (str(self.item)) + def is_dependent(self, child): + if self.cycles is not None: + for c in self.cycles: + if child in c.dependencies: + return True + if child.cycles is not None: + for c in child.cycles: + if c in self.dependencies: + return True + return child in self.dependencies + +class _EdgeCollection(object): + """a collection of directed edges.""" + def __init__(self): + self.parent_to_children = {} + self.child_to_parents = {} + def add(self, edge): + """add an edge to this collection.""" + (parentnode, childnode) = edge + if not self.parent_to_children.has_key(parentnode): + self.parent_to_children[parentnode] = util.Set() + self.parent_to_children[parentnode].add(childnode) + if not self.child_to_parents.has_key(childnode): + self.child_to_parents[childnode] = util.Set() + self.child_to_parents[childnode].add(parentnode) + parentnode.dependencies.add(childnode) + def remove(self, edge): + """remove an edge from this collection. return the childnode if it has no other parents""" + (parentnode, childnode) = edge + self.parent_to_children[parentnode].remove(childnode) + self.child_to_parents[childnode].remove(parentnode) + if len(self.child_to_parents[childnode]) == 0: + return childnode + else: + return None + def has_parents(self, node): + return self.child_to_parents.has_key(node) and len(self.child_to_parents[node]) > 0 + def edges_by_parent(self, node): + if self.parent_to_children.has_key(node): + return [(node, child) for child in self.parent_to_children[node]] + else: + return [] + def get_parents(self): + return self.parent_to_children.keys() + def pop_node(self, node): + """remove all edges where the given node is a parent. + + returns the collection of all nodes which were children of the given node, and have + no further parents.""" + children = self.parent_to_children.pop(node, None) + if children is not None: + for child in children: + self.child_to_parents[child].remove(node) + if not len(self.child_to_parents[child]): + yield child + def __len__(self): + return sum([len(x) for x in self.parent_to_children.values()]) + def __iter__(self): + for parent, children in self.parent_to_children.iteritems(): + for child in children: + yield (parent, child) + def __str__(self): + return repr(list(self)) + + class QueueDependencySorter(object): """topological sort adapted from wikipedia's article on the subject. it creates a straight-line list of elements, then a second pass groups non-dependent actions together to build more of a tree structure with siblings.""" - class Node: - """represents a node in a tree. stores an 'item' which represents the - dependent thing we are talking about. if node 'a' is an ancestor node of - node 'b', it means 'a's item is *not* dependent on that of 'b'.""" - def __init__(self, item): - self.item = item - self.edges = {} - self.dependencies = {} - self.children = [] - self.cycles = None - def __str__(self): - return self.safestr() - def safestr(self, indent=0): - return (' ' * indent * 2) + \ - str(self.item) + \ - (self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \ - "\n" + \ - string.join([n.safestr(indent + 1) for n in self.children], '') - - def describe(self): - return "%s" % (str(self.item)) - def __repr__(self): - return self.describe() - def is_dependent(self, child): - if self.cycles is not None: - for c in self.cycles: - if c.dependencies.has_key(child): - return True - if child.cycles is not None: - for c in child.cycles: - if self.dependencies.has_key(c): - return True - return self.dependencies.has_key(child) def __init__(self, tuples, allitems): self.tuples = tuples self.allitems = allitems - def _dump_edges(self, edges): - s = StringIO.StringIO() - for key, value in edges.iteritems(): - for c in value.keys(): - s.write("%s->%s\n" % (repr(key), repr(c))) - return s.getvalue() - def sort(self, allow_self_cycles=True, allow_all_cycles=False): (tuples, allitems) = (self.tuples, self.allitems) - #print "\n---------------------------------\n" #print repr([t for t in tuples]) #print repr([a for a in allitems]) #print "\n---------------------------------\n" nodes = {} - edges = {} + edges = _EdgeCollection() for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]: if not nodes.has_key(item): - node = QueueDependencySorter.Node(item) + node = _Node(item) nodes[item] = node - edges[node] = {} for t in tuples: if t[0] is t[1]: if allow_self_cycles: n = nodes[t[0]] - n.cycles = Set([n]) + n.cycles = util.Set([n]) continue else: raise FlushError("Self-referential dependency detected " + repr(t)) childnode = nodes[t[1]] parentnode = nodes[t[0]] - self._add_edge(edges, (parentnode, childnode)) + edges.add((parentnode, childnode)) queue = [] for n in nodes.values(): - if len(n.edges) == 0: + if not edges.has_parents(n): queue.append(n) cycles = {} output = [] - while len(edges) > 0: + while len(nodes) > 0: if len(queue) == 0: # edges remain but no edgeless nodes to remove; this indicates # a cycle if allow_all_cycles: + x = self._find_cycles(edges) for cycle in self._find_cycles(edges): lead = cycle[0][0] - lead.cycles = Set() + lead.cycles = util.Set() for edge in cycle: - n = self._remove_edge(edges, edge) + n = edges.remove(edge) lead.cycles.add(edge[0]) lead.cycles.add(edge[1]) if n is not None: @@ -135,9 +177,9 @@ class QueueDependencySorter(object): for n in lead.cycles: if n is not lead: n._cyclical = True - for k in list(edges[n]): - self._add_edge(edges, (lead,k)) - self._remove_edge(edges, (n,k)) + for (n,k) in list(edges.edges_by_parent(n)): + edges.add((lead, k)) + edges.remove((n,k)) continue else: # long cycles not allowed @@ -145,14 +187,9 @@ class QueueDependencySorter(object): node = queue.pop() if not hasattr(node, '_cyclical'): output.append(node) - nodeedges = edges.pop(node, None) - if nodeedges is None: - continue - for childnode in nodeedges.keys(): - del childnode.edges[node] - if len(childnode.edges) == 0: - queue.append(childnode) - + del nodes[node.item] + for childnode in edges.pop_node(node): + queue.append(childnode) return self._create_batched_tree(output) @@ -192,21 +229,6 @@ class QueueDependencySorter(object): return False return sort() - - def _add_edge(self, edges, edge): - (parentnode, childnode) = edge - edges[parentnode][childnode] = True - parentnode.dependencies[childnode] = True - childnode.edges[parentnode] = True - - def _remove_edge(self, edges, edge): - (parentnode, childnode) = edge - del edges[parentnode][childnode] - del childnode.edges[parentnode] - del parentnode.dependencies[childnode] - if len(childnode.edges) == 0: - return childnode - def _find_cycles(self, edges): involved_in_cycles = util.Set() cycles = {} @@ -217,12 +239,11 @@ class QueueDependencySorter(object): elif node is goal: return True - for key in edges[node].keys(): + for (n, key) in edges.edges_by_parent(node): if key in cycle: continue cycle.append(key) if traverse(key, goal, cycle): - #print "adding cycle", list(cycle) cycset = util.Set(cycle) for x in cycle: involved_in_cycles.add(x) @@ -236,19 +257,13 @@ class QueueDependencySorter(object): cycles[x] = cycset cycle.pop() - for parent in edges.keys(): + for parent in edges.get_parents(): traverse(parent) for cycle in dict((id(s), s) for s in cycles.values()).values(): edgecollection = [] - for edge in self.edge_iterator(edges): + for edge in edges: if edge[0] in cycle and edge[1] in cycle: edgecollection.append(edge) yield edgecollection - def edge_iterator(self, edges): - for key in edges.keys(): - for value in edges[key].keys(): - yield (key, value) - - diff --git a/test/base/dependency.py b/test/base/dependency.py index b794629d20..7c6578cb8d 100644 --- a/test/base/dependency.py +++ b/test/base/dependency.py @@ -10,7 +10,7 @@ class DependencySorter(topological.QueueDependencySorter):pass class DependencySortTest(PersistTest): - def assert_sort(self, tuples, node): + def assert_sort(self, tuples, node, collection=None): print str(node) def assert_tuple(tuple, node): if node.cycles: @@ -29,15 +29,20 @@ class DependencySortTest(PersistTest): for tuple in tuples: assert_tuple(list(tuple), node) + if collection is None: + collection = [] items = util.Set() def assert_unique(node): for item in [n.item for n in node.cycles or [node,]]: assert item not in items items.add(item) + if item in collection: + collection.remove(item) for c in node.children: assert_unique(c) assert_unique(node) - + assert len(collection) == 0 + def testsort(self): rootnode = 'root' node2 = 'node2' @@ -77,7 +82,7 @@ class DependencySortTest(PersistTest): (node6, node2) ] head = DependencySorter(tuples, [node7]).sort() - self.assert_sort(tuples, head) + self.assert_sort(tuples, head, [node7]) def testsort3(self): ['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords']