]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2005 03:33:48 +0000 (03:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2005 03:33:48 +0000 (03:33 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py

index 7585dc0cd47a47d435098c5c46d1ffeed7dd6951..27caa33a64027d148c4b311914775177fb551c96 100644 (file)
@@ -524,8 +524,11 @@ class ColumnProperty(MapperProperty):
             #setattr(instance, self.key, row[self.columns[0].label])
         
 
-
 class PropertyLoader(MapperProperty):
+    LEFT = 0
+    RIGHT = 1
+    CENTER = 2
+
     """describes an object property that holds a single item or list of items that correspond to a related
     database table."""
     def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None):
@@ -583,7 +586,9 @@ class PropertyLoader(MapperProperty):
             else:
                 self.foreignkey = w.dependent
 
-        if self.uselist is None and self.foreignkey is not None and self.foreignkey.table == self.parent.primarytable:
+        self.direction = self.get_direction()
+        
+        if self.uselist is None and self.direction == PropertyLoader.RIGHT:
             self.uselist = False
 
         if self.uselist is None:
@@ -595,6 +600,19 @@ class PropertyLoader(MapperProperty):
             #print "regiser list col on class %s key %s" % (parent.class_.__name__, key)
             objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist)
 
+    def get_direction(self):
+        if self.thiscol is not None:
+            if self.thiscol.primary_key:
+                return PropertyLoader.LEFT
+            else:
+                return PropertyLoader.RIGHT
+        if self.secondaryjoin is not None:
+            return PropertyLoader.CENTER
+        elif self.foreignkey.table == self.target:
+            return PropertyLoader.LEFT
+        elif self.foreignkey.table == self.parent.primarytable:
+            return PropertyLoader.RIGHT
+
     class FindDependent(sql.ClauseVisitor):
         def __init__(self):
             self.dependent = None
@@ -642,19 +660,18 @@ class PropertyLoader(MapperProperty):
 
             
     def register_dependencies(self, uowcommit):
-        if self.secondaryjoin is not None:
+        if self.direction == PropertyLoader.CENTER:
             # with many-to-many, set the parent as dependent on us, then the 
             # list of associations as dependent on the parent
             # if only a list changes, the parent mapper is the only mapper that
             # gets added to the "todo" list
             uowcommit.register_dependency(self.mapper, self.parent)
             uowcommit.register_task(self.parent, False, self, self.parent, False)
-        elif self.foreignkey.table == self.target:
+        elif self.direction == PropertyLoader.LEFT:
             uowcommit.register_dependency(self.parent, self.mapper)
             uowcommit.register_task(self.parent, False, self, self.parent, False)
             uowcommit.register_task(self.parent, True, self, self.parent, True)
-                
-        elif self.foreignkey.table == self.parent.primarytable:
+        elif self.direction == PropertyLoader.RIGHT:
             uowcommit.register_dependency(self.mapper, self.parent)
             uowcommit.register_task(self.mapper, False, self, self.parent, False)
             # TODO: private deletion thing for one-to-one relationship
@@ -671,22 +688,10 @@ class PropertyLoader(MapperProperty):
     def whose_dependent_on_who(self, obj1, obj2, uowcommit):
         if obj1 is obj2:
             return None
-        hist = self.get_object_dependencies(obj1, uowcommit)
-        if hist.history_contains(obj2):
-            if self.thiscol.primary_key:
-                return (obj1, obj2)
-            else:
-                return (obj2, obj1)
+        elif self.thiscol.primary_key:
+            return (obj1, obj2)
         else:
-            hist = self.get_object_dependencies(obj2, uowcommit)
-            if hist.history_contains(obj1):
-                if self.thiscol.primary_key:
-                    return (obj2, obj1)
-                else:
-                    return (obj1, obj2)
-            else:
-                return None
-            
+            return (obj2, obj1)
             
     def process_dependencies(self, deplist, uowcommit, delete = False):
         #print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete)
@@ -698,16 +703,13 @@ class PropertyLoader(MapperProperty):
         setter = BinaryVisitor(sync_foreign_keys)
 
         def getlist(obj, passive=True):
