]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2005 01:46:42 +0000 (01:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2005 01:46:42 +0000 (01:46 +0000)
examples/adjacencytree/tables.py [new file with mode: 0644]
lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/util.py

diff --git a/examples/adjacencytree/tables.py b/examples/adjacencytree/tables.py
new file mode 100644 (file)
index 0000000..3cfe1e9
--- /dev/null
@@ -0,0 +1,13 @@
+from sqlalchemy.schema import *
+import sqlalchemy.engine
+
+engine = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = True)
+
+trees = Table('treenodes', engine,
+    Column('node_id', Integer, primary_key=True),
+    Column('parent_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True),
+    Column('root_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True),
+    Column('node_name', String(50), nullable=False)
+    )
+    
+trees.create()
\ No newline at end of file
index a12b57accb72580cca360c16a8995aee5626d2a9..30d9a296d6b3bc7dad92671099abb7dae7c5faa4 100644 (file)
@@ -54,6 +54,8 @@ class PropHistory(object):
         self.obj = obj
         self.key = key
         self.orig = PropHistory.NONE
+    def history_contains(self, obj):
+        return self.orig is obj or self.obj.__dict__[self.key] is obj
     def setattr_clean(self, value):
         self.obj.__dict__[self.key] = value
     def setattr(self, value):
index 218c4fb66036275a2f0624eb43638e5cdeecce0c..7585dc0cd47a47d435098c5c46d1ffeed7dd6951 100644 (file)
@@ -38,8 +38,8 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin
     else:
         return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
     
-def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, **options):
-    return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, **options)
+def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, thiscol = None, **options):
+    return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, thiscol = thiscol, **options)
 
 class assignmapper(object):
     def __init__(self, table, **kwargs):
@@ -528,7 +528,7 @@ class ColumnProperty(MapperProperty):
 class PropertyLoader(MapperProperty):
     """describes an object property that holds a single item or list of items that correspond to a related
     database table."""
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None):
         self.uselist = uselist
         self.argument = argument
         self.secondary = secondary
@@ -536,6 +536,7 @@ class PropertyLoader(MapperProperty):
         self.secondaryjoin = secondaryjoin
         self.foreignkey = foreignkey
         self.private = private
+        self.thiscol = thiscol
         self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist))
 
     def _copy(self):
@@ -553,10 +554,12 @@ class PropertyLoader(MapperProperty):
             self.mapper = self.argument
             
         self.target = self.mapper.table
-            
         self.key = key
         self.parent = parent
-        
+
+        if self.parent.table is self.target and self.thiscol is None:
+            raise "Circular relationship requires 'thiscol' parameter"
+            
         # if join conditions were not specified, figure them out based on foreign keys
         if self.secondary is not None:
             if self.secondaryjoin is None:
@@ -660,15 +663,33 @@ class PropertyLoader(MapperProperty):
             raise " no foreign key ?"
 
     def get_object_dependencies(self, obj, uowcommit, passive = True):
-        """function to retreive the child list off of an object.  "passive" means, if its
-         a lazy loaded list that is not loaded yet, dont load it."""
         if self.uselist:
             return uowcommit.uow.attributes.get_list_history(obj, self.key, passive = passive)
         else: 
             return uowcommit.uow.attributes.get_history(obj, self.key)
-    
+
+    def whose_dependent_on_who(self, obj1, obj2, uowcommit):
+        if obj1 is obj2:
+            return None
+        hist = self.get_object_dependencies(obj1, uowcommit)
+        if hist.history_contains(obj2):
+            if self.thiscol.primary_key:
+                return (obj1, obj2)
+            else:
+                return (obj2, obj1)
+        else:
+            hist = self.get_object_dependencies(obj2, uowcommit)
+            if hist.history_contains(obj1):
+                if self.thiscol.primary_key:
+                    return (obj2, obj1)
+                else:
+                    return (obj1, obj2)
+            else:
+                return None
+            
+            
     def process_dependencies(self, deplist, uowcommit, delete = False):
-        print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete)
+        #print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete)
 
         # fucntion to set properties across a parent/child object plus an "association row",
         # based on a join condition
