]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- cleanup, converted unitofwork.py to standard fixtures
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Aug 2007 03:49:32 +0000 (03:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Aug 2007 03:49:32 +0000 (03:49 +0000)
test/orm/assorted_eager.py
test/orm/fixtures.py
test/orm/unitofwork.py
test/testlib/testing.py

index ce17e8dfd333da9949bf6e4a8e1c78cf2bfe471a..fb2707d4dce75925f230e127d1d064181f280945 100644 (file)
@@ -13,7 +13,10 @@ class EagerTest(AssertMixin):
         dbmeta = MetaData(testbase.db)
         
         # determine a literal value for "false" based on the dialect
-        false = Boolean().dialect_impl(testbase.db.dialect).bind_processor(testbase.db.dialect)(False)
+        false = False
+        bp = Boolean().dialect_impl(testbase.db.dialect).bind_processor(testbase.db.dialect)
+        if bp:
+            false = bp(false)
         
         owners = Table ( 'owners', dbmeta ,
                Column ( 'id', Integer, primary_key=True, nullable=False ),
index 89df774645b3554e036b68ac4f732c2f2b392885..d3d95a6b53ce6be7180de21ef88ce527d1a7f66f 100644 (file)
@@ -168,9 +168,8 @@ def install_fixture_data():
 
 class FixtureTest(ORMTest):
     def define_tables(self, meta):
-        # a slight dirty trick here. 
-        meta.tables = metadata.tables
-        metadata.connect(meta.bind)
+        pass
+FixtureTest.metadata = metadata
     
 class Fixtures(object):
     @property
index c7a5c055a03308c811bc51f5b9a01aa4c4607567..0a7df2a36ff75824422444246a192e14be4473f0 100644 (file)
@@ -2,37 +2,24 @@ import testbase
 import pickleable
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy.orm.mapper import global_extensions
-from sqlalchemy.orm import util as ormutil
 from testlib import *
 from testlib.tables import *
 from testlib import tables
 
 """tests unitofwork operations"""
 
-class UnitOfWorkTest(AssertMixin):
-    def setUpAll(self):
-        global Session, mapper
-        Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
-        mapper = Session.mapper
-    def tearDownAll(self):
-        global_extensions[:] = []
-    def tearDown(self):
-        Session.close_all()
-        clear_mappers()
+Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+mapper = Session.mapper
 
-class HistoryTest(UnitOfWorkTest):
-    def setUpAll(self):
-        tables.metadata.bind = testbase.db
-        UnitOfWorkTest.setUpAll(self)
-        users.create()
-        addresses.create()
-    def tearDownAll(self):
-        addresses.drop()
-        users.drop()
-        UnitOfWorkTest.tearDownAll(self)
-        
-    def testbackref(self):
+class UnitOfWorkTest(object):
+    pass
+    
+class HistoryTest(ORMTest):
+    metadata = tables.metadata
+    def define_tables(self, metadata):
+        pass
+        
+    def test_backref(self):
         s = Session()
         class User(object):pass
         class Address(object):pass
@@ -53,25 +40,16 @@ class HistoryTest(UnitOfWorkTest):
         u = s.query(m).select()[0]
         print u.addresses[0].user
             
-class VersioningTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        Session.close()
+class VersioningTest(ORMTest):
+    def define_tables(self, metadata):
         global version_table
-        version_table = Table('version_test', MetaData(testbase.db),
+        version_table = Table('version_test', metadata,
         Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
         Column('version_id', Integer, nullable=False),
         Column('value', String(40), nullable=False)
         )
-        version_table.create()
-    def tearDownAll(self):
-        UnitOfWorkTest.tearDownAll(self)
-        version_table.drop()
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        version_table.delete().execute()
     
-    def testbasic(self):
+    def test_basic(self):
         s = Session(scope=None)
         class Foo(object):pass
         mapper(Foo, version_table, version_id_col=version_table.c.version_id)
@@ -118,7 +96,7 @@ class VersioningTest(UnitOfWorkTest):
         if testbase.db.dialect.supports_sane_rowcount():
             assert success
 
-    def testversioncheck(self):
+    def test_versioncheck(self):
         """test that query.with_lockmode performs a 'version check' on an already loaded instance"""
         s1 = Session(scope=None)
         class Foo(object):pass
@@ -144,7 +122,7 @@ class VersioningTest(UnitOfWorkTest):
         s1.close()
         s1.query(Foo).with_lockmode('read').get(f1s1.id)
         
-    def testnoversioncheck(self):
+    def test_noversioncheck(self):
         """test that query.with_lockmode works OK when the mapper has no version id col"""
         s1 = Session()
         class Foo(object):pass
@@ -157,30 +135,17 @@ class VersioningTest(UnitOfWorkTest):
         assert f1s2.id == f1s1.id
         assert f1s2.value == f1s1.value
         
-class UnicodeTest(UnitOfWorkTest):
-    def setUpAll(self):
-        global metadata, uni_table, uni_table2
-        metadata = MetaData(testbase.db)
+class UnicodeTest(ORMTest):
+    def define_tables(self, metadata):
+        global uni_table, uni_table2
         uni_table = Table('uni_test', metadata,
             Column('id',  Integer, Sequence("uni_test_id_seq", optional=True), primary_key=True),
             Column('txt', Unicode(50), unique=True))
         uni_table2 = Table('uni2', metadata,
             Column('id',  Integer, Sequence("uni2_test_id_seq", optional=True), primary_key=True),
             Column('txt', Unicode(50), ForeignKey(uni_table.c.txt)))
-        metadata.create_all()
-        UnitOfWorkTest.setUpAll(self)
-
-    def tearDownAll(self):
-        UnitOfWorkTest.tearDownAll(self)
-        metadata.drop_all()
 
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        clear_mappers()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
-            
-    def testbasic(self):
+    def test_basic(self):
         class Test(object):
             def __init__(self, id, txt):
                 self.id = id
@@ -192,7 +157,8 @@ class UnicodeTest(UnitOfWorkTest):
         self.assert_(t1.txt == txt)
         Session.commit()
         self.assert_(t1.txt == txt)
-    def testrelation(self):
+        
+    def test_relation(self):
         class Test(object):
             def __init__(self, txt):
                 self.txt = txt
@@ -212,21 +178,15 @@ class UnicodeTest(UnitOfWorkTest):
         t1 = Session.query(Test).get_by(id=t1.id)
         assert len(t1.t2s) == 2
 
-class MutableTypesTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        global metadata, table
-        metadata = MetaData(testbase.db)
+class MutableTypesTest(ORMTest):
+    def define_tables(self, metadata):
+        global table
         table = Table('mutabletest', metadata,
             Column('id', Integer, Sequence('mutableidseq', optional=True), primary_key=True),
             Column('data', PickleType),
             Column('value', Unicode(30)))
-        table.create()
-    def tearDownAll(self):
-        table.drop()
-        UnitOfWorkTest.tearDownAll(self)
 
-    def testbasic(self):
+    def test_basic(self):
         """test that types marked as MutableType get changes detected on them"""
         class Foo(object):pass
         mapper(Foo, table)
@@ -244,7 +204,7 @@ class MutableTypesTest(UnitOfWorkTest):
         assert f3.data != f1.data
         assert f3.data == pickleable.Bar(4, 19)
 
-    def testmutablechanges(self):
+    def test_mutablechanges(self):
         """test that mutable changes are detected or not detected correctly"""
         class Foo(object):pass
         mapper(Foo, table)
@@ -272,7 +232,7 @@ class MutableTypesTest(UnitOfWorkTest):
         ])
         
         
-    def testnocomparison(self):
+    def test_nocomparison(self):
         """test that types marked as MutableType get changes detected on them when the type has no __eq__ method"""
         class Foo(object):pass
         mapper(Foo, table)
@@ -306,7 +266,7 @@ class MutableTypesTest(UnitOfWorkTest):
             Session.commit()
         self.assert_sql_count(testbase.db, go, 0)
         
-    def testunicode(self):
+    def test_unicode(self):
         """test that two equivalent unicode values dont get flagged as changed.
         
         apparently two equal unicode objects dont compare via "is" in all cases, so this
@@ -324,11 +284,10 @@ class MutableTypesTest(UnitOfWorkTest):
         self.assert_sql_count(testbase.db, go, 0)
         
         
-class PKTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        global table, table2, table3, metadata
-        metadata = MetaData(testbase.db)
+class PKTest(ORMTest):
+    def define_tables(self, metadata):
+        global table, table2, table3
+
         table = Table(
             'multipk', metadata, 
             Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True),
@@ -348,16 +307,11 @@ class PKTest(UnitOfWorkTest):
             Column('date_assigned', Date, key='assigned', primary_key=True),
             Column('data', String(30), )
             )
-        metadata.create_all()
 
-    def tearDownAll(self):
-        metadata.drop_all()
-        UnitOfWorkTest.tearDownAll(self)
-        
-    # not support on sqlite since sqlite's auto-pk generation only works with
+    # not supported on sqlite since sqlite's auto-pk generation only works with
     # single column primary keys    
     @testing.unsupported('sqlite')
-    def testprimarykey(self):
+    def test_primarykey(self):
         class Entry(object):
             pass
         Entry.mapper = mapper(Entry, table)
@@ -371,7 +325,7 @@ class PKTest(UnitOfWorkTest):
         self.assert_(e is not e2 and e._instance_key == e2._instance_key)
         
     # this one works with sqlite since we are manually setting up pk values
-    def testmanualpk(self):
+    def test_manualpk(self):
         class Entry(object):
             pass
         Entry.mapper = mapper(Entry, table2)
@@ -381,7 +335,7 @@ class PKTest(UnitOfWorkTest):
         e.data = 'im the data'
         Session.commit()
         
-    def testkeypks(self):
+    def test_keypks(self):
         import datetime
         class Entity(object):
             pass
@@ -393,7 +347,7 @@ class PKTest(UnitOfWorkTest):
         e.data = 'some more data'
         Session.commit()
 
-    def testpksimmutable(self):
+    def test_pksimmutable(self):
         class Entry(object):
             pass
         mapper(Entry, table)
@@ -411,13 +365,13 @@ class PKTest(UnitOfWorkTest):
             assert str(fe) == "Can't change the identity of instance Entry@%s in session (existing identity: (%s, (5, 5), None); new identity: (%s, (5, 6), None))" % (hex(id(e)), repr(e.__class__), repr(e.__class__))
             
             
-class ForeignPKTest(UnitOfWorkTest):
+class ForeignPKTest(ORMTest):
     """tests mapper detection of the relationship direction when parent/child tables are joined on their
     primary keys"""
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        global metadata, people, peoplesites
-        metadata = MetaData(testbase.db)
+    
+    def define_tables(self, metadata):
+        global people, peoplesites
+
         people = Table("people", metadata,
            Column('person', String(10), primary_key=True),
            Column('firstname', String(10)),
@@ -429,11 +383,8 @@ class ForeignPKTest(UnitOfWorkTest):
         primary_key=True),
             Column('site', String(10)),
         )
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-        UnitOfWorkTest.tearDownAll(self)
-    def testbasic(self):
+        
+    def test_basic(self):
         class PersonSite(object):pass
         class Person(object):pass
         m1 = mapper(PersonSite, peoplesites)
@@ -454,24 +405,13 @@ class ForeignPKTest(UnitOfWorkTest):
         Session.commit()
         assert people.count(people.c.person=='im the key').scalar() == peoplesites.count(peoplesites.c.person=='im the key').scalar() == 1
 
-class ClauseAttributesTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        global metadata, users_table
-        metadata = MetaData(testbase.db)
+class ClauseAttributesTest(ORMTest):
+    def define_tables(self, metadata):
+        global users_table
         users_table = Table('users', metadata,
             Column('id', Integer, Sequence('users_id_seq', optional=True), primary_key=True),
             Column('name', String(30)),
             Column('counter', Integer, default=1))
-        metadata.create_all()
-    
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        users_table.delete().execute()
-        
-    def tearDownAll(self):
-        metadata.drop_all()
-        UnitOfWorkTest.tearDownAll(self)
     
     def test_update(self):
         class User(object):
@@ -523,11 +463,10 @@ class ClauseAttributesTest(UnitOfWorkTest):
         
 
         
-class PassiveDeletesTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        global metadata, mytable,myothertable
-        metadata = MetaData(testbase.db)
+class PassiveDeletesTest(ORMTest):
+    def define_tables(self, metadata):
+        global mytable,myothertable
+
         mytable = Table('mytable', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)),
@@ -542,13 +481,8 @@ class PassiveDeletesTest(UnitOfWorkTest):
             test_needs_fk=True,
             )
 
-        metadata.create_all()
-    def tearDownAll(self):
-        metadata.drop_all()
-        UnitOfWorkTest.tearDownAll(self)
-
     @testing.unsupported('sqlite')
-    def testbasic(self):
+    def test_basic(self):
         class MyClass(object):
             pass
         class MyOtherClass(object):
@@ -576,14 +510,12 @@ class PassiveDeletesTest(UnitOfWorkTest):
         assert mytable.count().scalar() == 0
         assert myothertable.count().scalar() == 0
         
-
-        
-class DefaultTest(UnitOfWorkTest):
+class DefaultTest(ORMTest):
     """tests that when saving objects whose table contains DefaultGenerators, either python-side, preexec or database-side,
     the newly saved instances receive all the default values either through a post-fetch or getting the pre-exec'ed 
     defaults back from the engine."""
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
+    
+    def define_tables(self, metadata):
         db = testbase.db
         use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
 
@@ -597,19 +529,15 @@ class DefaultTest(UnitOfWorkTest):
             self.althohoval = 15
             
         global default_table
-        metadata = MetaData(db)
         default_table = Table('default_test', metadata,
         Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
         Column('hoho', hohotype, PassiveDefault(str(self.hohoval))),
         Column('counter', Integer, PassiveDefault("7")),
         Column('foober', String(30), default="im foober", onupdate="im the update")
         )
-        default_table.create()
-    def tearDownAll(self):
-        default_table.drop()
-        UnitOfWorkTest.tearDownAll(self)
+
         
-    def testinsert(self):
+    def test_insert(self):
         class Hoho(object):pass
         mapper(Hoho, default_table)
         
@@ -652,7 +580,7 @@ class DefaultTest(UnitOfWorkTest):
         self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
         self.assert_(h5.foober=='im the new foober')
     
-    def testinsertnopostfetch(self):
+    def test_insert_nopostfetch(self):
         # populates the PassiveDefaults explicitly so there is no "post-update"
         class Hoho(object):pass
         mapper(Hoho, default_table)
@@ -666,7 +594,7 @@ class DefaultTest(UnitOfWorkTest):
             self.assert_(h1.foober=="im foober")
         self.assert_sql_count(testbase.db, go, 0)
         
-    def testupdate(self):
+    def test_update(self):
         class Hoho(object):pass
         mapper(Hoho, default_table)
         h1 = Hoho()
@@ -676,20 +604,13 @@ class DefaultTest(UnitOfWorkTest):
         Session.commit()
         self.assert_(h1.foober == 'im the update')
 
-class OneToManyTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        tables.create()
-
-    def tearDownAll(self):
-        tables.drop()
-        UnitOfWorkTest.tearDownAll(self)
-
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        tables.delete()
+class OneToManyTest(ORMTest):
+    metadata = tables.metadata
+    
+    def define_tables(self, metadata):
+        pass
 
-    def testonetomany_1(self):
+    def test_onetomany_1(self):
         """test basic save of one to many."""
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True)
@@ -723,7 +644,7 @@ class OneToManyTest(UnitOfWorkTest):
         self.assertEqual(addresstable[0].values(), [addressid, userid, 'somethingnew@foo.com'])
         self.assert_(u.user_id == userid and a2.address_id == addressid)
 
-    def testonetomany_2(self):
+    def test_onetomany_2(self):
         """digs deeper into modifying the child items of an object to insure the correct
         updates take place"""
         m = mapper(User, users, properties = dict(
@@ -768,7 +689,7 @@ class OneToManyTest(UnitOfWorkTest):
                     ),
                 ])
 
-    def testchildmove(self):
+    def test_childmove(self):
         """tests moving a child from one parent to the other, then deleting the first parent, properly
         updates the child with the new parent.  this tests the 'trackparent' option in the attributes module."""
         m = mapper(User, users, properties = dict(
@@ -790,7 +711,7 @@ class OneToManyTest(UnitOfWorkTest):
         u2 = Session.get(User, u2.user_id)
         assert len(u2.addresses) == 1
 
-    def testchildmove_2(self):
+    def test_childmove_2(self):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True)
         ))
@@ -809,7 +730,7 @@ class OneToManyTest(UnitOfWorkTest):
         u2 = Session.get(User, u2.user_id)
         assert len(u2.addresses) == 1
 
-    def testo2mdeleteparent(self):
+    def test_o2m_delete_parent(self):
         m = mapper(User, users, properties = dict(
             address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False)
         ))
@@ -823,7 +744,7 @@ class OneToManyTest(UnitOfWorkTest):
         Session.commit()
         self.assert_(a.address_id is not None and a.user_id is None and not Session.identity_map.has_key(u._instance_key) and Session.identity_map.has_key(a._instance_key))
 
-    def testonetoone(self):
+    def test_onetoone(self):
         m = mapper(User, users, properties = dict(
             address = relation(mapper(Address, addresses), lazy = True, uselist = False)
         ))
@@ -837,7 +758,7 @@ class OneToManyTest(UnitOfWorkTest):
         u.address.email_address = 'imnew@foo.com'
         Session.commit()
 
-    def testbidirectional(self):
+    def test_bidirectional(self):
         m1 = mapper(User, users)
 
         m2 = mapper(Address, addresses, properties = dict(
@@ -866,7 +787,7 @@ class OneToManyTest(UnitOfWorkTest):
         Session.delete(u)
         Session.commit()
 
-    def testdoublerelation(self):
+    def test_doublerelation(self):
         m2 = mapper(Address, addresses)
         m = mapper(User, users, properties={
             'boston_addresses' : relation(m2, primaryjoin=
@@ -886,16 +807,13 @@ class OneToManyTest(UnitOfWorkTest):
         u.newyork_addresses.append(b)
         Session.commit()
 
-class SaveTest(UnitOfWorkTest):
-
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        tables.create()
-    def tearDownAll(self):
-        tables.drop()
-        UnitOfWorkTest.tearDownAll(self)
+class SaveTest(ORMTest):
+    metadata = tables.metadata
+    def define_tables(self, metadata):
+        pass
         
     def setUp(self):
+        super(SaveTest, self).setUp()
         keywords.insert().execute(
             dict(name='blue'),
             dict(name='red'),
@@ -906,11 +824,7 @@ class SaveTest(UnitOfWorkTest):
             dict(name='square')
         )
 
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        tables.delete()
-
-    def testbasic(self):
+    def test_basic(self):
         # save two users
         u = User()
         u.user_name = 'savetester'
@@ -948,9 +862,10 @@ class SaveTest(UnitOfWorkTest):
         self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname')
         self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2')
 
-    def testlazyattrcommit(self):
+    def test_lazyattr_commit(self):
         """tests that when a lazy-loaded list is unloaded, and a commit occurs, that the
         'passive' call on that list does not blow away its value"""
+        
         m1 = mapper(User, users, properties = {
             'addresses': relation(mapper(Address, addresses))
         })
@@ -968,7 +883,7 @@ class SaveTest(UnitOfWorkTest):
         Session.commit()
         self.assert_(len(u1.addresses) == 4)
         
-    def testinherits(self):
+    def test_inherits(self):
         m1 = mapper(User, users)
         
         class AddressUser(User):
@@ -987,7 +902,7 @@ class SaveTest(UnitOfWorkTest):
         l = Session.query(AddressUser).selectone()
         self.assert_(l.user_id == au.user_id and l.address_id == au.address_id)
     
-    def testdeferred(self):
+    def test_deferred(self):
         """test that a deferred load within a commit() doesnt screw up the connection"""
         mapper(User, users, properties={
             'user_name':deferred(users.c.user_name)
@@ -1011,7 +926,7 @@ class SaveTest(UnitOfWorkTest):
             Session.commit()
         self.assert_sql_count(testbase.db, go, 0)
 
-    def testmultitable(self):
+    def test_multitable(self):
         """tests a save of an object where each instance spans two tables. also tests
         redefinition of the keynames for the column properties."""
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
@@ -1052,7 +967,7 @@ class SaveTest(UnitOfWorkTest):
         u = Session.get(User, id)
         assert u.user_name == 'imnew'
     
-    def testhistoryget(self):
+    def test_history_get(self):
         """tests that the history properly lazy-fetches data when it wasnt otherwise loaded"""
         mapper(User, users, properties={
             'addresses':relation(Address, cascade="all, delete-orphan")
@@ -1072,7 +987,9 @@ class SaveTest(UnitOfWorkTest):
         
             
     
-    def testbatchmode(self):
+    def test_batchmode(self):
+        """test the 'batch=False' flag on mapper()"""
+        
         class TestExtension(MapperExtension):
             def before_insert(self, mapper, connection, instance):
                 self.current_instance = instance
@@ -1099,20 +1016,13 @@ class SaveTest(UnitOfWorkTest):
             assert True
         
     
-class ManyToOneTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        tables.create()
-
-    def tearDownAll(self):
-        tables.drop()
-        UnitOfWorkTest.tearDownAll(self)
-
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        tables.delete()
+class ManyToOneTest(ORMTest):
+    metadata = tables.metadata
     
-    def testm2oonetoone(self):
+    def define_tables(self, metadata):
+        pass
+    
+    def test_m2o_onetoone(self):
         # TODO: put assertion in here !!!
         m = mapper(Address, addresses, properties = dict(
             user = relation(mapper(User, users), lazy = True, uselist = False)
@@ -1170,7 +1080,7 @@ class ManyToOneTest(UnitOfWorkTest):
         assert l.fetchone().values() == [a.user.user_id, 'asdf8d', a.address_id, a.user_id, 'theater@foo.com']
 
 
-    def testmanytoone_1(self):
+    def test_manytoone_1(self):
         m = mapper(Address, addresses, properties = dict(
             user = relation(mapper(User, users), lazy = True)
         ))
@@ -1193,7 +1103,7 @@ class ManyToOneTest(UnitOfWorkTest):
         u1 = Session.query(User).get(u1.user_id)
         assert a1.user is None
 
-    def testmanytoone_2(self):
+    def test_manytoone_2(self):
         m = mapper(Address, addresses, properties = dict(
             user = relation(mapper(User, users), lazy = True)
         ))
@@ -1221,7 +1131,7 @@ class ManyToOneTest(UnitOfWorkTest):
         assert a1.user is None
         assert a2.user is u1
 
-    def testmanytoone_3(self):
+    def test_manytoone_3(self):
         m = mapper(Address, addresses, properties = dict(
             user = relation(mapper(User, users), lazy = True)
         ))
@@ -1248,20 +1158,13 @@ class ManyToOneTest(UnitOfWorkTest):
         u2 = Session.query(User).get(u2.user_id)
         assert a1.user is u2
         
-class ManyToManyTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        tables.create()
-
-    def tearDownAll(self):
-        tables.drop()
-        UnitOfWorkTest.tearDownAll(self)
-
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        tables.delete()
-
-    def testmanytomany(self):
+class ManyToManyTest(ORMTest):
+    metadata = tables.metadata
+    
+    def define_tables(self, metadata):
+        pass
+        
+    def test_manytomany(self):
         items = orderitems
 
         keywordmapper = mapper(Keyword, keywords)
@@ -1351,7 +1254,7 @@ class ManyToManyTest(UnitOfWorkTest):
         Session.delete(objects[3])
         Session.commit()
 
-    def testmanytomanyremove(self):
+    def test_manytomany_remove(self):
         """tests that setting a list-based attribute to '[]' properly affects the history and allows
         the many-to-many rows to be deleted"""
         keywordmapper = mapper(Keyword, keywords)
@@ -1372,7 +1275,7 @@ class ManyToManyTest(UnitOfWorkTest):
         Session.commit()
         assert itemkeywords.count().scalar() == 0
 
-    def testscalar(self):
+    def test_scalar(self):
         """test that dependency.py doesnt try to delete an m2m relation referencing None."""
         
         mapper(Keyword, keywords)
@@ -1388,7 +1291,7 @@ class ManyToManyTest(UnitOfWorkTest):
         
         
 
-    def testmanytomanyupdate(self):
+    def test_manytomany_update(self):
         """tests some history operations on a many to many"""
         class Keyword(object):
             def __init__(self, name):
@@ -1422,7 +1325,7 @@ class ManyToManyTest(UnitOfWorkTest):
         print item.keywords
         assert item.keywords == [k1, k2]
         
-    def testassociation(self):
+    def test_association(self):
         """basic test of an association object"""
         class IKAssociation(object):
             def __repr__(self):
@@ -1486,30 +1389,22 @@ class ManyToManyTest(UnitOfWorkTest):
         l = Item.query.filter(items.c.item_name.in_(*[e['item_name'] for e in data[1:]])).order_by(items.c.item_name).all()
         self.assert_result(l, *data)
     
-class SaveTest2(UnitOfWorkTest):
-
-    def setUp(self):
-        Session.close()
-        clear_mappers()
-        global meta, users, addresses
-        meta = MetaData(testbase.db)
-        users = Table('users', meta,
+class SaveTest2(ORMTest):
+    
+    def define_tables(self, metadata):
+        global users, addresses
+        users = Table('users', metadata,
             Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
             Column('user_name', String(20)),
         )
 
-        addresses = Table('email_addresses', meta,
+        addresses = Table('email_addresses', metadata,
             Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
             Column('rel_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
         )
-        meta.create_all()
-
-    def tearDown(self):
-        UnitOfWorkTest.tearDown(self)
-        meta.drop_all()
     
-    def testbackwardsnonmatch(self):
+    def test_m2o_nonmatch(self):
         m = mapper(Address, addresses, properties = dict(
             user = relation(mapper(User, users), lazy = True, uselist = False)
         ))
@@ -1564,39 +1459,27 @@ class SaveTest2(UnitOfWorkTest):
         )
 
 
-class SaveTest3(UnitOfWorkTest):
-    def setUpAll(self):
-        global st3_metadata, t1, t2, t3
-
-        UnitOfWorkTest.setUpAll(self)
+class SaveTest3(ORMTest):
+    def define_tables(self, metadata):
+        global t1, t2, t3
 
-        st3_metadata = MetaData(testbase.db)
-        t1 = Table('items', st3_metadata,
+        t1 = Table('items', metadata,
             Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
             Column('item_name', VARCHAR(50)),
         )
 
-        t3 = Table('keywords', st3_metadata,
+        t3 = Table('keywords', metadata,
             Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True),
             Column('name', VARCHAR(50)),
 
         )
-        t2 = Table('assoc', st3_metadata,
+        t2 = Table('assoc', metadata,
             Column('item_id', INT, ForeignKey("items")),
             Column('keyword_id', INT, ForeignKey("keywords")),
             Column('foo', Boolean, default=True)
         )
-        st3_metadata.create_all()
-    def tearDownAll(self):
-        st3_metadata.drop_all()
-        UnitOfWorkTest.tearDownAll(self)
-
-    def setUp(self):
-        pass
-    def tearDown(self):
-        pass
 
-    def testmanytomanyxtracolremove(self):
+    def test_manytomany_xtracol_delete(self):
         """test that a many-to-many on a table that has an extra column can properly delete rows from the table
         without referencing the extra column"""
         mapper(Keyword, t3)
index ba3670f4dcce238df390ca86da8e2a0e7cfb8b41..9afb2158792c17e6634d9502da4d76f270b72270 100644 (file)
@@ -293,14 +293,19 @@ _otest_metadata = None
 class ORMTest(AssertMixin):
     keep_mappers = False
     keep_data = False
-
+    metadata = None
+    
     def setUpAll(self):
         global MetaData, _otest_metadata
 
         if MetaData is None:
             from sqlalchemy import MetaData
         
-        _otest_metadata = MetaData(config.db)
+        if self.metadata is None:
+            _otest_metadata = MetaData(config.db)
+        else:
+            _otest_metadata = self.metadata
+            _otest_metadata.bind = config.db
         self.define_tables(_otest_metadata)
         _otest_metadata.create_all()
         self.insert_data()