-            if self.uselist:
-                return uowcommit.uow.attributes.get_list_history(obj, self.key, passive = passive)
-            else: 
-                return uowcommit.uow.attributes.get_history(obj, self.key)
+            return self.get_object_dependencies(obj, uowcommit, passive)
             
         associationrow = {}
         
         # plugin point
         
-        if self.secondaryjoin is not None:
+        if self.direction == PropertyLoader.CENTER:
             secondary_delete = []
             secondary_insert = []
             if delete:
@@ -742,7 +744,7 @@ class PropertyLoader(MapperProperty):
                 if len(secondary_insert):
                     statement = self.secondary.insert()
                     statement.execute(*secondary_insert)
-        elif self.foreignkey.table == self.target:
+        elif self.direction == PropertyLoader.LEFT:
             if delete and not self.private:
                 updates = []
                 clearkeys = True
@@ -772,7 +774,7 @@ class PropertyLoader(MapperProperty):
                     clearkeys = True
                     for child in childlist.deleted_items():
                          self.primaryjoin.accept_visitor(setter)
-        elif self.foreignkey.table == self.parent.primarytable:
+        elif self.direction == PropertyLoader.RIGHT:
             for child in deplist:
                 childlist = getlist(child)
                 if childlist is None: return
index eb601e3cb300c93cf7027be56c859e245d442354..3095f1bff7b2448ee3bdc9149e75c5bf447d5fb1 100644 (file)
@@ -330,21 +330,19 @@ class UOWTransaction(object):
         bymapper = {}
         
         def sort(node, isdel, res):
-            print "Sort: " + (node and str(node.mapper) or 'None')
+            print "Sort: " + (node and str(node.item) or 'None')
             if node is None:
                 return res
-            task = bymapper.get((node.mapper, isdel), None)
+            task = bymapper.get((node.item, isdel), None)
             if task is not None:
                 res.append(task)
-            for child in node.children:
-                if child is node:
-                    print "setting circular: " + str(task)
+                if node.circular:
                     task.iscircular = True
-                    continue
+            for child in node.children:
                 sort(child, isdel, res)
             return res
             
-        mappers = []
+        mappers = util.HashSet()
         for task in self.tasks.values():
             #print "new node for " + str(task)
             mappers.append(task.mapper)
@@ -359,6 +357,8 @@ class UOWTransaction(object):
         res.reverse()
         tasklist += res
 
+        print repr(self.tasks.values())
+        print repr(tasklist)
         assert(len(self.tasks.values()) == len(tasklist)) # "sorted task list not the same size as original task list"
 
         import string,sys
@@ -378,10 +378,10 @@ class UOWTask(object):
         #print "new task " + str(self)
     
     def execute(self, trans):
-        print "exec " + str(self) + " circualr=" + repr(self.iscircular)
         if self.iscircular:
             task = self.sort_circular_dependencies(trans)
-            task.execute_circular(trans)
+            if task is not None:
+                task.execute_circular(trans)
             return
             
         obj_list = self.objects
@@ -394,17 +394,11 @@ class UOWTask(object):
             self.mapper.delete_obj(obj_list, trans)
 
     def execute_circular(self, trans):
-        print "execcircular " + str(self)
-#        obj_list = self.objects
- #       if not self.listonly and not self.isdelete:
- #           self.mapper.save_obj(obj_list, trans)
- #       raise "hi"
         self.execute(trans)
         for obj in self.objects:
             childtask = self.taskhash[obj]
             childtask.execute_circular(trans)
     
-    
     def sort_circular_dependencies(self, trans):
         allobjects = self.objects
         tuples = []
@@ -418,9 +412,28 @@ class UOWTask(object):
                         if whosdep is not None:
                             tuples.append(whosdep)
         head = TupleSorter(tuples, allobjects).sort()