@@ -677,7 +698,10 @@ class PropertyLoader(MapperProperty):
         setter = BinaryVisitor(sync_foreign_keys)
 
         def getlist(obj, passive=True):
-            return self.get_object_dependencies(obj, uowcommit, passive)
+            if self.uselist:
+                return uowcommit.uow.attributes.get_list_history(obj, self.key, passive = passive)
+            else: 
+                return uowcommit.uow.attributes.get_history(obj, self.key)
             
         associationrow = {}
         
@@ -719,7 +743,6 @@ class PropertyLoader(MapperProperty):
                     statement = self.secondary.insert()
                     statement.execute(*secondary_insert)
         elif self.foreignkey.table == self.target:
-            print "HI"
             if delete and not self.private:
                 updates = []
                 clearkeys = True
@@ -739,11 +762,9 @@ class PropertyLoader(MapperProperty):
                     statement = self.target.update(self.lazywhere, values = values)
                     statement.execute(*updates)
             else:
-                print str(self.primaryjoin.compile())
                 for obj in deplist:
                     childlist = getlist(obj)
                     if childlist is None: return
-                    print "DEP: " +str(obj) + " LIST: " + repr([str(v) for v in childlist.added_items()])
                     uowcommit.register_saved_list(childlist)
                     clearkeys = False
                     for child in childlist.added_items():
@@ -768,7 +789,7 @@ class PropertyLoader(MapperProperty):
         else:
             raise " no foreign key ?"
     
-        print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
+        #print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
 
     def _sync_foreign_keys(self, binary, obj, child, associationrow, clearkeys):
         """given a binary clause with an = operator joining two table columns, synchronizes the values 
@@ -782,12 +803,12 @@ class PropertyLoader(MapperProperty):
                     source = binary.right
                 else:
                     raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname)
-                print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key
+                #print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key
                 self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source))
             else:
                 colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
                 if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target):
-                    print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key
+                    #print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key
                     if clearkeys:
                         self.mapper._setattrbycolumn(child, colmap[self.target], None)
                     else:
index f4639cb47f34a5caecf29b49c2813038838fd416..eb601e3cb300c93cf7027be56c859e245d442354 100644 (file)
@@ -24,7 +24,7 @@ import thread
 import sqlalchemy.util as util
 import sqlalchemy.attributes as attributes
 import weakref
-
+import string
 
 def get_id_key(ident, class_, table):
     """returns an identity-map key for use in storing/retrieving an item from the identity map, given
@@ -302,6 +302,7 @@ class UOWTransaction(object):
             task.mapper.register_dependencies(self)
         
         for task in self._sort_dependencies():
+            print "exec task: " + str(task)
             task.execute(self)
             
     def post_exec(self):
@@ -326,54 +327,10 @@ class UOWTransaction(object):
                 pass
 
     def _sort_dependencies(self):
-    
-        class Node:
-            def __init__(self, mapper):
-                #print "new node on " + str(mapper)
-                self.mapper = mapper
-                self.children = util.HashSet()
-                self.parent = None
-                
-        def maketree(tuples, allitems):
-            nodes = {}
-            head = None
-            for tup in tuples:
-                (parent, child) = (tup[0], tup[1])
-                #print "tuple: " + str(parent) + " " + str(child)
-                try:
-                    parentnode = nodes[parent]
-                except KeyError:
-                    parentnode = Node(parent)
-                    nodes[parent] = parentnode
-                try:
-                    childnode = nodes[child]
-                except KeyError:
-                    childnode = Node(child)
-                    nodes[child] = childnode
-
-                if head is None:
-                    head = parentnode
-                elif head is childnode:
-                    head = parentnode
-                if childnode.parent is not None:
-                    del childnode.parent.children[childnode]
-                    childnode.parent.children.append(parentnode)
-                parentnode.children.append(childnode)
-                childnode.parent = parentnode
-                
-            for item in allitems:
-                if not nodes.has_key(item):
-                    node = Node(item)
-                    if head is not None:
-                        head.parent = node
-                        node.children.append(head)
-                    head = node
-            return head
-        
         bymapper = {}
         
         def sort(node, isdel, res):
