]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a standardized test harness for ORM tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 Jan 2007 22:31:39 +0000 (22:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 Jan 2007 22:31:39 +0000 (22:31 +0000)
- added three-level mapping test.  needed some massaging for postgres

test/orm/inheritance.py
test/orm/inheritance2.py
test/orm/inheritance3.py
test/orm/inheritance4.py
test/orm/inheritance5.py
test/orm/manytomany.py
test/testbase.py

index d01fc5c4f1878738c8d564528b440fe0cb5eeb73..d09e685b3901d251b7efe3fd5a31a8c14b497621 100644 (file)
@@ -14,15 +14,14 @@ class User( Principal ):
 class Group( Principal ):
     pass
 
-class InheritTest(testbase.AssertMixin):
+class InheritTest(testbase.ORMTest):
     """deals with inheritance and many-to-many relationships"""
-    def setUpAll(self):
+    def define_tables(self, metadata):
         global principals
         global users
         global groups
         global user_group_map
-        global metadata
-        metadata = BoundMetaData(testbase.db)
+
         principals = Table(
             'principals',
             metadata,
@@ -57,14 +56,6 @@ class InheritTest(testbase.AssertMixin):
 
             )
 
-        metadata.create_all()
-        
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def setUp(self):
-        clear_mappers()
-        
     def testbasic(self):
         mapper( Principal, principals )
         mapper( 
@@ -87,11 +78,10 @@ class InheritTest(testbase.AssertMixin):
         sess.flush()
         # TODO: put an assertion
         
-class InheritTest2(testbase.AssertMixin):
+class InheritTest2(testbase.ORMTest):
     """deals with inheritance and many-to-many relationships"""
-    def setUpAll(self):
-        global foo, bar, foo_bar, metadata
-        metadata = BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global foo, bar, foo_bar
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
             Column('data', String(20)),
@@ -105,9 +95,6 @@ class InheritTest2(testbase.AssertMixin):
         foo_bar = Table('foo_bar', metadata,
             Column('foo_id', Integer, ForeignKey('foo.id')),
             Column('bar_id', Integer, ForeignKey('bar.bid')))
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
 
     def testbasic(self):
         class Foo(object): 
@@ -147,11 +134,11 @@ class InheritTest2(testbase.AssertMixin):
             {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])},
             )
 
-class InheritTest3(testbase.AssertMixin):
+class InheritTest3(testbase.ORMTest):
     """deals with inheritance and many-to-many relationships"""
-    def setUpAll(self):
-        global foo, bar, blub, bar_foo, blub_bar, blub_foo,metadata
-        metadata = BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global foo, bar, blub, bar_foo, blub_bar, blub_foo
+
         # the 'data' columns are to appease SQLite which cant handle a blank INSERT
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_seq'), primary_key=True),
@@ -176,13 +163,6 @@ class InheritTest3(testbase.AssertMixin):
         blub_foo = Table('blub_foo', metadata,
             Column('blub_id', Integer, ForeignKey('blub.id')),
             Column('foo_id', Integer, ForeignKey('foo.id')))
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def tearDown(self):
-        for table in metadata.table_iterator():
-            table.delete().execute()
             
     def testbasic(self):
         class Foo(object):
@@ -253,11 +233,10 @@ class InheritTest3(testbase.AssertMixin):
         self.echo(x)
         self.assert_(repr(x) == compare)
         
-class InheritTest4(testbase.AssertMixin):
+class InheritTest4(testbase.ORMTest):
     """deals with inheritance and one-to-many relationships"""
-    def setUpAll(self):
-        global foo, bar, blub, metadata
-        metadata = BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global foo, bar, blub
         # the 'data' columns are to appease SQLite which cant handle a blank INSERT
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_seq'), primary_key=True),
@@ -271,13 +250,6 @@ class InheritTest4(testbase.AssertMixin):
             Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
             Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
             Column('data', String(20)))
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def tearDown(self):
-        for table in metadata.table_iterator():
-            table.delete().execute()
 
     def testbasic(self):
         class Foo(object):
@@ -316,12 +288,11 @@ class InheritTest4(testbase.AssertMixin):
         self.assert_(compare == result)
         self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
-class InheritTest5(testbase.AssertMixin):
+class InheritTest5(testbase.ORMTest):
     """testing that construction of inheriting mappers works regardless of when extra properties
     are added to the superclass mapper"""
-    def setUpAll(self):
-        global content_type, content, product, metadata
-        metadata = BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global content_type, content, product
         content_type = Table('content_type', metadata, 
             Column('id', Integer, primary_key=True)
             )
