]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 06:25:43 +0000 (06:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 06:25:43 +0000 (06:25 +0000)
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/util.py
test/dependency.py [new file with mode: 0644]

index 1d0349a07b7069a6e1f982131ed34f557501f06b..547a2c0b78c9bbe61469c3f283d32edf143e2ca6 100644 (file)
@@ -26,8 +26,11 @@ from sqlalchemy.ansisql import *
 try:
     import psycopg2 as psycopg
 except:
-    import psycopg
-
+    try:
+        import psycopg
+    except:
+        psycopg = None
+        
 class PGNumeric(sqltypes.Numeric):
     def get_col_spec(self):
         return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
@@ -70,8 +73,9 @@ def engine(opts, **params):
 class PGSQLEngine(ansisql.ANSISQLEngine):
     def __init__(self, opts, module = None, **params):
         if module is None:
-            self.module = __import__('psycopg2')
-            #self.module = psycopg
+            if psycopg is None:
+                raise "Couldnt locate psycopg1 or psycopg2: specify postgres module argument"
+            self.module = psycopg
         else:
             self.module = module
         self.opts = opts or {}
index 5f65f02942d28159a3dabb0b84c2498a3af03fe2..c90b50769da1f7a93f454b7a756142d983d5bf17 100644 (file)
@@ -443,11 +443,12 @@ class Mapper(object):
                     (obj, params) = rec
                     statement.execute(**params)
                     primary_keys = table.engine.last_inserted_ids()
-                    i = 0
-                    for col in self.primary_keys[table]:
-                        if self._getattrbycolumn(obj, col) is None:
-                            self._setattrbycolumn(obj, col, primary_keys[i])
-                        i+=1
+                    if primary_keys is not None:
+                        i = 0
+                        for col in self.primary_keys[table]:
+                            if self._getattrbycolumn(obj, col) is None:
+                                self._setattrbycolumn(obj, col, primary_keys[i])
+                            i+=1
                     self.extension.after_insert(self, obj)
                     
     def delete_obj(self, objects, uow):
index 4fe35272b8914c60b55ada5c0ba584ed08dc5f7d..9b414f1818d9a934b67b527cc245f2c8033a401e 100644 (file)
@@ -324,7 +324,7 @@ class UOWTransaction(object):
             task.mapper.register_dependencies(self)
 
         head = self._sort_dependencies()
-        print "Task dump:\n" + head.dump()
+        #print "Task dump:\n" + head.dump()
         if head is not None:
             head.execute(self)
             
@@ -434,15 +434,12 @@ class UOWTask(object):
             self.circular.execute(trans)
             return
 
-#        print "task " + str(self) + " tosave: " + repr(self.tosave_objects())
         self.mapper.save_obj(self.tosave_objects(), trans)
         for dep in self.save_dependencies():
             (processor, targettask, isdelete) = dep
             processor.process_dependencies(targettask, [elem.obj for elem in targettask.tosave_elements()], trans, delete = False)
- #           print "processed dependencies on " + repr([elem.obj for elem in targettask.tosave_elements()])
         for element in self.tosave_elements():
             if element.childtask is not None:
-#                print "execute elem childtask " + str(element.childtask)
                 element.childtask.execute(trans)
         for dep in self.delete_dependencies():
             (processor, targettask, isdelete) = dep
index da5e7c7e82df5eb761d9fc44dd66518c957c3ba8..063291dde7b53b7efc385237fe42fb736c88ae93 100644 (file)
@@ -337,55 +337,46 @@ class DependencySorter(object):
             self.children = HashSet()
             self.parent = None
             self.circular = False
+        def append(self, node):
+            if node.parent is not None:
+                del node.parent.children[node]
+            self.children.append(node)
+            node.parent = self
         def __str__(self):
             return self.safestr({})
         def safestr(self, hash, indent = 0):
             if hash.has_key(self):
                 return (' ' * indent) + "RECURSIVE:%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or 'None')
             hash[self] = True
-            return (' ' * indent) + "%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '')
+            return (' ' * indent) + "%s  (idself=%s, idparent=%s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '')
 
     def __init__(self, tuples, allitems):
         self.tuples = tuples
         self.allitems = allitems
     def sort(self):
         (tuples, allitems) = (self.tuples, self.allitems)
+        
         nodes = {}
-        head = None
+        # make nodes for all the items and store in the hash
+        for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
+            if not nodes.has_key(item):
+                nodes[item] = DependencySorter.Node(item)
+
+        # loop through tuples
         for tup in tuples:
             (parent, child) = (tup[0], tup[1])
-            #print "tuple: " + str(parent) + " " + str(child)
-
             # get parent node
-            try:
-                parentnode = nodes[parent]
-            except KeyError:
-                parentnode = DependencySorter.Node(parent)
-                nodes[parent] = parentnode
+            parentnode = nodes[parent]
 
             # if parent is child, mark "circular" attribute on the node
             if parent is child:
                 parentnode.circular = True
-                # set head if its nothing
-                if head is None:
-                    head = parentnode
-                # nothing more to do for this one
+                # and just continue
                 continue
 
             # get child node
-            try:
-                childnode = nodes[child]
-            except KeyError:
-                childnode = DependencySorter.Node(child)
-                nodes[child] = childnode
-
-            # set head if its nothing, move it up to the parent
-            # if its the child node
-            if head is None:
-                head = parentnode
-            elif head is childnode:
-                head = parentnode
-
+            childnode = nodes[child]
+                    
             # now see, if the parent is an ancestor of the child
             c = childnode
             while c is not None and c is not parentnode:
@@ -394,28 +385,17 @@ class DependencySorter(object):
             # nope, so we have to move the child down from whereever
             # it currently is to a child of the parent
             if c is None:
-                if childnode.parent is not None:
-                    del childnode.parent.children[childnode]
-                    childnode.parent.children.append(parentnode)
-                parentnode.children.append(childnode)
-                childnode.parent = parentnode
-
-        # go through the total list of items.  for those 
-        # that had no dependency tuples, and therefore are not
-        # in the tree, add them as head nodes in a line
-        newhead = None
-        for item in allitems:
-            if not nodes.has_key(item):
-                if newhead is None:
-                    newhead = DependencySorter.Node(item)
-                    if head is not None:
-                        head.parent = newhead
-                        newhead.children.append(head)
-                    head = newhead
+                parentnode.append(childnode)
+        
+        # now we have a collection of subtrees which represent dependencies.
+        # go through the collection root nodes wire them together into one tree        
+        head = None
+        for node in nodes.values():
+            if node.parent is None:
+                if head is not None:
+                    head.append(node)
                 else:
-                    n = DependencySorter.Node(item)
-                    head.children.append(n)
-                    n.parent = head
-        #print str(head)
+                    head = node
+
         return head
             
\ No newline at end of file
diff --git a/test/dependency.py b/test/dependency.py
new file mode 100644 (file)
index 0000000..7e04e7e
--- /dev/null
@@ -0,0 +1,57 @@
+from testbase import PersistTest
+import sqlalchemy.util as util
+import unittest, sys, os
+
+class thingy(object):
+    def __init__(self, name):
+        self.name = name
+    def __repr__(self):
+        return "thingy(%d, %s)" % (id(self), self.name)
+    def __str__(self):
+        return repr(self)
+        
+class DependencySortTest(PersistTest):
+    def testsort(self):
+        rootnode = thingy('root')
+        node2 = thingy('node2')
+        node3 = thingy('node3')
+        node4 = thingy('node4')
+        subnode1 = thingy('subnode1')
+        subnode2 = thingy('subnode2')
+        subnode3 = thingy('subnode3')
+        subnode4 = thingy('subnode4')
+        subsubnode1 = thingy('subsubnode1')
+        tuples = [
+            (subnode3, subsubnode1),
+            (node2, subnode1),
+            (node2, subnode2),
+            (rootnode, node2),
+            (rootnode, node3),
+            (rootnode, node4),
+            (node4, subnode3),
+            (node4, subnode4)
+        ]
+        head = util.DependencySorter(tuples, []).sort()
+        print str(head)
+
+    def testsort2(self):
+        node1 = thingy('node1')
+        node2 = thingy('node2')
+        node3 = thingy('node3')
+        node4 = thingy('node4')
+        node5 = thingy('node5')
+        node6 = thingy('node6')
+        node7 = thingy('node7')
+        tuples = [
+            (node1, node2),
+            (node3, node4),
+            (node5, node6),
+            (node6, node2)
+        ]
+        head = util.DependencySorter(tuples, [node7]).sort()
+        print "\n" + str(head)
+
+
+
+if __name__ == "__main__":
+    unittest.main()