]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
had to take out the "treeification" of the dependency sort as it doenst really work...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 Apr 2006 18:11:54 +0000 (18:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 Apr 2006 18:11:54 +0000 (18:11 +0000)
CHANGES
lib/sqlalchemy/mapping/topological.py
lib/sqlalchemy/mapping/unitofwork.py
lib/sqlalchemy/sql.py
test/dependency.py
test/relationships.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 60b81cf7875b45c958e6325c6776f73e923f1325..99eb0bc10f6c3651622e7f999a7a2bd53554d173 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,6 @@
+next
+- some fixes to topological sort algorithm
+
 0.1.6
 - support for MS-SQL added courtesy Rick Morrison, Runar Petursson
 - the latest SQLSoup from J. Ellis
index 95807bf5f269f2298e9172d7521170371a41d004..779faab2d4b5e9bae93e1365e4c9ec2c2b7420f7 100644 (file)
@@ -141,15 +141,16 @@ class QueueDependencySorter(object):
         #print repr(output)
         head = None
         node = None
+        # put the sorted list into a "tree".  this is not much of a 
+        # "tree" at the moment as its more of a linked list.  it would be nice
+        # to group non-dependent nodes into sibling nodes, which allows better batching
+        # of SQL statements, but this algorithm has proved tricky
         for o in output:
             if head is None:
                 head = o
-                node = o
             else:
-                for x in node.children:
-                    if x.dependencies.has_key(o):
-                        node = x
                 node.children.append(o)
+            node = o
         return head
 
     def _add_edge(self, edges, edge):
index 3ef1d96aec85be3c7578dd949ec3f3f33aeb1437..873bed54832053d4c2cffa3162fb30b1b00a3d67 100644 (file)
@@ -422,7 +422,6 @@ class UOWTransaction(object):
         mappers = util.HashSet()
         for task in self.tasks.values():
             mappers.append(task.mapper)
-    
         head = DependencySorter(self.dependencies, mappers).sort(allow_all_cycles=True)
         #print str(head)
         task = sort_hier(head)
index 7129781a70d888ad0cade060e25e8a4bda7c7529..b18b0916e91fd42a40a9039a4a8d35b652ff4c58 100644 (file)
@@ -917,7 +917,6 @@ class Join(FromClause):
     def __init__(self, left, right, onclause=None, isouter = False):
         self.left = left
         self.right = right
-        
         # TODO: if no onclause, do NATURAL JOIN
         if onclause is None:
             self.onclause = self._match_primaries(left, right)
@@ -925,6 +924,8 @@ class Join(FromClause):
             self.onclause = onclause
         self.isouter = isouter
 
+    name = property(lambda self: "Join on %s, %s" % (self.left.name, self.right.name))
+
     def _locate_oid_column(self):
         return self.left.oid_column
     
index 5fd3df2fd4f2710571921b18cbdd183356aff9f0..0aede4c7e5555f901c7b598d4d3965d3927dfd30 100644 (file)
@@ -17,6 +17,26 @@ class thingy(object):
         return repr(self)
         
 class DependencySortTest(PersistTest):
+    
+    def _assert_sort(self, tuples, allnodes, **kwargs):
+
+        head = DependencySorter(tuples, allnodes).sort(**kwargs)
+
+        print "\n" + str(head)
+        def findnode(t, n, parent=False):
+            if n.item is t[0]:
+                parent=True
+            elif n.item is t[1]:
+                if not parent and t[0] not in [c.item for c in n.cycles]:
+                    self.assert_(False, "Node " + str(t[1]) + " not a child of " +str(t[0]))
+                else:
+                    return
+            for c in n.children:
+                findnode(t, c, parent)
+            
+        for t in tuples:
+            findnode(t, head)
+            
     def testsort(self):
         rootnode = thingy('root')
         node2 = thingy('node2')
@@ -27,6 +47,7 @@ class DependencySortTest(PersistTest):
         subnode3 = thingy('subnode3')
         subnode4 = thingy('subnode4')
         subsubnode1 = thingy('subsubnode1')
+        allnodes = [rootnode, node2,node3,node4,subnode1,subnode2,subnode3,subnode4,subsubnode1]
         tuples = [
             (subnode3, subsubnode1),
             (node2, subnode1),
@@ -37,8 +58,8 @@ class DependencySortTest(PersistTest):
             (node4, subnode3),
             (node4, subnode4)
         ]
-        head = DependencySorter(tuples, []).sort()
-        print "\n" + str(head)
+
+        self._assert_sort(tuples, allnodes)
 
     def testsort2(self):
         node1 = thingy('node1')
@@ -55,8 +76,7 @@ class DependencySortTest(PersistTest):
             (node5, node6),
             (node6, node2)
         ]
-        head = DependencySorter(tuples, [node7]).sort()
-        print "\n" + str(head)
+        self._assert_sort(tuples, [node1,node2,node3,node4,node5,node6,node7])
 
     def testsort3(self):
         ['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords']
@@ -68,15 +88,10 @@ class DependencySortTest(PersistTest):
             (node3, node2),
             (node1,node3)
         ]