-            #print "Sort: " + (node and str(node.mapper) or 'None')
+            print "Sort: " + (node and str(node.mapper) or 'None')
             if node is None:
                 return res
             task = bymapper.get((node.mapper, isdel), None)
@@ -381,6 +338,8 @@ class UOWTransaction(object):
                 res.append(task)
             for child in node.children:
                 if child is node:
+                    print "setting circular: " + str(task)
+                    task.iscircular = True
                     continue
                 sort(child, isdel, res)
             return res
@@ -391,7 +350,7 @@ class UOWTransaction(object):
             mappers.append(task.mapper)
             bymapper[(task.mapper, task.isdelete)] = task
     
-        head = maketree(self.dependencies, mappers)
+        head = TupleSorter(self.dependencies, mappers).sort()
         res = []
         tasklist = sort(head, False, res)
 
@@ -415,9 +374,16 @@ class UOWTask(object):
         self.objects = util.HashSet(ordered = True)
         self.dependencies = []
         self.listonly = listonly
+        self.iscircular = False
         #print "new task " + str(self)
     
     def execute(self, trans):
+        print "exec " + str(self) + " circualr=" + repr(self.iscircular)
+        if self.iscircular:
+            task = self.sort_circular_dependencies(trans)
+            task.execute_circular(trans)
+            return
+            
         obj_list = self.objects
         if not self.listonly and not self.isdelete:
             self.mapper.save_obj(obj_list, trans)
@@ -427,32 +393,151 @@ class UOWTask(object):
         if not self.listonly and self.isdelete:
             self.mapper.delete_obj(obj_list, trans)
 
-    def sort_circular_dependencies(self):
+    def execute_circular(self, trans):
+        print "execcircular " + str(self)
+#        obj_list = self.objects
+ #       if not self.listonly and not self.isdelete:
+ #           self.mapper.save_obj(obj_list, trans)
+ #       raise "hi"
+        self.execute(trans)
+        for obj in self.objects:
+            childtask = self.taskhash[obj]
+            childtask.execute_circular(trans)
+    
+    
+    def sort_circular_dependencies(self, trans):
+        allobjects = self.objects
+        tuples = []
+        for obj in self.objects:
+            for dep in self.dependencies:
+                (processor, targettask) = dep
+                if targettask is self:
+                    childlist = processor.get_object_dependencies(obj, trans, passive = True)
+                    for o in childlist.added_items() + childlist.deleted_items():
+                        whosdep = processor.whose_dependent_on_who(obj, o, trans)
+                        if whosdep is not None:
+                            tuples.append(whosdep)
+        head = TupleSorter(tuples, allobjects).sort()
+        print "---------"
+        print str(head)
+        raise "hi"
+        
+    def old_sort_circular_dependencies(self, trans):
+        dependents = {}
         d = {}
-        head = None
-        for obj in obj_list:
-            d[obj] = UOWTask(self.mapper, self.isdelete, self.listonly)
-            d[obj].dependencies = self.dependencies
-            if head is None:
-                head = obj
+
+        def make_task():
+            t = UOWTask(self.mapper, self.isdelete, self.listonly)
+            t.dependencies = self.dependencies
+            t.taskhash = d
+            return t
+
+        head = make_task()
+        for obj in self.objects:
+            print "obj: " + str(obj)
+            task  = make_task()
+            d[obj] = task
+            if not dependents.has_key(obj):
+                head.objects.append(obj)
             for dep in self.dependencies:
                 (processor, targettask) = dep
