"""
from sqlalchemy.exc import CircularDependencyError
+from sqlalchemy import util
__all__ = ['sort', 'sort_with_cycles', 'sort_as_tree']
"""A collection of directed edges."""
def __init__(self):
- self.parent_to_children = {}
- self.child_to_parents = {}
+ self.parent_to_children = util.defaultdict(set)
+ self.child_to_parents = util.defaultdict(set)
def add(self, edge):
"""Add an edge to this collection."""
- (parentnode, childnode) = edge
- if parentnode not in self.parent_to_children:
- self.parent_to_children[parentnode] = set()
+ parentnode, childnode = edge
self.parent_to_children[parentnode].add(childnode)
- if childnode not in self.child_to_parents:
- self.child_to_parents[childnode] = set()
self.child_to_parents[childnode].add(parentnode)
parentnode.dependencies.add(childnode)
(parentnode, childnode) = edge
self.parent_to_children[parentnode].remove(childnode)
self.child_to_parents[childnode].remove(parentnode)
- if len(self.child_to_parents[childnode]) == 0:
+ if not self.child_to_parents[childnode]:
return childnode
else:
return None
def has_parents(self, node):
- return node in self.child_to_parents and len(self.child_to_parents[node]) > 0
+ return node in self.child_to_parents and bool(self.child_to_parents[node])
def edges_by_parent(self, node):
if node in self.parent_to_children:
for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]:
item_id = id(item)
if item_id not in nodes:
- node = _Node(item)
- nodes[item_id] = node
+ nodes[item_id] = _Node(item)
for t in tuples:
- id0 = id(t[0])
+ id0, id1 = id(t[0]), id(t[1])
if t[0] is t[1]:
if allow_cycles:
n = nodes[id0]
elif not ignore_self_cycles:
raise CircularDependencyError("Self-referential dependency detected " + repr(t))
continue
- childnode = nodes[id(t[1])]
+ childnode = nodes[id1]
parentnode = nodes[id0]
edges.add((parentnode, childnode))
for parent in edges.get_parents():
traverse(parent)
- # sets are not hashable, so uniquify with id
- unique_cycles = dict((id(s), s) for s in cycles.values()).values()
+ unique_cycles = set(tuple(s) for s in cycles.values())
+
for cycle in unique_cycles:
edgecollection = [edge for edge in edges
if edge[0] in cycle and edge[1] in cycle]
assert_tuple(list(tuple), node)
if collection is None:
- collection = []
+ collection = set()
items = set()
- def assert_unique(node):
- for item in [i for i in node[1] or [node[0]]]:
- assert item not in items, str(node)
+ def assert_unique(n):
+ for item in [i for i in n[1] or [n[0]]]:
+ assert item not in items, node
items.add(item)
if item in collection:
collection.remove(item)
- for c in node[2]:
- assert_unique(c)
+ for item in n[2]:
+ assert_unique(item)
assert_unique(node)
assert len(collection) == 0
self.assert_sort(tuples, head)
def testcircular3(self):
- nodes = {}
- tuples = [('Question', 'Issue'), ('ProviderService', 'Issue'), ('Provider', 'Question'), ('Question', 'Provider'), ('ProviderService', 'Question'), ('Provider', 'ProviderService'), ('Question', 'Answer'), ('Issue', 'Question')]
+ question, issue, providerservice, answer, provider = "Question", "Issue", "ProviderService", "Answer", "Provider"
+
+ tuples = [(question, issue), (providerservice, issue), (provider, question), (question, provider), (providerservice, question), (provider, providerservice), (question, answer), (issue, question)]
+
head = topological.sort_as_tree(tuples, [], with_cycles=True)
self.assert_sort(tuples, head)