@@ -332,10 +303,6 @@ class InheritTest5(testbase.AssertMixin):
         product = Table('product', metadata, 
             Column('id', Integer, ForeignKey('content.id'), primary_key=True)
         )
-    def tearDownAll(self):
-        pass
-    def tearDown(self):
-        pass
 
     def testbasic(self):
         class ContentType(object): pass
@@ -366,12 +333,11 @@ class InheritTest5(testbase.AssertMixin):
         p.contenttype = ContentType()
         # TODO: assertion ??
         
-class InheritTest6(testbase.AssertMixin):
+class InheritTest6(testbase.ORMTest):
     """tests eager load/lazy load of child items off inheritance mappers, tests that
     LazyLoader constructs the right query condition."""
-    def setUpAll(self):
-        global foo, bar, bar_foo, metadata
-        metadata=BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global foo, bar, bar_foo
         foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True),
         Column('data', String(30)))
         bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
@@ -381,10 +347,7 @@ class InheritTest6(testbase.AssertMixin):
         Column('bar_id', Integer, ForeignKey('bar.id')),
         Column('foo_id', Integer, ForeignKey('foo.id'))
         )
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-        
+
     def testbasic(self):
         class Foo(object): pass
         class Bar(Foo): pass
@@ -413,11 +376,10 @@ class InheritTest6(testbase.AssertMixin):
         self.assert_(len(q.selectfirst().eager) == 1)
 
 
-class InheritTest7(testbase.AssertMixin):
+class InheritTest7(testbase.ORMTest):
     """test dependency sorting among inheriting mappers"""
-    def setUpAll(self):
-        global users, roles, user_roles, admins, metadata
-        metadata=BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global users, roles, user_roles, admins
         users = Table('users', metadata,
             Column('id', Integer, primary_key=True),
             Column('email', String(128)),
@@ -438,12 +400,6 @@ class InheritTest7(testbase.AssertMixin):
             Column('id', Integer, primary_key=True),
             Column('user_id', Integer, ForeignKey('users.id'))
         )
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-    def tearDown(self):
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
             
     def testone(self):
         class User(object):pass
index 94962aa65fc6240061b4503d566c764946c7eb6a..90652645665060b852413c93b14ab76667cf6537 100644 (file)
@@ -2,12 +2,11 @@ import testbase
 from sqlalchemy import *
 from datetime import datetime
 
