]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Refactored test support code, moved most into 'testlib/'
authorJason Kirtland <jek@discorporate.us>
Mon, 23 Jul 2007 01:50:54 +0000 (01:50 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 23 Jul 2007 01:50:54 +0000 (01:50 +0000)
Cleaned up imports, all tests should be runnable stand-alone or suite now
Updated most of the perf tests
Removed dead test suites
Added new profiling decorator
Added new profilable perf test, 'ormsession' to try to capture a typical workload

82 files changed:
test/base/dependency.py
test/base/utils.py
test/dialect/mysql.py
test/dialect/postgres.py
test/engine/bind.py
test/engine/execute.py
test/engine/metadata.py
test/engine/parseconnect.py
test/engine/pool.py
test/engine/reconnect.py
test/engine/reflection.py
test/engine/transaction.py
test/ext/activemapper.py
test/ext/assignmapper.py
test/ext/associationproxy.py
test/ext/orderinglist.py
test/orm/association.py
test/orm/assorted_eager.py
test/orm/attributes.py
test/orm/cascade.py
test/orm/collection.py
test/orm/compile.py
test/orm/cycles.py
test/orm/eager_relations.py
test/orm/entity.py
test/orm/fixtures.py
test/orm/generative.py
test/orm/inheritance/abc_inheritance.py
test/orm/inheritance/basic.py
test/orm/inheritance/concrete.py
test/orm/inheritance/magazine.py
test/orm/inheritance/manytomany.py
test/orm/inheritance/poly_linked_list.py
test/orm/inheritance/polymorph.py
test/orm/inheritance/polymorph2.py
test/orm/inheritance/productspec.py
test/orm/inheritance/single.py
test/orm/lazy_relations.py
test/orm/lazytest1.py
test/orm/manytomany.py
test/orm/mapper.py
test/orm/memusage.py
test/orm/merge.py
test/orm/onetoone.py
test/orm/query.py
test/orm/relationships.py
test/orm/session.py
test/orm/sessioncontext.py
test/orm/unitofwork.py
test/perf/cascade_speed.py
test/perf/masscreate.py
test/perf/masscreate2.py
test/perf/masseagerload.py
test/perf/massload.py
test/perf/masssave.py
test/perf/ormsession.py [new file with mode: 0644]
test/perf/poolload.py
test/perf/threaded_compile.py
test/perf/wsgi.py
test/sql/alltests.py
test/sql/case_statement.py
test/sql/constraints.py
test/sql/defaults.py
test/sql/generative.py
test/sql/labels.py
test/sql/query.py
test/sql/quote.py
test/sql/rowcount.py
test/sql/select.py
test/sql/selectable.py
test/sql/testtypes.py
test/sql/unicode.py
test/testbase.py
test/testlib/__init__.py [new file with mode: 0644]
test/testlib/config.py [new file with mode: 0644]
test/testlib/coverage.py [moved from test/coverage.py with 99% similarity]
test/testlib/profiling.py [new file with mode: 0644]
test/testlib/schema.py [new file with mode: 0644]
test/testlib/tables.py [moved from test/tables.py with 96% similarity]
test/testlib/testing.py [new file with mode: 0644]
test/zblog/tables.py
test/zblog/tests.py

index c5e54fc9fa5630a02a90b1f99c8ec0d552707d95..ddadd1b3163d402e28f53d6f66d30ad637384dc8 100644 (file)
@@ -1,7 +1,8 @@
-from testbase import PersistTest
+import testbase
 import sqlalchemy.topological as topological
-import unittest, sys, os
 from sqlalchemy import util
+from testlib import *
+
 
 # TODO:  need assertion conditions in this suite
 
@@ -190,4 +191,4 @@ class DependencySortTest(PersistTest):
             
             
 if __name__ == "__main__":
-    unittest.main()
+    testbase.main()
index ccf6b4419738557b9a61b864f3fd53bdc93b26b7..96d3c96e432723a595102c34f92c7e35e46c81b1 100644 (file)
@@ -1,7 +1,9 @@
 import testbase
 from sqlalchemy import util
+from testlib import *
 
-class OrderedDictTest(testbase.PersistTest):
+
+class OrderedDictTest(PersistTest):
     def test_odict(self):
         o = util.OrderedDict()
         o['a'] = 1
index 086baf9eb0d79426fbe4acd71f4c98faa2331331..dbba78893d94c8dd8c33e84a65816b89241032ff 100644 (file)
@@ -1,16 +1,13 @@
-from testbase import PersistTest, AssertMixin
 import testbase
 from sqlalchemy import *
 from sqlalchemy.databases import mysql
-from testbase import Table, Column
-import sys, StringIO
+from testlib import *
 
-db = testbase.db
 
 class TypesTest(AssertMixin):
     "Test MySQL column types"
 
-    @testbase.supported('mysql')
+    @testing.supported('mysql')
     def test_numeric(self):
         "Exercise type specification and options for numeric types."
         
@@ -105,13 +102,13 @@ class TypesTest(AssertMixin):
              'SMALLINT(4) UNSIGNED ZEROFILL'),
            ]
 
-        table_args = ['test_mysql_numeric', MetaData(db)]
+        table_args = ['test_mysql_numeric', MetaData(testbase.db)]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         numeric_table = Table(*table_args)
-        gen = db.dialect.schemagenerator(db, None, None)
+        gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
         
         for col in numeric_table.c:
             index = int(col.name[1:])
@@ -125,7 +122,7 @@ class TypesTest(AssertMixin):
             raise
         numeric_table.drop()
     
-    @testbase.supported('mysql')
+    @testing.supported('mysql')
     def test_charset(self):
         """Exercise CHARACTER SET and COLLATE-related options on string-type
         columns."""
@@ -189,13 +186,13 @@ class TypesTest(AssertMixin):
              '''ENUM('foo','bar') UNICODE''')
            ]
 
-        table_args = ['test_mysql_charset', MetaData(db)]
+        table_args = ['test_mysql_charset', MetaData(testbase.db)]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw)))
 
         charset_table = Table(*table_args)
-        gen = db.dialect.schemagenerator(db, None, None)
+        gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
         
         for col in charset_table.c:
             index = int(col.name[1:])
@@ -209,11 +206,12 @@ class TypesTest(AssertMixin):
             raise
         charset_table.drop()
 
-    @testbase.supported('mysql')
+    @testing.supported('mysql')
     def test_enum(self):
         "Exercise the ENUM type"
-
-        enum_table = Table('mysql_enum', MetaData(db),
+        
+        db = testbase.db
+        enum_table = Table('mysql_enum', MetaData(testbase.db),
             Column('e1', mysql.MSEnum('"a"', "'b'")),
             Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False),
             Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)),
@@ -252,8 +250,8 @@ class TypesTest(AssertMixin):
         # This is known to fail with MySQLDB 1.2.2 beta versions
         # which return these as sets.Set(['a']), sets.Set(['b'])
         # (even on Pythons with __builtin__.set)
-        if db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
-           db.dialect.dbapi.version_info >= (1, 2, 2):
+        if testbase.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
+           testbase.db.dialect.dbapi.version_info >= (1, 2, 2):
             # these mysqldb seem to always uses 'sets', even on later pythons
             import sets 
             def convert(value):
@@ -272,10 +270,10 @@ class TypesTest(AssertMixin):
         self.assertEqual(res, expected)
         enum_table.drop()
 
-    @testbase.supported('mysql')
+    @testing.supported('mysql')
     def test_type_reflection(self):
         # FIXME: older versions need their own test
-        if db.dialect.get_version_info(db) < (5, 0):
+        if testbase.db.dialect.get_version_info(testbase.db) < (5, 0):
             return
 
         # (ask_for, roundtripped_as_if_different)
@@ -305,12 +303,12 @@ class TypesTest(AssertMixin):
 
         columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
 
-        m = MetaData(db)
+        m = MetaData(testbase.db)
         t_table = Table('mysql_types', m, *columns)
         m.drop_all()
         m.create_all()
         
-        m2 = MetaData(db)
+        m2 = MetaData(testbase.db)
         rt = Table('mysql_types', m2, autoload=True)
 
         #print
index 8b939eea0e760360907c86bf7ee6293f1e19c460..550966d0a40b978de7dc960eb64203c46fb83fb3 100644 (file)
@@ -1,61 +1,60 @@
-from testbase import AssertMixin
 import testbase
+import datetime
 from sqlalchemy import *
 from sqlalchemy.databases import postgres
-import datetime
+from testlib import *
 
-db = testbase.db
 
 class DomainReflectionTest(AssertMixin):
     "Test PostgreSQL domains"
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def setUpAll(self):
-        con = db.connect()
+        con = testbase.db.connect()
         con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
         con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
         con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
         con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
         con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def tearDownAll(self):
