]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dependency sort thing getting out of hand
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2005 03:16:44 +0000 (03:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2005 03:16:44 +0000 (03:16 +0000)
lib/sqlalchemy/objectstore.py

index a22178fcab075c77aa3822c18965ffe479e067d6..13cebcc15f407e9819c41cf50874fa55f1053e6e 100644 (file)
@@ -107,8 +107,8 @@ class UnitOfWork(object):
         self.identity_map[key] = obj
     
     def _remove_deleted(self, obj):
-        if hasattr(obj, "_instancekey"):
-            del self.identity_map[obj._instancekey]
+        if hasattr(obj, "_instance_key"):
+            del self.identity_map[obj._instance_key]
         del self.deleted[obj]
         self.attributes.remove(obj)
         
@@ -150,7 +150,11 @@ class UnitOfWork(object):
             return True
         
     def register_deleted(self, obj):
+        #if self.deleted.contains(obj):
+        #    return
         self.deleted.append(obj)  
+        mapper = object_mapper(obj)
+        mapper.register_deleted(obj, self)
 
     # TODO: tie in register_new/register_dirty with table transaction begins ?
     def begin(self):
@@ -158,8 +162,6 @@ class UnitOfWork(object):
         uow.set(u)
         
     def commit(self, *objects):
-        import sqlalchemy.mapper
-
         commit_context = UOWTransaction(self)
 
         if len(objects):
@@ -175,6 +177,7 @@ class UnitOfWork(object):
                 obj = item.obj
                 commit_context.append_task(obj)
             for obj in self.deleted:
+                print "going to delete.... " + repr(obj)
                 commit_context.add_item_to_delete(obj)
                 
         engines = util.HashSet()
@@ -213,8 +216,7 @@ class UnitOfWork(object):
 class UOWTransaction(object):
     def __init__(self, uow):
         self.uow = uow
-        # links objects to their mappers
-        self.object_mappers = {}
+
         #  unique list of all the mappers we come across
         self.mappers = util.HashSet()
         self.dependencies = {}
@@ -225,12 +227,14 @@ class UOWTransaction(object):
         self.deleted_lists = util.HashSet()
 
     def append_task(self, obj):
-        mapper = self.object_mapper(obj)
+        mapper = object_mapper(obj)
+        self.mappers.append(mapper)
         task = self.get_task_by_mapper(mapper)
         task.objects.append(obj)
 
     def add_item_to_delete(self, obj):
-        mapper = self.object_mapper(obj)
+        mapper = object_mapper(obj)
+        self.mappers.append(mapper)
         task = self.get_task_by_mapper(mapper, True)
         task.objects.append(obj)
 
@@ -268,44 +272,11 @@ class UOWTransaction(object):
     def register_deleted_object(self, obj):
         self.deleted_objects.append(obj)
         
-        
-    def object_mapper(self, obj):
-        import sqlalchemy.mapper
-        try:
-            return self.object_mappers[obj]
-        except KeyError:
-            mapper = sqlalchemy.mapper.object_mapper(obj)
-            self.object_mappers[obj] = mapper
-            self.mappers.append(mapper)
-            return mapper
-            
     def execute(self):
         for task in self.tasks.values():
             task.mapper.register_dependencies(self)
-            
-        tasklist = self.tasks.values()
-        def compare(a, b):
-            if a.mapper is b.mapper:
-                return a.isdelete and 1 or -1
-            elif self.dependencies.has_key((a.mapper, b.mapper)):
-                if a.isdelete is not b.isdelete:
-                    return a.isdelete and 1 or -1
-                else:
-                    return -1
-            elif self.dependencies.has_key((b.mapper, a.mapper)):
-                if a.isdelete is not b.isdelete:
-                    return a.isdelete and 1 or -1
-                else:
-                    return 1
-            else:
-                return 0
-            return c
-        tasklist.sort(compare)
-
-        import string
-        print string.join([str(t) for t in tasklist], ',')
-
-        for task in tasklist:
+        
+        for task in self._sort_dependencies():
             obj_list = task.objects
             if not task.isdelete:
                 task.mapper.save_obj(obj_list, self)
@@ -317,7 +288,7 @@ class UOWTransaction(object):
             
     def post_exec(self):
         for obj in self.saved_objects:
-            mapper = self.object_mapper(obj)
+            mapper = object_mapper(obj)
             obj._instance_key = mapper.identity_key(obj)
             self.uow.register_clean(obj)
 
@@ -335,6 +306,60 @@ class UOWTransaction(object):
                 del self.uow.modified_lists[obj]
             except KeyError:
                 pass
+
+    def _sort_dependencies(self):        
+        nodes = {}
+        def maketree(tuples):
+            head = None
+            for tup in tuples:
+                (parent, child) = (tup[0], tup[1])
+                
+                try:
+                    parentnode = nodes[parent]
+                except KeyError:
+                    parentnode = (parent, [])
+                    nodes[parent] = parentnode
+                try:
+                    childnode = nodes[child]
+                except KeyError:
+                    childnode = (child, [])
+                    nodes[child] = childnode
+
+                if head is None:
+                    head = parentnode
+                elif head is childnode:
+                    head = parentnode
+                parentnode[1].append(childnode)
+            return head
+        
+        bymapper = {}
+        
+        def sort(node, isdel, res):
+            task = bymapper.get((node[0], isdel), None)
+            if task is not None:
+                res.append(task)
+            for child in node[1]:
+                sort(child, isdel, res)
+            return res
+            
+        for task in self.tasks.values():
+            print "new node for " + str(task)
+            bymapper[(task.mapper, task.isdelete)] = task
+            
+    
+        head = maketree(self.dependencies)
+        res = []
+        tasklist = sort(head, False, res)
+        res = []
+        sort(head, True, res)
+        res.reverse()
+        tasklist += res
+
+        import string,sys
+        print string.join([str(t) for t in tasklist], ',')
+        #sys.exit(0)
+        
+        return tasklist
             
 class UOWTask(object):
     def __init__(self, mapper, isdelete = False):
@@ -342,6 +367,7 @@ class UOWTask(object):
         self.isdelete = isdelete
         self.objects = util.HashSet()
         self.dependencies = []
+        print "new task " + str(self)
     
     def __str__(self):
         if self.isdelete:
@@ -349,4 +375,9 @@ class UOWTask(object):
         else:
             return self.mapper.table.name + " saves"
             
-uow = util.ScopedRegistry(lambda: UnitOfWork(), "thread")
\ No newline at end of file
+uow = util.ScopedRegistry(lambda: UnitOfWork(), "thread")
+
+
+def object_mapper(obj):
+    import sqlalchemy.mapper
+    return sqlalchemy.mapper.object_mapper(obj)