-class InheritTest(testbase.AssertMixin):
+class InheritTest(testbase.ORMTest):
     """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
-    def setUpAll(self):
-        global metadata, products_table, specification_table, documents_table
+    def define_tables(self, metadata):
+        global products_table, specification_table, documents_table
         global Product, Detail, Assembly, SpecLine, Document, RasterDocument
-        metadata = BoundMetaData(testbase.db)
 
         products_table = Table('products', metadata,
            Column('product_id', Integer, primary_key=True),
@@ -37,8 +36,6 @@ class InheritTest(testbase.AssertMixin):
             Column('size', Integer, default=0),
             )
             
-        metadata.create_all()
-
         class Product(object):
             def __init__(self, name, mark=''):
                 self.name = name
@@ -77,14 +74,6 @@ class InheritTest(testbase.AssertMixin):
         class RasterDocument(Document): 
             pass
 
-    def tearDown(self):
-        clear_mappers()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
-            
-    def tearDownAll(self):
-        metadata.drop_all()
-
     def testone(self):
         product_mapper = mapper(Product, products_table,
             polymorphic_on=products_table.c.product_type,
index 8eca2fd57404855fe2bd30dda287799b4fa782fd..a9c88ef60c8dabc697091030a4975dbde27dd38b 100644 (file)
@@ -64,13 +64,11 @@ class MagazinePage(Page):
 class ClassifiedPage(MagazinePage):
     pass
 
-class InheritTest(testbase.AssertMixin):
+class InheritTest(testbase.ORMTest):
     """tests a large polymorphic relationship"""
-    def setUpAll(self):
-        global metadata, publication_table, issue_table, location_table, location_name_table, magazine_table, \
+    def define_tables(self, metadata):
+        global publication_table, issue_table, location_table, location_name_table, magazine_table, \
         page_table, magazine_page_table, classified_page_table, page_size_table
-        
-        metadata = BoundMetaData(testbase.db)
 
         zerodefault = {} #{'default':0}
         publication_table = Table('publication', metadata,
@@ -118,8 +116,6 @@ class InheritTest(testbase.AssertMixin):
             Column('name', String(45), default=''),
         )
 
-        metadata.create_all()
-        
         publication_mapper = mapper(Publication, publication_table)
 
         issue_mapper = mapper(Issue, issue_table, properties = {
@@ -163,14 +159,6 @@ class InheritTest(testbase.AssertMixin):
 
         classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c')
 
-    def tearDown(self):
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-        clear_mappers()
-        
     def testone(self):
         session = create_session()
 
index 7a97d06c149cac9fe04030d7623b10cdce6578da..9f4e275ae2df86a12dc80bce9f778b3a21cf57b9 100644 (file)
@@ -1,10 +1,9 @@
 from sqlalchemy import *
 import testbase
 
-class ConcreteTest1(testbase.AssertMixin):
-    def setUpAll(self):
-        global managers_table, engineers_table, metadata
-        metadata = BoundMetaData(testbase.db)
+class ConcreteTest1(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global managers_table, engineers_table
         managers_table = Table('managers', metadata, 
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
@@ -17,10 +16,6 @@ class ConcreteTest1(testbase.AssertMixin):
             Column('engineer_info', String(50)),
         )
 
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-        
     def testbasic(self):
         class Employee(object):
             def __init__(self, name):
index 4a90772b443795a648a15311917325d8233a9e05..ab948c0035a70800ad975e22437cfcfd9f71e95d 100644 (file)
@@ -8,11 +8,10 @@ class AttrSettable(object):
         return self.__class__.__name__ + ' ' + ','.join(["%s=%s" % (k,v) for k, v in self.__dict__.iteritems() if k[0] != '_'])
 
 
-class RelationTest1(testbase.PersistTest):
+class RelationTest1(testbase.ORMTest):
     """test self-referential relationships on polymorphic mappers"""
-    def setUpAll(self):
-        global people, managers, metadata
-        metadata = BoundMetaData(testbase.db)
+    def define_tables(self, metadata):
+        global people, managers
 
         people = Table('people', metadata, 
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
@@ -26,17 +25,6 @@ class RelationTest1(testbase.PersistTest):
            Column('manager_name', String(50))
            )
 
-        metadata.create_all()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def tearDown(self):
-        clear_mappers()
-        people.update().execute(manager_id=None)
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
-
     def testbasic(self):
         class Person(AttrSettable):
             pass
@@ -72,12 +60,10 @@ class RelationTest1(testbase.PersistTest):
         print p, m, p.manager
         assert p.manager is m
             
-class RelationTest2(testbase.AssertMixin):
+class RelationTest2(testbase.ORMTest):
     """test self-referential relationships on polymorphic mappers"""
-    def setUpAll(self):
-        global people, managers, metadata
-        metadata = BoundMetaData(testbase.db)
-
+    def define_tables(self, metadata):
+        global people, managers
         people = Table('people', metadata, 
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
            Column('name', String(50)),
@@ -89,16 +75,6 @@ class RelationTest2(testbase.AssertMixin):
            Column('status', String(30)),
            )
 
-        metadata.create_all()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def tearDown(self):
-        clear_mappers()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
-
     def testrelationonsubclass(self):
         class Person(AttrSettable):
             pass
@@ -130,12 +106,10 @@ class RelationTest2(testbase.AssertMixin):
         print m
         assert m.colleague is p
 
-class RelationTest3(testbase.AssertMixin):
+class RelationTest3(testbase.ORMTest):
     """test self-referential relationships on polymorphic mappers"""
-    def setUpAll(self):
-        global people, managers, metadata
-        metadata = BoundMetaData(testbase.db)
-
+    def define_tables(self, metadata):
+        global people, managers
         people = Table('people', metadata, 
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
            Column('colleague_id', Integer, ForeignKey('people.person_id')),
@@ -147,16 +121,6 @@ class RelationTest3(testbase.AssertMixin):
            Column('status', String(30)),
            )
 
-        metadata.create_all()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def tearDown(self):
-        clear_mappers()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
-
     def testrelationonbaseclass(self):
         class Person(AttrSettable):
             pass
@@ -193,10 +157,9 @@ class RelationTest3(testbase.AssertMixin):
         assert len(p.colleagues) == 1
         assert p.colleagues == [p2]
 
-class RelationTest4(testbase.AssertMixin):
-    def setUpAll(self):
-        global metadata, people, engineers, managers, cars
-        metadata = BoundMetaData(testbase.db)
+class RelationTest4(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global people, engineers, managers, cars
         people = Table('people', metadata, 
            Column('person_id', Integer, primary_key=True),
            Column('name', String(50)))
@@ -212,13 +175,6 @@ class RelationTest4(testbase.AssertMixin):
         cars = Table('cars', metadata, 
            Column('car_id', Integer, primary_key=True),
            Column('owner', Integer, ForeignKey('people.person_id')))
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-    def tearDown(self):
-        clear_mappers()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
     
     def testmanytoonepolymorphic(self):
         """in this test, the polymorphic union is between two subclasses, but does not include the base table by itself
