]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
polymorphic linked list test, tests polymorphic inheritance with circular refs
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 May 2006 22:15:16 +0000 (22:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 May 2006 22:15:16 +0000 (22:15 +0000)
test/alltests.py
test/poly_linked_list.py [new file with mode: 0644]
test/testbase.py

index 99573fc1d9ff9bd8440bab2ba3e7fd695050d683..a0ee689fc3d1343e9fca2a943e9afc93b12e57f3 100644 (file)
@@ -44,6 +44,7 @@ def suite():
         
         # cyclical ORM persistence
         'cycles',
+        'poly_linked_list',
         
         # more select/persistence, backrefs
         'entity',
diff --git a/test/poly_linked_list.py b/test/poly_linked_list.py
new file mode 100644 (file)
index 0000000..0dfeb31
--- /dev/null
@@ -0,0 +1,176 @@
+from sqlalchemy import *
+import testbase
+
+class PolymorphicCircularTest(testbase.PersistTest):
+    def setUpAll(self):
+        global metadata
+        global Table1, Table1B, Table2, Table3,  Data
+        metadata = BoundMetaData(testbase.db)
+
+        table1 = Table('table1', metadata,
+                       Column('id', Integer, primary_key=True),
+                       Column('related_id', Integer, ForeignKey('table1.id'), nullable=True),
+                       Column('type', String(30)),
+                       Column('name', String(30))
+                       )
+
+        table2 = Table('table2', metadata,
+                       Column('id', Integer, ForeignKey('table1.id'), primary_key=True),
+                       )
+
+        table3 = Table('table3', metadata,
+                      Column('id', Integer, ForeignKey('table1.id'), primary_key=True),
+                      )
+
+        data = Table('data', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('node_id', Integer, ForeignKey('table1.id')),
+            Column('data', String(30))
+            )
+            
+        metadata.create_all()
+
+        join = polymorphic_union(
+            {
+            'table3' : table1.join(table3),
+            'table2' : table1.join(table2),
+            'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
+            }, None, 'pjoin')
+
+        # still with us so far ?
+        
+        class Table1(object):
+            def __init__(self, name, data=None):
+                self.name = name
+                if data is not None:
+                    self.data = data
+            def __repr__(self):
+                return "%s(%d, %s, %s)" % (self.__class__.__name__, self.id, repr(str(self.name)), repr(self.data))
+
+        class Table1B(Table1):
+            pass
+            
+        class Table2(Table1):
+            pass
+
+        class Table3(Table1):
+            pass
+    
+        class Data(object):
+            def __init__(self, data):
+                self.data = data
+            def __repr__(self):
+                return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data)))
+            
+        # currently, all of these "eager" relationships degrade to lazy relationships
+        # due to the polymorphic load.
+        table1_mapper = mapper(Table1, table1,
+                               select_table=join,
+                               polymorphic_on=join.c.type,
+                               polymorphic_identity='table1',
+                               properties={
+                                'next': relation(Table1, 
+                                    backref=backref('prev', primaryjoin=join.c.id==join.c.related_id, foreignkey=join.c.id, uselist=False), 
+                                    uselist=False, lazy=False, primaryjoin=join.c.id==join.c.related_id),
+                                'data':relation(mapper(Data, data), lazy=False)
+                                }
+                        )
+
+        table1b_mapper = mapper(Table1B, inherits=table1_mapper, polymorphic_identity='table1b')
+
+        table2_mapper = mapper(Table2, table2,
+                               inherits=table1_mapper,
+                               polymorphic_identity='table2')
+
+        table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3')
+    def tearDown(self):
+        for t in metadata.table_iterator(reverse=True):
+            t.delete().execute()
+            
+    def tearDownAll(self):
+        clear_mappers()
+        metadata.drop_all()
+
+    def testone(self):
+        self.do_testlist([Table1, Table2, Table1, Table2])
+
+    def testtwo(self):
+        self.do_testlist([Table3])
+        
+    def testthree(self):
+        self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1])
+
+    def testfour(self):
+        self.do_testlist([
+                Table2('t2', [Data('data1'), Data('data2')]), 
+                Table1('t1', []),
+                Table3('t3', [Data('data3')]),
+                Table1B('t1b', [Data('data4'), Data('data5')])
+                ])
+        
+    def do_testlist(self, classes):
+        sess = create_session(echo_uow=False)
+
+        # create objects in a linked list
+        count = 1
+        obj = None
+        for c in classes:
+            if isinstance(c, type):
+                newobj = c('item %d' % count)
+                count += 1
+            else:
+                newobj = c
+            if obj is not None:
+                obj.next = newobj
+            else:
+                t = newobj
+            obj = newobj
+
+        # save to DB
+        sess.save(t)
+        sess.flush()
+        
+        # string version of the saved list
+        assertlist = []
+        node = t
+        while (node):
+            assertlist.append(node)
+            n = node.next
+            if n is not None:
+                assert n.prev is node
+            node = n
+        original = repr(assertlist)
+
+
+        # clear and query forwards
+        sess.clear()
+        node = sess.query(Table1).selectfirst(Table1.c.id==t.id)
+        assertlist = []
+        while (node):
+            assertlist.append(node)
+            n = node.next
+            if n is not None:
+                assert n.prev is node
+            node = n
+        forwards = repr(assertlist)
+
+        # clear and query backwards
+        sess.clear()
+        node = sess.query(Table1).selectfirst(Table1.c.id==obj.id)
+        assertlist = []
+        while (node):
+            assertlist.insert(0, node)
+            n = node.prev
+            if n is not None:
+                assert n.next is node
+            node = n
+        backwards = repr(assertlist)
+        
+        # everything should match !
+        print original
+        print backwards
+        print forwards
+        assert original == forwards == backwards
+
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file
index 57fcb58c30e2add6dc30251c3b289fe210f203e6..bec34487f946b6dc1753c265e8503496db76fe4c 100644 (file)
@@ -39,6 +39,7 @@ def parse_argv():
     parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (sqlite, sqlite_file, postgres, mysql, oracle, oracle8, mssql)")
     parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool")
     parser.add_option("--verbose", action="store_true", dest="verbose", help="full debug echoing")
+    parser.add_option("--noecho", action="store_true", dest="noecho", help="Disable SQL statement echoing")
     parser.add_option("--quiet", action="store_true", dest="quiet", help="be totally quiet")
     parser.add_option("--nothreadlocal", action="store_true", dest="nothreadlocal", help="dont use thread-local mod")
     parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to SA default)")
@@ -86,9 +87,9 @@ def parse_argv():
     if options.enginestrategy is not None:
         opts['strategy'] = options.enginestrategy    
     if options.mockpool:
-        db = engine.create_engine(db_uri, echo=True, default_ordering=True, poolclass=MockPool, **opts)
+        db = engine.create_engine(db_uri, echo=(not options.noecho), default_ordering=True, poolclass=MockPool, **opts)
     else:
-        db = engine.create_engine(db_uri, echo=True, default_ordering=True, **opts)
+        db = engine.create_engine(db_uri, echo=(not options.noecho), default_ordering=True, **opts)
     db = EngineAssert(db)
     metadata = sqlalchemy.BoundMetaData(db)