-        head1 = DependencySorter(tuples, [node1, node2, node3]).sort()
-        head2 = DependencySorter(tuples, [node3, node1, node2]).sort()
-        head3 = DependencySorter(tuples, [node3, node2, node1]).sort()
+        self._assert_sort(tuples, [node1, node2, node3])
+        self._assert_sort(tuples, [node3, node1, node2])
+        self._assert_sort(tuples, [node3, node2, node1])
         
-        # TODO: figure out a "node == node2" function
-        #self.assert_(str(head1) == str(head2) == str(head3))
-        print "\n" + str(head1)
-        print "\n" + str(head2)
-        print "\n" + str(head3)
 
     def testsort4(self):
         node1 = thingy('keywords')
@@ -89,8 +104,7 @@ class DependencySortTest(PersistTest):
             (node1, node3),
             (node3, node2)
         ]
-        head = DependencySorter(tuples, []).sort()
-        print "\n" + str(head)
+        self._assert_sort(tuples, [node1,node2,node3,node4])
 
     def testsort5(self):
         # this one, depenending on the weather, 
@@ -117,8 +131,21 @@ class DependencySortTest(PersistTest):
             node3,
             node4
         ]
-        head = DependencySorter(tuples, allitems).sort()
-        print "\n" + str(head)
+        self._assert_sort(tuples, allitems)
+
+    def testsort6(self):
+        #('tbl_c', 'tbl_d'), ('tbl_a', 'tbl_c'), ('tbl_b', 'tbl_d')
+        nodea = thingy('tbl_a')
+        nodeb = thingy('tbl_b')
+        nodec = thingy('tbl_c')
+        noded = thingy('tbl_d')
+        tuples = [
+            (nodec, noded),
+            (nodea, nodec),
+            (nodeb, noded)
+        ]
+        allitems = [nodea,nodeb,nodec,noded]
+        self._assert_sort(tuples, allitems)
 
     def testcircular(self):
         node1 = thingy('node1')
@@ -134,8 +161,7 @@ class DependencySortTest(PersistTest):
             (node3, node1),
             (node4, node1)
         ]
-        head = DependencySorter(tuples, []).sort(allow_all_cycles=True)
-        print "\n" + str(head)
+        self._assert_sort(tuples, [node1,node2,node3,node4,node5], allow_all_cycles=True)
         
 
 if __name__ == "__main__":
diff --git a/test/relationships.py b/test/relationships.py
new file mode 100644 (file)
index 0000000..36f5fe3
--- /dev/null
@@ -0,0 +1,99 @@
+"""Test complex relationships"""
+
+import testbase
+import unittest, sys, datetime
+
+db = testbase.db
+#db.echo_uow=True
+
+from sqlalchemy import *
+
+
+class RelationTest(testbase.PersistTest):
+    """this is essentially an extension of the "dependency.py" topological sort test.  this exposes
+    a particular issue that doesnt always occur with the straight dependency tests, due to the nature
+    of the sort being different based on random conditions"""
+    def setUpAll(self):
+        testbase.db.tables.clear()
+        global tbl_a
+        global tbl_b
+        global tbl_c
+        global tbl_d
+        tbl_a = Table("tbl_a", db,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+        )
+        tbl_b = Table("tbl_b", db,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+        )
+        tbl_c = Table("tbl_c", db,
+            Column("id", Integer, primary_key=True),
+            Column("tbl_a_id", Integer, ForeignKey("tbl_a.id"), nullable=False),
+            Column("name", String),
+        )
+        tbl_d = Table("tbl_d", db,
+            Column("id", Integer, primary_key=True),
+            Column("tbl_c_id", Integer, ForeignKey("tbl_c.id"), nullable=False),
+            Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
+            Column("name", String),
+        )
+    def setUp(self):
+        tbl_a.create()
+        tbl_b.create()
+        tbl_c.create()
+        tbl_d.create()
+
+        objectstore.clear()
+        clear_mappers()
+
+        class A(object):
+            pass
+        class B(object):
+            pass
+        class C(object):
+            pass
+        class D(object):
+            pass
+
+        D.mapper = mapper(D, tbl_d)
+        C.mapper = mapper(C, tbl_c, properties=dict(
+            d_rows=relation(D, private=True, backref="c_row"),
+        ))
+        B.mapper = mapper(B, tbl_b)
+        A.mapper = mapper(A, tbl_a, properties=dict(
+            c_rows=relation(C, private=True, backref="a_row"),
+        ))
+        D.mapper.add_property("b_row", relation(B))
+
+        global a
+        global c
+        a = A(); a.name = "a1"
+        b = B(); b.name = "b1"
+        c = C(); c.name = "c1"; c.a_row = a
+        # we must have more than one d row or it won't fail
+        d = D(); d.name = "d1"; d.b_row = b; d.c_row = c
+        d = D(); d.name = "d2"; d.b_row = b; d.c_row = c
+        d = D(); d.name = "d3"; d.b_row = b; d.c_row = c
+
+    def tearDown(self):
+        tbl_d.drop()
+        tbl_c.drop()
+        tbl_b.drop()
+        tbl_a.drop()
+    
+    def testDeleteRootTable(self):
+        session = objectstore.get_session()
+        session.commit()
+        session.delete(a) # works as expected
+        session.commit()
+
+    def testDeleteMiddleTable(self):
+        session = objectstore.get_session()
+        session.commit()
+        session.delete(c) # fails
+        session.commit()
+        
+        
+if __name__ == "__main__":
+    testbase.main()