@@ -301,10 +257,9 @@ class RelationTest4(testbase.AssertMixin):
         car1 = session.query(Car).options(eagerload('employee')).get(car1.car_id)
         assert str(car1.employee) == "Engineer E4, status X"
 
-class RelationTest5(testbase.AssertMixin):
-    def setUpAll(self):
-        global metadata, people, engineers, managers, cars
-        metadata = BoundMetaData(testbase.db)
+class RelationTest5(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global people, engineers, managers, cars
         people = Table('people', metadata, 
            Column('person_id', Integer, primary_key=True),
            Column('name', String(50)),
@@ -321,13 +276,6 @@ class RelationTest5(testbase.AssertMixin):
         cars = Table('cars', metadata, 
            Column('car_id', Integer, primary_key=True),
            Column('owner', Integer, ForeignKey('people.person_id')))
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-    def tearDown(self):
-        clear_mappers()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
     
     def testeagerempty(self):
         """an easy one...test parent object with child relation to an inheriting mapper, using eager loads,
@@ -368,6 +316,76 @@ class RelationTest5(testbase.AssertMixin):
         carlist = sess.query(Car).select()
         assert carlist[0].manager is None
         assert carlist[1].manager.person_id == car2.manager.person_id
+
+class MultiLevelTest(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global table_Employee, table_Engineer, table_Manager
+        table_Employee = Table( 'Employee', metadata,
+            Column( 'name', type= String, ),
+            Column( 'id', primary_key= True, type= Integer, ),
+            Column( 'atype', type= String, ),
+        )
+
+        table_Engineer = Table( 'Engineer', metadata,
+            Column( 'machine', type= String, ),
+            Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ),
+        )
+
+        table_Manager = Table( 'Manager', metadata,
+            Column( 'duties', type= String, ),
+            Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ),
+        )
+    def test_threelevels(self):
+        class Employee( object):
+            def set( me, **kargs):
+                for k,v in kargs.iteritems(): setattr( me, k, v)
+                return me
+            def __str__(me): return str(me.__class__.__name__)+':'+str(me.name)
+            __repr__ = __str__
+        class Engineer( Employee): pass
+        class Manager( Engineer): pass
+        pu_Employee = polymorphic_union( {
+                    'Manager':  table_Employee.join( table_Engineer).join( table_Manager),
+                    'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]),
+                    'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'),
+                }, None, 'pu_employee', )
+        
+        mapper_Employee = mapper( Employee, table_Employee,
+                    polymorphic_identity= 'Employee',
+                    polymorphic_on= pu_Employee.c.atype,
+                    select_table= pu_Employee,
+                )
+
+        pu_Engineer = polymorphic_union( {
+                    'Manager':  table_Employee.join( table_Engineer).join( table_Manager),
+                    'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]),
+                }, None, 'pu_engineer', )
+        mapper_Engineer = mapper( Engineer, table_Engineer,
+                    inherit_condition= table_Engineer.c.id == table_Employee.c.id,
+                    inherits= mapper_Employee,
+                    polymorphic_identity= 'Engineer',
+                    polymorphic_on= pu_Engineer.c.atype,
+                    select_table= pu_Engineer,
+                )
+
+        mapper_Manager = mapper( Manager, table_Manager,
+                    inherit_condition= table_Manager.c.id == table_Engineer.c.id,
+                    inherits= mapper_Engineer,
+                    polymorphic_identity= 'Manager',
+                )
+
+        a = Employee().set( name= 'one')
+        b = Engineer().set( egn= 'two', machine= 'any')
+        c = Manager().set( name= 'head', machine= 'fast', duties= 'many')
+
+        session = create_session()
+        session.save(a)
+        session.save(b)
+        session.save(c)
+        session.flush()
+        assert set(session.query(Employee).select()) == set([a,b,c])
+        assert set(session.query( Engineer).select()) == set([b,c])
+        assert session.query( Manager).select() == [c]
         
 if __name__ == "__main__":    
     testbase.main()
index 443e93ab24378645c9984744aad2894f8ee5868b..e8076b9e07749cccbadd6c3a4c21cdd289c05520 100644 (file)
@@ -25,10 +25,8 @@ class Transition(object):
     def __repr__(self):
         return object.__repr__(self)+ " " + repr(self.inputs) + " " + repr(self.outputs)
         
-class M2MTest(testbase.AssertMixin):
-    def setUpAll(self):
-        global metadata
-        metadata = testbase.metadata
+class M2MTest(testbase.ORMTest):
+    def define_tables(self, metadata):
         global place
         place = Table('place', metadata,
             Column('place_id', Integer, Sequence('pid_seq', optional=True), primary_key=True),
@@ -67,22 +65,6 @@ class M2MTest(testbase.AssertMixin):
             Column('pl1_id', Integer, ForeignKey('place.place_id')),
             Column('pl2_id', Integer, ForeignKey('place.place_id')),
             )
-        metadata.create_all()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-        clear_mappers()
-        #testbase.db.tables.clear()
-        
-    def setUp(self):
-        clear_mappers()
-
-    def tearDown(self):
-        place_place.delete().execute()
-        place_input.delete().execute()
-        place_output.delete().execute()
-        transition.delete().execute()
-        place.delete().execute()
 
     def testcircular(self):
         """tests a many-to-many relationship from a table to itself."""
@@ -194,9 +176,8 @@ class M2MTest(testbase.AssertMixin):
         self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
         self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
 
-class M2MTest2(testbase.AssertMixin):        
-    def setUpAll(self):
-        metadata = testbase.metadata
+class M2MTest2(testbase.ORMTest):
+    def define_tables(self, metadata):
         global studentTbl
         studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True))
         global courseTbl
@@ -205,19 +186,6 @@ class M2MTest2(testbase.AssertMixin):
         enrolTbl = Table('enrol', metadata,
             Column('student_id', String(20), ForeignKey('student.name'),primary_key=True),
             Column('course_id', String(20), ForeignKey('course.name'), primary_key=True))
-        metadata.create_all()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-        clear_mappers()
-        
-    def setUp(self):
-        clear_mappers()
-
-    def tearDown(self):
-        enrolTbl.delete().execute()
-        courseTbl.delete().execute()
-        studentTbl.delete().execute()
 
     def testcircular(self): 
         class Student(object):
@@ -249,9 +217,8 @@ class M2MTest2(testbase.AssertMixin):
         del s.courses[1]
         self.assert_(len(s.courses) == 2)
         
-class M2MTest3(testbase.AssertMixin):    
-    def setUpAll(self):
-        metadata = testbase.metadata
+class M2MTest3(testbase.ORMTest):
+    def define_tables(self, metadata):
         global c, c2a1, c2a2, b, a
         c = Table('c', metadata, 
             Column('c1', Integer, primary_key = True),
@@ -278,15 +245,6 @@ class M2MTest3(testbase.AssertMixin):
             Column('a1', Integer, ForeignKey('a.a1')),
             Column('b2', Boolean)
         )
-        metadata.create_all()
-        
-    def tearDownAll(self):
-        b.drop()
-        c2a2.drop()
-        c2a1.drop()
-        a.drop()
-        c.drop()
-        clear_mappers()
         
     def testbasic(self):
         class C(object):pass
index e96b30dcbe0dde3dcd48d0764e97ab6901aa99d4..7d981509d1c2acb66c998eb52af0297cf2e5801f 100644 (file)
@@ -10,7 +10,8 @@ import sqlalchemy.pool as pool
 import re
 import sqlalchemy
 import optparse
-
+from sqlalchemy.schema import BoundMetaData
+from sqlalchemy.orm import clear_mappers
 
 db = None
 metadata = None
@@ -203,6 +204,24 @@ class AssertMixin(PersistTest):
         finally:
             self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count))
 
+class ORMTest(AssertMixin):
+    def setUpAll(self):
+        global metadata
+        metadata = BoundMetaData(db)
+        self.define_tables(metadata)
+        metadata.create_all()
+    def define_tables(self, metadata):
+        raise NotImplementedError()
+    def get_metadata(self):
+        return metadata
+    def tearDownAll(self):
+        metadata.drop_all()
+    def tearDown(self):
+        clear_mappers()
+        for t in metadata.table_iterator(reverse=True):
+            t.delete().execute().close()
+
+
 class EngineAssert(proxy.BaseProxyEngine):
     """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
     def __init__(self, engine):