-        print "---------"
-        print str(head)
-        raise "hi"
+        if head is None:
+            return None
+
+        d = {}
+        def make_task():
+            t = UOWTask(self.mapper, self.isdelete, self.listonly)
+            t.dependencies = self.dependencies
+            t.taskhash = d
+            return t
+        
+        def make_task_tree(node, parenttask):
+            if node is None:
+                return
+            parenttask.objects.append(node.item)
+            t = make_task()
+            d[node.item] = t
+            for n in node.children:
+                make_task_tree(n, t)
+        
+        t = make_task()
+        make_task_tree(head, t)
+        return t
         
     def old_sort_circular_dependencies(self, trans):
         dependents = {}
@@ -485,18 +498,19 @@ class UOWTask(object):
 class TupleSorter(object):
 
     class Node:
-        def __init__(self, mapper):
-            #print "new node on " + str(mapper)
-            self.mapper = mapper
+        def __init__(self, item):
+            #print "new node on " + str(item)
+            self.item = item
             self.children = util.HashSet()
             self.parent = None
+            self.circular = False
         def __str__(self):
             return self.safestr({})
-        def safestr(self, hash):
+        def safestr(self, hash, indent = 0):
             if hash.has_key(self):
-                return "[RECURSIVE:%s(%s, %s)]" % (str(self.mapper), repr(id(self)), repr(id(self.parent)))
+                return (' ' * indent) + "RECURSIVE:%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or 'None')
             hash[self] = True
-            return "%s(%s, %s)" % (str(self.mapper), repr(id(self)), repr(id(self.parent))) + "\n" + string.join([n.safestr(hash) for n in self.children])
+            return (' ' * indent) + "%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '')
             
     def __init__(self, tuples, allitems):
         self.tuples = tuples
@@ -507,35 +521,69 @@ class TupleSorter(object):
         head = None
         for tup in tuples:
             (parent, child) = (tup[0], tup[1])
-            print "tuple: " + str(parent) + " " + str(child)
+            #print "tuple: " + str(parent) + " " + str(child)
+            
+            # get parent node
             try:
                 parentnode = nodes[parent]
             except KeyError:
                 parentnode = TupleSorter.Node(parent)
                 nodes[parent] = parentnode
+
+            # if parent is child, mark "circular" attribute on the node
+            if parent is child:
+                parentnode.circular = True
+                # set head if its nothing
+                if head is None:
+                    head = parentnode
+                # nothing more to do for this one
+                continue
+
+            # get child node
             try:
                 childnode = nodes[child]
             except KeyError:
                 childnode = TupleSorter.Node(child)
                 nodes[child] = childnode
 
+            # set head if its nothing, move it up to the parent
+            # if its the child node
             if head is None:
                 head = parentnode
             elif head is childnode:
                 head = parentnode
-            if childnode.parent is not None:
-                del childnode.parent.children[childnode]
-                childnode.parent.children.append(parentnode)
-            parentnode.children.append(childnode)
-            childnode.parent = parentnode
             
+            # now see, if the parent is an ancestor of the child
+            c = childnode
+            while c is not None and c is not parentnode:
+                c = c.parent
+            
+            # nope, so we have to move the child down from whereever
+            # it currently is to a child of the parent
+            if c is None:
+                if childnode.parent is not None:
+                    del childnode.parent.children[childnode]
+                    childnode.parent.children.append(parentnode)
+                parentnode.children.append(childnode)
+                childnode.parent = parentnode
+            #print str(head)
+        
+        # go through the total list of items.  for those 
+        # that had no dependency tuples, and therefore are not
+        # in the tree, add them as head nodes in a line
+        newhead = None
         for item in allitems:
             if not nodes.has_key(item):
-                node = TupleSorter.Node(item)
-                if head is not None:
-                    head.parent = node
-                    node.children.append(head)
-                head = node
+                if newhead is None:
+                    newhead = TupleSorter.Node(item)
+                    if head is not None:
+                        head.parent = newhead
+                        newhead.children.append(head)
+                    head = newhead
+                else:
+                    n = TupleSorter.Node(item)
+                    head.children.append(n)
+                    n.parent = head
         return head
                     
 uow = util.ScopedRegistry(lambda: UnitOfWork(), "thread")