-                if targetttask is self:
-                    for o in processor.get_object_dependencies(obj, self, passive = True):
-                        if o is head:
-                            head = obj
-                        d[obj].objects.append(o)
-        if head is None:
-            return self
-        else:
-            return d[head]
+                if targettask is self:
+                    childlist = processor.get_object_dependencies(obj, trans, passive = True)
+                    for o in childlist.added_items() + childlist.deleted_items():
+                        whosdep = processor.whose_dependent_on_who(obj, o, trans)
+                        if whosdep is not None:
+                            (child, parent) = whosdep
+                            if not d.has_key(parent):
+                                d[parent] = make_task()
+                            if dependents.has_key(child):
+                                p2 = dependents[child]
+                                wd2 = processor.whose_dependent_on_who(parent, p2, trans)
+                                
+                            d[parent].objects.append(child)
+                            dependents[child] = parent
+                            print "dependent obj: " + str(child) + " is dependent in relation " + str(obj) + " " + str(o)
+                            if head.objects.contains(child):
+                                del head.objects[child]
+
+        def printtask(t):
+            print "l1"
+            print repr([str(v) for v in t.objects])
+            for v in t.objects:
+                t2 = t.taskhash[v]
+                print "l2"
+                print repr([str(v2) for v2 in t2.objects])
+                for v3 in t2.objects:
+                    t3 = t.taskhash[v3]
+                    print "l3"
+                    print repr([str(v4) for v4 in t3.objects])
+#                printtask(t2)
+        print "sorted hierarchical tasks: "
+        printtask(head)
+        raise "hi"
+        return head
         
     def __str__(self):
         if self.isdelete:
             return self.mapper.primarytable.name + " deletes " + repr(self.listonly)
         else:
             return self.mapper.primarytable.name + " saves " + repr(self.listonly)
+
+class TupleSorter(object):
+
+    class Node:
+        def __init__(self, mapper):
+            #print "new node on " + str(mapper)
+            self.mapper = mapper
+            self.children = util.HashSet()
+            self.parent = None
+        def __str__(self):
+            return self.safestr({})
+        def safestr(self, hash):
+            if hash.has_key(self):
+                return "[RECURSIVE:%s(%s, %s)]" % (str(self.mapper), repr(id(self)), repr(id(self.parent)))
+            hash[self] = True
+            return "%s(%s, %s)" % (str(self.mapper), repr(id(self)), repr(id(self.parent))) + "\n" + string.join([n.safestr(hash) for n in self.children])
+            
+    def __init__(self, tuples, allitems):
+        self.tuples = tuples
+        self.allitems = allitems
+    def sort(self):
+        (tuples, allitems) = (self.tuples, self.allitems)
+        nodes = {}
+        head = None
+        for tup in tuples:
+            (parent, child) = (tup[0], tup[1])
+            print "tuple: " + str(parent) + " " + str(child)
+            try:
+                parentnode = nodes[parent]
+            except KeyError:
+                parentnode = TupleSorter.Node(parent)
+                nodes[parent] = parentnode
+            try:
+                childnode = nodes[child]
+            except KeyError:
+                childnode = TupleSorter.Node(child)
+                nodes[child] = childnode
+
+            if head is None:
+                head = parentnode
+            elif head is childnode:
+                head = parentnode
+            if childnode.parent is not None:
+                del childnode.parent.children[childnode]
+                childnode.parent.children.append(parentnode)
+            parentnode.children.append(childnode)
+            childnode.parent = parentnode
             
+        for item in allitems:
+            if not nodes.has_key(item):
+                node = TupleSorter.Node(item)
+                if head is not None:
+                    head.parent = node
+                    node.children.append(head)
+                head = node
+        return head
+                    
 uow = util.ScopedRegistry(lambda: UnitOfWork(), "thread")
 
 
index bc13f5dbde6da1b75e50529bfc5230e189be92f4..446b87cb9516a04c311d5be49f4ddbcf70fe4782 100644 (file)
@@ -121,28 +121,20 @@ class HashSet(object):
         if iter is not None:
             for i in iter:
                 self.append(i)
-        
     def __iter__(self):
         return iter(self.map.values())
     def contains(self, item):
         return self.map.has_key(item)
-
     def clear(self):
         self.map.clear()
-        
     def append(self, item):
         self.map[item] = item
-
     def __add__(self, other):
         return HashSet(self.map.values() + [i for i in other])
-        
     def __len__(self):
         return len(self.map)
-        
     def __delitem__(self, key):
         del self.map[key]
     def __getitem__(self, key):
         return self.map[key]
         
@@ -179,7 +171,8 @@ class HistoryArraySet(UserList.UserList):
             if not self._setrecord(self.data[i]):
                del self.data[i]
                i -= 1
-
+    def history_contains(self, obj):
+        return self.records.has_key(obj)
     def __hash__(self):
         return id(self)
     def _setrecord(self, item):