-        con = db.connect()
+        con = testbase.db.connect()
         con.execute('DROP TABLE testtable')
         con.execute('DROP TABLE alt_schema.testtable')
         con.execute('DROP TABLE crosschema')
         con.execute('DROP DOMAIN testdomain')
         con.execute('DROP DOMAIN alt_schema.testdomain')
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_table_is_reflected(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         table = Table('testtable', metadata, autoload=True)
         self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
         self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
         
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_domain_is_reflected(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         table = Table('testtable', metadata, autoload=True)
         self.assertEquals(str(table.columns.answer.default.arg), '42', "Reflected default value didn't equal expected value")
         self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.")
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_table_is_reflected_alt_schema(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         table = Table('testtable', metadata, autoload=True, schema='alt_schema')
         self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
         self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_schema_domain_is_reflected(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         table = Table('testtable', metadata, autoload=True, schema='alt_schema')
         self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value")
         self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_crosschema_domain_is_reflected(self):
         metadata = MetaData(db)
         table = Table('crosschema', metadata, autoload=True)
@@ -63,7 +62,7 @@ class DomainReflectionTest(AssertMixin):
         self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
 
 class MiscTest(AssertMixin):
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_date_reflection(self):
         m1 = MetaData(testbase.db)
         t1 = Table('pgdate', m1, 
@@ -79,7 +78,7 @@ class MiscTest(AssertMixin):
         finally:
             m1.drop_all()
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_pg_weirdchar_reflection(self):
         meta1 = MetaData(testbase.db)
         subject = Table("subject", meta1,
@@ -100,7 +99,7 @@ class MiscTest(AssertMixin):
         finally:
             meta1.drop_all()
         
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_checksfor_sequence(self):
         meta1 = MetaData(testbase.db)
         t = Table('mytable', meta1, 
@@ -111,7 +110,7 @@ class MiscTest(AssertMixin):
         finally:
             t.drop(checkfirst=True)
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_schema_reflection(self):
         """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
 
@@ -142,7 +141,7 @@ class MiscTest(AssertMixin):
         finally:
             meta1.drop_all()
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_schema_reflection_2(self):
         meta1 = MetaData(testbase.db)
         subject = Table("subject", meta1,
@@ -163,7 +162,7 @@ class MiscTest(AssertMixin):
         finally:
             meta1.drop_all()
             
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_schema_reflection_3(self):
         meta1 = MetaData(testbase.db)
         subject = Table("subject", meta1,
@@ -186,7 +185,7 @@ class MiscTest(AssertMixin):
         finally:
             meta1.drop_all()
         
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_preexecute_passivedefault(self):
         """test that when we get a primary key column back 
         from reflecting a table which has a default value on it, we pre-execute
@@ -217,7 +216,7 @@ class TimezoneTest(AssertMixin):
     if postgres returns it.  python then will not let you compare a datetime with a tzinfo to a datetime
     that doesnt have one.  this test illustrates two ways to have datetime types with and without timezone
     info. """
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def setUpAll(self):
         global tztable, notztable, metadata
         metadata = MetaData(testbase.db)
@@ -234,11 +233,11 @@ class TimezoneTest(AssertMixin):
             Column("name", String(20)),
         )
         metadata.create_all()
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def tearDownAll(self):
         metadata.drop_all()
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_with_timezone(self):
         # get a date with a tzinfo
         somedate = testbase.db.connect().scalar(func.current_timestamp().select())
@@ -247,7 +246,7 @@ class TimezoneTest(AssertMixin):
         x = c.last_updated_params()
         print x['date'] == somedate
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_without_timezone(self):
         # get a date without a tzinfo
         somedate = datetime.datetime(2005, 10,20, 11, 52, 00)
@@ -257,7 +256,7 @@ class TimezoneTest(AssertMixin):
         print x['date'] == somedate
 
 class ArrayTest(AssertMixin):
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def setUpAll(self):
         global metadata, arrtable
         metadata = MetaData(testbase.db)
@@ -268,11 +267,11 @@ class ArrayTest(AssertMixin):
             Column('strarr', postgres.PGArray(String), nullable=False)
         )
         metadata.create_all()
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def tearDownAll(self):
         metadata.drop_all()
     
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_reflect_array_column(self):
         metadata2 = MetaData(testbase.db)
         tbl = Table('arrtable', metadata2, autoload=True)
@@ -281,7 +280,7 @@ class ArrayTest(AssertMixin):
         self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
         self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
         
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_insert_array(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         results = arrtable.select().execute().fetchall()
@@ -290,7 +289,7 @@ class ArrayTest(AssertMixin):
         self.assertEquals(results[0]['strarr'], ['abc','def'])
         arrtable.delete().execute()
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_array_where(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
@@ -299,7 +298,7 @@ class ArrayTest(AssertMixin):
         self.assertEquals(results[0]['intarr'], [1,2,3])
         arrtable.delete().execute()
     
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_array_concat(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
index a639a4a90ae4f343d349be47d6e7767a5f853a3b..6a0c78f5782edb4a46c823d827dc109a8814801b 100644 (file)
@@ -2,13 +2,10 @@
 including the deprecated versions of these arguments"""
 
 import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
 
-class BindTest(testbase.PersistTest):
+class BindTest(PersistTest):
     def test_create_drop_explicit(self):
         metadata = MetaData()
         table = Table('test_table', metadata,   
@@ -201,8 +198,6 @@ class BindTest(testbase.PersistTest):
                 assert False
             except exceptions.InvalidRequestError, e:
                 assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
-
-                
         finally:
             if isinstance(bind, engine.Connection):
                 bind.close()
@@ -210,4 +205,4 @@ class BindTest(testbase.PersistTest):
         
                
 if __name__ == '__main__':
-    testbase.main()
\ No newline at end of file
+    testbase.main()
index 5bc9dbfe91ae757deb5fc4bd2bb4b78c185e15ea..3d3b43f9b61d7802e761a30868408d2b9eef9159 100644 (file)
@@ -1,12 +1,8 @@
-
 import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
 
-class ExecuteTest(testbase.PersistTest):
+class ExecuteTest(PersistTest):
     def setUpAll(self):
         global users, metadata
         metadata = MetaData(testbase.db)
@@ -21,7 +17,7 @@ class ExecuteTest(testbase.PersistTest):
     def tearDownAll(self):
         metadata.drop_all()
         
-    @testbase.supported('sqlite')
+    @testing.supported('sqlite')
     def test_raw_qmark(self):
         for conn in (testbase.db, testbase.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
@@ -33,7 +29,7 @@ class ExecuteTest(testbase.PersistTest):
             assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
             conn.execute("delete from users")
 
-    @testbase.supported('mysql', 'postgres')
+    @testing.supported('mysql', 'postgres')
     def test_raw_sprintf(self):
         for conn in (testbase.db, testbase.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"])
@@ -46,7 +42,7 @@ class ExecuteTest(testbase.PersistTest):
 
     # pyformat is supported for mysql, but skipping because a few driver
     # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_raw_python(self):
         for conn in (testbase.db, testbase.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
@@ -56,7 +52,7 @@ class ExecuteTest(testbase.PersistTest):
             assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
             conn.execute("delete from users")
 
-    @testbase.supported('sqlite')
+    @testing.supported('sqlite')
     def test_raw_named(self):
         for conn in (testbase.db, testbase.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
index 28b0535a5761adcec7b00408516068e90a5a638a..973007fab8a4ddafebc2d57f2365b17313a27b45 100644 (file)
@@ -1,8 +1,8 @@
 import testbase
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
 
-class MetaDataTest(testbase.PersistTest):
+class MetaDataTest(PersistTest):
     def test_metadata_connect(self):
         metadata = MetaData()
         t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
index eb4f95619d059da85060aeeae74ec6b83a9d334f..3e186275d5bded67e64e91113ef064d4c9f6e04f 100644 (file)
@@ -1,8 +1,7 @@
-from testbase import PersistTest
 import testbase
-import sqlalchemy.engine.url as url
 from sqlalchemy import *
-import unittest
+import sqlalchemy.engine.url as url
+from testlib import *
 
         
 class ParseConnectTest(PersistTest):
index 85d44dfd388f7f403d429b605394219c9dfa8781..3bb25ec72a7fb11db44cfe80aa8c95d307a46d85 100644 (file)
@@ -1,10 +1,9 @@
 import testbase
-from testbase import PersistTest
-import unittest, sys, os, time
-import threading, thread
-
+import threading, thread, time
 import sqlalchemy.pool as pool
 import sqlalchemy.exceptions as exceptions
+from testlib import *
+
 
 mcid = 1
 class MockDBAPI(object):
@@ -45,7 +44,7 @@ class PoolTest(PersistTest):
         connection2 = manager.connect('foo.db')
         connection3 = manager.connect('bar.db')
         
-        self.echo( "connection " + repr(connection))
+        print "connection " + repr(connection)
         self.assert_(connection.cursor() is not None)
         self.assert_(connection is connection2)
         self.assert_(connection2 is not connection3)
@@ -64,7 +63,7 @@ class PoolTest(PersistTest):
         connection = manager.connect('foo.db')
         connection2 = manager.connect('foo.db')
 
-        self.echo( "connection " + repr(connection))
+        print "connection " + repr(connection)
 
         self.assert_(connection.cursor() is not None)
         self.assert_(connection is not connection2)
@@ -80,7 +79,7 @@ class PoolTest(PersistTest):
     
         def status(pool):
             tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
-            self.echo( "Pool size: %d  Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup)
+            print "Pool size: %d  Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
             return tup
                 
         c1 = p.connect()
index 1c8594d0eeb3918e451a1fd1b42a98072a065fe3..7c213695f2643385c9946a8480d144af93731ddd 100644 (file)
@@ -1,6 +1,8 @@
 import testbase
+import sys, weakref
 from sqlalchemy import create_engine, exceptions
-import gc, weakref, sys
+from testlib import *
+
 
 class MockDisconnect(Exception):
     pass
@@ -37,7 +39,7 @@ class MockCursor(object):
     def close(self):
         pass
         
-class ReconnectTest(testbase.PersistTest):
+class ReconnectTest(PersistTest):
     def test_reconnect(self):
         """test that an 'is_disconnect' condition will invalidate the connection, and additionally
         dispose the previous connection pool and recreate."""
@@ -92,4 +94,4 @@ class ReconnectTest(testbase.PersistTest):
         assert len(dbapi.connections) == 1
         
 if __name__ == '__main__':
-    testbase.main()
\ No newline at end of file
+    testbase.main()
index 78ffd1fdcf248a3482d54a4e9d7d61aee3513ec1..1fa4b4b90f2569af2a3d02971690adfbe55427d6 100644 (file)
@@ -1,13 +1,12 @@
-from testbase import PersistTest
 import testbase
-import pickle
-import sqlalchemy.ansisql as ansisql
+import pickle, StringIO
 
 from sqlalchemy import *
+import sqlalchemy.ansisql as ansisql
 from sqlalchemy.exceptions import NoSuchTableError
 import sqlalchemy.databases.mysql as mysql
-from testbase import Table, Column
-import unittest, re, StringIO
+from testlib import *
+
 
 class ReflectionTest(PersistTest):
     def testbasic(self):
@@ -207,7 +206,7 @@ class ReflectionTest(PersistTest):
         finally:
             meta.drop_all()
             
-    @testbase.supported('mysql')
+    @testing.supported('mysql')
     def testmysqltypes(self):
         meta1 = MetaData(testbase.db)
         table = Table(
@@ -284,7 +283,7 @@ class ReflectionTest(PersistTest):
         finally:
             testbase.db.execute("drop table book")
             
-    @testbase.supported('sqlite')
+    @testing.supported('sqlite')
     def test_goofy_sqlite(self):
         """test autoload of table where quotes were used with all the colnames.  quirky in sqlite."""
         testbase.db.execute("""CREATE TABLE "django_content_type" (
@@ -458,7 +457,7 @@ class ReflectionTest(PersistTest):
         finally:
             table.drop()
 
-    @testbase.supported('mssql')
+    @testing.supported('mssql')
     def testidentity(self):
         meta = MetaData(testbase.db)
         table = Table(
@@ -576,7 +575,7 @@ class CreateDropTest(PersistTest):
 
 class SchemaTest(PersistTest):
     # this test should really be in the sql tests somewhere, not engine
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testiteration(self):
         metadata = MetaData()
         table1 = Table('table1', metadata, 
@@ -600,7 +599,7 @@ class SchemaTest(PersistTest):
         assert buf.index("CREATE TABLE someschema.table1") > -1
         assert buf.index("CREATE TABLE someschema.table2") > -1
     
-    @testbase.supported('mysql','postgres')
+    @testing.supported('mysql','postgres')
     def testcreate(self):
         engine = testbase.db
         schema = engine.dialect.get_default_schema_name(engine)
index f86d0cbdb1c17ab97edc67457bd5ce2af8fe2104..593a069a96178cf9328c25dcde75cdce239e0e61 100644 (file)
@@ -1,13 +1,12 @@
-
 import testbase
-import unittest, sys, datetime, random, time, threading
-import tables
-db = testbase.db
+import sys, time, threading
+
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
+
 
-class TransactionTest(testbase.PersistTest):
+class TransactionTest(PersistTest):
     def setUpAll(self):
         global users, metadata
         metadata = MetaData()
@@ -116,7 +115,7 @@ class TransactionTest(testbase.PersistTest):
         assert len(result.fetchall()) == 0
         connection.close()
     
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testnestedsubtransactionrollback(self):
         connection = testbase.db.connect()
         transaction = connection.begin()
@@ -133,7 +132,7 @@ class TransactionTest(testbase.PersistTest):
         )
         connection.close()
 
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testnestedsubtransactioncommit(self):
         connection = testbase.db.connect()
         transaction = connection.begin()
@@ -150,7 +149,7 @@ class TransactionTest(testbase.PersistTest):
         )
         connection.close()
 
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testrollbacktosubtransaction(self):
         connection = testbase.db.connect()
         transaction = connection.begin()
@@ -169,7 +168,7 @@ class TransactionTest(testbase.PersistTest):
         )
         connection.close()
     
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def testtwophasetransaction(self):
         connection = testbase.db.connect()
         
@@ -197,7 +196,7 @@ class TransactionTest(testbase.PersistTest):
         )
         connection.close()
 
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def testmixedtransaction(self):
         connection = testbase.db.connect()
         
@@ -230,7 +229,7 @@ class TransactionTest(testbase.PersistTest):
         )
         connection.close()
         
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def testtwophaserecover(self):
         # MySQL recovery doesn't currently seem to work correctly
         # Prepared transactions disappear when connections are closed and even
@@ -262,7 +261,7 @@ class TransactionTest(testbase.PersistTest):
         )
         connection2.close()
 
-class AutoRollbackTest(testbase.PersistTest):
+class AutoRollbackTest(PersistTest):
     def setUpAll(self):
         global metadata
         metadata = MetaData()
@@ -270,7 +269,7 @@ class AutoRollbackTest(testbase.PersistTest):
     def tearDownAll(self):
         metadata.drop_all(testbase.db)
         
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testrollback_deadlock(self):
         """test that returning connections to the pool clears any object locks."""
         conn1 = testbase.db.connect()
@@ -289,10 +288,10 @@ class AutoRollbackTest(testbase.PersistTest):
         users.drop(conn2)
         conn2.close()
 
-class TLTransactionTest(testbase.PersistTest):
+class TLTransactionTest(PersistTest):
     def setUpAll(self):
         global users, metadata, tlengine
-        tlengine = create_engine(testbase.db_uri, strategy='threadlocal')
+        tlengine = create_engine(testbase.db.url, strategy='threadlocal')
         metadata = MetaData()
         users = Table('query_users', metadata,
             Column('user_id', INT, primary_key = True),
@@ -402,7 +401,7 @@ class TLTransactionTest(testbase.PersistTest):
         finally:
             external_connection.close()
         
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testnesting(self):
         """tests nesting of tranacstions"""
         external_connection = tlengine.connect()
@@ -496,7 +495,7 @@ class TLTransactionTest(testbase.PersistTest):
         c2.close()
         assert c1.connection.connection is not None
 
-class ForUpdateTest(testbase.PersistTest):
+class ForUpdateTest(PersistTest):
     def setUpAll(self):
         global counters, metadata
         metadata = MetaData()
@@ -512,7 +511,7 @@ class ForUpdateTest(testbase.PersistTest):
         counters.drop(testbase.db)
 
     def increment(self, count, errors, update_style=True, delay=0.005):
-        con = db.connect()
+        con = testbase.db.connect()
         sel = counters.select(for_update=update_style,
                               whereclause=counters.c.counter_id==1)
         
@@ -539,7 +538,7 @@ class ForUpdateTest(testbase.PersistTest):
 
         con.close()
 
-    @testbase.supported('mysql', 'oracle', 'postgres')
+    @testing.supported('mysql', 'oracle', 'postgres')
     def testqueued_update(self):
         """Test SELECT FOR UPDATE with concurrent modifications.
 
@@ -574,7 +573,7 @@ class ForUpdateTest(testbase.PersistTest):
     def overlap(self, ids, errors, update_style):
         sel = counters.select(for_update=update_style,
                               whereclause=counters.c.counter_id.in_(*ids))
-        con = db.connect()
+        con = testbase.db.connect()
         trans = con.begin()
         try:
             rows = con.execute(sel).fetchall()
@@ -600,7 +599,7 @@ class ForUpdateTest(testbase.PersistTest):
 
         return errors
         
-    @testbase.supported('mysql', 'oracle', 'postgres')
+    @testing.supported('mysql', 'oracle', 'postgres')
     def testqueued_select(self):
         """Simple SELECT FOR UPDATE conflict test"""
 
@@ -609,7 +608,7 @@ class ForUpdateTest(testbase.PersistTest):
             sys.stderr.write("Failure: %s\n" % e)
         self.assert_(len(errors) == 0)
 
-    @testbase.supported('oracle', 'postgres')
+    @testing.supported('oracle', 'postgres')
     def testnowait_select(self):
         """Simple SELECT FOR UPDATE NOWAIT conflict test"""
 
index da1726f6a6d757a5ca5a9b394a17c19c619984ec..e28c72cd7315459dc1bd6e198e7760f2d27adf39 100644 (file)
@@ -1,15 +1,16 @@
 import testbase
+from datetime import datetime
+
 from sqlalchemy.ext.activemapper           import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
 from sqlalchemy             import and_, or_, exceptions
 from sqlalchemy             import ForeignKey, String, Integer, DateTime, Table, Column
 from sqlalchemy.orm         import clear_mappers, backref, create_session, class_mapper
-from datetime               import datetime
-import sqlalchemy
-
 import sqlalchemy.ext.activemapper as activemapper
+import sqlalchemy
+from testlib import *
 
 
-class testcase(testbase.PersistTest):
+class testcase(PersistTest):
     def setUpAll(self):
         clear_mappers()
         objectstore.clear()
@@ -143,7 +144,7 @@ class testcase(testbase.PersistTest):
         self.assertEquals(len(person.addresses), 2)
         self.assertEquals(person.addresses[0].postal_code, '30338')
 
-    @testbase.unsupported('mysql')
+    @testing.unsupported('mysql')
     def test_update(self):
         p1 = self.create_person_one()
         objectstore.flush()
@@ -261,7 +262,7 @@ class testcase(testbase.PersistTest):
 
         self.assertEquals(Person.query.count(), 2)
 
-class testmanytomany(testbase.PersistTest):
+class testmanytomany(PersistTest):
      def setUpAll(self):
          clear_mappers()
          objectstore.clear()
@@ -316,7 +317,7 @@ class testmanytomany(testbase.PersistTest):
          foo1.bazrel.append(baz1)
          assert (foo1.bazrel == [baz1])
         
-class testselfreferential(testbase.PersistTest):
+class testselfreferential(PersistTest):
     def setUpAll(self):
         clear_mappers()
         objectstore.clear()
index f3ef3d180d7bd90c340df5aa6aa460cb1eadf929..31b3dd576f4095f1b2d376cdd81bd6d9314c45d7 100644 (file)
@@ -1,12 +1,11 @@
-from testbase import PersistTest, AssertMixin
 import testbase
 
 from sqlalchemy import *
 from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
-
 from sqlalchemy.ext.assignmapper import assign_mapper
 from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
+from testlib import *
+
 
 class AssignMapperTest(PersistTest):
     def setUpAll(self):
index 3156e429237abc1e81c7523d014430704a55ea07..60362501e032a2f165cb7af6570c06d837ce6ec5 100644 (file)
@@ -1,14 +1,11 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest
 import testbase
+
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm.collections import collection
 from sqlalchemy.ext.associationproxy import *
-from testbase import Table, Column
+from testlib import *
 
-db = testbase.db
 
 class DictCollection(dict):
     @collection.appender
@@ -40,7 +37,7 @@ class _CollectionOperations(PersistTest):
     def setUp(self):
         collection_class = self.collection_class
 
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
     
         parents_table = Table('Parent', metadata,
                               Column('id', Integer, primary_key=True),
@@ -476,7 +473,7 @@ class CustomObjectTest(_CollectionOperations):
 
 class ScalarTest(PersistTest):
     def test_scalar_proxy(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
     
         parents_table = Table('Parent', metadata,
                               Column('id', Integer, primary_key=True),
@@ -592,7 +589,7 @@ class ScalarTest(PersistTest):
 
 class LazyLoadTest(PersistTest):
     def setUp(self):
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
     
         parents_table = Table('Parent', metadata,
                               Column('id', Integer, primary_key=True),
index cf6ab038e4f04c6649f5e36d261a0656b5743558..d16e20da73a228e4e938e4981e6219168f709133 100644 (file)
@@ -1,13 +1,10 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest, sys, os
 import testbase
+
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.ext.orderinglist import *
-from testbase import Table, Column
+from testlib import *
 
-db = testbase.db
 metadata = None
 
 # order in whole steps 
@@ -54,7 +51,7 @@ class OrderingListTest(PersistTest):
 
         global metadata, slides_table, bullets_table, Slide, Bullet
 
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         slides_table = Table('test_Slides', metadata,
                              Column('id', Integer, primary_key=True),
                              Column('name', String))
index 4bb8b97b4bcf237c7035a0a6bb38e668d73e171d..a2b8994188808caaf82478ac16122cee642e69f8 100644 (file)
@@ -2,9 +2,9 @@ import testbase
 
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
-class AssociationTest(testbase.PersistTest):
+class AssociationTest(PersistTest):
     def setUpAll(self):
         global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation
         metadata = MetaData(testbase.db)
@@ -139,7 +139,7 @@ class AssociationTest(testbase.PersistTest):
         sess.flush()
         self.assert_(item_keywords.count().scalar() == 0)
 
-class AssociationTest2(testbase.PersistTest):
+class AssociationTest2(PersistTest):
     def setUpAll(self):
         global table_originals, table_people, table_isauthor, metadata, Originals, People, IsAuthor
         metadata = MetaData(testbase.db)
index 4187ad8def7f516af3ccfb203f97bf4aa48973aa..652186b8e6171bae5a98ac9cad92d1b8c48f5c60 100644 (file)
@@ -1,12 +1,11 @@
 """eager loading unittests derived from mailing list-reported problems and trac tickets."""
 
-from testbase import PersistTest, AssertMixin, ORMTest
 import testbase
+import random, datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
-import random, datetime
+from testlib import *
 
 class EagerTest(AssertMixin):
     def setUpAll(self):
@@ -204,7 +203,7 @@ class EagerTest2(AssertMixin):
         obj = session.query(Left).get_by(tag='tag1')
         print obj.middle.right[0]
 
-class EagerTest3(testbase.ORMTest):
+class EagerTest3(ORMTest):
     """test eager loading combined with nested SELECT statements, functions, and aggregates"""
     def define_tables(self, metadata):
         global datas, foo, stats
@@ -272,7 +271,7 @@ class EagerTest3(testbase.ORMTest):
         # algorithms and there are repeated 'somedata' values in the list)
         assert verify_result == arb_result
 
-class EagerTest4(testbase.ORMTest):
+class EagerTest4(ORMTest):
     def define_tables(self, metadata):
         global departments, employees
         departments = Table('departments', metadata,
@@ -324,7 +323,7 @@ class EagerTest4(testbase.ORMTest):
         assert q.count() == 2
         assert q[0] is d2
 
-class EagerTest5(testbase.ORMTest):
+class EagerTest5(ORMTest):
     """test the construction of AliasedClauses for the same eager load property but different 
     parent mappers, due to inheritance"""
     def define_tables(self, metadata):
@@ -572,8 +571,8 @@ class EagerTest7(ORMTest):
 
         i = ctx.current.query(Invoice).get(invoice_id)
 
-        self.echo(repr(c))
-        self.echo(repr(i.company))
+        print repr(c)
+        print repr(i.company)
         self.assert_(repr(c) == repr(i.company))
 
     def testtwo(self):
@@ -638,7 +637,7 @@ class EagerTest7(ORMTest):
         ctx.current.clear()
 
         a = ctx.current.query(Company).get(company_id)
-        self.echo(repr(a))
+        print repr(a)
 
         # set up an invoice
         i1 = Invoice()
@@ -667,7 +666,7 @@ class EagerTest7(ORMTest):
         ctx.current.clear()
 
         c = ctx.current.query(Company).get(company_id)
-        self.echo(repr(c))
+        print repr(c)
 
         ctx.current.clear()
 
@@ -675,7 +674,7 @@ class EagerTest7(ORMTest):
 
         assert repr(i.company) == repr(c), repr(i.company) +  " does not match " + repr(c)
 
-class EagerTest8(testbase.ORMTest):
+class EagerTest8(ORMTest):
     def define_tables(self, metadata):
         global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t
 
index e63860b8d74743e8159495b4b0e63ac6308dc71a..9b5f738bf7f2fa7bfe8722ec8e8bc0814dcaed5d 100644 (file)
@@ -1,11 +1,9 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
+import testbase
+import pickle
 import sqlalchemy.orm.attributes as attributes
 from sqlalchemy.orm.collections import collection
 from sqlalchemy import exceptions
-import unittest, sys, os
-import pickle
-import testbase
+from testlib import *
 
 class MyTest(object):pass
 class MyTest2(object):pass
index d43f069bcb79972aa197572927c175d69c5e1e15..b832c427e0c252782345c866fd8f7850f972cf03 100644 (file)
@@ -1,12 +1,12 @@
-import testbase, tables
-import unittest, sys, datetime
+import testbase
 
-from sqlalchemy.ext.sessioncontext import SessionContext
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from sqlalchemy.ext.sessioncontext import SessionContext
+from testlib import *
+import testlib.tables as tables
 
-class O2MCascadeTest(testbase.AssertMixin):
+class O2MCascadeTest(AssertMixin):
     def tearDown(self):
         tables.delete()
 
@@ -114,7 +114,7 @@ class O2MCascadeTest(testbase.AssertMixin):
         sess = create_session()
         l = sess.query(tables.User).select()
         for u in l:
-            self.echo( repr(u.orders))
+            print repr(u.orders)
         self.assert_result(l, data[0], *data[1:])
 
         ids = (l[0].user_id, l[2].user_id)
@@ -174,7 +174,7 @@ class O2MCascadeTest(testbase.AssertMixin):
         self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids)  &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0)
 
 
-class M2OCascadeTest(testbase.AssertMixin):
+class M2OCascadeTest(AssertMixin):
     def tearDown(self):
         ctx.current.clear()
         for t in metadata.table_iterator(reverse=True):
@@ -262,7 +262,7 @@ class M2OCascadeTest(testbase.AssertMixin):
     
         
 
-class M2MCascadeTest(testbase.AssertMixin):
+class M2MCascadeTest(AssertMixin):
     def setUpAll(self):
         global metadata, a, b, atob
         metadata = MetaData(testbase.db)
@@ -337,7 +337,7 @@ class M2MCascadeTest(testbase.AssertMixin):
         assert b.count().scalar() == 0
         assert a.count().scalar() == 0
 
-class UnsavedOrphansTest(testbase.ORMTest):
+class UnsavedOrphansTest(ORMTest):
     """tests regarding pending entities that are orphans"""
     
     def define_tables(self, metadata):
@@ -397,7 +397,7 @@ class UnsavedOrphansTest(testbase.ORMTest):
         assert a.address_id is None, "Error: address should not be persistent"
 
 
-class UnsavedOrphansTest2(testbase.ORMTest):
+class UnsavedOrphansTest2(ORMTest):
     """same test as UnsavedOrphans only three levels deep"""
 
     def define_tables(self, meta):
@@ -457,7 +457,7 @@ class UnsavedOrphansTest2(testbase.ORMTest):
         assert item.id is None
         assert attr.id is None
 
-class DoubleParentOrphanTest(testbase.AssertMixin):
+class DoubleParentOrphanTest(AssertMixin):
     """test orphan detection for an entity with two parent relations"""
     
     def setUpAll(self):
@@ -523,7 +523,7 @@ class DoubleParentOrphanTest(testbase.AssertMixin):
             assert False
         except exceptions.FlushError, e:
             assert True
-    
-            
+
+
 if __name__ == "__main__":
     testbase.main()        
index e5cd1e935fe2b34c2ddfecebcf41c8b7daae9610..1f4f6492818e8e53a3f0d6849633b03c0c4c1875 100644 (file)
@@ -7,7 +7,7 @@ import sqlalchemy.orm.collections as collections
 from sqlalchemy.orm.collections import collection
 from sqlalchemy import util
 from operator import and_
-
+from testlib import *
 
 class Canary(interfaces.AttributeExtension):
     def __init__(self):
@@ -48,7 +48,7 @@ def dictable_entity(a=None, b=None, c=None):
     return Entity(a or str(_id), b or 'value %s' % _id, c)
 
 
-class CollectionsTest(testbase.PersistTest):
+class CollectionsTest(PersistTest):
     def _test_adapter(self, typecallable, creator=entity_maker,
                       to_set=None):
         class Foo(object):
@@ -980,7 +980,7 @@ class CollectionsTest(testbase.PersistTest):
         obj.attr[0] = e3
         self.assert_(e3 in canary.data)
 
-class DictHelpersTest(testbase.ORMTest):
+class DictHelpersTest(ORMTest):
     def define_tables(self, metadata):
         global parents, children, Parent, Child
         
index ef5faa21ddbfa3c2b36f96ff8207bec3640401f6..23f04db856048087c3f4d893590abf28314ffbc8 100644 (file)
@@ -1,10 +1,10 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
 
-class CompileTest(testbase.AssertMixin):
+class CompileTest(AssertMixin):
     """test various mapper compilation scenarios"""
     def tearDownAll(self):
         clear_mappers()
index bdbb146e9f5193dfca42d95ac7e560f8ce451148..ce3065f777bb294e80bef1c1a71b8412fe6f72af 100644 (file)
@@ -1,13 +1,8 @@
-from testbase import PersistTest, AssertMixin, ORMTest
-import unittest, sys, os
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-import StringIO
-import testbase
-from testbase import Table, Column
-
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
 
 """test cyclical mapper relationships.  Many of the assertions are provided
 via running with postgres, which is strict about foreign keys.
@@ -538,7 +533,7 @@ class OneToManyManyToOneTest(AssertMixin):
         sess.save(b)
         sess.save(p)
         
-        self.assert_sql(db, lambda: sess.flush(), [
+        self.assert_sql(testbase.db, lambda: sess.flush(), [
             (
                 "INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)",
                 {'favorite_ball_id': None, 'data':'some data'}
@@ -592,7 +587,7 @@ class OneToManyManyToOneTest(AssertMixin):
                 )
             ])
         sess.delete(p)
-        self.assert_sql(db, lambda: sess.flush(), [
+        self.assert_sql(testbase.db, lambda: sess.flush(), [
             # heres the post update (which is a pre-update with deletes)
             (
                 "UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id",
@@ -642,7 +637,7 @@ class OneToManyManyToOneTest(AssertMixin):
         sess = create_session()
         [sess.save(x) for x in [b,p,b2,b3,b4]]
 
-        self.assert_sql(db, lambda: sess.flush(), [
+        self.assert_sql(testbase.db, lambda: sess.flush(), [
                 (
                     "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
                     {'person_id':None, 'data':'some data'}
@@ -721,7 +716,7 @@ class OneToManyManyToOneTest(AssertMixin):
         ])
 
         sess.delete(p)
-        self.assert_sql(db, lambda: sess.flush(), [
+        self.assert_sql(testbase.db, lambda: sess.flush(), [
             (
                 "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
                 lambda ctx:{'person_id': None, 'ball_id': b.id}
@@ -833,7 +828,7 @@ class SelfReferentialPostUpdateTest(AssertMixin):
         remove_child(root, cats)
         # pre-trigger lazy loader on 'cats' to make the test easier
         cats.children
-        self.assert_sql(db, lambda: session.flush(), [
+        self.assert_sql(testbase.db, lambda: session.flush(), [
             (
                 "UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id",
                 lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}
index 396c28bf94d437c2628af5f29ac46e831cc17d39..49ea65153be3aaa793d0c1bb22bcea64cb1cf4aa 100644 (file)
@@ -1,9 +1,9 @@
 """basic tests of eager loaded attributes"""
 
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-import testbase
-
+from testlib import *
 from fixtures import *
 from query import QueryTest
 
@@ -438,7 +438,7 @@ class EagerTest(QueryTest):
         l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id)
         assert fixtures.user_address_result[1:2] == l.all()
 
-class SelfReferentialEagerTest(testbase.ORMTest):
+class SelfReferentialEagerTest(ORMTest):
     def define_tables(self, metadata):
         global nodes
         nodes = Table('nodes', metadata,
index 813ed327cf67eddf3013dd165e5736f5b595b1f6..da76e8df0533f2a449d2a2878566993c978345c0 100644 (file)
@@ -1,13 +1,9 @@
-from testbase import PersistTest, AssertMixin
-import unittest
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
-
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
 
 class EntityTest(AssertMixin):
     """tests mappers that are constructed based on "entity names", which allows the same class
index 39b0da383f0eba4cafe8b8f6c3333476c4d74243..4a7d41459f74c086bd032ab728da0f78b6b07afe 100644 (file)
@@ -1,5 +1,6 @@
+import testbase
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
 
 _recursion_stack = util.Set()
 class Base(object):
index fea783d268d400de35cae90731e3c62142ae76d3..5106388d733e9376af41a65e03be1c040bdd3507 100644 (file)
@@ -1,11 +1,9 @@
-from testbase import PersistTest, AssertMixin, ORMTest
 import testbase
-import tables
-
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy import exceptions
-from testbase import Table, Column
+from testlib import *
+import testlib.tables as tables
 
 # TODO: these are more tests that should be updated to be part of test/orm/query.py
 
@@ -40,7 +38,7 @@ class GenerativeQueryTest(PersistTest):
         assert res.order_by([Foo.c.bar])[0].bar == 5
         assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
         
-    @testbase.unsupported('mssql')
+    @testing.unsupported('mssql')
     def test_slice(self):
         sess = create_session()
         query = sess.query(Foo)
@@ -54,7 +52,7 @@ class GenerativeQueryTest(PersistTest):
         assert list(query[-5:]) == orig[-5:]
         assert query[10:20][5] == orig[10:20][5]
 
-    @testbase.supported('mssql')
+    @testing.supported('mssql')
     def test_slice_mssql(self):
         sess = create_session()
         query = sess.query(Foo)
@@ -71,23 +69,23 @@ class GenerativeQueryTest(PersistTest):
         assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
         assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
 
-    @testbase.unsupported('mysql')
+    @testing.unsupported('mysql')
     def test_aggregate_1(self):
         # this one fails in mysql as the result comes back as a string
         query = create_session().query(Foo)
         assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
 
-    @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
+    @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_2(self):
         query = create_session().query(Foo)
         assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
 
-    @testbase.supported('postgres', 'mysql', 'firebird', 'mssql')
+    @testing.supported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_2_int(self):
         query = create_session().query(Foo)
         assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
 
-    @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
+    @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_3(self):
         query = create_session().query(Foo)
         assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5
index 7677311739f114b64948e30dd6c671eb78c47135..3b35b3713de302f050b8ea5549515467ab2c259b 100644 (file)
@@ -1,15 +1,14 @@
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-
 from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE
-import testbase
+from testlib import *
 
 def produce_test(parent, child, direction):
     """produce a testcase for A->B->C inheritance with a self-referential
     relationship between two of the classes, using either one-to-many or
     many-to-one."""
-    class ABCTest(testbase.ORMTest):
+    class ABCTest(ORMTest):
         def define_tables(self, meta):
             global ta, tb, tc
             ta = ["a", meta]
index bcf269a8766803958bb4d77746ca19c88814a6f8..c6cd43f439d46540491bf76f32f38f0a2cca8430 100644 (file)
@@ -1,10 +1,10 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
 
-class O2MTest(testbase.ORMTest):
+class O2MTest(ORMTest):
     """deals with inheritance and one-to-many relationships"""
     def define_tables(self, metadata):
         global foo, bar, blub
@@ -58,11 +58,11 @@ class O2MTest(testbase.ORMTest):
         sess.clear()
         l = sess.query(Blub).select()
         result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
-        self.echo(result)
+        print result
         self.assert_(compare == result)
         self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
-class AddPropTest(testbase.ORMTest):
+class AddPropTest(ORMTest):
     """testing that construction of inheriting mappers works regardless of when extra properties
     are added to the superclass mapper"""
     def define_tables(self, metadata):
@@ -107,7 +107,7 @@ class AddPropTest(testbase.ORMTest):
         p.contenttype = ContentType()
         # TODO: assertion ??
         
-class EagerLazyTest(testbase.ORMTest):
+class EagerLazyTest(ORMTest):
     """tests eager load/lazy load of child items off inheritance mappers, tests that
     LazyLoader constructs the right query condition."""
     def define_tables(self, metadata):
@@ -149,7 +149,7 @@ class EagerLazyTest(testbase.ORMTest):
         self.assert_(len(q.selectfirst().eager) == 1)
 
 
-class FlushTest(testbase.ORMTest):
+class FlushTest(ORMTest):
     """test dependency sorting among inheriting mappers"""
     def define_tables(self, metadata):
         global users, roles, user_roles, admins
@@ -238,7 +238,7 @@ class FlushTest(testbase.ORMTest):
         sess.flush()
         assert user_roles.count().scalar() == 1
 
-class DistinctPKTest(testbase.ORMTest):
+class DistinctPKTest(ORMTest):
     """test the construction of mapper.primary_key when an inheriting relationship
     joins on a column other than primary key column."""
     keep_data = True
index d0d03210bf482a08943e93b87a9017f0e5a551aa..167b25256da9bbcb38d06fd9646c0cffca546af6 100644 (file)
@@ -1,9 +1,9 @@
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-import testbase
-from testbase import Table, Column
+from testlib import *
 
-class ConcreteTest1(testbase.ORMTest):
+class ConcreteTest1(ORMTest):
     def define_tables(self, metadata):
         global managers_table, engineers_table
         managers_table = Table('managers', metadata, 
index 509216d3745d888b38d8f0e488c6a0428ffd22ac..a0bf2414852d83237eadc7cedb14542ec5a37a54 100644 (file)
@@ -1,7 +1,7 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
 
 class BaseObject(object):
@@ -68,7 +68,7 @@ class ClassifiedPage(MagazinePage):
     pass
 
 
-class MagazineTest(testbase.ORMTest):
+class MagazineTest(ORMTest):
     def define_tables(self, metadata):
         global publication_table, issue_table, location_table, location_name_table, magazine_table, \
         page_table, magazine_page_table, classified_page_table, page_size_table
index 97885c1b087483ff21c67a95f1ee2a44dd928af0..df00f39d0b4a0262c1c63ebbcff1ffe87dec7523 100644 (file)
@@ -1,10 +1,10 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
 
-class InheritTest(testbase.ORMTest):
+class InheritTest(ORMTest):
     """deals with inheritance and many-to-many relationships"""
     def define_tables(self, metadata):
         global principals
@@ -76,7 +76,7 @@ class InheritTest(testbase.ORMTest):
         sess.flush()
         # TODO: put an assertion
         
-class InheritTest2(testbase.ORMTest):
+class InheritTest2(ORMTest):
     """deals with inheritance and many-to-many relationships"""
     def define_tables(self, metadata):
         global foo, bar, foo_bar
@@ -148,7 +148,7 @@ class InheritTest2(testbase.ORMTest):
             {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])},
             )
 
-class InheritTest3(testbase.ORMTest):
+class InheritTest3(ORMTest):
     """deals with inheritance and many-to-many relationships"""
     def define_tables(self, metadata):
         global foo, bar, blub, bar_foo, blub_bar, blub_foo
@@ -203,7 +203,7 @@ class InheritTest3(testbase.ORMTest):
         compare = repr(b) + repr(b.foos)
         sess.clear()
         l = sess.query(Bar).select()
-        self.echo(repr(l[0]) + repr(l[0].foos))
+        print repr(l[0]) + repr(l[0].foos)
         self.assert_(repr(l[0]) + repr(l[0].foos) == compare)
     
     def testadvanced(self):    
@@ -243,11 +243,11 @@ class InheritTest3(testbase.ORMTest):
         sess.clear()
 
         l = sess.query(Blub).select()
-        self.echo(l)
+        print l
         self.assert_(repr(l[0]) == compare)
         sess.clear()
         x = sess.query(Blub).get_by(id=blubid)
-        self.echo(x)
+        print x
         self.assert_(repr(x) == compare)
         
         
index 7858689b1dc2a7bc07e5d6f62f2cc4765b6e96b8..7297002f526a12223565140086bd2a38dd69dfa9 100644 (file)
@@ -1,10 +1,10 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
 
-class PolymorphicCircularTest(testbase.ORMTest):
+class PolymorphicCircularTest(ORMTest):
     keep_mappers = True
     def define_tables(self, metadata):
         global Table1, Table1B, Table2, Table3,  Data
index a3395924f3d43ec76419d894627741668bbf51f0..3eb2e032f0442ca036a222a48b092aace867c623 100644 (file)
@@ -1,10 +1,10 @@
 """tests basic polymorphic mapper loading/saving, minimal relations"""
 
 import testbase
+import sets
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-import sets
+from testlib import *
 
 
 class Person(object):
@@ -35,7 +35,7 @@ class Company(object):
     def __repr__(self):
         return "Company %s" % self.name
 
-class PolymorphTest(testbase.ORMTest):
+class PolymorphTest(ORMTest):
     def define_tables(self, metadata):
         global companies, people, engineers, managers, boss
         
index d18539aa10adfa383b8560897c803fc0a4bf7423..a2f9c4a5f001e10bec52970bd40f5b47c439d4bd 100644 (file)
@@ -1,7 +1,7 @@
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-import testbase
+from testlib import *
 
 
 class AttrSettable(object):
@@ -11,7 +11,7 @@ class AttrSettable(object):
         return self.__class__.__name__ + "(%s)" % (hex(id(self)))
 
 
-class RelationTest1(testbase.ORMTest):
+class RelationTest1(ORMTest):
     """test self-referential relationships on polymorphic mappers"""
     def define_tables(self, metadata):
         global people, managers
@@ -91,7 +91,7 @@ class RelationTest1(testbase.ORMTest):
         print p, m, m.employee
         assert m.employee is p
             
-class RelationTest2(testbase.ORMTest):
+class RelationTest2(ORMTest):
     """test self-referential relationships on polymorphic mappers"""
     def define_tables(self, metadata):
         global people, managers, data
@@ -186,7 +186,7 @@ class RelationTest2(testbase.ORMTest):
         if usedata:
             assert m.data.data == 'ms data'
 
-class RelationTest3(testbase.ORMTest):
+class RelationTest3(ORMTest):
     """test self-referential relationships on polymorphic mappers"""
     def define_tables(self, metadata):
         global people, managers, data
@@ -289,7 +289,7 @@ for jointype in ["join1", "join2", "join3", "join4"]:
         setattr(RelationTest3, func.__name__, func)
             
         
-class RelationTest4(testbase.ORMTest):
+class RelationTest4(ORMTest):
     def define_tables(self, metadata):
         global people, engineers, managers, cars
         people = Table('people', metadata, 
@@ -405,7 +405,7 @@ class RelationTest4(testbase.ORMTest):
         c = s.join("employee").filter(Person.name=="E4")[0]
         assert c.car_id==car1.car_id
 
-class RelationTest5(testbase.ORMTest):
+class RelationTest5(ORMTest):
     def define_tables(self, metadata):
         global people, engineers, managers, cars
         people = Table('people', metadata, 
@@ -465,7 +465,7 @@ class RelationTest5(testbase.ORMTest):
         assert carlist[0].manager is None
         assert carlist[1].manager.person_id == car2.manager.person_id
 
-class RelationTest6(testbase.ORMTest):
+class RelationTest6(ORMTest):
     """test self-referential relationships on a single joined-table inheritance mapper"""
     def define_tables(self, metadata):
         global people, managers, data
@@ -508,7 +508,7 @@ class RelationTest6(testbase.ORMTest):
         m2 = sess.query(Manager).get(m2.person_id)
         assert m.colleague is m2
 
-class RelationTest7(testbase.ORMTest):
+class RelationTest7(ORMTest):
     def define_tables(self, metadata):
         global people, engineers, managers, cars, offroad_cars
         cars = Table('cars', metadata,
@@ -607,7 +607,7 @@ class RelationTest7(testbase.ORMTest):
         for p in r:
             assert p.car_id == p.car.car_id
     
-class GenerativeTest(testbase.AssertMixin):
+class GenerativeTest(AssertMixin):
     def setUpAll(self):
         #  cars---owned by---  people (abstract) --- has a --- status
         #   |                  ^    ^                            |
@@ -733,7 +733,7 @@ class GenerativeTest(testbase.AssertMixin):
             r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id))
             assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
         
-class MultiLevelTest(testbase.ORMTest):
+class MultiLevelTest(ORMTest):
     def define_tables(self, metadata):
         global table_Employee, table_Engineer, table_Manager
         table_Employee = Table( 'Employee', metadata,
@@ -810,7 +810,7 @@ class MultiLevelTest(testbase.ORMTest):
         assert set(session.query( Engineer).select()) == set([b,c])
         assert session.query( Manager).select() == [c]
 
-class ManyToManyPolyTest(testbase.ORMTest):
+class ManyToManyPolyTest(ORMTest):
     def define_tables(self, metadata):
         global base_item_table, item_table, base_item_collection_table, collection_table
         base_item_table = Table(
@@ -860,7 +860,7 @@ class ManyToManyPolyTest(testbase.ORMTest):
         
         class_mapper(BaseItem)
 
-class CustomPKTest(testbase.ORMTest):
+class CustomPKTest(ORMTest):
     def define_tables(self, metadata):
         global t1, t2
         t1 = Table('t1', metadata, 
index aff89f5b77c65cdc4b59c5bf3d80ef4f79e9d937..2459cd36e1f58fb2dfb3dedb416d4673968e0e47 100644 (file)
@@ -1,11 +1,11 @@
 import testbase
+from datetime import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 
-from datetime import datetime
 
-class InheritTest(testbase.ORMTest):
+class InheritTest(ORMTest):
     """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
     def define_tables(self, metadata):
         global products_table, specification_table, documents_table
index 1432887e00a7e4ce4bf8ea55607ffe3e53faec93..68fe821af0dd7f342e1105d25f638b6b9bfe1d0e 100644 (file)
@@ -1,10 +1,10 @@
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-import testbase
+from testlib import *
 
 
-class SingleInheritanceTest(testbase.AssertMixin):
+class SingleInheritanceTest(AssertMixin):
     def setUpAll(self):
         metadata = MetaData(testbase.db)
         global employees_table
index e9d77e09c6df6f346f450dd79ce12f0f141b8280..6684c628815d3a8d256cf9a7a966dcb4b27e8912 100644 (file)
@@ -1,9 +1,9 @@
 """basic tests of lazy loaded attributes"""
 
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-import testbase
-
+from testlib import *
 from fixtures import *
 from query import QueryTest
 
index 83694d3961849d5dfb6d5f2594296723faf87915..b5296120b37899535c9815426f50842a8d4d2d66 100644 (file)
@@ -1,10 +1,7 @@
-from testbase import PersistTest, AssertMixin
 import testbase
-import unittest, sys, os
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-import datetime
+from testlib import *
 
 class LazyTest(AssertMixin):
     def setUpAll(self):
index 3cd680fd2a19dc9e7f233ff7eb88a55b5b7bf628..8b310f86c5e8eccd98b79246291cea5a614969e7 100644 (file)
@@ -1,8 +1,8 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-import string
+from testlib import *
+
 
 class Place(object):
     '''represents a place'''
@@ -27,7 +27,7 @@ class Transition(object):
     def __repr__(self):
         return object.__repr__(self)+ " " + repr(self.inputs) + " " + repr(self.outputs)
         
-class M2MTest(testbase.ORMTest):
+class M2MTest(ORMTest):
     def define_tables(self, metadata):
         global place
         place = Table('place', metadata,
@@ -112,7 +112,7 @@ class M2MTest(testbase.ORMTest):
 
         for p in l:
             pp = p.places
-            self.echo("Place " + str(p) +" places " + repr(pp))
+            print "Place " + str(p) +" places " + repr(pp)
 
         [sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7]
         sess.flush()
@@ -178,7 +178,7 @@ class M2MTest(testbase.ORMTest):
         self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
         self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
 
-class M2MTest2(testbase.ORMTest):
+class M2MTest2(ORMTest):
     def define_tables(self, metadata):
         global studentTbl
         studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True))
@@ -245,7 +245,7 @@ class M2MTest2(testbase.ORMTest):
         sess.flush()
         assert enrolTbl.count().scalar() == 0
         
-class M2MTest3(testbase.ORMTest):
+class M2MTest3(ORMTest):
     def define_tables(self, metadata):
         global c, c2a1, c2a2, b, a
         c = Table('c', metadata, 
index 5fc5e15a5ad15f0bd2c25225a6fe9c897d6bdd7f..f9c3aac841919fbb9237bbbaa2097cad6cf763c6 100644 (file)
@@ -1,14 +1,14 @@
 """tests general mapper operations with an emphasis on selecting/loading"""
 
-from testbase import PersistTest, AssertMixin, ORMTest
 import testbase
-import unittest, sys, os
 from sqlalchemy import *
 from sqlalchemy.orm import *
 import sqlalchemy.exceptions as exceptions
 from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
+
 
 class MapperSuperTest(AssertMixin):
     def setUpAll(self):
@@ -168,7 +168,7 @@ class MapperTest(MapperSuperTest):
         u = q2.selectfirst(users.c.user_id==8)
         def go():
             s.refresh(u)
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
 
     def testexpire(self):
         """test the expire function"""
@@ -300,7 +300,7 @@ class MapperTest(MapperSuperTest):
 #        l = create_session().query(User).select(order_by=None)
         
         
-    @testbase.unsupported('firebird') 
+    @testing.unsupported('firebird') 
     def testfunction(self):
         """test mapping to a SELECT statement that has functions in it."""
         s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
@@ -313,7 +313,7 @@ class MapperTest(MapperSuperTest):
         assert l[0].concat == l[0].user_id * 2 == 14
         assert l[1].concat == l[1].user_id * 2 == 16
 
-    @testbase.unsupported('firebird') 
+    @testing.unsupported('firebird') 
     def testcount(self):
         """test the count function on Query.
         
@@ -390,7 +390,7 @@ class MapperTest(MapperSuperTest):
         def go():
             u = sess.query(User).options(eagerload('adlist')).get_by(user_name='jack')
             self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
         
     def testextensionoptions(self):
         sess  = create_session()
@@ -429,7 +429,7 @@ class MapperTest(MapperSuperTest):
 
         def go():
             self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
 
     def testeageroptionswithlimit(self):
         sess = create_session()
@@ -441,7 +441,7 @@ class MapperTest(MapperSuperTest):
         def go():
             assert u.user_id == 8
             assert len(u.addresses) == 3
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
 
         sess.clear()
         
@@ -450,7 +450,7 @@ class MapperTest(MapperSuperTest):
             u = sess.query(User).get_by(user_id=8)
             assert u.user_id == 8
             assert len(u.addresses) == 3
-        assert "tbl_row_count" not in self.capture_sql(db, go)
+        assert "tbl_row_count" not in self.capture_sql(testbase.db, go)
         
     def testlazyoptionswithlimit(self):
         sess = create_session()
@@ -462,7 +462,7 @@ class MapperTest(MapperSuperTest):
         def go():
             assert u.user_id == 8
             assert len(u.addresses) == 3
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
 
     def testeagerdegrade(self):
         """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available"""
@@ -475,7 +475,7 @@ class MapperTest(MapperSuperTest):
         def go():
             l = sess.query(usermapper).select()
             self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
 
         sess.clear()
         
@@ -486,7 +486,7 @@ class MapperTest(MapperSuperTest):
             r = users.select().execute()
             l = usermapper.instances(r, sess)
             self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(db, go, 4)
+        self.assert_sql_count(testbase.db, go, 4)
         
         clear_mappers()
 
@@ -513,7 +513,7 @@ class MapperTest(MapperSuperTest):
         def go():
             l = sess.query(usermapper).select()
             self.assert_result(l, User, *user_all_result)
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
 
         sess.clear()
         
@@ -523,7 +523,7 @@ class MapperTest(MapperSuperTest):
             r = users.select().execute()
             l = usermapper.instances(r, sess)
             self.assert_result(l, User, *user_all_result)
-        self.assert_sql_count(db, go, 7)
+        self.assert_sql_count(testbase.db, go, 7)
         
         
     def testlazyoptions(self):
@@ -535,7 +535,7 @@ class MapperTest(MapperSuperTest):
         l = sess.query(User).options(lazyload('addresses')).select()
         def go():
             self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(db, go, 3)
+        self.assert_sql_count(testbase.db, go, 3)
 
     def testlatecompile(self):
         """tests mappers compiling late in the game"""
@@ -549,7 +549,7 @@ class MapperTest(MapperSuperTest):
         u = sess.query(User).select()
         def go():
             print u[0].orders[1].items[0].keywords[1]
-        self.assert_sql_count(db, go, 3)
+        self.assert_sql_count(testbase.db, go, 3)
 
     def testdeepoptions(self):
         mapper(User, users,
@@ -567,7 +567,7 @@ class MapperTest(MapperSuperTest):
         u = sess.query(User).select()
         def go():
             print u[0].orders[1].items[0].keywords[1]
-        self.assert_sql_count(db, go, 3)
+        self.assert_sql_count(testbase.db, go, 3)
         sess.clear()
         
         
@@ -578,7 +578,7 @@ class MapperTest(MapperSuperTest):
         def go():
             print u[0].orders[1].items[0].keywords[1]
         print "-------MARK2----------"
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
 
         sess.clear()
 
@@ -588,7 +588,7 @@ class MapperTest(MapperSuperTest):
         def go():
             print u[0].orders[1].items[0].keywords[1]
         print "-------MARK3----------"
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
         print "-------MARK4----------"
 
         sess.clear()
@@ -598,7 +598,7 @@ class MapperTest(MapperSuperTest):
         print "-------MARK5----------"
         q3 = sess.query(User).options(eagerload('orders.items.keywords'))
         u = q3.select()
-        self.assert_sql_count(db, go, 2)
+        self.assert_sql_count(testbase.db, go, 2)
             
     
 class DeferredTest(MapperSuperTest):
@@ -619,8 +619,8 @@ class DeferredTest(MapperSuperTest):
             o2 = l[2]
             print o2.description
 
-        orderby = str(orders.default_order_by()[0].compile(bind=db))
-        self.assert_sql(db, go, [
+        orderby = str(orders.default_order_by()[0].compile(bind=testbase.db))
+        self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
             ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
         ])
@@ -682,8 +682,8 @@ class DeferredTest(MapperSuperTest):
             assert o2.opened == 1
             assert o2.userident == 7
             assert o2.description == 'order 3'
-        orderby = str(orders.default_order_by()[0].compile(db))
-        self.assert_sql(db, go, [
+        orderby = str(orders.default_order_by()[0].compile(testbase.db))
+        self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
             ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
         ])
@@ -695,7 +695,7 @@ class DeferredTest(MapperSuperTest):
         o2.description = 'order 3'
         def go():
             sess.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
     
     def testcommitsstate(self):
         """test that when deferred elements are loaded via a group, they get the proper CommittedState
@@ -717,7 +717,7 @@ class DeferredTest(MapperSuperTest):
         def go():
             # therefore the flush() shouldnt actually issue any SQL
             sess.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
             
     def testoptions(self):
         """tests using options on a mapper to create deferred and undeferred columns"""
@@ -729,8 +729,8 @@ class DeferredTest(MapperSuperTest):
             l = q2.select()
             print l[2].user_id
             
-        orderby = str(orders.default_order_by()[0].compile(db))
-        self.assert_sql(db, go, [
+        orderby = str(orders.default_order_by()[0].compile(testbase.db))
+        self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
             ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
         ])
@@ -739,7 +739,7 @@ class DeferredTest(MapperSuperTest):
         def go():
             l = q3.select()
             print l[3].user_id
-        self.assert_sql(db, go, [
+        self.assert_sql(testbase.db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
         ])
 
@@ -759,8 +759,8 @@ class DeferredTest(MapperSuperTest):
             assert o2.opened == 1
             assert o2.userident == 7
             assert o2.description == 'order 3'
-        orderby = str(orders.default_order_by()[0].compile(db))
-        self.assert_sql(db, go, [
+        orderby = str(orders.default_order_by()[0].compile(testbase.db))
+        self.assert_sql(testbase.db, go, [
             ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen, orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
         ])
 
@@ -779,7 +779,7 @@ class DeferredTest(MapperSuperTest):
         item = l[0].orders[1].items[1]
         def go():
             print item.item_name
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
         self.assert_(item.item_name == 'item 4')
         sess.clear()
         q2 = q.options(undefer('orders.items.item_name'))
@@ -787,7 +787,7 @@ class DeferredTest(MapperSuperTest):
         item = l[0].orders[1].items[1]
         def go():
             print item.item_name
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
         self.assert_(item.item_name == 'item 4')
 
 class CompositeTypesTest(ORMTest):
index 5834974ec539ed1a013434d293b2398f7fe88f2b..26da7c010d6f29d828a494a1f3ac03def0539869 100644 (file)
@@ -1,23 +1,18 @@
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from sqlalchemy.orm import mapperlib, session, unitofwork, attributes
-Mapper = mapperlib.Mapper
-import gc
 import testbase
-from testbase import Table, Column
-import tables
+import gc
+from sqlalchemy import MetaData, Integer, String, ForeignKey
+from sqlalchemy.orm import mapper, relation, clear_mappers, create_session
+from sqlalchemy.orm.mapper import Mapper
+from testlib import *
 
 class A(object):pass
 class B(object):pass
 
-class MapperCleanoutTest(testbase.AssertMixin):
+class MapperCleanoutTest(AssertMixin):
     """test that clear_mappers() removes everything related to the class.
     
     does not include classes that use the assignmapper extension."""
-    def setUp(self):
-        global engine
-        engine = testbase.db
-    
+
     def test_mapper_cleanup(self):
         for x in range(0, 5):
             self.do_test()
@@ -35,7 +30,7 @@ class MapperCleanoutTest(testbase.AssertMixin):
         assert True
         
     def do_test(self):
-        metadata = MetaData(engine)
+        metadata = MetaData(testbase.db)
 
         table1 = Table("mytable", metadata, 
             Column('col1', Integer, primary_key=True),
index 66e21ed8d16b491f4c4ca843bd715baf26e9a9ca..3dd0a95a47fd5e5d9dd8f3c9d390c2a36390b671 100644 (file)
@@ -1,10 +1,9 @@
-from testbase import PersistTest, AssertMixin
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
 
 class MergeTest(AssertMixin):
     """tests session.merge() functionality"""
@@ -166,5 +165,3 @@ class MergeTest(AssertMixin):
         
 if __name__ == "__main__":    
     testbase.main()
-
-                
index ada4c98cdf44d06479d28b4a67b6307d02986e08..e41fa1d20f1d0a8f33a375d5b5c1f1b9fddd7887 100644 (file)
@@ -2,7 +2,7 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
+from testlib import *
 
 class Jack(object):
     def __repr__(self):
@@ -24,7 +24,7 @@ class Port(object):
         self.name=name
         self.description = description
 
-class O2OTest(testbase.AssertMixin):
+class O2OTest(AssertMixin):
     def setUpAll(self):
         global jack, port, metadata, ctx
         metadata = MetaData(testbase.db)
index a32a1439e65ec57aadd83f7e6cf80edcf58c57ee..9d516f3780b51237a2f5b809f6330c30bac7531f 100644 (file)
@@ -1,10 +1,10 @@
 import testbase
+import operator
 from sqlalchemy import *
 from sqlalchemy import ansisql
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 from fixtures import *
-import operator
 
 class Base(object):
     def __init__(self, **kwargs):
@@ -40,7 +40,7 @@ class Base(object):
         else:
             return True
 
-class QueryTest(testbase.ORMTest):
+class QueryTest(ORMTest):
     keep_mappers = True
     keep_data = True
     
@@ -703,7 +703,7 @@ class CustomJoinTest(QueryTest):
 
         assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all()
 
-class SelfReferentialJoinTest(testbase.ORMTest):
+class SelfReferentialJoinTest(ORMTest):
     def define_tables(self, metadata):
         global nodes
         nodes = Table('nodes', metadata,
index 2d32c8ac374b87e2e181fa65312be0a84fe6c58d..9fca22b2446e360f6e056ed7e34bed7ec0486f44 100644 (file)
@@ -1,14 +1,12 @@
 import testbase
-import unittest, sys, datetime
+import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm import collections
 from sqlalchemy.orm.collections import collection
-from testbase import Table, Column
+from testlib import *
 
-db = testbase.db
-
-class RelationTest(testbase.PersistTest):
+class RelationTest(PersistTest):
     """this is essentially an extension of the "dependency.py" topological sort test.  
     in this test, a table is dependent on two other tables that are otherwise unrelated to each other.
     the dependency sort must insure that this childmost table is below both parent tables in the outcome
@@ -17,10 +15,8 @@ class RelationTest(testbase.PersistTest):
     to subtle differences in program execution, this test case was exposing the bug whereas the simpler tests
     were not."""
     def setUpAll(self):
-        global tbl_a
-        global tbl_b
-        global tbl_c
-        global tbl_d
+        global metadata, tbl_a, tbl_b, tbl_c, tbl_d
+
         metadata = MetaData()
         tbl_a = Table("tbl_a", metadata,
             Column("id", Integer, primary_key=True),
@@ -89,7 +85,7 @@ class RelationTest(testbase.PersistTest):
         conn.drop(tbl_a)
 
     def tearDownAll(self):
-        testbase.metadata.tables.clear()
+        metadata.drop_all(testbase.db)
     
     def testDeleteRootTable(self):
         session.flush()
@@ -101,7 +97,7 @@ class RelationTest(testbase.PersistTest):
         session.delete(c) # fails
         session.flush()
         
-class RelationTest2(testbase.PersistTest):
+class RelationTest2(PersistTest):
     """this test tests a relationship on a column that is included in multiple foreign keys,
     as well as a self-referential relationship on a composite key where one column in the foreign key
     is 'joined to itself'."""
@@ -218,7 +214,7 @@ class RelationTest2(testbase.PersistTest):
         assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
         assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
         
-class RelationTest3(testbase.PersistTest):
+class RelationTest3(PersistTest):
     def setUpAll(self):
         global jobs, pageversions, pages, metadata, Job, Page, PageVersion, PageComment
         import datetime
@@ -352,7 +348,7 @@ class RelationTest3(testbase.PersistTest):
         s.delete(j)
         s.flush()
 
-class RelationTest4(testbase.ORMTest):
+class RelationTest4(ORMTest):
     """test syncrules on foreign keys that are also primary"""
     def define_tables(self, metadata):
         global tableA, tableB
@@ -500,7 +496,7 @@ class RelationTest4(testbase.ORMTest):
         assert a1 not in sess
         assert b1 not in sess
 
-class RelationTest5(testbase.ORMTest):
+class RelationTest5(ORMTest):
     """test a map to a select that relates to a map to the table"""
     def define_tables(self, metadata):
         global items
@@ -556,7 +552,7 @@ class RelationTest5(testbase.ORMTest):
             assert old.id == new.id
         
         
-class TypeMatchTest(testbase.ORMTest):
+class TypeMatchTest(ORMTest):
     """test errors raised when trying to add items whose type is not handled by a relation"""
     def define_tables(self, metadata):
         global a, b, c, d
@@ -674,7 +670,7 @@ class TypeMatchTest(testbase.ORMTest):
         except exceptions.AssertionError, err:
             assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
 
-class TypedAssociationTable(testbase.ORMTest):
+class TypedAssociationTable(ORMTest):
     def define_tables(self, metadata):
         global t1, t2, t3
         
@@ -724,7 +720,7 @@ class TypedAssociationTable(testbase.ORMTest):
         assert t3.count().scalar() == 1
         
 # TODO: move these tests to either attributes.py test or its own module
-class CustomCollectionsTest(testbase.ORMTest):
+class CustomCollectionsTest(ORMTest):
     def define_tables(self, metadata):
         global sometable, someothertable
         sometable = Table('sometable', metadata,
@@ -1005,7 +1001,7 @@ class CustomCollectionsTest(testbase.ORMTest):
         o = list(p2.children)
         assert len(o) == 3
 
-class ViewOnlyTest(testbase.ORMTest):
+class ViewOnlyTest(ORMTest):
     """test a view_only mapping where a third table is pulled into the primary join condition,
     using overlapping PK column names (should not produce "conflicting column" error)"""
     def define_tables(self, metadata):
@@ -1056,7 +1052,7 @@ class ViewOnlyTest(testbase.ORMTest):
         assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id])
         assert set([x.id for x in c1.t2_view]) == set([c2b.id])
 
-class ViewOnlyTest2(testbase.ORMTest):
+class ViewOnlyTest2(ORMTest):
     """test a view_only mapping where a third table is pulled into the primary join condition,
     using non-overlapping PK column names (should not produce "mapper has no column X" error)"""
     def define_tables(self, metadata):
index eca48f836da9702935aee5972b8e79ba2b50a7d5..43327967378e27d1a3911c2c86aca0d92f7de556 100644 (file)
@@ -1,14 +1,9 @@
-from testbase import AssertMixin
 import testbase
-import unittest, sys, datetime
-
-import tables
-from tables import *
-
-db = testbase.db
 from sqlalchemy import *
 from sqlalchemy.orm import *
-
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
 
 class SessionTest(AssertMixin):
     def setUpAll(self):
@@ -78,7 +73,7 @@ class SessionTest(AssertMixin):
         # then see if expunge fails
         session.expunge(u)
     
-    @testbase.unsupported('sqlite')    
+    @testing.unsupported('sqlite')    
     def test_transaction(self):
         class User(object):pass
         mapper(User, users)
@@ -95,7 +90,7 @@ class SessionTest(AssertMixin):
         assert conn1.execute("select count(1) from users").scalar() == 1
         assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
     
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def test_autoflush(self):
         class User(object):pass
         mapper(User, users)
@@ -114,7 +109,7 @@ class SessionTest(AssertMixin):
         assert conn1.execute("select count(1) from users").scalar() == 1
         assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
 
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def test_autoflush_unbound(self):
         class User(object):pass
         mapper(User, users)
@@ -163,7 +158,7 @@ class SessionTest(AssertMixin):
         trans.rollback() # rolls back
         assert len(sess.query(User).select()) == 0
 
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def test_external_nested_transaction(self):
         class User(object):pass
         mapper(User, users)
@@ -187,7 +182,7 @@ class SessionTest(AssertMixin):
             conn.close()
             raise
     
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def test_twophase(self):
         # TODO: mock up a failure condition here
         # to ensure a rollback succeeds
@@ -226,7 +221,7 @@ class SessionTest(AssertMixin):
         sess.rollback() # rolls back
         assert len(sess.query(User).select()) == 0
 
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def test_nested_transaction(self):
         class User(object):pass
         mapper(User, users)
@@ -248,7 +243,7 @@ class SessionTest(AssertMixin):
         sess.commit()
         assert len(sess.query(User).select()) == 1
 
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def test_nested_autotrans(self):
         class User(object):pass
         mapper(User, users)
index f6cd8f9f482a2c585ea6ca1482900c50b42ebdd2..7a60b47c7e390323f2b78dc2a7220ea6bb8b87a2 100644 (file)
@@ -1,12 +1,10 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.orm.session import object_session, Session
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.session import object_session, Session
+from testlib import *
 
-import testbase
-from testbase import Table, Column
 
 metadata = MetaData()
 users = Table('users', metadata,
index b7b584794301ada7b94d26758caf10836074aaaf..eb39549eeacd9e6d8aaec56502f81c865908e894 100644 (file)
@@ -1,15 +1,14 @@
-from testbase import PersistTest, AssertMixin
-from sqlalchemy import *
-from sqlalchemy.orm import *
 import testbase
-from testbase import Table, Column
 import pickleable
+from sqlalchemy import *
+from sqlalchemy.orm import *
 from sqlalchemy.orm.mapper import global_extensions
 from sqlalchemy.orm import util as ormutil
 from sqlalchemy.ext.sessioncontext import SessionContext
 import sqlalchemy.ext.assignmapper as assignmapper
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+from testlib import tables
 
 """tests unitofwork operations"""
 
@@ -28,6 +27,7 @@ class UnitOfWorkTest(AssertMixin):
 
 class HistoryTest(UnitOfWorkTest):
     def setUpAll(self):
+        tables.metadata.bind = testbase.db
         UnitOfWorkTest.setUpAll(self)
         users.create()
         addresses.create()
@@ -63,7 +63,7 @@ class VersioningTest(UnitOfWorkTest):
         UnitOfWorkTest.setUpAll(self)
         ctx.current.clear()
         global version_table
-        version_table = Table('version_test', MetaData(db),
+        version_table = Table('version_test', MetaData(testbase.db),
         Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
         Column('version_id', Integer, nullable=False),
         Column('value', String(40), nullable=False)
@@ -255,9 +255,9 @@ class MutableTypesTest(UnitOfWorkTest):
         ctx.current.flush()
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
         f1.value = unicode('someothervalue')
-        self.assert_sql(db, lambda: ctx.current.flush(), [
+        self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
             (
                 "UPDATE mutabletest SET value=:value WHERE mutabletest.id = :mutabletest_id",
                 {'mutabletest_id': f1.id, 'value': u'someothervalue'}
@@ -265,7 +265,7 @@ class MutableTypesTest(UnitOfWorkTest):
         ])
         f1.value = unicode('hi')
         f1.data.x = 9
-        self.assert_sql(db, lambda: ctx.current.flush(), [
+        self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
             (
                 "UPDATE mutabletest SET data=:data, value=:value WHERE mutabletest.id = :mutabletest_id",
                 {'mutabletest_id': f1.id, 'value': u'hi', 'data':f1.data}
@@ -283,7 +283,7 @@ class MutableTypesTest(UnitOfWorkTest):
         
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
         
         ctx.current.clear()
 
@@ -291,12 +291,12 @@ class MutableTypesTest(UnitOfWorkTest):
 
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
 
         f2.data.y = 19
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 1)
+        self.assert_sql_count(testbase.db, go, 1)
         
         ctx.current.clear()
         f3 = ctx.current.query(Foo).get_by(id=f1.id)
@@ -305,7 +305,7 @@ class MutableTypesTest(UnitOfWorkTest):
 
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
         
     def testunicode(self):
         """test that two equivalent unicode values dont get flagged as changed.
@@ -322,14 +322,14 @@ class MutableTypesTest(UnitOfWorkTest):
         f1.value = u'hi'
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
         
         
 class PKTest(UnitOfWorkTest):
     def setUpAll(self):
         UnitOfWorkTest.setUpAll(self)
         global table, table2, table3, metadata
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         table = Table(
             'multipk', metadata, 
             Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True),
@@ -357,7 +357,7 @@ class PKTest(UnitOfWorkTest):
         
     # not support on sqlite since sqlite's auto-pk generation only works with
     # single column primary keys    
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testprimarykey(self):
         class Entry(object):
             pass
@@ -479,7 +479,7 @@ class PassiveDeletesTest(UnitOfWorkTest):
         metadata.drop_all()
         UnitOfWorkTest.tearDownAll(self)
 
-    @testbase.unsupported('sqlite')
+    @testing.unsupported('sqlite')
     def testbasic(self):
         class MyClass(object):
             pass
@@ -516,6 +516,7 @@ class DefaultTest(UnitOfWorkTest):
     defaults back from the engine."""
     def setUpAll(self):
         UnitOfWorkTest.setUpAll(self)
+        db = testbase.db
         use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
 
         if use_string_defaults:
@@ -610,7 +611,7 @@ class OneToManyTest(UnitOfWorkTest):
         a2 = Address()
         a2.email_address = 'lala@test.org'
         u.addresses.append(a2)
-        self.echo( repr(u.addresses))
+        print repr(u.addresses)
         ctx.current.flush()
 
         usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall()
@@ -660,7 +661,7 @@ class OneToManyTest(UnitOfWorkTest):
         u2.user_name = 'user2modified'
         u1.addresses.append(a3)
         del u1.addresses[0]
-        self.assert_sql(db, lambda: ctx.current.flush(), 
+        self.assert_sql(testbase.db, lambda: ctx.current.flush(), 
                 [
                     (
                         "UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id",
@@ -832,7 +833,7 @@ class SaveTest(UnitOfWorkTest):
 
         # assert the first one retreives the same from the identity map
         nu = ctx.current.get(m, u.user_id)
-        self.echo( "U: " + repr(u) + "NU: " + repr(nu))
+        print "U: " + repr(u) + "NU: " + repr(nu)
         self.assert_(u is nu)
         
         # clear out the identity map, so next get forces a SELECT
@@ -913,7 +914,7 @@ class SaveTest(UnitOfWorkTest):
         u.user_name = ""
         def go():
             ctx.current.flush()
-        self.assert_sql_count(db, go, 0)
+        self.assert_sql_count(testbase.db, go, 0)
 
     def testmultitable(self):
         """tests a save of an object where each instance spans two tables. also tests
@@ -1039,7 +1040,7 @@ class ManyToOneTest(UnitOfWorkTest):
         objects[2].email_address = 'imnew@foo.bar'
         objects[3].user = User()
         objects[3].user.user_name = 'imnewlyadded'
-        self.assert_sql(db, lambda: ctx.current.flush(), [
+        self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
                 (
                     "INSERT INTO users (user_name) VALUES (:user_name)",
                     {'user_name': 'imnewlyadded'}
@@ -1209,7 +1210,7 @@ class ManyToManyTest(UnitOfWorkTest):
         k = Keyword()
         k.name = 'yellow'
         objects[5].keywords.append(k)
-        self.assert_sql(db, lambda:ctx.current.flush(), [
+        self.assert_sql(testbase.db, lambda:ctx.current.flush(), [
             {
                 "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
                 {'item_name': 'item4updated', 'items_item_id': objects[4].item_id}
@@ -1238,7 +1239,7 @@ class ManyToManyTest(UnitOfWorkTest):
         objects[2].keywords.append(k)
         dkid = objects[5].keywords[1].keyword_id
         del objects[5].keywords[1]
-        self.assert_sql(db, lambda:ctx.current.flush(), [
+        self.assert_sql(testbase.db, lambda:ctx.current.flush(), [
                 (
                     "DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id",
                     [{'item_id': objects[5].item_id, 'keyword_id': dkid}]
@@ -1423,7 +1424,7 @@ class SaveTest2(UnitOfWorkTest):
         ctx.current.clear()
         clear_mappers()
         global meta, users, addresses
-        meta = MetaData(db)
+        meta = MetaData(testbase.db)
         users = Table('users', meta,
             Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
             Column('user_name', String(20)),
@@ -1455,7 +1456,7 @@ class SaveTest2(UnitOfWorkTest):
             a.user = User()
             a.user.user_name = elem['user_name']
             objects.append(a)
-        self.assert_sql(db, lambda: ctx.current.flush(), [
+        self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
                 (
                     "INSERT INTO users (user_name) VALUES (:user_name)",
                     {'user_name': 'thesub'}
@@ -1494,30 +1495,32 @@ class SaveTest2(UnitOfWorkTest):
                         ]
         )
 
-class SaveTest3(UnitOfWorkTest):
 
+class SaveTest3(UnitOfWorkTest):
     def setUpAll(self):
+        global st3_metadata, t1, t2, t3
+
         UnitOfWorkTest.setUpAll(self)
-        global metadata, t1, t2, t3
-        metadata = testbase.metadata
-        t1 = Table('items', metadata,
+
+        st3_metadata = MetaData(testbase.db)
+        t1 = Table('items', st3_metadata,
             Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
             Column('item_name', VARCHAR(50)),
         )
 
-        t3 = Table('keywords', metadata,
+        t3 = Table('keywords', st3_metadata,
             Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True),
             Column('name', VARCHAR(50)),
 
         )
-        t2 = Table('assoc', metadata,
+        t2 = Table('assoc', st3_metadata,
             Column('item_id', INT, ForeignKey("items")),
             Column('keyword_id', INT, ForeignKey("keywords")),
             Column('foo', Boolean, default=True)
         )
-        metadata.create_all()
+        st3_metadata.create_all()
     def tearDownAll(self):
-        metadata.drop_all()
+        st3_metadata.drop_all()
         UnitOfWorkTest.tearDownAll(self)
 
     def setUp(self):
index dd095ab9aa1c8963705f58ecaf3563114971cc05..34d046381fe2be20c5eaee96d1768533f26c7567 100644 (file)
@@ -1,7 +1,7 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
 from timeit import Timer
 import sys
 
index e603e2c00258a0315c6ec12073d137c7e2c801f8..346a725e3544db395158f2dfa2eb59a08454fab8 100644 (file)
@@ -1,8 +1,7 @@
 # times how long it takes to create 26000 objects
-import sys
-sys.path.insert(0, './lib/')
+import testbase
 
-from sqlalchemy.attributes import *
+from sqlalchemy.orm.attributes import *
 import time
 import gc
 
index 3a68f3612d62b1cacd8733951cc709d286008918..2e29a63272f97b143fc83a8e38b61c4b0185284d 100644 (file)
@@ -1,11 +1,9 @@
-import sys
-sys.path.insert(0, './lib/')
-
+import testbase
 import gc
 
 import random, string
 
-from sqlalchemy.attributes import *
+from sqlalchemy.orm.attributes import *
 
 # with this test, run top.  make sure the Python process doenst grow in size arbitrarily.
 
index c2c0933a59300e5c93c424e7fd813eca573fe417..f1c0f292b0e5f75d17fa027bf9f6dc571431394f 100644 (file)
@@ -1,64 +1,54 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
+import testbase
+import hotshot, hotshot.stats
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testbase import Table, Column
-import StringIO
-import testbase
-import gc
-import time
-import hotshot
-import hotshot.stats
-
-db = testbase.db
+from testlib import *
 
 NUM = 500
 DIVISOR = 50
 
-class LoadTest(AssertMixin):
-    def setUpAll(self):
-        global items, meta,subitems
-        meta = MetaData(db)
-        items = Table('items', meta, 
-            Column('item_id', Integer, primary_key=True),
-            Column('value', String(100)))
-        subitems = Table('subitems', meta, 
-            Column('sub_id', Integer, primary_key=True),
-            Column('parent_id', Integer, ForeignKey('items.item_id')),
-            Column('value', String(100)))
-        meta.create_all()
-    def tearDownAll(self):
-        meta.drop_all()
-    def setUp(self):
-        clear_mappers()
+meta = MetaData(testbase.db)
+items = Table('items', meta, 
+              Column('item_id', Integer, primary_key=True),
+              Column('value', String(100)))
+subitems = Table('subitems', meta, 
+                 Column('sub_id', Integer, primary_key=True),
+                 Column('parent_id', Integer, ForeignKey('items.item_id')),
+                 Column('value', String(100)))
+
+class Item(object):pass
+class SubItem(object):pass
+mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)})
+mapper(SubItem, subitems)
+
+def load():
+    global l
+    l = []
+    for x in range(1,NUM/DIVISOR + 1):
+        l.append({'item_id':x, 'value':'this is item #%d' % x})
+    #print l
+    items.insert().execute(*l)
+    for x in range(1, NUM/DIVISOR + 1):
         l = []
-        for x in range(1,NUM/DIVISOR + 1):
-            l.append({'item_id':x, 'value':'this is item #%d' % x})
+        for y in range(1, DIVISOR + 1):
+            z = ((x-1) * DIVISOR) + y
+            l.append({'sub_id':z,'value':'this is item #%d' % z, 'parent_id':x})
         #print l
-        items.insert().execute(*l)
-        for x in range(1, NUM/DIVISOR + 1):
-            l = []
-            for y in range(1, DIVISOR + 1):
-                z = ((x-1) * DIVISOR) + y
-                l.append({'sub_id':z,'value':'this is iteim #%d' % z, 'parent_id':x})
-            #print l
-            subitems.insert().execute(*l)    
-    def testload(self):
-        class Item(object):pass
-        class SubItem(object):pass
-        mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)})
-        mapper(SubItem, subitems)
-        sess = create_session()
-        prof = hotshot.Profile("masseagerload.prof")
-        prof.start()
-        query = sess.query(Item)
-        l = query.select()
-        print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
-        prof.stop()
-        prof.close()
-        stats = hotshot.stats.load("masseagerload.prof")
-        stats.sort_stats('time', 'calls')
-        stats.print_stats()
-        
-if __name__ == "__main__":
-    testbase.main()        
+        subitems.insert().execute(*l)    
+
+@profiling.profiled('masseagerload', always=True)
+def masseagerload(session):
+    query = session.query(Item)
+    l = query.select()
+    print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
+
+def all():
+    meta.create_all()
+    try:
+        load()
+        masseagerload(create_session())
+    finally:
+        meta.drop_all()
+
+if __name__ == '__main__':
+    all()
index 02b847599dc9790d5fa62cb11c6a94ab97742f69..92cf0fe9200ef2aece67abf478e8051d42e02bc7 100644 (file)
@@ -1,14 +1,10 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import sqlalchemy.orm.attributes as attributes
-from testbase import Table, Column
-import StringIO
 import testbase
-import gc
 import time
-
-db = testbase.db
+#import gc
+#import sqlalchemy.orm.attributes as attributes
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
 
 NUM = 2500
 
@@ -21,7 +17,7 @@ for best results, dont run with sqlite :memory: database, and keep an eye on top
 class LoadTest(AssertMixin):
     def setUpAll(self):
         global items, meta
-        meta = MetaData(db)
+        meta = MetaData(testbase.db)
         items = Table('items', meta, 
             Column('item_id', Integer, primary_key=True),
             Column('value', String(100)))
@@ -29,8 +25,6 @@ class LoadTest(AssertMixin):
     def tearDownAll(self):
         items.drop()
     def setUp(self):
-        objectstore.clear()
-        clear_mappers()
         for x in range(1,NUM/500+1):
             l = []
             for y in range(x*500-500 + 1, x*500 + 1):
index 53f13119ed86234b5a079b18ee3a8a714e5559c5..dd03f3962918edb0fef91c46b440c799f7f044d5 100644 (file)
@@ -1,21 +1,16 @@
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import sqlalchemy.attributes as attributes
-from testbase import Table, Column
-import StringIO
 import testbase
-import gc
-import sqlalchemy.orm.session
 import types
-db = testbase.db
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
 
 NUM = 250000
 
 class SaveTest(AssertMixin):
     def setUpAll(self):
         global items, metadata
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         items = Table('items', metadata, 
             Column('item_id', Integer, primary_key=True),
             Column('value', String(100)))
diff --git a/test/perf/ormsession.py b/test/perf/ormsession.py
new file mode 100644 (file)
index 0000000..e41ec06
--- /dev/null
@@ -0,0 +1,196 @@
+import testbase
+import time
+from datetime import datetime
+
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.profiling import profiled
+
+class Item(object):
+    def __repr__(self):
+        return 'Item<#%s "%s">' % (self.id, self.name)
+class SubItem(object):
+    def __repr__(self):
+        return 'SubItem<#%s "%s">' % (self.id, self.name)
+class Customer(object):
+    def __repr__(self):
+        return 'Customer<#%s "%s">' % (self.id, self.name)
+class Purchase(object):
+    def __repr__(self):
+        return 'Purchase<#%s "%s">' % (self.id, self.purchase_date)
+
+items, subitems, customers, purchases, purchaseitems = \
+    None, None, None, None, None
+
+metadata = MetaData()
+
+@profiled('table')
+def define_tables():
+    global items, subitems, customers, purchases, purchaseitems
+    items = Table('items', metadata,
+                  Column('id', Integer, primary_key=True),
+                  Column('name', String(100)),
+                  test_needs_acid=True)
+    subitems = Table('subitems', metadata,
+                     Column('id', Integer, primary_key=True),
+                     Column('item_id', Integer, ForeignKey('items.id'),
+                            nullable=False),
+                     Column('name', String(100), PassiveDefault('no name')),
+                     test_needs_acid=True)
+    customers = Table('customers', metadata,
+                      Column('id', Integer, primary_key=True),
+                      Column('name', String(100)),
+                      *[Column("col_%s" % chr(i), String(64), default=str(i))
+                        for i in range(97,117)],
+                      **dict(test_needs_acid=True))
+    purchases = Table('purchases', metadata,
+                      Column('id', Integer, primary_key=True),
+                      Column('customer_id', Integer,
+                             ForeignKey('customers.id'), nullable=False),
+                      Column('purchase_date', DateTime,
+                             default=datetime.now),
+                      test_needs_acid=True)
+    purchaseitems = Table('purchaseitems', metadata,
+                      Column('purchase_id', Integer,
+                             ForeignKey('purchases.id'),
+                             nullable=False, primary_key=True),
+                      Column('item_id', Integer, ForeignKey('items.id'),
+                             nullable=False, primary_key=True),
+                      test_needs_acid=True)
+
+@profiled('mapper')
+def setup_mappers():
+    mapper(Item, items, properties={
+            'subitems': relation(SubItem, backref='item', lazy=True)
+            })
+    mapper(SubItem, subitems)
+    mapper(Customer, customers, properties={
+            'purchases': relation(Purchase, lazy=True, backref='customer')
+            })
+    mapper(Purchase, purchases, properties={
+            'items': relation(Item, lazy=True, secondary=purchaseitems)
+            })
+
+@profiled('inserts')
+def insert_data():
+    q_items = 1000
+    q_sub_per_item = 10
+    q_customers = 1000
+
+    con = testbase.db.connect()
+
+    transaction = con.begin()
+    data, subdata = [], []
+    for item_id in xrange(1, q_items + 1):
+        data.append({'name': "item number %s" % item_id})
+        for subitem_id in xrange(1, (item_id % q_sub_per_item) + 1):
+            subdata.append({'item_id': item_id,
+                         'name': "subitem number %s" % subitem_id})
+        if item_id % 100 == 0:
+            items.insert().execute(*data)
+            subitems.insert().execute(*subdata)
+            del data[:]
+            del subdata[:]
+    if data:
+        items.insert().execute(*data)
+    if subdata:
+        subitems.insert().execute(*subdata)
+    transaction.commit()
+
+    transaction = con.begin()
+    data = []
+    for customer_id in xrange(1, q_customers):
+        data.append({'name': "customer number %s" % customer_id})
+        if customer_id % 100 == 0:
+            customers.insert().execute(*data)
+            del data[:]
+    if data:
+        customers.insert().execute(*data)
+    transaction.commit()
+
+    transaction = con.begin()
+    data, subdata = [], []
+    order_t = int(time.time()) - (5000 * 5 * 60)
+    current = xrange(1, q_customers)
+    step, purchase_id = 1, 0
+    while current:
+        next = []
+        for customer_id in current:
+            order_t += 300
+            data.append({'customer_id': customer_id,
+                         'purchase_date': datetime.fromtimestamp(order_t)})
+            purchase_id += 1
+            for item_id in range(customer_id % 200, customer_id + 1, 200):
+                if item_id != 0:
+                    subdata.append({'purchase_id': purchase_id,
+                                    'item_id': item_id})
+            if customer_id % 10 > step:
+                next.append(customer_id)
+
+            if len(data) >= 100:
+                purchases.insert().execute(*data)
+                if subdata:
+                    purchaseitems.insert().execute(*subdata)
+                del data[:]
+                del subdata[:]
+        step, current = step + 1, next
+
+    if data:
+        purchases.insert().execute(*data)
+    if subdata:
+        purchaseitems.insert().execute(*subdata)
+    transaction.commit()
+
+@profiled('queries')
+def run_queries():
+    session = create_session()
+    # no explicit transaction here.
+    
+    # build a report of summarizing the last 50 purchases and 
+    # the top 20 items from all purchases
+
+    q = session.query(Purchase). \
+        limit(50).order_by(desc(Purchase.purchase_date)). \
+        options(eagerload('items'), eagerload('items.subitems'),
+                eagerload('customer'))
+
+    report = []
+    # "write" the report.  pretend it's going to a web template or something,
+    # the point is to actually pull data through attributes and collections.
+    for purchase in q:
+        report.append(purchase.customer.name)
+        report.append(purchase.customer.col_a)
+        report.append(purchase.purchase_date)
+        for item in purchase.items:
+            report.append(item.name)
+            report.extend([s.name for s in item.subitems])
+    
+    # pull a report of the top 20 items of all time
+    _item_id = purchaseitems.c.item_id
+    top_20_q = select([func.distinct(_item_id).label('id')],
+                      group_by=[purchaseitems.c.purchase_id, _item_id],
+                      order_by=[desc(func.count(_item_id)), _item_id],
+                      limit=20)
+    q2 = session.query(Item).filter(Item.id.in_(top_20_q))
+
+    for num, item in enumerate(q2):
+        report.append("number %s: %s" % (num + 1, item.name))
+
+def setup_db():
+    metadata.drop_all()
+    metadata.create_all()
+def cleanup_db():
+    metadata.drop_all()
+
+@profiled('all')
+def main():
+    metadata.bind = testbase.db
+    define_tables()
+    setup_mappers()
+    setup_db()
+    insert_data()
+    run_queries()
+    cleanup_db()
+
+main()
index d096f1c67ffa5e74dda353f275fff6b938b1948f..1a2ff6978b75fc6991917105d7343cdbcd6f0cf1 100644 (file)
@@ -1,10 +1,11 @@
 # load test of connection pool
 
+import testbase
 from sqlalchemy import *
 import sqlalchemy.pool as pool
 import thread,time
-db = create_engine('mysql://scott:tiger@127.0.0.1/test', pool_timeout=30, echo_pool=True)
 
+db = create_engine(testbase.db.url, pool_timeout=30, echo_pool=True)
 metadata = MetaData(db)
 
 users_table = Table('users', metadata,
@@ -18,7 +19,7 @@ users_table.insert().execute([{'user_name':'user#%d' % i, 'password':'pw#%d' % i
 
 def runfast():
     while True:
-        c = db.connection_provider._pool.connect()
+        c = db.pool.connect()
         time.sleep(.5)
         c.close()
 #        result = users_table.select(limit=100).execute()
index 38fe145cd1b040aee2f58640f87385630c2864c0..13ec31fd61dea2bde383ecd161b5ab0cbee0adb9 100644 (file)
@@ -2,11 +2,12 @@
 when additional mappers are created while the existing 
 collection is being compiled."""
 
+import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 import thread, time
 from sqlalchemy.orm import mapperlib
-from testbase import Table, Column
+from testlib import *
 
 meta = MetaData('sqlite:///foo.db')
 
index e40171f07fa90b47d960103ac1d06f5a54c3a5d7..d22eeb76a00a93ae2b89a375d043e07e523fa291 100644 (file)
@@ -1,54 +1,55 @@
 #!/usr/bin/python
+"""Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop."""
 
+import testbase
 from sqlalchemy import *
-import sqlalchemy.pool as pool
+from sqlalchemy.orm import *
 import thread
-from sqlalchemy import exceptions
-from testbase import Table, Column
+from testlib import *
+
+port = 8000
 
 import logging
 logging.basicConfig()
 logging.getLogger('sqlalchemy.pool').setLevel(logging.INFO)
 
 threadids = set()
-#meta = MetaData('postgres://scott:tiger@127.0.0.1/test')
-
-#meta = MetaData('mysql://scott:tiger@localhost/test', poolclass=pool.SingletonThreadPool)
-meta = MetaData('mysql://scott:tiger@localhost/test')
+meta = MetaData(testbase.db)
 foo = Table('foo', meta, 
     Column('id', Integer, primary_key=True),
     Column('data', String(30)))
-
-meta.drop_all()
-meta.create_all()
-
-data = []
-for x in range(1,500):
-    data.append({'id':x,'data':"this is x value %d" % x})
-foo.insert().execute(data)
-
 class Foo(object):
     pass
-
 mapper(Foo, foo)
 
-root = './'
-port = 8000
+def prep():
+    meta.drop_all()
+    meta.create_all()
+
+    data = []
+    for x in range(1,500):
+        data.append({'id':x,'data':"this is x value %d" % x})
+    foo.insert().execute(data)
 
 def serve(environ, start_response):
+    start_response("200 OK", [('Content-type', 'text/plain')])
     sess = create_session()
     l = sess.query(Foo).select()
-            
-    start_response("200 OK", [('Content-type','text/plain')])
     threadids.add(thread.get_ident())
-    print "sending response on thread", thread.get_ident(), " total threads ", len(threadids)
-    return ["\n".join([x.data for x in l])]
+
+    print ("sending response on thread", thread.get_ident(),
+           " total threads ", len(threadids))
+    return [str("\n".join([x.data for x in l]))]
 
         
 if __name__ == '__main__':
-    from wsgiutils import wsgiServer
-    server = wsgiServer.WSGIServer (('localhost', port), {'/': serve})
-    print "Server listening on port %d" % port
-    server.serve_forever()
+    from wsgiref import simple_server
+    try:
+        prep()
+        server = simple_server.make_server('localhost', port, serve)
+        print "Server listening on port %d" % port
+        server.serve_forever()
+    finally:
+        meta.drop_all()
 
 
index ebb3fe34c612b8417dbef3a0d67705e4aa9ecabc..a669a25f2df6f396634e5ac9996806efaf3fcd24 100644 (file)
@@ -32,7 +32,5 @@ def suite():
         alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
     return alltests
 
-
-
 if __name__ == '__main__':
     testbase.main(suite())
index bcf8849644f36366f7970d9714a64c8ac8ae556d..493545b228b7cd043c8f2b6c6550bbb83ec25f61 100644 (file)
@@ -1,10 +1,10 @@
-import sys
 import testbase
+import sys
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
 
 
-class CaseTest(testbase.PersistTest):
+class CaseTest(PersistTest):
 
     def setUpAll(self):
         metadata = MetaData(testbase.db)
index b5f1a17414fab23ff3007917a933735d37a64d83..1c2bd1b57c20e417b59d6ceb5546fbe734507832 100644 (file)
@@ -1,9 +1,8 @@
 import testbase
 from sqlalchemy import *
-from testbase import Table, Column
-import sys
+from testlib import *
 
-class ConstraintTest(testbase.AssertMixin):
+class ConstraintTest(AssertMixin):
     
     def setUp(self):
         global metadata
@@ -53,7 +52,7 @@ class ConstraintTest(testbase.AssertMixin):
             )
         metadata.create_all()
         
-    @testbase.unsupported('mysql')
+    @testing.unsupported('mysql')
     def test_check_constraint(self):
         foo = Table('foo', metadata, 
             Column('id', Integer, primary_key=True),
index a9dd2f5ad29e0a9e42af97e00f341c22530593b0..5cbdc3e3fb3151fdc1579f53817968b7b4bfe126 100644 (file)
@@ -1,20 +1,17 @@
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest, sys, os
-import sqlalchemy.schema as schema
 import testbase
 from sqlalchemy import *
+import sqlalchemy.util as util
+import sqlalchemy.schema as schema
 from sqlalchemy.orm import mapper, create_session
-from testbase import Table, Column
-import sqlalchemy
-
-db = testbase.db
+from testlib import *
 
 class DefaultTest(PersistTest):
 
     def setUpAll(self):
         global t, f, f2, ts, currenttime, metadata
-        metadata = MetaData(testbase.db)
+
+        db = testbase.db
+        metadata = MetaData(db)
         x = {'x':50}
         def mydefault():
             x['x'] += 1
@@ -80,7 +77,7 @@ class DefaultTest(PersistTest):
         t.delete().execute()
         
     def teststandalone(self):
-        c = db.engine.contextual_connect()
+        c = testbase.db.engine.contextual_connect()
         x = c.execute(t.c.col1.default)
         y = t.c.col2.default.execute()
         z = c.execute(t.c.col3.default)
@@ -97,7 +94,7 @@ class DefaultTest(PersistTest):
         t.insert().execute()
 
         ctexec = currenttime.scalar()
-        self.echo("Currenttime "+ repr(ctexec))
+        print "Currenttime "+ repr(ctexec)
         l = t.select().execute()
         self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)])
 
@@ -112,7 +109,7 @@ class DefaultTest(PersistTest):
         pk = r.last_inserted_ids()[0]
         t.update(t.c.col1==pk).execute(col4=None, col5=None)
         ctexec = currenttime.scalar()
-        self.echo("Currenttime "+ repr(ctexec))
+        print "Currenttime "+ repr(ctexec)
         l = t.select(t.c.col1==pk).execute()
         l = l.fetchone()
         self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False))
@@ -127,7 +124,7 @@ class DefaultTest(PersistTest):
         l = l.fetchone()
         self.assert_(l['col3'] == 55)
 
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def testpassiveoverride(self):
         """primarily for postgres, tests that when we get a primary key column back 
         from reflecting a table which has a default value on it, we pre-execute
@@ -155,7 +152,7 @@ class DefaultTest(PersistTest):
             testbase.db.execute("drop table speedy_users", None)
 
 class AutoIncrementTest(PersistTest):
-    @testbase.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql')
     def testnonautoincrement(self):
         meta = MetaData(testbase.db)
         nonai_table = Table("aitest", meta, 
@@ -219,7 +216,7 @@ class AutoIncrementTest(PersistTest):
         
 
 class SequenceTest(PersistTest):
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def setUpAll(self):
         global cartitems, sometable, metadata
         metadata = MetaData(testbase.db)
@@ -236,7 +233,7 @@ class SequenceTest(PersistTest):
         
         metadata.create_all()
     
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def testseqnonpk(self):
         """test sequences fire off as defaults on non-pk columns"""
         sometable.insert().execute(name="somename")
@@ -246,7 +243,7 @@ class SequenceTest(PersistTest):
             (2, "someother", 2),
         ]
         
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def testsequence(self):
         cartitems.insert().execute(description='hi')
         cartitems.insert().execute(description='there')
@@ -255,7 +252,7 @@ class SequenceTest(PersistTest):
         cartitems.select().execute().fetchall()
    
    
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def test_implicit_sequence_exec(self):
         s = Sequence("my_sequence", metadata=MetaData(testbase.db))
         s.create()
@@ -265,7 +262,7 @@ class SequenceTest(PersistTest):
         finally:
             s.drop()
 
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def teststandalone_explicit(self):
         s = Sequence("my_sequence")
         s.create(bind=testbase.db)
@@ -275,7 +272,7 @@ class SequenceTest(PersistTest):
         finally:
             s.drop(testbase.db)
     
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def test_checkfirst(self):
         s = Sequence("my_sequence")
         s.create(testbase.db, checkfirst=False)
@@ -283,12 +280,12 @@ class SequenceTest(PersistTest):
         s.drop(testbase.db, checkfirst=False)
         s.drop(testbase.db, checkfirst=True)
         
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def teststandalone2(self):
         x = cartitems.c.cart_id.sequence.execute()
         self.assert_(1 <= x <= 4)
         
-    @testbase.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle')
     def tearDownAll(self): 
         metadata.drop_all()
 
index 82f3b175b160f328b92e0f8dff5038fd839b7927..357a66fcdfc16e478b84f59246eb0ff1adcb947e 100644 (file)
@@ -1,9 +1,9 @@
 import testbase
 from sql import select as selecttests
-
 from sqlalchemy import *
+from testlib import *
 
-class TraversalTest(testbase.AssertMixin):
+class TraversalTest(AssertMixin):
     """test ClauseVisitor's traversal, particularly its ability to copy and modify
     a ClauseElement in place."""
     
@@ -269,10 +269,7 @@ class SelectTest(selecttests.SQLTest):
         s = select([t2], t2.c.col1==t1.c.col1, correlate=False)
         s = s.correlate(t1).order_by(t2.c.col3)
         self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3")
-        
-        
-        
-        
-        
+
+
 if __name__ == '__main__':
-    testbase.main()        
\ No newline at end of file
+    testbase.main()        
index 6029c83088ac02df99f99e2fcf238a178dfe5fd2..5cca4160e50b0cfd26e829fa21a23da62f6b8c34 100644 (file)
@@ -1,11 +1,12 @@
 import testbase
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
+
 
 # TODO: either create a mock dialect with named paramstyle and a short identifier length,
 # or find a way to just use sqlite dialect and make those changes
 
-class LabelTypeTest(testbase.PersistTest):
+class LabelTypeTest(PersistTest):
     def test_type(self):
         m = MetaData()
         t = Table('sometable', m, 
@@ -14,7 +15,7 @@ class LabelTypeTest(testbase.PersistTest):
         assert isinstance(t.c.col1.label('hi').type, Integer)
         assert isinstance(select([t.c.col2], scalar=True).label('lala').type, Float)
 
-class LongLabelsTest(testbase.PersistTest):
+class LongLabelsTest(PersistTest):
     def setUpAll(self):
         global metadata, table1, maxlen
         metadata = MetaData(testbase.db)
index 05b6d0419dbf2120ef56c58179683eb125c54f25..61734d5d92a5653f267f4a4f5b162801fb71f7a5 100644 (file)
@@ -1,14 +1,8 @@
-from testbase import PersistTest
 import testbase
-import unittest, sys, datetime
-
-import sqlalchemy.databases.sqlite as sqllite
-
-import tables
+import datetime
 from sqlalchemy import *
-from sqlalchemy.engine import ResultProxy, RowProxy
 from sqlalchemy import exceptions
-from testbase import Table, Column
+from testlib import *
 
 
 class QueryTest(PersistTest):
@@ -189,7 +183,7 @@ class QueryTest(PersistTest):
         r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall()
         self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r))
         
-    @testbase.unsupported('mssql')
+    @testing.unsupported('mssql')
     def testselectlimitoffset(self):
         users.insert().execute(user_id=1, user_name='john')
         users.insert().execute(user_id=2, user_name='jack')
@@ -203,7 +197,7 @@ class QueryTest(PersistTest):
         r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall()
         self.assert_(r==[(6, 'ralph'), (7, 'fido')])
         
-    @testbase.supported('mssql')
+    @testing.supported('mssql')
     def testselectlimitoffset_mssql(self):
         try:
             r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
@@ -211,7 +205,7 @@ class QueryTest(PersistTest):
         except exceptions.InvalidRequestError:
             pass
 
-    @testbase.unsupported('mysql')  
+    @testing.unsupported('mysql')  
     def test_scalar_select(self):
         """test that scalar subqueries with labels get their type propigated to the result set."""
         # mysql and/or mysqldb has a bug here, type isnt propigated for scalar subquery.
@@ -346,7 +340,7 @@ class QueryTest(PersistTest):
         finally:
             meta.drop_all()
             
-    @testbase.supported('postgres')
+    @testing.supported('postgres')
     def test_functions_with_cols(self):
         # TODO: shouldnt this work on oracle too ?
         x = testbase.db.func.current_date().execute().scalar()
@@ -379,7 +373,7 @@ class QueryTest(PersistTest):
         self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
         self.assertEqual(r.values(), ['foo', 1])
     
-    @testbase.unsupported('oracle', 'firebird') 
+    @testing.unsupported('oracle', 'firebird') 
     def test_column_accessor_shadow(self):
         meta = MetaData(testbase.db)
         shadowed = Table('test_shadowed', meta,
@@ -409,7 +403,7 @@ class QueryTest(PersistTest):
         finally:
             shadowed.drop(checkfirst=True)
     
-    @testbase.supported('mssql')
+    @testing.supported('mssql')
     def test_fetchid_trigger(self):
         meta = MetaData(testbase.db)
         t1 = Table('t1', meta,
@@ -435,7 +429,7 @@ class QueryTest(PersistTest):
             con.execute("""drop trigger paj""")
             meta.drop_all()
     
-    @testbase.supported('mssql')
+    @testing.supported('mssql')
     def test_insertid_schema(self):
         meta = MetaData(testbase.db)
         con = testbase.db.connect()
@@ -448,7 +442,7 @@ class QueryTest(PersistTest):
             tbl.drop()
             con.execute('drop schema paj')
 
-    @testbase.supported('mssql')
+    @testing.supported('mssql')
     def test_insertid_reserved(self):
         meta = MetaData(testbase.db)
         table = Table(
@@ -567,7 +561,7 @@ class CompoundTest(PersistTest):
         assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         
-    @testbase.unsupported('mysql')
+    @testing.unsupported('mysql')
     def test_intersect(self):
         i = intersect(
             select([t2.c.col3, t2.c.col4]),
@@ -576,7 +570,7 @@ class CompoundTest(PersistTest):
         assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
 
-    @testbase.unsupported('mysql', 'oracle')
+    @testing.unsupported('mysql', 'oracle')
     def test_except_style1(self):
         e = except_(union(
             select([t1.c.col3, t1.c.col4]),
@@ -585,7 +579,7 @@ class CompoundTest(PersistTest):
         ), select([t2.c.col3, t2.c.col4]))
         assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
-    @testbase.unsupported('mysql', 'oracle')
+    @testing.unsupported('mysql', 'oracle')
     def test_except_style2(self):
         e = except_(union(
             select([t1.c.col3, t1.c.col4]),
@@ -595,7 +589,7 @@ class CompoundTest(PersistTest):
         assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
         assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
-    @testbase.unsupported('sqlite', 'mysql', 'oracle')
+    @testing.unsupported('sqlite', 'mysql', 'oracle')
     def test_except_style3(self):
         # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
         e = except_(
@@ -607,7 +601,7 @@ class CompoundTest(PersistTest):
         )
         self.assertEquals(e.execute().fetchall(), [('ccc',)])
 
-    @testbase.unsupported('sqlite', 'mysql', 'oracle')
+    @testing.unsupported('sqlite', 'mysql', 'oracle')
     def test_union_union_all(self):
         e = union_all(
             select([t1.c.col3]),
@@ -618,7 +612,7 @@ class CompoundTest(PersistTest):
         )
         self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)])
 
-    @testbase.unsupported('mysql')
+    @testing.unsupported('mysql')
     def test_composite(self):
         u = intersect(
             select([t2.c.col3, t2.c.col4]),
index 08cef8e5f9b3cf9887867cb528d11ec0dfa93f3f..2fdf9dba0c106517bf63e198eaa33b453d915e7b 100644 (file)
@@ -1,7 +1,6 @@
-from testbase import PersistTest
 import testbase
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
 
 
 class QuoteTest(PersistTest):
@@ -80,7 +79,7 @@ class QuoteTest(PersistTest):
         assert t1.c.UcCol.case_sensitive is False
         assert t2.c.normalcol.case_sensitive is False
    
-    @testbase.unsupported('oracle') 
+    @testing.unsupported('oracle') 
     def testlabels(self):
         """test the quoting of labels.
         
index a74f3e0f84ffe2cdef3fc41230570e943991a346..e0da96a81dd2fb8065925de7b82b1d7bf6e56dfb 100644 (file)
@@ -1,9 +1,9 @@
-from sqlalchemy import *
-from testbase import Table, Column
 import testbase
+from sqlalchemy import *
+from testlib import *
 
 
-class FoundRowsTest(testbase.AssertMixin):
+class FoundRowsTest(AssertMixin):
     """tests rowcount functionality"""
     def setUpAll(self):
         metadata = MetaData(testbase.db)
index 3d5996df9daf748e56e8337a686476a89e58ce43..18709f55d91328353dea2ee191b74802f99eb592 100644 (file)
@@ -1,9 +1,8 @@
-from testbase import PersistTest
 import testbase
+import re, operator
 from sqlalchemy import *
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
-from testbase import Table, Column
-import unittest, re, operator
+from testlib import *
 
 
 # the select test now tests almost completely with TableClause/ColumnClause objects,
@@ -55,7 +54,7 @@ addresses = table('addresses',
 class SQLTest(PersistTest):
     def runtest(self, clause, result, dialect = None, params = None, checkparams = None):
         c = clause.compile(parameters=params, dialect=dialect)
-        self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
+        print "\nSQL String:\n" + str(c) + repr(c.get_params())
         cc = re.sub(r'\n', '', str(c))
         self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'")
         if checkparams is not None:
index bcf70bd2062baefe4749d7cf6386fb557f39a258..dcc8550747b1b9f122aaf8dee9354575a66011cf 100755 (executable)
@@ -3,14 +3,10 @@ useable primary keys and foreign keys.  Full relational algebra depends on
 every selectable unit behaving nicely with others.."""\r
  \r
 import testbase\r
-import unittest, sys, datetime\r
 from sqlalchemy import *\r
-from testbase import Table, Column\r
-\r
-db = testbase.db\r
-metadata = MetaData(db)\r
-\r
+from testlib import *\r
 \r
+metadata = MetaData()\r
 table = Table('table1', metadata, \r
     Column('col1', Integer, primary_key=True),\r
     Column('col2', String(20)),\r
@@ -26,7 +22,7 @@ table2 = Table('table2', metadata,
     Column('coly', Integer),\r
 )\r
 \r
-class SelectableTest(testbase.AssertMixin):\r
+class SelectableTest(AssertMixin):\r
     def testdistance(self):\r
         s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])\r
 \r
@@ -172,7 +168,7 @@ class SelectableTest(testbase.AssertMixin):
         self.assert_(criterion.compare(j.onclause))\r
         \r
 \r
-class PrimaryKeyTest(testbase.AssertMixin):\r
+class PrimaryKeyTest(AssertMixin):\r
     def test_join_pk_collapse_implicit(self):\r
         """test that redundant columns in a join get 'collapsed' into a minimal primary key, \r
         which is the root column along a chain of foreign key relationships."""\r
index 59b9da388406f098836b3e451d6cc36fe2890e60..28e7db3a3ed88f45906effaed730d9aca328c14a 100644 (file)
@@ -1,16 +1,12 @@
-from testbase import PersistTest, AssertMixin
 import testbase
 import pickleable
+import datetime, os
 from sqlalchemy import *
-import string,datetime, re, sys, os
 import sqlalchemy.engine.url as url
-import sqlalchemy.types
 from sqlalchemy.databases import mssql, oracle, mysql
-from testbase import Table, Column
+from testlib import *
 
 
-db = testbase.db
-
 class MyType(types.TypeEngine):
     def get_col_spec(self):
         return "VARCHAR(100)"
@@ -108,7 +104,7 @@ class OverrideTest(PersistTest):
 
     def setUpAll(self):
         global users
-        users = Table('type_users', MetaData(db), 
+        users = Table('type_users', MetaData(testbase.db),
             Column('user_id', Integer, primary_key = True),
             # totall custom type
             Column('goofy', MyType, nullable = False),
@@ -139,6 +135,7 @@ class ColumnsTest(AssertMixin):
                             'float_column': 'float_column NUMERIC(25, 2)'
                           }
 
+        db = testbase.db
         if not db.name=='sqlite' and not db.name=='oracle':
             expectedResults['float_column'] = 'float_column FLOAT(25)'
     
@@ -158,7 +155,7 @@ class UnicodeTest(AssertMixin):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""
     def setUpAll(self):
         global unicode_table
-        metadata = MetaData(db)
+        metadata = MetaData(testbase.db)
         unicode_table = Table('unicode_table', metadata, 
             Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
             Column('unicode_varchar', Unicode(250)),
@@ -177,49 +174,49 @@ class UnicodeTest(AssertMixin):
                                        unicode_text=unicodedata,
                                        plain_varchar=rawdata)
         x = unicode_table.select().execute().fetchone()
-        self.echo(repr(x['unicode_varchar']))
-        self.echo(repr(x['unicode_text']))
-        self.echo(repr(x['plain_varchar']))
+        print repr(x['unicode_varchar'])
+        print repr(x['unicode_text'])
+        print repr(x['plain_varchar'])
         self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
         self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
         if isinstance(x['plain_varchar'], unicode):
             # SQLLite and MSSQL return non-unicode data as unicode
-            self.assert_(db.name in ('sqlite', 'mssql'))
+            self.assert_(testbase.db.name in ('sqlite', 'mssql'))
             self.assert_(x['plain_varchar'] == unicodedata)
-            self.echo("it's %s!" % db.name)
+            print "it's %s!" % testbase.db.name
         else:
             self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
 
     def testengineparam(self):
         """tests engine-wide unicode conversion"""
-        prev_unicode = db.engine.dialect.convert_unicode
+        prev_unicode = testbase.db.engine.dialect.convert_unicode
         try:
-            db.engine.dialect.convert_unicode = True
+            testbase.db.engine.dialect.convert_unicode = True
             rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
             unicodedata = rawdata.decode('utf-8')
             unicode_table.insert().execute(unicode_varchar=unicodedata,
                                            unicode_text=unicodedata,
                                            plain_varchar=rawdata)
             x = unicode_table.select().execute().fetchone()
-            self.echo(repr(x['unicode_varchar']))
-            self.echo(repr(x['unicode_text']))
-            self.echo(repr(x['plain_varchar']))
+            print repr(x['unicode_varchar'])
+            print repr(x['unicode_text'])
+            print repr(x['plain_varchar'])
             self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
             self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
             self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata)
         finally:
-            db.engine.dialect.convert_unicode = prev_unicode
+            testbase.db.engine.dialect.convert_unicode = prev_unicode
 
-    @testbase.unsupported('oracle')
+    @testing.unsupported('oracle')
     def testlength(self):
         """checks the database correctly understands the length of a unicode string"""
         teststr = u'aaa\x1234'
-        self.assert_(db.func.length(teststr).scalar() == len(teststr))
+        self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr))
   
 class BinaryTest(AssertMixin):
     def setUpAll(self):
         global binary_table
-        binary_table = Table('binary_table', MetaData(db), 
+        binary_table = Table('binary_table', MetaData(testbase.db), 
         Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
         Column('data', Binary),
         Column('data_slice', Binary(100)),
@@ -270,6 +267,7 @@ class DateTest(AssertMixin):
     def setUpAll(self):
         global users_with_date, insert_data
 
+        db = testbase.db
         if db.engine.name == 'oracle':
             import sqlalchemy.databases.oracle as oracle
             insert_data =  [
@@ -309,7 +307,8 @@ class DateTest(AssertMixin):
             collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)),
                            Column('user_date', Date), Column('user_time', Time)]
  
-        users_with_date = Table('query_users_with_date', MetaData(db), *collist)
+        users_with_date = Table('query_users_with_date',
+                                MetaData(testbase.db), *collist)
         users_with_date.create()
         insert_dicts = [dict(zip(fnames, d)) for d in insert_data]
 
@@ -327,7 +326,7 @@ class DateTest(AssertMixin):
 
 
     def testtextdate(self):     
-        x = db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall()
+        x = testbase.db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall()
         
         print repr(x)
         self.assert_(isinstance(x[0][0], datetime.datetime))
@@ -336,7 +335,11 @@ class DateTest(AssertMixin):
         #print repr(x)
 
     def testdate2(self):
-        t = Table('testdate', testbase.metadata, Column('id', Integer, Sequence('datetest_id_seq', optional=True), primary_key=True),
+        meta = MetaData(testbase.db)
+        t = Table('testdate', meta,
+                  Column('id', Integer,
+                         Sequence('datetest_id_seq', optional=True),
+                         primary_key=True),
                 Column('adate', Date), Column('adatetime', DateTime))
         t.create()
         try:
index cfd2354da95c17e3d29e591be494c8bad8b294b6..f882c2a5f8a9f2c5671ec2cf437cd5705db27d11 100644 (file)
@@ -4,10 +4,10 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import mapper, relation, create_session, eagerload
-from testbase import Table, Column
+from testlib import *
 
 
-class UnicodeSchemaTest(testbase.PersistTest):
+class UnicodeSchemaTest(PersistTest):
     def setUpAll(self):
         global unicode_bind, metadata, t1, t2
 
index c11ca1457ebc7a93c58de3a01f0be59a244bd5d2..1195db340050b687acd6af1eb7634674552dcf3d 100644 (file)
-"""base import for all test cases.  Patches in enhancements to unittest.TestCase, 
-instruments SQLAlchemy dialect/engine to track SQL statements for assertion purposes,
-provides base test classes for common test scenarios."""
+"""First import for all test cases, sets sys.path and loads configuration."""
 
-import sys
-import coverage
+__all__ = 'db',
 
-import os, unittest, StringIO, re, ConfigParser, optparse
+import sys, os, logging
 sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
+logging.basicConfig()
 
-db = None
-metadata = None
-db_uri = None
-echo = True
-table_options = {}
-
-# redefine sys.stdout so all those print statements go to the echo func
-local_stdout = sys.stdout
-class Logger(object):
-    def write(self, msg):
-        if echo:
-            local_stdout.write(msg)
-    def flush(self):
-        pass
-
-def echo_text(text):
-    print text
-
-def parse_argv():
-    # we are using the unittest main runner, so we are just popping out the 
-    # arguments we need instead of using our own getopt type of thing
-    global db, db_uri, metadata
-    
-    DBTYPE = 'sqlite'
-    PROXY = False
-
-    base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:s@localhost/tmp/test.fdb
-"""
-    config = ConfigParser.ConfigParser()
-    config.readfp(StringIO.StringIO(base_config))
-    config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-
-    parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
-    parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)")
-    parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (%s)" % ', '.join(config.options('db')))
-    parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool (asserts only one connection used)")
-    parser.add_option("--verbose", action="store_true", dest="verbose", help="enable stdout echoing/printing")
-    parser.add_option("--quiet", action="store_true", dest="quiet", help="suppress unittest output")
-    parser.add_option("--log-info", action="append", dest="log_info", help="turn on info logging for <LOG> (multiple OK)")
-    parser.add_option("--log-debug", action="append", dest="log_debug", help="turn on debug logging for <LOG> (multiple OK)")
-    parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)")
-    parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
-    parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
-    parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
-    parser.add_option("--require", action="append", dest="require", help="Require a particular driver or module version", default=[])
-    parser.add_option("--mysql-engine", action="store", dest="mysql_engine", help="Use the specified MySQL storage engine for all tables, default is a db-default/InnoDB combo.", default=None)
-    parser.add_option("--table-option", action="append", dest="tableopts", help="Add a dialect-specific table option, key=value", default=[])
-    
-    (options, args) = parser.parse_args()
-    sys.argv[1:] = args
-    
-    if options.dburi:
-        db_uri = param = options.dburi
-        DBTYPE = db_uri[:db_uri.index(':')]
-    elif options.db:
-        DBTYPE = param = options.db
-
-    if options.require or (config.has_section('require') and
-                           config.items('require')):
-        try:
-            import pkg_resources
-        except ImportError:
-            raise "setuptools is required for version requirements"
-
-        cmdline = []
-        for requirement in options.require:
-            pkg_resources.require(requirement)
-            cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
-
-        if config.has_section('require'):
-            for label, requirement in config.items('require'):
-                if not label == DBTYPE or label.startswith('%s.' % DBTYPE):
-                    continue
-                seen = [c for c in cmdline if requirement.startswith(c)]
-                if seen:
-                    continue
-                pkg_resources.require(requirement)
-        
-    opts = {}
-    if (None == db_uri):
-        if DBTYPE not in config.options('db'):
-            raise ("Could not create engine.  specify --db <%s> to " 
-                   "test runner." % '|'.join(config.options('db')))
-
-        db_uri = config.get('db', DBTYPE)
-
-    if not db_uri:
-        raise "Could not create engine.  specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql|firebird> to test runner."
-
-    global table_options
-    for spec in options.tableopts:
-        key, value = spec.split('=')
-        table_options[key] = value
-
-    if options.mysql_engine:
-        table_options['mysql_engine'] = options.mysql_engine
-
-
-    global echo
-    echo = options.verbose and not options.quiet
-    
-    global quiet
-    quiet = options.quiet
-    
-    global with_coverage
-    with_coverage = options.coverage
-    if with_coverage:
-        coverage.erase()
-        coverage.start()
-
-    from sqlalchemy import engine, schema
-    
-    if options.serverside:
-        opts['server_side_cursors'] = True
-    
-    if options.enginestrategy is not None:
-        opts['strategy'] = options.enginestrategy    
-    if options.mockpool:
-        db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts)
-    else:
-        db = engine.create_engine(db_uri, **opts)
-
-    # decorate the dialect's create_execution_context() method
-    # to produce a wrapper
-    create_context = db.dialect.create_execution_context
-    def create_exec_context(*args, **kwargs):
-        return ExecutionContextWrapper(create_context(*args, **kwargs))
-    db.dialect.create_execution_context = create_exec_context
-    
-    if options.topological:
-        from sqlalchemy.orm import unitofwork
-        from sqlalchemy import topological
-        class RevQueueDepSort(topological.QueueDependencySorter):
-            def __init__(self, tuples, allitems):
-                self.tuples = list(tuples)
-                self.allitems = list(allitems)
-                self.tuples.reverse()
-                self.allitems.reverse()
-        topological.QueueDependencySorter = RevQueueDepSort
-        unitofwork.DependencySorter = RevQueueDepSort
-            
-    import logging
-    logging.basicConfig()
-    if options.log_info is not None:
-        for elem in options.log_info:
-            logging.getLogger(elem).setLevel(logging.INFO)
-    if options.log_debug is not None:
-        for elem in options.log_debug:
-            logging.getLogger(elem).setLevel(logging.DEBUG)
-    metadata = schema.MetaData(db)
-    
-def unsupported(*dbs):
-    """a decorator that marks a test as unsupported by one or more database implementations"""
-    
-    def decorate(func):
-        name = db.name
-        for d in dbs:
-            if d == name:
-                def lala(self):
-                    echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'")
-                lala.__name__ = func.__name__
-                return lala
-        else:
-            return func
-    return decorate
-
-def supported(*dbs):
-    """a decorator that marks a test as supported by one or more database implementations"""
-    
-    def decorate(func):
-        name = db.name
-        for d in dbs:
-            if d == name:
-                return func
-        else:
-            def lala(self):
-                echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'")
-            lala.__name__ = func.__name__
-            return lala
-    return decorate
-
-def Table(*args, **kw):
-    """A schema.Table wrapper/hook for dialect-specific tweaks."""
-
-    test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
-                      if k.startswith('test_')])
-
-    kw.update(table_options)
-
-    if db.engine.name == 'mysql':
-        if 'mysql_engine' not in kw and 'mysql_type' not in kw:
-            if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
-                kw['mysql_engine'] = 'InnoDB'
-
-    return schema.Table(*args, **kw)
-
-def Column(*args, **kw):
-    """A schema.Column wrapper/hook for dialect-specific tweaks."""
-
-    # TODO: a Column that creates a Sequence automatically for PK columns,
-    # which would help Oracle tests
-    return schema.Column(*args, **kw)
-        
-class PersistTest(unittest.TestCase):
-
-    def __init__(self, *args, **params):
-        unittest.TestCase.__init__(self, *args, **params)
-
-    def echo(self, text):
-        """DEPRECATED.  use print <statement>"""
-        echo_text(text)
-        
-
-    def setUpAll(self):
-        pass
-    def tearDownAll(self):
-        pass
-
-    def shortDescription(self):
-        """overridden to not return docstrings"""
-        return None
-
-class AssertMixin(PersistTest):
-    """given a list-based structure of keys/properties which represent information within an object structure, and
-    a list of actual objects, asserts that the list of objects corresponds to the structure."""
-    
-    def assert_result(self, result, class_, *objects):
-        result = list(result)
-        if echo:
-            print repr(result)
-        self.assert_list(result, class_, objects)
-        
-    def assert_list(self, result, class_, list):
-        self.assert_(len(result) == len(list), "result list is not the same size as test list, for class " + class_.__name__)
-        for i in range(0, len(list)):
-            self.assert_row(class_, result[i], list[i])
-            
-    def assert_row(self, class_, rowobj, desc):
-        self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_))
-        for key, value in desc.iteritems():
-            if isinstance(value, tuple):
-                if isinstance(value[1], list):
-                    self.assert_list(getattr(rowobj, key), value[0], value[1])
-                else:
-                    self.assert_row(value[0], getattr(rowobj, key), value[1])
-            else:
-                self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
-                
-    def assert_sql(self, db, callable_, list, with_sequences=None):
-        global testdata
-        testdata = TestData()
-        if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
-            testdata.set_assert_list(self, with_sequences)
-        else:
-            testdata.set_assert_list(self, list)
-        try:
-            callable_()
-        finally:
-            testdata.set_assert_list(None, None)
-
-    def assert_sql_count(self, db, callable_, count):
-        global testdata
-        testdata = TestData()
-        try:
-            callable_()
-        finally:
-            self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count))
-
-    def capture_sql(self, db, callable_):
-        global testdata
-        testdata = TestData()
-        buffer = StringIO.StringIO()
-        testdata.buffer = buffer
-        try:
-            callable_()
-            return buffer.getvalue()
-        finally:
-            testdata.buffer = None
-            
-class ORMTest(AssertMixin):
-    keep_mappers = False
-    keep_data = False
-    def setUpAll(self):
-        global _otest_metadata
-        _otest_metadata = MetaData(db)
-        self.define_tables(_otest_metadata)
-        _otest_metadata.create_all()
-        self.insert_data()
-    def define_tables(self, _otest_metadata):
-        raise NotImplementedError()
-    def insert_data(self):
-        pass
-    def get_metadata(self):
-        return _otest_metadata
-    def tearDownAll(self):
-        clear_mappers()
-        _otest_metadata.drop_all()
-    def tearDown(self):
-        if not self.keep_mappers:
-            clear_mappers()
-        if not self.keep_data:
-            for t in _otest_metadata.table_iterator(reverse=True):
-                t.delete().execute().close()
-
-class TestData(object):
-    """tracks SQL expressions as theyre executed via an instrumented ExecutionContext."""
-    
-    def __init__(self):
-        self.set_assert_list(None, None)
-        self.sql_count = 0
-        self.buffer = None
-        
-    def set_assert_list(self, unittest, list):
-        self.unittest = unittest
-        self.assert_list = list
-        if list is not None:
-            self.assert_list.reverse()
-
-testdata = TestData()
-
-class ExecutionContextWrapper(object):
-    """instruments the ExecutionContext created by the Engine so that SQL expressions
-    can be tracked."""
-    
-    def __init__(self, ctx):
-        self.__dict__['ctx'] = ctx
-    def __getattr__(self, key):
-        return getattr(self.ctx, key)
-    def __setattr__(self, key, value):
-        setattr(self.ctx, key, value)
-        
-    def post_exec(self):
-        ctx = self.ctx
-        statement = unicode(ctx.compiled)
-        statement = re.sub(r'\n', '', ctx.statement)
-        if testdata.buffer is not None:
-            testdata.buffer.write(statement + "\n")
-
-        if testdata.assert_list is not None:
-            assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
-            item = testdata.assert_list[-1]
-            if not isinstance(item, dict):
-                item = testdata.assert_list.pop()
-            else:
-                # asserting a dictionary of statements->parameters
-                # this is to specify query assertions where the queries can be in 
-                # multiple orderings
-                if not item.has_key('_converted'):
-                    for key in item.keys():
-                        ckey = self.convert_statement(key)
-                        item[ckey] = item[key]
-                        if ckey != key:
-                            del item[key]
-                    item['_converted'] = True
-                try:
-                    entry = item.pop(statement)
-                    if len(item) == 1:
-                        testdata.assert_list.pop()
-                    item = (statement, entry)
-                except KeyError:
-                    assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
-
-            (query, params) = item
-            if callable(params):
-                params = params(ctx)
-            if params is not None and isinstance(params, list) and len(params) == 1:
-                params = params[0]
-            
-            if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
-                parameters = ctx.compiled_parameters.get_original_dict()
-            elif isinstance(ctx.compiled_parameters, list):
-                parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
-                    
-            query = self.convert_statement(query)
-            if db.engine.name == 'mssql' and statement.endswith('; select scope_identity()'):
-                statement = statement[:-25]
-            testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
-        testdata.sql_count += 1
-        self.ctx.post_exec()
-        
-    def convert_statement(self, query):
-        paramstyle = self.ctx.dialect.paramstyle
-        if paramstyle == 'named':
-            pass
-        elif paramstyle =='pyformat':
-            query = re.sub(r':([\w_]+)', r"%(\1)s", query)
-        else:
-            # positional params
-            repl = None
-            if paramstyle=='qmark':
-                repl = "?"
-            elif paramstyle=='format':
-                repl = r"%s"
-            elif paramstyle=='numeric':
-                repl = None
-            query = re.sub(r':([\w_]+)', repl, query)
-        return query
-        
-class TTestSuite(unittest.TestSuite):
-    """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality"""
-    def __init__(self, tests=()):
-        if len(tests) >0 and isinstance(tests[0], PersistTest):
-            self._initTest = tests[0]
-        else:
-            self._initTest = None
-        unittest.TestSuite.__init__(self, tests)
-
-    def do_run(self, result):
-        """nice job unittest !  you switched __call__ and run() between py2.3 and 2.4 thereby
-        making straight subclassing impossible !"""
-        for test in self._tests:
-            if result.shouldStop:
-                break
-            test(result)
-        return result
-
-    def run(self, result):
-        return self(result)
-
-    def __call__(self, result):
-        try:
-            if self._initTest is not None:
-                self._initTest.setUpAll()
-        except:
-            result.addError(self._initTest, self.__exc_info())
-            pass
-        try:
-            return self.do_run(result)
-        finally:
-            try:
-                if self._initTest is not None:
-                    self._initTest.tearDownAll()
-            except:
-                result.addError(self._initTest, self.__exc_info())
-                pass
-
-    def __exc_info(self):
-        """Return a version of sys.exc_info() with the traceback frame
-           minimised; usually the top level of the traceback frame is not
-           needed.
-           ripped off out of unittest module since its double __
-        """
-        exctype, excvalue, tb = sys.exc_info()
-        if sys.platform[:4] == 'java': ## tracebacks look different in Jython
-            return (exctype, excvalue, tb)
-        return (exctype, excvalue, tb)
-
-unittest.TestLoader.suiteClass = TTestSuite
-
-parse_argv()
-
-import sqlalchemy
-from sqlalchemy import schema, MetaData, sql
-from sqlalchemy.orm import clear_mappers
-                    
-def runTests(suite):
-    sys.stdout = Logger()    
-    runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
-    try:
-        return runner.run(suite)
-    finally:
-        if with_coverage:
-            global echo
-            echo=True
-            coverage.stop()
-            coverage.report(list(covered_files()), show_missing=False)
-
-def covered_files():
-    for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):                          
-        for x in rec[2]:
-            if x.endswith('.py'):
-                yield os.path.join(rec[0], x)
-
-def main(suite=None):
-    
-    if not suite:
-        if len(sys.argv[1:]):
-            suite =unittest.TestLoader().loadTestsFromNames(sys.argv[1:], __import__('__main__'))
-        else:
-            suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
-
-    result = runTests(suite)
-    sys.exit(not result.wasSuccessful())
+import testlib.config
+testlib.config.configure()
 
+from testlib.testing import main
+db = testlib.config.db
 
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py
new file mode 100644 (file)
index 0000000..ff5c4c1
--- /dev/null
@@ -0,0 +1,11 @@
+"""Enhance unittest and instrument SQLAlchemy classes for testing.
+
+Load after sqlalchemy imports to use instrumented stand-ins like Table.
+"""
+
+import testlib.config
+from testlib.schema import Table, Column
+import testlib.testing as testing
+from testlib.testing import PersistTest, AssertMixin, ORMTest
+import testlib.profiling
+
diff --git a/test/testlib/config.py b/test/testlib/config.py
new file mode 100644 (file)
index 0000000..ca24a72
--- /dev/null
@@ -0,0 +1,255 @@
+import optparse, os, sys, ConfigParser, StringIO
+logging, require = None, None
+
+__all__ = 'parser', 'configure', 'options',
+
+db, db_uri, db_type, db_label = None, None, None, None
+
+options = None
+file_config = None
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:s@localhost/tmp/test.fdb
+"""
+
+parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
+
+def configure():
+    global options, config
+    global getopts_options, file_config
+
+    file_config = ConfigParser.ConfigParser()
+    file_config.readfp(StringIO.StringIO(base_config))
+    file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+
+    # Opt parsing can fire immediate actions, like logging and coverage
+    (options, args) = parser.parse_args()
+    sys.argv[1:] = args
+
+    # Lazy setup of other options (post coverage)
+    for fn in post_configure:
+        fn(options, file_config)
+
+    return options, file_config
+
+def _log(option, opt_str, value, parser):
+    global logging
+    if not logging:
+        import logging
+        logging.basicConfig()
+
+    if opt_str.endswith('-info'):
+        logging.getLogger(value).setLevel(logging.INFO)
+    elif opt_str.endswith('-debug'):
+        logging.getLogger(value).setLevel(logging.DEBUG)
+
+def _start_coverage(option, opt_str, value, parser):
+    import sys, atexit, coverage
+    true_out = sys.stdout
+
+    def _iter_covered_files():
+        import sqlalchemy
+        for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
+            for x in rec[2]:
+                if x.endswith('.py'):
+                    yield os.path.join(rec[0], x)
+    def _stop():
+        coverage.stop()
+        true_out.write("\nPreparing coverage report...\n")
+        coverage.report(list(_iter_covered_files()),
+                        show_missing=False, ignore_errors=False,
+                        file=true_out)
+    atexit.register(_stop)
+    coverage.erase()
+    coverage.start()
+    
+def _list_dbs(*args):
+    print "Available --db options (use --dburi to override)"
+    for macro in sorted(file_config.options('db')):
+        print "%20s\t%s" % (macro, file_config.get('db', macro))
+    sys.exit(0)
+
+opt = parser.add_option
+opt("--verbose", action="store_true", dest="verbose",
+    help="enable stdout echoing/printing")
+opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+opt("--log-info", action="callback", callback=_log,
+    help="turn on info logging for <LOG> (multiple OK)")
+opt("--log-debug", action="callback", callback=_log,
+    help="turn on debug logging for <LOG> (multiple OK)")
+opt("--require", action="append", dest="require", default=[],
+    help="require a particular driver or module version (multiple OK)")
+opt("--db", action="store", dest="db", default="sqlite",
+    help="Use prefab database uri")
+opt('--dbs', action='callback', callback=_list_dbs,
+    help="List available prefab dbs")
+opt("--dburi", action="store", dest="dburi",
+    help="Database uri (overrides --db)")
+opt("--mockpool", action="store_true", dest="mockpool",
+    help="Use mock pool (asserts only one connection used)")
+opt("--enginestrategy", action="store", dest="enginestrategy", default=None,
+    help="Engine strategy (plain or threadlocal, defaults toplain)")
+opt("--reversetop", action="store_true", dest="reversetop", default=False,
+    help="Reverse the collection ordering for topological sorts (helps "
+          "reveal dependency issues)")
+opt("--serverside", action="store_true", dest="serverside",
+    help="Turn on server side cursors for PG")
+opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+    help="Use the specified MySQL storage engine for all tables, default is "
+         "a db-default/InnoDB combo.")
+opt("--table-option", action="append", dest="tableopts", default=[],
+    help="Add a dialect-specific table option, key=value")
+opt("--coverage", action="callback", callback=_start_coverage,
+    help="Dump a full coverage report after running tests")
+opt("--profile", action="append", dest="profile_targets", default=[],
+    help="Enable a named profile target (multiple OK.)")
+opt("--profile-sort", action="store", dest="profile_sort", default=None,
+    help="Sort profile stats with this comma-separated sort order")
+opt("--profile-limit", type="int", action="store", dest="profile_limit",
+    default=None,
+    help="Limit function count in profile stats")
+
+class _ordered_map(object):
+    def __init__(self):
+        self._keys = list()
+        self._data = dict()
+
+    def __setitem__(self, key, value):
+        if key not in self._keys:
+            self._keys.append(key)
+        self._data[key] = value
+
+    def __iter__(self):
+        for key in self._keys:
+            yield self._data[key]
+    
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+    global db_label, db_uri
+    db_label = 'sqlite'
+    if options.dburi:
+        db_uri = options.dburi
+        db_label = db_uri[:db_uri.index(':')]
+    elif options.db:
+        db_label = options.db
+        db_uri = None
+
+    if db_uri is None:
+        if db_label not in file_config.options('db'):
+            raise RuntimeError(
+                "Unknown engine.  Specify --dbs for known engines.")
+        db_uri = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+    if not(options.require or
+           (file_config.has_section('require') and
+            file_config.items('require'))):
+        return
+
+    try:
+        import pkg_resources
+    except ImportError:
+        raise RuntimeError("setuptools is required for version requirements")
+
+    cmdline = []
+    for requirement in options.require:
+        pkg_resources.require(requirement)
+        cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+    if file_config.has_section('require'):
+        for label, requirement in file_config.items('require'):
+            if not label == db_label or label.startswith('%s.' % db_label):
+                continue
+            seen = [c for c in cmdline if requirement.startswith(c)]
+            if seen:
+                continue
+            pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _create_testing_engine(options, file_config):
+    from sqlalchemy import engine
+    global db, db_type
+    engine_opts = {}
+    if options.serverside:
+        engine_opts['server_side_cursors'] = True
+    
+    if options.enginestrategy is not None:
+        engine_opts['strategy'] = options.enginestrategy    
+
+    if options.mockpool:
+        db = engine.create_engine(db_uri, poolclass=pool.AssertionPool,
+                                  **engine_opts)
+    else:
+        db = engine.create_engine(db_uri, **engine_opts)
+    db_type = db.name
+
+    # decorate the dialect's create_execution_context() method
+    # to produce a wrapper
+    from testlib.testing import ExecutionContextWrapper
+
+    create_context = db.dialect.create_execution_context
+    def create_exec_context(*args, **kwargs):
+        return ExecutionContextWrapper(create_context(*args, **kwargs))
+    db.dialect.create_execution_context = create_exec_context
+post_configure['create_engine'] = _create_testing_engine
+
+def _set_table_options(options, file_config):
+    import testlib.schema
+    
+    table_options = testlib.schema.table_options
+    for spec in options.tableopts:
+        key, value = spec.split('=')
+        table_options[key] = value
+
+    if options.mysql_engine:
+        table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+    if options.reversetop:
+        from sqlalchemy.orm import unitofwork
+        from sqlalchemy import topological
+        class RevQueueDepSort(topological.QueueDependencySorter):
+            def __init__(self, tuples, allitems):
+                self.tuples = list(tuples)
+                self.allitems = list(allitems)
+                self.tuples.reverse()
+                self.allitems.reverse()
+        topological.QueueDependencySorter = RevQueueDepSort
+        unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+
+def _set_profile_targets(options, file_config):
+    from testlib import profiling
+    
+    profile_config = profiling.profile_config
+
+    for target in options.profile_targets:
+        profile_config['targets'].add(target)
+
+    if options.profile_sort:
+        profile_config['sort'] = options.profile_sort.split(',')
+
+    if options.profile_limit:
+        profile_config['limit'] = options.profile_limit
+
+    if options.quiet:
+        profile_config['report'] = False
+
+    # magic "all" target
+    if 'all' in profiling.all_targets:
+        targets = profile_config['targets']
+        if 'all' in targets and len(targets) != 1:
+            targets.clear()
+            targets.add('all')
+post_configure['profile_targets'] = _set_profile_targets
similarity index 99%
rename from test/coverage.py
rename to test/testlib/coverage.py
index 618f962fef3ea839349fed3ba40b482433fdc55f..0203dbf7d5dfe33316b94250afd6e8b9bd678d0c 100644 (file)
@@ -1101,4 +1101,3 @@ if __name__ == '__main__':
 # DAMAGE.
 #
 # $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
-
diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py
new file mode 100644 (file)
index 0000000..51b8bb5
--- /dev/null
@@ -0,0 +1,73 @@
+"""Profiling support for unit and performance tests."""
+
+import time, hotshot, hotshot.stats
+from testlib.config import parser, post_configure
+import testlib.config
+
+__all__ = 'profiled',
+
+all_targets = set()
+profile_config = { 'targets': set(),
+                   'report': True,
+                   'sort': ('time', 'calls'),
+                   'limit': None }
+
+def profiled(target, **target_opts):
+    """Optional function profiling.
+
+    @profiled('label')
+    or
+    @profiled('label', report=True, sort=('calls',), limit=20)
+    
+    Enables profiling for a function when 'label' is targetted for
+    profiling.  Report options can be supplied, and override the global
+    configuration and command-line options.
+    """
+
+    # manual or automatic namespacing by module would remove conflict issues
+    if target in all_targets:
+        print "Warning: redefining profile target '%s'" % target
+    all_targets.add(target)
+
+    filename = "%s.prof" % target
+
+    def decorator(fn):
+        def profiled(*args, **kw):
+            if (target not in profile_config['targets'] and
+                not target_opts.get('always', None)):
+                return fn(*args, **kw)
+
+            prof = hotshot.Profile(filename)
+            began = time.time()
+            prof.start()
+            try:
+                result = fn(*args, **kw)
+            finally:
+                prof.stop()
+                ended = time.time()
+                prof.close()
+
+            if not testlib.config.options.quiet:
+                print "Profiled target '%s', wall time: %.2f seconds" % (
+                    target, ended - began)
+
+            report = target_opts.get('report', profile_config['report'])
+            if report:
+                sort_ = target_opts.get('sort', profile_config['sort'])
+                limit = target_opts.get('limit', profile_config['limit'])
+                print "Profile report for target '%s' (%s)" % (
+                    target, filename)
+
+                stats = hotshot.stats.load(filename)
+                stats.sort_stats(*sort_)
+                if limit:
+                    stats.print_stats(limit)
+                else:
+                    stats.print_stats()
+            return result
+        try:
+            profiled.__name__ = fn.__name__
+        except:
+            pass
+        return profiled
+    return decorator
diff --git a/test/testlib/schema.py b/test/testlib/schema.py
new file mode 100644 (file)
index 0000000..a2fc912
--- /dev/null
@@ -0,0 +1,28 @@
+import testbase
+from sqlalchemy import schema
+
+__all__ = 'Table', 'Column',
+
+table_options = {}
+
+def Table(*args, **kw):
+    """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+    test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+                      if k.startswith('test_')])
+
+    kw.update(table_options)
+
+    if testbase.db.name == 'mysql':
+        if 'mysql_engine' not in kw and 'mysql_type' not in kw:
+            if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
+                kw['mysql_engine'] = 'InnoDB'
+
+    return schema.Table(*args, **kw)
+
+def Column(*args, **kw):
+    """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+    # TODO: a Column that creates a Sequence automatically for PK columns,
+    # which would help Oracle tests
+    return schema.Column(*args, **kw)
similarity index 96%
rename from test/tables.py
rename to test/testlib/tables.py
index 6ae53a1dae0d68b729711c96963595e7ba96676c..2dc9200f35b6d18feabe1dd39780267ed8d0fd21 100644 (file)
@@ -1,12 +1,8 @@
-
-from sqlalchemy import *
-import os
 import testbase
-from testbase import Table, Column
+from sqlalchemy import *
+from testlib.schema import Table, Column
 
-ECHO = testbase.echo
-db = testbase.db
-metadata = MetaData(db)
+metadata = MetaData()
 
 users = Table('users', metadata,
     Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
@@ -51,13 +47,19 @@ itemkeywords = Table('itemkeywords', metadata,
 )
 
 def create():
+    if not metadata.bind:
+        metadata.bind = testbase.db
     metadata.create_all()
 def drop():
+    if not metadata.bind:
+        metadata.bind = testbase.db
     metadata.drop_all()
 def delete():
     for t in metadata.table_iterator(reverse=True):
         t.delete().execute()
 def user_data():
+    if not metadata.bind:
+        metadata.bind = testbase.db
     users.insert().execute(
         dict(user_id = 7, user_name = 'jack'),
         dict(user_id = 8, user_name = 'ed'),
@@ -209,4 +211,4 @@ order_result = [
 {'order_id' : 4, 'items':(Item, [])},
 {'order_id' : 5, 'items':(Item, [])},
 ]
-#db.echo = True
+
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
new file mode 100644 (file)
index 0000000..0361cfb
--- /dev/null
@@ -0,0 +1,363 @@
+"""TestCase and TestSuite artifacts and testing decorators."""
+
+# monkeypatches unittest.TestLoader.suiteClass at import time
+
+import unittest, re, sys, os
+from cStringIO import StringIO
+from sqlalchemy import MetaData, sql
+from sqlalchemy.orm import clear_mappers
+import testlib.config as config
+
+__all__ = 'PersistTest', 'AssertMixin', 'ORMTest'
+
+def unsupported(*dbs):
+    """Mark a test as unsupported by one or more database implementations"""
+    
+    def decorate(fn):
+        fn_name = fn.__name__
+        def maybe(*args, **kw):
+            if config.db.name in dbs:
+                print "'%s' unsupported on DB implementation '%s'" % (
+                    fn_name, config.db.name)
+                return True
+            else:
+                return fn(*args, **kw)
+        try:
+            maybe.__name__ = fn_name
+        except:
+            pass
+        return maybe
+    return decorate
+
+def supported(*dbs):
+    """Mark a test as supported by one or more database implementations"""
+    
+    def decorate(fn):
+        fn_name = fn.__name__
+        def maybe(*args, **kw):
+            if config.db.name in dbs:
+                return fn(*args, **kw)
+            else:
+                print "'%s' unsupported on DB implementation '%s'" % (
+                    fn_name, config.db.name)
+                return True
+        try:
+            maybe.__name__ = fn_name
+        except:
+            pass
+        return maybe
+    return decorate
+
+class TestData(object):
+    """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
+    
+    def __init__(self):
+        self.set_assert_list(None, None)
+        self.sql_count = 0
+        self.buffer = None
+        
+    def set_assert_list(self, unittest, list):
+        self.unittest = unittest
+        self.assert_list = list
+        if list is not None:
+            self.assert_list.reverse()
+
+testdata = TestData()
+
+
+class ExecutionContextWrapper(object):
+    """instruments the ExecutionContext created by the Engine so that SQL expressions
+    can be tracked."""
+    
+    def __init__(self, ctx):
+        self.__dict__['ctx'] = ctx
+    def __getattr__(self, key):
+        return getattr(self.ctx, key)
+    def __setattr__(self, key, value):
+        setattr(self.ctx, key, value)
+        
+    def post_exec(self):
+        ctx = self.ctx
+        statement = unicode(ctx.compiled)
+        statement = re.sub(r'\n', '', ctx.statement)
+        if testdata.buffer is not None:
+            testdata.buffer.write(statement + "\n")
+
+        if testdata.assert_list is not None:
+            assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
+            item = testdata.assert_list[-1]
+            if not isinstance(item, dict):
+                item = testdata.assert_list.pop()
+            else:
+                # asserting a dictionary of statements->parameters
+                # this is to specify query assertions where the queries can be in 
+                # multiple orderings
+                if not item.has_key('_converted'):
+                    for key in item.keys():
+                        ckey = self.convert_statement(key)
+                        item[ckey] = item[key]
+                        if ckey != key:
+                            del item[key]
+                    item['_converted'] = True
+                try:
+                    entry = item.pop(statement)
+                    if len(item) == 1:
+                        testdata.assert_list.pop()
+                    item = (statement, entry)
+                except KeyError:
+                    assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
+
+            (query, params) = item
+            if callable(params):
+                params = params(ctx)
+            if params is not None and isinstance(params, list) and len(params) == 1:
+                params = params[0]
+            
+            if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+                parameters = ctx.compiled_parameters.get_original_dict()
+            elif isinstance(ctx.compiled_parameters, list):
+                parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
+                    
+            query = self.convert_statement(query)
+            if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'):
+                statement = statement[:-25]
+            testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+        testdata.sql_count += 1
+        self.ctx.post_exec()
+        
+    def convert_statement(self, query):
+        paramstyle = self.ctx.dialect.paramstyle
+        if paramstyle == 'named':
+            pass
+        elif paramstyle =='pyformat':
+            query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+        else:
+            # positional params
+            repl = None
+            if paramstyle=='qmark':
+                repl = "?"
+            elif paramstyle=='format':
+                repl = r"%s"
+            elif paramstyle=='numeric':
+                repl = None
+            query = re.sub(r':([\w_]+)', repl, query)
+        return query
+
+class PersistTest(unittest.TestCase):
+
+    def __init__(self, *args, **params):
+        unittest.TestCase.__init__(self, *args, **params)
+
+    def setUpAll(self):
+        pass
+
+    def tearDownAll(self):
+        pass
+
+    def shortDescription(self):
+        """overridden to not return docstrings"""
+        return None
+
+class AssertMixin(PersistTest):
+    """given a list-based structure of keys/properties which represent information within an object structure, and
+    a list of actual objects, asserts that the list of objects corresponds to the structure."""
+    
+    def assert_result(self, result, class_, *objects):
+        result = list(result)
+        print repr(result)
+        self.assert_list(result, class_, objects)
+        
+    def assert_list(self, result, class_, list):
+        self.assert_(len(result) == len(list),
+                     "result list is not the same size as test list, " +
+                     "for class " + class_.__name__)
+        for i in range(0, len(list)):
+            self.assert_row(class_, result[i], list[i])
+            
+    def assert_row(self, class_, rowobj, desc):
+        self.assert_(rowobj.__class__ is class_,
+                     "item class is not " + repr(class_))
+        for key, value in desc.iteritems():
+            if isinstance(value, tuple):
+                if isinstance(value[1], list):
+                    self.assert_list(getattr(rowobj, key), value[0], value[1])
+                else:
+                    self.assert_row(value[0], getattr(rowobj, key), value[1])
+            else:
+                self.assert_(getattr(rowobj, key) == value,
+                             "attribute %s value %s does not match %s" % (
+                             key, getattr(rowobj, key), value))
+                
+    def assert_sql(self, db, callable_, list, with_sequences=None):
+        global testdata
+        testdata = TestData()
+        if with_sequences is not None and (config.db.name == 'postgres' or
+                                           config.db.name == 'oracle'):
+            testdata.set_assert_list(self, with_sequences)
+        else:
+            testdata.set_assert_list(self, list)
+        try:
+            callable_()
+        finally:
+            testdata.set_assert_list(None, None)
+
+    def assert_sql_count(self, db, callable_, count):
+        global testdata
+        testdata = TestData()
+        try:
+            callable_()
+        finally:
+            self.assert_(testdata.sql_count == count,
+                         "desired statement count %d does not match %d" % (
+                         count, testdata.sql_count))
+
+    def capture_sql(self, db, callable_):
+        global testdata
+        testdata = TestData()
+        buffer = StringIO()
+        testdata.buffer = buffer
+        try:
+            callable_()
+            return buffer.getvalue()
+        finally:
+            testdata.buffer = None
+
+_otest_metadata = None
+class ORMTest(AssertMixin):
+    keep_mappers = False
+    keep_data = False
+
+    def setUpAll(self):
+        global _otest_metadata
+        _otest_metadata = MetaData(config.db)
+        self.define_tables(_otest_metadata)
+        _otest_metadata.create_all()
+        self.insert_data()
+
+    def define_tables(self, _otest_metadata):
+        raise NotImplementedError()
+
+    def insert_data(self):
+        pass
+
+    def get_metadata(self):
+        return _otest_metadata
+
+    def tearDownAll(self):
+        clear_mappers()
+        _otest_metadata.drop_all()
+
+    def tearDown(self):
+        if not self.keep_mappers:
+            clear_mappers()
+        if not self.keep_data:
+            for t in _otest_metadata.table_iterator(reverse=True):
+                t.delete().execute().close()
+
+
+class TTestSuite(unittest.TestSuite):
+    """A TestSuite with once per TestCase setUpAll() and tearDownAll()"""
+
+    def __init__(self, tests=()):
+        if len(tests) >0 and isinstance(tests[0], PersistTest):
+            self._initTest = tests[0]
+        else:
+            self._initTest = None
+        unittest.TestSuite.__init__(self, tests)
+
+    def do_run(self, result):
+        # nice job unittest !  you switched __call__ and run() between py2.3
+        # and 2.4 thereby making straight subclassing impossible !
+        for test in self._tests:
+            if result.shouldStop:
+                break
+            test(result)
+        return result
+
+    def run(self, result):
+        return self(result)
+
+    def __call__(self, result):
+        try:
+            if self._initTest is not None:
+                self._initTest.setUpAll()
+        except:
+            result.addError(self._initTest, self.__exc_info())
+            pass
+        try:
+            return self.do_run(result)
+        finally:
+            try:
+                if self._initTest is not None:
+                    self._initTest.tearDownAll()
+            except:
+                result.addError(self._initTest, self.__exc_info())
+                pass
+
+    def __exc_info(self):
+        """Return a version of sys.exc_info() with the traceback frame
+           minimised; usually the top level of the traceback frame is not
+           needed.
+           ripped off out of unittest module since its double __
+        """
+        exctype, excvalue, tb = sys.exc_info()
+        if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+            return (exctype, excvalue, tb)
+        return (exctype, excvalue, tb)
+
+unittest.TestLoader.suiteClass = TTestSuite
+
+def _iter_covered_files():
+    import sqlalchemy
+    for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
+        for x in rec[2]:
+            if x.endswith('.py'):
+                yield os.path.join(rec[0], x)
+
+def cover(callable_, file_=None):
+    from testlib import coverage
+    coverage_client = coverage.the_coverage
+    coverage_client.get_ready()
+    coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
+    coverage_client.erase()
+    coverage_client.start()
+    try:
+        return callable_()
+    finally:
+        coverage_client.stop()
+        coverage_client.save()
+        coverage_client.report(list(_iter_covered_files()),
+                               show_missing=False, ignore_errors=False,
+                               file=file_)
+
+class DevNullWriter(object):
+    def write(self, msg):
+        pass
+    def flush(self):
+        pass
+
+def runTests(suite):
+    verbose = config.options.verbose
+    quiet = config.options.quiet
+    orig_stdout = sys.stdout
+
+    try:
+        if not verbose or quiet:
+            sys.stdout = DevNullWriter()
+        runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
+        return runner.run(suite)
+    finally:
+        if not verbose or quiet:
+            sys.stdout = orig_stdout
+
+def main(suite=None):
+    if not suite:
+        if len(sys.argv[1:]):
+            suite =unittest.TestLoader().loadTestsFromNames(
+                sys.argv[1:], __import__('__main__'))
+        else:
+            suite = unittest.TestLoader().loadTestsFromModule(
+                __import__('__main__'))
+
+    result = runTests(suite)
+    sys.exit(not result.wasSuccessful())
index 0ca66fb7227c69ac30574a56aed02c4258e636e4..3a5059972d1e5d101e3cc325a91fa9e26ebcd6ec 100644 (file)
@@ -1,8 +1,10 @@
+"""application table metadata objects are described here."""
+
 from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
+
 
 metadata = MetaData()
-"""application table metadata objects are described here."""
 
 users = Table('users', metadata, 
     Column('user_id', Integer, primary_key=True),
index 5bcad2da23bf243cb6ed91f00ab1140c7d85eb51..ad6876937d7e5ed11d8d929ac6bcf4503bd8eb45 100644 (file)
@@ -1,21 +1,20 @@
-from testbase import AssertMixin
 import testbase
-import unittest
 
-db = testbase.db
 from sqlalchemy import *
 from sqlalchemy.orm import *
-
+from testlib import *
 from zblog import mappers, tables
 from zblog.user import *
 from zblog.blog import *
 
+
 class ZBlogTest(AssertMixin):
 
     def create_tables(self):
-        tables.metadata.create_all(bind=db)
+        tables.metadata.drop_all(bind=testbase.db)
+        tables.metadata.create_all(bind=testbase.db)
     def drop_tables(self):
-        tables.metadata.drop_all(bind=db)
+        tables.metadata.drop_all(bind=testbase.db)
         
     def setUpAll(self):
         self.create_tables()
@@ -32,7 +31,7 @@ class SavePostTest(ZBlogTest):
         super(SavePostTest, self).setUpAll()
         mappers.zblog_mappers()
         global blog_id, user_id
-        s = create_session(bind=db)
+        s = create_session(bind=testbase.db)
         user = User('zbloguser', "Zblog User", "hello", group=administrator)
         blog = Blog(owner=user)
         blog.name = "this is a blog"
@@ -51,7 +50,7 @@ class SavePostTest(ZBlogTest):
         """test that a transient/pending instance has proper bi-directional behavior.
         
         this requires that lazy loaders do not fire off for a transient/pending instance."""
-        s = create_session(bind=db)
+        s = create_session(bind=testbase.db)
 
         s.begin()
         try:
@@ -67,7 +66,7 @@ class SavePostTest(ZBlogTest):
     def testoptimisticorphans(self):
         """test that instances in the session with un-loaded parents will not 
         get marked as "orphans" and then deleted """
-        s = create_session(bind=db)
+        s = create_session(bind=testbase.db)
         
         s.begin()
         try:
@@ -97,4 +96,4 @@ class SavePostTest(ZBlogTest):
 if __name__ == "__main__":
     testbase.main()
 
-        
\ No newline at end of file
+