from sqlalchemy import util
-__all__ = ['sort']
-
-class _EdgeCollection(object):
- """A collection of directed edges."""
-
- def __init__(self, edges):
- self.parent_to_children = util.defaultdict(set)
- self.child_to_parents = util.defaultdict(set)
- for parentnode, childnode in edges:
- self.parent_to_children[parentnode].add(childnode)
- self.child_to_parents[childnode].add(parentnode)
-
- def outgoing(self, node):
- return self.parent_to_children[node]
-
- def incoming(self, node):
- return self.child_to_parents[node]
-
- def __iter__(self):
- for parent in self.parent_to_children:
- for child in self.outgoing(parent):
- yield (parent, child)
-
- def __repr__(self):
- return repr(list(self))
+__all__ = ['sort', 'sort_as_subsets', 'find_cycles']
def sort_as_subsets(tuples, allitems):
output = set()
todo = set(allitems)
- edges = _EdgeCollection(tuples)
+
+ edges = util.defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
while todo:
for node in list(todo):
- if not todo.intersection(edges.incoming(node)):
+ if not todo.intersection(edges[node]):
output.add(node)
if not output:
raise CircularDependencyError(
- "Circular dependency detected: cycles: %r all edges: %r" %
- (find_cycles(tuples, allitems), edges))
+ "Circular dependency detected: cycles: %r all edges: %s" %
+ (find_cycles(tuples, allitems), _dump_edges(edges, True)))
todo.difference_update(output)
yield output
def find_cycles(tuples, allitems):
# straight from gvr with some mods
todo = set(allitems)
- edges = _EdgeCollection(tuples)
+
+ edges = util.defaultdict(set)
+ for parent, child in tuples:
+ edges[parent].add(child)
output = set()
stack = [node]
while stack:
top = stack[-1]
- for node in edges.outgoing(top):
+ for node in edges[top]:
if node in stack:
cyc = stack[stack.index(node):]
todo.difference_update(cyc)
node = stack.pop()
return output
+def _dump_edges(edges, reverse):
+ l = []
+ for left in edges:
+ for right in edges[left]:
+ if reverse:
+ l.append((right, left))
+ else:
+ l.append((left, right))
+ return repr(l)
+
+
def test_sorter( self ):
tables = metadata.sorted_tables
table_names = [t.name for t in tables]
- self.assert_( table_names == ['users', 'orders', 'items', 'email_addresses'] or table_names == ['users', 'email_addresses', 'orders', 'items'])
+ ua = [n for n in table_names if n in ('users', 'email_addresses')]
+ oi = [n for n in table_names if n in ('orders', 'items')]
+
+ eq_(ua, ['users', 'email_addresses'])
+ eq_(oi, ['orders', 'items'])
+
def testcheckfirst(self):
try: