]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- further reduce what topological has to do, expects full list of nodes
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Apr 2010 16:24:01 +0000 (12:24 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Apr 2010 16:24:01 +0000 (12:24 -0400)
- fix some side-effect-dependent behaviors in uow.  we can now
unconditionally remove "disabled" actions without rewriting

lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/topological.py
test/base/test_dependency.py

index a34c4523327233211c1146458adadaee06583389..dec90adfec43990009dad36670eea0e4e1c7da07 100644 (file)
@@ -94,10 +94,11 @@ class DependencyProcessor(object):
         """
         # locate and disable the aggregate processors
         # for this dependency
-        after_save = unitofwork.ProcessAll(uow, self, False, True)
+        
         before_delete = unitofwork.ProcessAll(uow, self, True, True)
-        after_save.disabled = True
         before_delete.disabled = True
+        after_save = unitofwork.ProcessAll(uow, self, False, True)
+        after_save.disabled = True
 
         # check if the "child" side is part of the cycle
         child_saves = unitofwork.SaveUpdateAll(uow, self.mapper.base_mapper)
@@ -122,7 +123,7 @@ class DependencyProcessor(object):
         # check if the "parent" side is part of the cycle
         if not isdelete:
             parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
-            parent_deletes = before_delte = None
+            parent_deletes = before_delete = None
             if parent_saves in uow.cycles:
                 parent_in_cycles = True
         else:
@@ -133,19 +134,18 @@ class DependencyProcessor(object):
         
         # now create actions /dependencies for each state.
         for state in states:
+            # I'd like to emit the before_delete/after_save actions
+            # here and have the unit of work not get confused by that
+            # when it alters the list of dependencies...
             if isdelete:
                 before_delete = unitofwork.ProcessState(uow, self, True, state)
-                yield before_delete
+                if parent_in_cycles:
+                    parent_deletes = unitofwork.DeleteState(uow, state)
             else:
                 after_save = unitofwork.ProcessState(uow, self, False, state)
-                yield after_save
-                
-            if parent_in_cycles:
-                if isdelete:
-                    parent_deletes = unitofwork.DeleteState(uow, state)
-                else:
+                if parent_in_cycles:
                     parent_saves = unitofwork.SaveUpdateState(uow, state)
-                    
+                
             if child_in_cycles:
                 # locate each child state associated with the parent action,
                 # create dependencies for each.
@@ -174,7 +174,11 @@ class DependencyProcessor(object):
                                                 child_action, 
                                                 after_save, before_delete, 
                                                 isdelete, childisdelete)
-
+        
+        # ... but at the moment it 
+        # does so we emit a null iterator
+        return iter([])
+        
     def presort_deletes(self, uowcommit, states):
         pass
         
@@ -304,8 +308,8 @@ class OneToManyDP(DependencyProcessor):
             ])
         else:
             uow.dependencies.update([
-                (child_action, before_delete),
-                (before_delete, delete_parent),
+                (before_delete, child_action),
+                (child_action, delete_parent)
             ])
         
     def presort_deletes(self, uowcommit, states):
index b8373ff63075d7bb082438fda86b084b145d9f99..898be9139861eeaa236f7d97799d50f9b94c3ad4 100644 (file)
@@ -195,6 +195,8 @@ class UOWTransaction(object):
                     self.dependencies.remove(edge)
                 elif cycles.issuperset(edge):
                     self.dependencies.remove(edge)
+                elif edge[0].disabled or edge[1].disabled:
+                    self.dependencies.remove(edge)
                 elif edge[0] in cycles:
                     self.dependencies.remove(edge)
                     for dep in convert[edge[0]]:
@@ -203,19 +205,18 @@ class UOWTransaction(object):
                     self.dependencies.remove(edge)
                     for dep in convert[edge[1]]:
                         self.dependencies.add((edge[0], dep))
-                elif edge[0].disabled or edge[1].disabled:
-                    self.dependencies.remove(edge)
         
         postsort_actions = set(
-                                [a for a in self.postsort_actions.values() 
-                                if not a.disabled]
+                                [a for a in self.postsort_actions.values()
+                                if not a.disabled
+                                ]
                             ).difference(cycles)
         
         # execute actions
         sort = topological.sort(self.dependencies, postsort_actions)
-#        print "------------------------"
-#        print self.dependencies
-#        print sort
+        #print "------------------------"
+        #print self.dependencies
+        #print sort
         for rec in sort:
             rec.execute(self)
             
index 5fc982ae0debad0af95c9c3c33814f17fa92e9ac..bcf47bd6492bf4a4c2b3236d61690b24dddc32bc 100644 (file)
@@ -26,26 +26,16 @@ __all__ = ['sort']
 class _EdgeCollection(object):
     """A collection of directed edges."""
 
-    def __init__(self):
+    def __init__(self, edges):
         self.parent_to_children = util.defaultdict(set)
         self.child_to_parents = util.defaultdict(set)
-
-    def add(self, edge):
-        """Add an edge to this collection."""
-
-        parentnode, childnode = edge
-        self.parent_to_children[parentnode].add(childnode)
-        self.child_to_parents[childnode].add(parentnode)
-
+        for parentnode, childnode in edges:
+            self.parent_to_children[parentnode].add(childnode)
+            self.child_to_parents[childnode].add(parentnode)
+            
     def has_parents(self, node):
         return node in self.child_to_parents and bool(self.child_to_parents[node])
 
-    def edges_by_parent(self, node):
-        if node in self.parent_to_children:
-            return [(node, child) for child in self.parent_to_children[node]]
-        else:
-            return []
-    
     def outgoing(self, node):
         """an iterable returning all nodes reached via node's outgoing edges"""
         
@@ -79,13 +69,9 @@ def sort(tuples, allitems):
     'tuples' is a list of tuples representing a partial ordering.
     """
 
-    edges = _EdgeCollection()
+    edges = _EdgeCollection(tuples)
     nodes = set(allitems)
 
-    for t in tuples:
-        nodes.update(t)
-        edges.add(t)
-
     queue = []
     for n in nodes:
         if not edges.has_parents(n):
@@ -106,12 +92,8 @@ def sort(tuples, allitems):
 def find_cycles(tuples, allitems):
     # straight from gvr with some mods
     todo = set(allitems)
-    edges = _EdgeCollection()
+    edges = _EdgeCollection(tuples)
 
-    for t in tuples:
-        todo.update(t)
-        edges.add(t)
-    
     output = set()
     
     while todo:
index 8c38a98b0badf3d9a86b933e61449e1fb32024b3..462e923f1b34fbdb29a5d7d266f162b0a95bd288 100644 (file)
@@ -5,7 +5,14 @@ from sqlalchemy import exc
 import collections
 
 class DependencySortTest(TestBase):
-    def assert_sort(self, tuples, result):
+    def assert_sort(self, tuples, allitems=None):
+        
+        if allitems is None:
+            allitems = self._nodes_from_tuples(tuples)
+        else:
+            allitems = self._nodes_from_tuples(tuples).union(allitems)
+            
+        result = topological.sort(tuples, allitems)
         
         deps = collections.defaultdict(set)
         for parent, child in tuples:
@@ -16,6 +23,12 @@ class DependencySortTest(TestBase):
             for n in result[i:]:
                 assert node not in deps[n]
 
+    def _nodes_from_tuples(self, tups):
+        s = set()
+        for tup in tups:
+            s.update(tup)
+        return s
+        
     def test_sort_one(self):
         rootnode = 'root'
         node2 = 'node2'
@@ -36,7 +49,7 @@ class DependencySortTest(TestBase):
             (node4, subnode3),
             (node4, subnode4)
         ]
-        self.assert_sort(tuples, topological.sort(tuples, []))
+        self.assert_sort(tuples)
 
     def test_sort_two(self):
         node1 = 'node1'
@@ -53,7 +66,7 @@ class DependencySortTest(TestBase):
             (node5, node6),
             (node6, node2)
         ]
-        self.assert_sort(tuples, topological.sort(tuples, [node7]))
+        self.assert_sort(tuples, [node7])
 
     def test_sort_three(self):
         node1 = 'keywords'
@@ -66,7 +79,7 @@ class DependencySortTest(TestBase):
             (node1, node3),
             (node3, node2)
         ]
-        self.assert_sort(tuples, topological.sort(tuples, []))
+        self.assert_sort(tuples)
 
     def test_raise_on_cycle_one(self):
         node1 = 'node1'
@@ -82,7 +95,7 @@ class DependencySortTest(TestBase):
             (node3, node1),
             (node4, node1)
         ]
-        allitems = [node1, node2, node3, node4]
+        allitems = self._nodes_from_tuples(tuples)
         assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems)
 
         # TODO: test find_cycles
@@ -101,7 +114,8 @@ class DependencySortTest(TestBase):
             (node3, node2),
             (node2, node3)
         ]
-        assert_raises(exc.CircularDependencyError, topological.sort, tuples, [])
+        allitems = self._nodes_from_tuples(tuples)
+        assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems)
 
         # TODO: test find_cycles
 
@@ -112,24 +126,19 @@ class DependencySortTest(TestBase):
                     (question, provider), (providerservice, question), 
                     (provider, providerservice), (question, answer), (issue, question)]
 
-        assert_raises(exc.CircularDependencyError, topological.sort, tuples, [])
+        allitems = self._nodes_from_tuples(tuples)
+        assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems)
         
         # TODO: test find_cycles
         
     def test_large_sort(self):
         tuples = [(i, i + 1) for i in range(0, 1500, 2)]
-        self.assert_sort(
-            tuples,
-            topological.sort(tuples, [])
-        )
+        self.assert_sort(tuples)
 
     def test_ticket_1380(self):
         # ticket:1380 regression: would raise a KeyError
         tuples = [(id(i), i) for i in range(3)]
-        self.assert_sort(
-            tuples,
-            topological.sort(tuples, [])
-        )
+        self.assert_sort(tuples)
         
     def test_find_cycle_one(self):
         node1 = 'node1'
@@ -145,7 +154,7 @@ class DependencySortTest(TestBase):
         ]
 
         eq_(
-            topological.find_cycles(tuples),
+            topological.find_cycles(tuples, self._nodes_from_tuples(tuples)),
             set([node1, node2, node3])
         )