From: Jason Kirtland Date: Mon, 23 Jul 2007 01:50:54 +0000 (+0000) Subject: Refactored test support code, moved most into 'testlib/' X-Git-Tag: rel_0_4_6~47 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0296f9c412cd3d21ac40ab03de24cec3f33c7064;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Refactored test support code, moved most into 'testlib/' Cleaned up imports, all tests should be runnable stand-alone or suite now Updated most of the perf tests Removed dead test suites Added new profiling decorator Added new profilable perf test, 'ormsession' to try to capture a typical workload --- diff --git a/test/base/dependency.py b/test/base/dependency.py index c5e54fc9fa..ddadd1b316 100644 --- a/test/base/dependency.py +++ b/test/base/dependency.py @@ -1,7 +1,8 @@ -from testbase import PersistTest +import testbase import sqlalchemy.topological as topological -import unittest, sys, os from sqlalchemy import util +from testlib import * + # TODO: need assertion conditions in this suite @@ -190,4 +191,4 @@ class DependencySortTest(PersistTest): if __name__ == "__main__": - unittest.main() + testbase.main() diff --git a/test/base/utils.py b/test/base/utils.py index ccf6b44197..96d3c96e43 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -1,7 +1,9 @@ import testbase from sqlalchemy import util +from testlib import * -class OrderedDictTest(testbase.PersistTest): + +class OrderedDictTest(PersistTest): def test_odict(self): o = util.OrderedDict() o['a'] = 1 diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index 086baf9eb0..dbba78893d 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -1,16 +1,13 @@ -from testbase import PersistTest, AssertMixin import testbase from sqlalchemy import * from sqlalchemy.databases import mysql -from testbase import Table, Column -import sys, StringIO +from testlib import * -db = testbase.db class TypesTest(AssertMixin): "Test MySQL column types" - @testbase.supported('mysql') + @testing.supported('mysql') def test_numeric(self): "Exercise type specification and options for numeric types." @@ -105,13 +102,13 @@ class TypesTest(AssertMixin): 'SMALLINT(4) UNSIGNED ZEROFILL'), ] - table_args = ['test_mysql_numeric', MetaData(db)] + table_args = ['test_mysql_numeric', MetaData(testbase.db)] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw))) numeric_table = Table(*table_args) - gen = db.dialect.schemagenerator(db, None, None) + gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) for col in numeric_table.c: index = int(col.name[1:]) @@ -125,7 +122,7 @@ class TypesTest(AssertMixin): raise numeric_table.drop() - @testbase.supported('mysql') + @testing.supported('mysql') def test_charset(self): """Exercise CHARACTER SET and COLLATE-related options on string-type columns.""" @@ -189,13 +186,13 @@ class TypesTest(AssertMixin): '''ENUM('foo','bar') UNICODE''') ] - table_args = ['test_mysql_charset', MetaData(db)] + table_args = ['test_mysql_charset', MetaData(testbase.db)] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw))) charset_table = Table(*table_args) - gen = db.dialect.schemagenerator(db, None, None) + gen = testbase.db.dialect.schemagenerator(testbase.db, None, None) for col in charset_table.c: index = int(col.name[1:]) @@ -209,11 +206,12 @@ class TypesTest(AssertMixin): raise charset_table.drop() - @testbase.supported('mysql') + @testing.supported('mysql') def test_enum(self): "Exercise the ENUM type" - - enum_table = Table('mysql_enum', MetaData(db), + + db = testbase.db + enum_table = Table('mysql_enum', MetaData(testbase.db), Column('e1', mysql.MSEnum('"a"', "'b'")), Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False), Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)), @@ -252,8 +250,8 @@ class TypesTest(AssertMixin): # This is known to fail with MySQLDB 1.2.2 beta versions # which return these as sets.Set(['a']), sets.Set(['b']) # (even on Pythons with __builtin__.set) - if db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \ - db.dialect.dbapi.version_info >= (1, 2, 2): + if testbase.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \ + testbase.db.dialect.dbapi.version_info >= (1, 2, 2): # these mysqldb seem to always uses 'sets', even on later pythons import sets def convert(value): @@ -272,10 +270,10 @@ class TypesTest(AssertMixin): self.assertEqual(res, expected) enum_table.drop() - @testbase.supported('mysql') + @testing.supported('mysql') def test_type_reflection(self): # FIXME: older versions need their own test - if db.dialect.get_version_info(db) < (5, 0): + if testbase.db.dialect.get_version_info(testbase.db) < (5, 0): return # (ask_for, roundtripped_as_if_different) @@ -305,12 +303,12 @@ class TypesTest(AssertMixin): columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)] - m = MetaData(db) + m = MetaData(testbase.db) t_table = Table('mysql_types', m, *columns) m.drop_all() m.create_all() - m2 = MetaData(db) + m2 = MetaData(testbase.db) rt = Table('mysql_types', m2, autoload=True) #print diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index 8b939eea0e..550966d0a4 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -1,61 +1,60 @@ -from testbase import AssertMixin import testbase +import datetime from sqlalchemy import * from sqlalchemy.databases import postgres -import datetime +from testlib import * -db = testbase.db class DomainReflectionTest(AssertMixin): "Test PostgreSQL domains" - @testbase.supported('postgres') + @testing.supported('postgres') def setUpAll(self): - con = db.connect() + con = testbase.db.connect() con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42') con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0') con.execute('CREATE TABLE testtable (question integer, answer testdomain)') con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)') con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)') - @testbase.supported('postgres') + @testing.supported('postgres') def tearDownAll(self): - con = db.connect() + con = testbase.db.connect() con.execute('DROP TABLE testtable') con.execute('DROP TABLE alt_schema.testtable') con.execute('DROP TABLE crosschema') con.execute('DROP DOMAIN testdomain') con.execute('DROP DOMAIN alt_schema.testdomain') - @testbase.supported('postgres') + @testing.supported('postgres') def test_table_is_reflected(self): - metadata = MetaData(db) + metadata = MetaData(testbase.db) table = Table('testtable', metadata, autoload=True) self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger) - @testbase.supported('postgres') + @testing.supported('postgres') def test_domain_is_reflected(self): - metadata = MetaData(db) + metadata = MetaData(testbase.db) table = Table('testtable', metadata, autoload=True) self.assertEquals(str(table.columns.answer.default.arg), '42', "Reflected default value didn't equal expected value") self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.") - @testbase.supported('postgres') + @testing.supported('postgres') def test_table_is_reflected_alt_schema(self): - metadata = MetaData(db) + metadata = MetaData(testbase.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger) - @testbase.supported('postgres') + @testing.supported('postgres') def test_schema_domain_is_reflected(self): - metadata = MetaData(db) + metadata = MetaData(testbase.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value") self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") - @testbase.supported('postgres') + @testing.supported('postgres') def test_crosschema_domain_is_reflected(self): metadata = MetaData(db) table = Table('crosschema', metadata, autoload=True) @@ -63,7 +62,7 @@ class DomainReflectionTest(AssertMixin): self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") class MiscTest(AssertMixin): - @testbase.supported('postgres') + @testing.supported('postgres') def test_date_reflection(self): m1 = MetaData(testbase.db) t1 = Table('pgdate', m1, @@ -79,7 +78,7 @@ class MiscTest(AssertMixin): finally: m1.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_pg_weirdchar_reflection(self): meta1 = MetaData(testbase.db) subject = Table("subject", meta1, @@ -100,7 +99,7 @@ class MiscTest(AssertMixin): finally: meta1.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_checksfor_sequence(self): meta1 = MetaData(testbase.db) t = Table('mytable', meta1, @@ -111,7 +110,7 @@ class MiscTest(AssertMixin): finally: t.drop(checkfirst=True) - @testbase.supported('postgres') + @testing.supported('postgres') def test_schema_reflection(self): """note: this test requires that the 'alt_schema' schema be separate and accessible by the test user""" @@ -142,7 +141,7 @@ class MiscTest(AssertMixin): finally: meta1.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_schema_reflection_2(self): meta1 = MetaData(testbase.db) subject = Table("subject", meta1, @@ -163,7 +162,7 @@ class MiscTest(AssertMixin): finally: meta1.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_schema_reflection_3(self): meta1 = MetaData(testbase.db) subject = Table("subject", meta1, @@ -186,7 +185,7 @@ class MiscTest(AssertMixin): finally: meta1.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_preexecute_passivedefault(self): """test that when we get a primary key column back from reflecting a table which has a default value on it, we pre-execute @@ -217,7 +216,7 @@ class TimezoneTest(AssertMixin): if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime that doesnt have one. this test illustrates two ways to have datetime types with and without timezone info. """ - @testbase.supported('postgres') + @testing.supported('postgres') def setUpAll(self): global tztable, notztable, metadata metadata = MetaData(testbase.db) @@ -234,11 +233,11 @@ class TimezoneTest(AssertMixin): Column("name", String(20)), ) metadata.create_all() - @testbase.supported('postgres') + @testing.supported('postgres') def tearDownAll(self): metadata.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_with_timezone(self): # get a date with a tzinfo somedate = testbase.db.connect().scalar(func.current_timestamp().select()) @@ -247,7 +246,7 @@ class TimezoneTest(AssertMixin): x = c.last_updated_params() print x['date'] == somedate - @testbase.supported('postgres') + @testing.supported('postgres') def test_without_timezone(self): # get a date without a tzinfo somedate = datetime.datetime(2005, 10,20, 11, 52, 00) @@ -257,7 +256,7 @@ class TimezoneTest(AssertMixin): print x['date'] == somedate class ArrayTest(AssertMixin): - @testbase.supported('postgres') + @testing.supported('postgres') def setUpAll(self): global metadata, arrtable metadata = MetaData(testbase.db) @@ -268,11 +267,11 @@ class ArrayTest(AssertMixin): Column('strarr', postgres.PGArray(String), nullable=False) ) metadata.create_all() - @testbase.supported('postgres') + @testing.supported('postgres') def tearDownAll(self): metadata.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_reflect_array_column(self): metadata2 = MetaData(testbase.db) tbl = Table('arrtable', metadata2, autoload=True) @@ -281,7 +280,7 @@ class ArrayTest(AssertMixin): self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer)) self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String)) - @testbase.supported('postgres') + @testing.supported('postgres') def test_insert_array(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = arrtable.select().execute().fetchall() @@ -290,7 +289,7 @@ class ArrayTest(AssertMixin): self.assertEquals(results[0]['strarr'], ['abc','def']) arrtable.delete().execute() - @testbase.supported('postgres') + @testing.supported('postgres') def test_array_where(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) arrtable.insert().execute(intarr=[4,5,6], strarr='ABC') @@ -299,7 +298,7 @@ class ArrayTest(AssertMixin): self.assertEquals(results[0]['intarr'], [1,2,3]) arrtable.delete().execute() - @testbase.supported('postgres') + @testing.supported('postgres') def test_array_concat(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall() diff --git a/test/engine/bind.py b/test/engine/bind.py index a639a4a90a..6a0c78f578 100644 --- a/test/engine/bind.py +++ b/test/engine/bind.py @@ -2,13 +2,10 @@ including the deprecated versions of these arguments""" import testbase -import unittest, sys, datetime -import tables -db = testbase.db from sqlalchemy import * -from testbase import Table, Column +from testlib import * -class BindTest(testbase.PersistTest): +class BindTest(PersistTest): def test_create_drop_explicit(self): metadata = MetaData() table = Table('test_table', metadata, @@ -201,8 +198,6 @@ class BindTest(testbase.PersistTest): assert False except exceptions.InvalidRequestError, e: assert str(e).startswith("Could not locate any Engine or Connection bound to mapper") - - finally: if isinstance(bind, engine.Connection): bind.close() @@ -210,4 +205,4 @@ class BindTest(testbase.PersistTest): if __name__ == '__main__': - testbase.main() \ No newline at end of file + testbase.main() diff --git a/test/engine/execute.py b/test/engine/execute.py index 5bc9dbfe91..3d3b43f9b6 100644 --- a/test/engine/execute.py +++ b/test/engine/execute.py @@ -1,12 +1,8 @@ - import testbase -import unittest, sys, datetime -import tables -db = testbase.db from sqlalchemy import * -from testbase import Table, Column +from testlib import * -class ExecuteTest(testbase.PersistTest): +class ExecuteTest(PersistTest): def setUpAll(self): global users, metadata metadata = MetaData(testbase.db) @@ -21,7 +17,7 @@ class ExecuteTest(testbase.PersistTest): def tearDownAll(self): metadata.drop_all() - @testbase.supported('sqlite') + @testing.supported('sqlite') def test_raw_qmark(self): for conn in (testbase.db, testbase.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) @@ -33,7 +29,7 @@ class ExecuteTest(testbase.PersistTest): assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')] conn.execute("delete from users") - @testbase.supported('mysql', 'postgres') + @testing.supported('mysql', 'postgres') def test_raw_sprintf(self): for conn in (testbase.db, testbase.db.connect()): conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"]) @@ -46,7 +42,7 @@ class ExecuteTest(testbase.PersistTest): # pyformat is supported for mysql, but skipping because a few driver # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) - @testbase.supported('postgres') + @testing.supported('postgres') def test_raw_python(self): for conn in (testbase.db, testbase.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) @@ -56,7 +52,7 @@ class ExecuteTest(testbase.PersistTest): assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") - @testbase.supported('sqlite') + @testing.supported('sqlite') def test_raw_named(self): for conn in (testbase.db, testbase.db.connect()): conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'}) diff --git a/test/engine/metadata.py b/test/engine/metadata.py index 28b0535a57..973007fab8 100644 --- a/test/engine/metadata.py +++ b/test/engine/metadata.py @@ -1,8 +1,8 @@ import testbase from sqlalchemy import * -from testbase import Table, Column +from testlib import * -class MetaDataTest(testbase.PersistTest): +class MetaDataTest(PersistTest): def test_metadata_connect(self): metadata = MetaData() t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py index eb4f95619d..3e186275d5 100644 --- a/test/engine/parseconnect.py +++ b/test/engine/parseconnect.py @@ -1,8 +1,7 @@ -from testbase import PersistTest import testbase -import sqlalchemy.engine.url as url from sqlalchemy import * -import unittest +import sqlalchemy.engine.url as url +from testlib import * class ParseConnectTest(PersistTest): diff --git a/test/engine/pool.py b/test/engine/pool.py index 85d44dfd38..3bb25ec72a 100644 --- a/test/engine/pool.py +++ b/test/engine/pool.py @@ -1,10 +1,9 @@ import testbase -from testbase import PersistTest -import unittest, sys, os, time -import threading, thread - +import threading, thread, time import sqlalchemy.pool as pool import sqlalchemy.exceptions as exceptions +from testlib import * + mcid = 1 class MockDBAPI(object): @@ -45,7 +44,7 @@ class PoolTest(PersistTest): connection2 = manager.connect('foo.db') connection3 = manager.connect('bar.db') - self.echo( "connection " + repr(connection)) + print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is connection2) self.assert_(connection2 is not connection3) @@ -64,7 +63,7 @@ class PoolTest(PersistTest): connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') - self.echo( "connection " + repr(connection)) + print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) @@ -80,7 +79,7 @@ class PoolTest(PersistTest): def status(pool): tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) - self.echo( "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup) + print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup return tup c1 = p.connect() diff --git a/test/engine/reconnect.py b/test/engine/reconnect.py index 1c8594d0ee..7c213695f2 100644 --- a/test/engine/reconnect.py +++ b/test/engine/reconnect.py @@ -1,6 +1,8 @@ import testbase +import sys, weakref from sqlalchemy import create_engine, exceptions -import gc, weakref, sys +from testlib import * + class MockDisconnect(Exception): pass @@ -37,7 +39,7 @@ class MockCursor(object): def close(self): pass -class ReconnectTest(testbase.PersistTest): +class ReconnectTest(PersistTest): def test_reconnect(self): """test that an 'is_disconnect' condition will invalidate the connection, and additionally dispose the previous connection pool and recreate.""" @@ -92,4 +94,4 @@ class ReconnectTest(testbase.PersistTest): assert len(dbapi.connections) == 1 if __name__ == '__main__': - testbase.main() \ No newline at end of file + testbase.main() diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 78ffd1fdcf..1fa4b4b90f 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -1,13 +1,12 @@ -from testbase import PersistTest import testbase -import pickle -import sqlalchemy.ansisql as ansisql +import pickle, StringIO from sqlalchemy import * +import sqlalchemy.ansisql as ansisql from sqlalchemy.exceptions import NoSuchTableError import sqlalchemy.databases.mysql as mysql -from testbase import Table, Column -import unittest, re, StringIO +from testlib import * + class ReflectionTest(PersistTest): def testbasic(self): @@ -207,7 +206,7 @@ class ReflectionTest(PersistTest): finally: meta.drop_all() - @testbase.supported('mysql') + @testing.supported('mysql') def testmysqltypes(self): meta1 = MetaData(testbase.db) table = Table( @@ -284,7 +283,7 @@ class ReflectionTest(PersistTest): finally: testbase.db.execute("drop table book") - @testbase.supported('sqlite') + @testing.supported('sqlite') def test_goofy_sqlite(self): """test autoload of table where quotes were used with all the colnames. quirky in sqlite.""" testbase.db.execute("""CREATE TABLE "django_content_type" ( @@ -458,7 +457,7 @@ class ReflectionTest(PersistTest): finally: table.drop() - @testbase.supported('mssql') + @testing.supported('mssql') def testidentity(self): meta = MetaData(testbase.db) table = Table( @@ -576,7 +575,7 @@ class CreateDropTest(PersistTest): class SchemaTest(PersistTest): # this test should really be in the sql tests somewhere, not engine - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testiteration(self): metadata = MetaData() table1 = Table('table1', metadata, @@ -600,7 +599,7 @@ class SchemaTest(PersistTest): assert buf.index("CREATE TABLE someschema.table1") > -1 assert buf.index("CREATE TABLE someschema.table2") > -1 - @testbase.supported('mysql','postgres') + @testing.supported('mysql','postgres') def testcreate(self): engine = testbase.db schema = engine.dialect.get_default_schema_name(engine) diff --git a/test/engine/transaction.py b/test/engine/transaction.py index f86d0cbdb1..593a069a96 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -1,13 +1,12 @@ - import testbase -import unittest, sys, datetime, random, time, threading -import tables -db = testbase.db +import sys, time, threading + from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * + -class TransactionTest(testbase.PersistTest): +class TransactionTest(PersistTest): def setUpAll(self): global users, metadata metadata = MetaData() @@ -116,7 +115,7 @@ class TransactionTest(testbase.PersistTest): assert len(result.fetchall()) == 0 connection.close() - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testnestedsubtransactionrollback(self): connection = testbase.db.connect() transaction = connection.begin() @@ -133,7 +132,7 @@ class TransactionTest(testbase.PersistTest): ) connection.close() - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testnestedsubtransactioncommit(self): connection = testbase.db.connect() transaction = connection.begin() @@ -150,7 +149,7 @@ class TransactionTest(testbase.PersistTest): ) connection.close() - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testrollbacktosubtransaction(self): connection = testbase.db.connect() transaction = connection.begin() @@ -169,7 +168,7 @@ class TransactionTest(testbase.PersistTest): ) connection.close() - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def testtwophasetransaction(self): connection = testbase.db.connect() @@ -197,7 +196,7 @@ class TransactionTest(testbase.PersistTest): ) connection.close() - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def testmixedtransaction(self): connection = testbase.db.connect() @@ -230,7 +229,7 @@ class TransactionTest(testbase.PersistTest): ) connection.close() - @testbase.supported('postgres') + @testing.supported('postgres') def testtwophaserecover(self): # MySQL recovery doesn't currently seem to work correctly # Prepared transactions disappear when connections are closed and even @@ -262,7 +261,7 @@ class TransactionTest(testbase.PersistTest): ) connection2.close() -class AutoRollbackTest(testbase.PersistTest): +class AutoRollbackTest(PersistTest): def setUpAll(self): global metadata metadata = MetaData() @@ -270,7 +269,7 @@ class AutoRollbackTest(testbase.PersistTest): def tearDownAll(self): metadata.drop_all(testbase.db) - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testrollback_deadlock(self): """test that returning connections to the pool clears any object locks.""" conn1 = testbase.db.connect() @@ -289,10 +288,10 @@ class AutoRollbackTest(testbase.PersistTest): users.drop(conn2) conn2.close() -class TLTransactionTest(testbase.PersistTest): +class TLTransactionTest(PersistTest): def setUpAll(self): global users, metadata, tlengine - tlengine = create_engine(testbase.db_uri, strategy='threadlocal') + tlengine = create_engine(testbase.db.url, strategy='threadlocal') metadata = MetaData() users = Table('query_users', metadata, Column('user_id', INT, primary_key = True), @@ -402,7 +401,7 @@ class TLTransactionTest(testbase.PersistTest): finally: external_connection.close() - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testnesting(self): """tests nesting of tranacstions""" external_connection = tlengine.connect() @@ -496,7 +495,7 @@ class TLTransactionTest(testbase.PersistTest): c2.close() assert c1.connection.connection is not None -class ForUpdateTest(testbase.PersistTest): +class ForUpdateTest(PersistTest): def setUpAll(self): global counters, metadata metadata = MetaData() @@ -512,7 +511,7 @@ class ForUpdateTest(testbase.PersistTest): counters.drop(testbase.db) def increment(self, count, errors, update_style=True, delay=0.005): - con = db.connect() + con = testbase.db.connect() sel = counters.select(for_update=update_style, whereclause=counters.c.counter_id==1) @@ -539,7 +538,7 @@ class ForUpdateTest(testbase.PersistTest): con.close() - @testbase.supported('mysql', 'oracle', 'postgres') + @testing.supported('mysql', 'oracle', 'postgres') def testqueued_update(self): """Test SELECT FOR UPDATE with concurrent modifications. @@ -574,7 +573,7 @@ class ForUpdateTest(testbase.PersistTest): def overlap(self, ids, errors, update_style): sel = counters.select(for_update=update_style, whereclause=counters.c.counter_id.in_(*ids)) - con = db.connect() + con = testbase.db.connect() trans = con.begin() try: rows = con.execute(sel).fetchall() @@ -600,7 +599,7 @@ class ForUpdateTest(testbase.PersistTest): return errors - @testbase.supported('mysql', 'oracle', 'postgres') + @testing.supported('mysql', 'oracle', 'postgres') def testqueued_select(self): """Simple SELECT FOR UPDATE conflict test""" @@ -609,7 +608,7 @@ class ForUpdateTest(testbase.PersistTest): sys.stderr.write("Failure: %s\n" % e) self.assert_(len(errors) == 0) - @testbase.supported('oracle', 'postgres') + @testing.supported('oracle', 'postgres') def testnowait_select(self): """Simple SELECT FOR UPDATE NOWAIT conflict test""" diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py index da1726f6a6..e28c72cd73 100644 --- a/test/ext/activemapper.py +++ b/test/ext/activemapper.py @@ -1,15 +1,16 @@ import testbase +from datetime import datetime + from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore from sqlalchemy import and_, or_, exceptions from sqlalchemy import ForeignKey, String, Integer, DateTime, Table, Column from sqlalchemy.orm import clear_mappers, backref, create_session, class_mapper -from datetime import datetime -import sqlalchemy - import sqlalchemy.ext.activemapper as activemapper +import sqlalchemy +from testlib import * -class testcase(testbase.PersistTest): +class testcase(PersistTest): def setUpAll(self): clear_mappers() objectstore.clear() @@ -143,7 +144,7 @@ class testcase(testbase.PersistTest): self.assertEquals(len(person.addresses), 2) self.assertEquals(person.addresses[0].postal_code, '30338') - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_update(self): p1 = self.create_person_one() objectstore.flush() @@ -261,7 +262,7 @@ class testcase(testbase.PersistTest): self.assertEquals(Person.query.count(), 2) -class testmanytomany(testbase.PersistTest): +class testmanytomany(PersistTest): def setUpAll(self): clear_mappers() objectstore.clear() @@ -316,7 +317,7 @@ class testmanytomany(testbase.PersistTest): foo1.bazrel.append(baz1) assert (foo1.bazrel == [baz1]) -class testselfreferential(testbase.PersistTest): +class testselfreferential(PersistTest): def setUpAll(self): clear_mappers() objectstore.clear() diff --git a/test/ext/assignmapper.py b/test/ext/assignmapper.py index f3ef3d180d..31b3dd576f 100644 --- a/test/ext/assignmapper.py +++ b/test/ext/assignmapper.py @@ -1,12 +1,11 @@ -from testbase import PersistTest, AssertMixin import testbase from sqlalchemy import * from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper - from sqlalchemy.ext.assignmapper import assign_mapper from sqlalchemy.ext.sessioncontext import SessionContext -from testbase import Table, Column +from testlib import * + class AssignMapperTest(PersistTest): def setUpAll(self): diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py index 3156e42923..60362501e0 100644 --- a/test/ext/associationproxy.py +++ b/test/ext/associationproxy.py @@ -1,14 +1,11 @@ -from testbase import PersistTest -import sqlalchemy.util as util -import unittest import testbase + from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection from sqlalchemy.ext.associationproxy import * -from testbase import Table, Column +from testlib import * -db = testbase.db class DictCollection(dict): @collection.appender @@ -40,7 +37,7 @@ class _CollectionOperations(PersistTest): def setUp(self): collection_class = self.collection_class - metadata = MetaData(db) + metadata = MetaData(testbase.db) parents_table = Table('Parent', metadata, Column('id', Integer, primary_key=True), @@ -476,7 +473,7 @@ class CustomObjectTest(_CollectionOperations): class ScalarTest(PersistTest): def test_scalar_proxy(self): - metadata = MetaData(db) + metadata = MetaData(testbase.db) parents_table = Table('Parent', metadata, Column('id', Integer, primary_key=True), @@ -592,7 +589,7 @@ class ScalarTest(PersistTest): class LazyLoadTest(PersistTest): def setUp(self): - metadata = MetaData(db) + metadata = MetaData(testbase.db) parents_table = Table('Parent', metadata, Column('id', Integer, primary_key=True), diff --git a/test/ext/orderinglist.py b/test/ext/orderinglist.py index cf6ab038e4..d16e20da73 100644 --- a/test/ext/orderinglist.py +++ b/test/ext/orderinglist.py @@ -1,13 +1,10 @@ -from testbase import PersistTest -import sqlalchemy.util as util -import unittest, sys, os import testbase + from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.orderinglist import * -from testbase import Table, Column +from testlib import * -db = testbase.db metadata = None # order in whole steps @@ -54,7 +51,7 @@ class OrderingListTest(PersistTest): global metadata, slides_table, bullets_table, Slide, Bullet - metadata = MetaData(db) + metadata = MetaData(testbase.db) slides_table = Table('test_Slides', metadata, Column('id', Integer, primary_key=True), Column('name', String)) diff --git a/test/orm/association.py b/test/orm/association.py index 4bb8b97b4b..a2b8994188 100644 --- a/test/orm/association.py +++ b/test/orm/association.py @@ -2,9 +2,9 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * -class AssociationTest(testbase.PersistTest): +class AssociationTest(PersistTest): def setUpAll(self): global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation metadata = MetaData(testbase.db) @@ -139,7 +139,7 @@ class AssociationTest(testbase.PersistTest): sess.flush() self.assert_(item_keywords.count().scalar() == 0) -class AssociationTest2(testbase.PersistTest): +class AssociationTest2(PersistTest): def setUpAll(self): global table_originals, table_people, table_isauthor, metadata, Originals, People, IsAuthor metadata = MetaData(testbase.db) diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index 4187ad8def..652186b8e6 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -1,12 +1,11 @@ """eager loading unittests derived from mailing list-reported problems and trac tickets.""" -from testbase import PersistTest, AssertMixin, ORMTest import testbase +import random, datetime from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext -from testbase import Table, Column -import random, datetime +from testlib import * class EagerTest(AssertMixin): def setUpAll(self): @@ -204,7 +203,7 @@ class EagerTest2(AssertMixin): obj = session.query(Left).get_by(tag='tag1') print obj.middle.right[0] -class EagerTest3(testbase.ORMTest): +class EagerTest3(ORMTest): """test eager loading combined with nested SELECT statements, functions, and aggregates""" def define_tables(self, metadata): global datas, foo, stats @@ -272,7 +271,7 @@ class EagerTest3(testbase.ORMTest): # algorithms and there are repeated 'somedata' values in the list) assert verify_result == arb_result -class EagerTest4(testbase.ORMTest): +class EagerTest4(ORMTest): def define_tables(self, metadata): global departments, employees departments = Table('departments', metadata, @@ -324,7 +323,7 @@ class EagerTest4(testbase.ORMTest): assert q.count() == 2 assert q[0] is d2 -class EagerTest5(testbase.ORMTest): +class EagerTest5(ORMTest): """test the construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance""" def define_tables(self, metadata): @@ -572,8 +571,8 @@ class EagerTest7(ORMTest): i = ctx.current.query(Invoice).get(invoice_id) - self.echo(repr(c)) - self.echo(repr(i.company)) + print repr(c) + print repr(i.company) self.assert_(repr(c) == repr(i.company)) def testtwo(self): @@ -638,7 +637,7 @@ class EagerTest7(ORMTest): ctx.current.clear() a = ctx.current.query(Company).get(company_id) - self.echo(repr(a)) + print repr(a) # set up an invoice i1 = Invoice() @@ -667,7 +666,7 @@ class EagerTest7(ORMTest): ctx.current.clear() c = ctx.current.query(Company).get(company_id) - self.echo(repr(c)) + print repr(c) ctx.current.clear() @@ -675,7 +674,7 @@ class EagerTest7(ORMTest): assert repr(i.company) == repr(c), repr(i.company) + " does not match " + repr(c) -class EagerTest8(testbase.ORMTest): +class EagerTest8(ORMTest): def define_tables(self, metadata): global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t diff --git a/test/orm/attributes.py b/test/orm/attributes.py index e63860b8d7..9b5f738bf7 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -1,11 +1,9 @@ -from testbase import PersistTest -import sqlalchemy.util as util +import testbase +import pickle import sqlalchemy.orm.attributes as attributes from sqlalchemy.orm.collections import collection from sqlalchemy import exceptions -import unittest, sys, os -import pickle -import testbase +from testlib import * class MyTest(object):pass class MyTest2(object):pass diff --git a/test/orm/cascade.py b/test/orm/cascade.py index d43f069bcb..b832c427e0 100644 --- a/test/orm/cascade.py +++ b/test/orm/cascade.py @@ -1,12 +1,12 @@ -import testbase, tables -import unittest, sys, datetime +import testbase -from sqlalchemy.ext.sessioncontext import SessionContext from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from sqlalchemy.ext.sessioncontext import SessionContext +from testlib import * +import testlib.tables as tables -class O2MCascadeTest(testbase.AssertMixin): +class O2MCascadeTest(AssertMixin): def tearDown(self): tables.delete() @@ -114,7 +114,7 @@ class O2MCascadeTest(testbase.AssertMixin): sess = create_session() l = sess.query(tables.User).select() for u in l: - self.echo( repr(u.orders)) + print repr(u.orders) self.assert_result(l, data[0], *data[1:]) ids = (l[0].user_id, l[2].user_id) @@ -174,7 +174,7 @@ class O2MCascadeTest(testbase.AssertMixin): self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0) -class M2OCascadeTest(testbase.AssertMixin): +class M2OCascadeTest(AssertMixin): def tearDown(self): ctx.current.clear() for t in metadata.table_iterator(reverse=True): @@ -262,7 +262,7 @@ class M2OCascadeTest(testbase.AssertMixin): -class M2MCascadeTest(testbase.AssertMixin): +class M2MCascadeTest(AssertMixin): def setUpAll(self): global metadata, a, b, atob metadata = MetaData(testbase.db) @@ -337,7 +337,7 @@ class M2MCascadeTest(testbase.AssertMixin): assert b.count().scalar() == 0 assert a.count().scalar() == 0 -class UnsavedOrphansTest(testbase.ORMTest): +class UnsavedOrphansTest(ORMTest): """tests regarding pending entities that are orphans""" def define_tables(self, metadata): @@ -397,7 +397,7 @@ class UnsavedOrphansTest(testbase.ORMTest): assert a.address_id is None, "Error: address should not be persistent" -class UnsavedOrphansTest2(testbase.ORMTest): +class UnsavedOrphansTest2(ORMTest): """same test as UnsavedOrphans only three levels deep""" def define_tables(self, meta): @@ -457,7 +457,7 @@ class UnsavedOrphansTest2(testbase.ORMTest): assert item.id is None assert attr.id is None -class DoubleParentOrphanTest(testbase.AssertMixin): +class DoubleParentOrphanTest(AssertMixin): """test orphan detection for an entity with two parent relations""" def setUpAll(self): @@ -523,7 +523,7 @@ class DoubleParentOrphanTest(testbase.AssertMixin): assert False except exceptions.FlushError, e: assert True - - + + if __name__ == "__main__": testbase.main() diff --git a/test/orm/collection.py b/test/orm/collection.py index e5cd1e935f..1f4f649281 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -7,7 +7,7 @@ import sqlalchemy.orm.collections as collections from sqlalchemy.orm.collections import collection from sqlalchemy import util from operator import and_ - +from testlib import * class Canary(interfaces.AttributeExtension): def __init__(self): @@ -48,7 +48,7 @@ def dictable_entity(a=None, b=None, c=None): return Entity(a or str(_id), b or 'value %s' % _id, c) -class CollectionsTest(testbase.PersistTest): +class CollectionsTest(PersistTest): def _test_adapter(self, typecallable, creator=entity_maker, to_set=None): class Foo(object): @@ -980,7 +980,7 @@ class CollectionsTest(testbase.PersistTest): obj.attr[0] = e3 self.assert_(e3 in canary.data) -class DictHelpersTest(testbase.ORMTest): +class DictHelpersTest(ORMTest): def define_tables(self, metadata): global parents, children, Parent, Child diff --git a/test/orm/compile.py b/test/orm/compile.py index ef5faa21dd..23f04db856 100644 --- a/test/orm/compile.py +++ b/test/orm/compile.py @@ -1,10 +1,10 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * -class CompileTest(testbase.AssertMixin): +class CompileTest(AssertMixin): """test various mapper compilation scenarios""" def tearDownAll(self): clear_mappers() diff --git a/test/orm/cycles.py b/test/orm/cycles.py index bdbb146e9f..ce3065f777 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -1,13 +1,8 @@ -from testbase import PersistTest, AssertMixin, ORMTest -import unittest, sys, os +import testbase from sqlalchemy import * from sqlalchemy.orm import * -import StringIO -import testbase -from testbase import Table, Column - -from tables import * -import tables +from testlib import * +from testlib.tables import * """test cyclical mapper relationships. Many of the assertions are provided via running with postgres, which is strict about foreign keys. @@ -538,7 +533,7 @@ class OneToManyManyToOneTest(AssertMixin): sess.save(b) sess.save(p) - self.assert_sql(db, lambda: sess.flush(), [ + self.assert_sql(testbase.db, lambda: sess.flush(), [ ( "INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)", {'favorite_ball_id': None, 'data':'some data'} @@ -592,7 +587,7 @@ class OneToManyManyToOneTest(AssertMixin): ) ]) sess.delete(p) - self.assert_sql(db, lambda: sess.flush(), [ + self.assert_sql(testbase.db, lambda: sess.flush(), [ # heres the post update (which is a pre-update with deletes) ( "UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id", @@ -642,7 +637,7 @@ class OneToManyManyToOneTest(AssertMixin): sess = create_session() [sess.save(x) for x in [b,p,b2,b3,b4]] - self.assert_sql(db, lambda: sess.flush(), [ + self.assert_sql(testbase.db, lambda: sess.flush(), [ ( "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)", {'person_id':None, 'data':'some data'} @@ -721,7 +716,7 @@ class OneToManyManyToOneTest(AssertMixin): ]) sess.delete(p) - self.assert_sql(db, lambda: sess.flush(), [ + self.assert_sql(testbase.db, lambda: sess.flush(), [ ( "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id", lambda ctx:{'person_id': None, 'ball_id': b.id} @@ -833,7 +828,7 @@ class SelfReferentialPostUpdateTest(AssertMixin): remove_child(root, cats) # pre-trigger lazy loader on 'cats' to make the test easier cats.children - self.assert_sql(db, lambda: session.flush(), [ + self.assert_sql(testbase.db, lambda: session.flush(), [ ( "UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id", lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id} diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 396c28bf94..49ea65153b 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -1,9 +1,9 @@ """basic tests of eager loaded attributes""" +import testbase from sqlalchemy import * from sqlalchemy.orm import * -import testbase - +from testlib import * from fixtures import * from query import QueryTest @@ -438,7 +438,7 @@ class EagerTest(QueryTest): l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id) assert fixtures.user_address_result[1:2] == l.all() -class SelfReferentialEagerTest(testbase.ORMTest): +class SelfReferentialEagerTest(ORMTest): def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, diff --git a/test/orm/entity.py b/test/orm/entity.py index 813ed327cf..da76e8df05 100644 --- a/test/orm/entity.py +++ b/test/orm/entity.py @@ -1,13 +1,9 @@ -from testbase import PersistTest, AssertMixin -import unittest import testbase from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext -from testbase import Table, Column - -from tables import * -import tables +from testlib import * +from testlib.tables import * class EntityTest(AssertMixin): """tests mappers that are constructed based on "entity names", which allows the same class diff --git a/test/orm/fixtures.py b/test/orm/fixtures.py index 39b0da383f..4a7d41459f 100644 --- a/test/orm/fixtures.py +++ b/test/orm/fixtures.py @@ -1,5 +1,6 @@ +import testbase from sqlalchemy import * -from testbase import Table, Column +from testlib import * _recursion_stack = util.Set() class Base(object): diff --git a/test/orm/generative.py b/test/orm/generative.py index fea783d268..5106388d73 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -1,11 +1,9 @@ -from testbase import PersistTest, AssertMixin, ORMTest import testbase -import tables - from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exceptions -from testbase import Table, Column +from testlib import * +import testlib.tables as tables # TODO: these are more tests that should be updated to be part of test/orm/query.py @@ -40,7 +38,7 @@ class GenerativeQueryTest(PersistTest): assert res.order_by([Foo.c.bar])[0].bar == 5 assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 - @testbase.unsupported('mssql') + @testing.unsupported('mssql') def test_slice(self): sess = create_session() query = sess.query(Foo) @@ -54,7 +52,7 @@ class GenerativeQueryTest(PersistTest): assert list(query[-5:]) == orig[-5:] assert query[10:20][5] == orig[10:20][5] - @testbase.supported('mssql') + @testing.supported('mssql') def test_slice_mssql(self): sess = create_session() query = sess.query(Foo) @@ -71,23 +69,23 @@ class GenerativeQueryTest(PersistTest): assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29 assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29 - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_aggregate_1(self): # this one fails in mysql as the result comes back as a string query = create_session().query(Foo) assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435 - @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql') + @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_2(self): query = create_session().query(Foo) assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 - @testbase.supported('postgres', 'mysql', 'firebird', 'mssql') + @testing.supported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_2_int(self): query = create_session().query(Foo) assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14 - @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql') + @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_3(self): query = create_session().query(Foo) assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5 diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py index 7677311739..3b35b3713d 100644 --- a/test/orm/inheritance/abc_inheritance.py +++ b/test/orm/inheritance/abc_inheritance.py @@ -1,15 +1,14 @@ +import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column - from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE -import testbase +from testlib import * def produce_test(parent, child, direction): """produce a testcase for A->B->C inheritance with a self-referential relationship between two of the classes, using either one-to-many or many-to-one.""" - class ABCTest(testbase.ORMTest): + class ABCTest(ORMTest): def define_tables(self, meta): global ta, tb, tc ta = ["a", meta] diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index bcf269a876..c6cd43f439 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -1,10 +1,10 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * -class O2MTest(testbase.ORMTest): +class O2MTest(ORMTest): """deals with inheritance and one-to-many relationships""" def define_tables(self, metadata): global foo, bar, blub @@ -58,11 +58,11 @@ class O2MTest(testbase.ORMTest): sess.clear() l = sess.query(Blub).select() result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo) - self.echo(result) + print result self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') -class AddPropTest(testbase.ORMTest): +class AddPropTest(ORMTest): """testing that construction of inheriting mappers works regardless of when extra properties are added to the superclass mapper""" def define_tables(self, metadata): @@ -107,7 +107,7 @@ class AddPropTest(testbase.ORMTest): p.contenttype = ContentType() # TODO: assertion ?? -class EagerLazyTest(testbase.ORMTest): +class EagerLazyTest(ORMTest): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" def define_tables(self, metadata): @@ -149,7 +149,7 @@ class EagerLazyTest(testbase.ORMTest): self.assert_(len(q.selectfirst().eager) == 1) -class FlushTest(testbase.ORMTest): +class FlushTest(ORMTest): """test dependency sorting among inheriting mappers""" def define_tables(self, metadata): global users, roles, user_roles, admins @@ -238,7 +238,7 @@ class FlushTest(testbase.ORMTest): sess.flush() assert user_roles.count().scalar() == 1 -class DistinctPKTest(testbase.ORMTest): +class DistinctPKTest(ORMTest): """test the construction of mapper.primary_key when an inheriting relationship joins on a column other than primary key column.""" keep_data = True diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/concrete.py index d0d03210bf..167b25256d 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/concrete.py @@ -1,9 +1,9 @@ +import testbase from sqlalchemy import * from sqlalchemy.orm import * -import testbase -from testbase import Table, Column +from testlib import * -class ConcreteTest1(testbase.ORMTest): +class ConcreteTest1(ORMTest): def define_tables(self, metadata): global managers_table, engineers_table managers_table = Table('managers', metadata, diff --git a/test/orm/inheritance/magazine.py b/test/orm/inheritance/magazine.py index 509216d374..a0bf241485 100644 --- a/test/orm/inheritance/magazine.py +++ b/test/orm/inheritance/magazine.py @@ -1,7 +1,7 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * class BaseObject(object): @@ -68,7 +68,7 @@ class ClassifiedPage(MagazinePage): pass -class MagazineTest(testbase.ORMTest): +class MagazineTest(ORMTest): def define_tables(self, metadata): global publication_table, issue_table, location_table, location_name_table, magazine_table, \ page_table, magazine_page_table, classified_page_table, page_size_table diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py index 97885c1b08..df00f39d0b 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/manytomany.py @@ -1,10 +1,10 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * -class InheritTest(testbase.ORMTest): +class InheritTest(ORMTest): """deals with inheritance and many-to-many relationships""" def define_tables(self, metadata): global principals @@ -76,7 +76,7 @@ class InheritTest(testbase.ORMTest): sess.flush() # TODO: put an assertion -class InheritTest2(testbase.ORMTest): +class InheritTest2(ORMTest): """deals with inheritance and many-to-many relationships""" def define_tables(self, metadata): global foo, bar, foo_bar @@ -148,7 +148,7 @@ class InheritTest2(testbase.ORMTest): {'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])}, ) -class InheritTest3(testbase.ORMTest): +class InheritTest3(ORMTest): """deals with inheritance and many-to-many relationships""" def define_tables(self, metadata): global foo, bar, blub, bar_foo, blub_bar, blub_foo @@ -203,7 +203,7 @@ class InheritTest3(testbase.ORMTest): compare = repr(b) + repr(b.foos) sess.clear() l = sess.query(Bar).select() - self.echo(repr(l[0]) + repr(l[0].foos)) + print repr(l[0]) + repr(l[0].foos) self.assert_(repr(l[0]) + repr(l[0].foos) == compare) def testadvanced(self): @@ -243,11 +243,11 @@ class InheritTest3(testbase.ORMTest): sess.clear() l = sess.query(Blub).select() - self.echo(l) + print l self.assert_(repr(l[0]) == compare) sess.clear() x = sess.query(Blub).get_by(id=blubid) - self.echo(x) + print x self.assert_(repr(x) == compare) diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/poly_linked_list.py index 7858689b1d..7297002f52 100644 --- a/test/orm/inheritance/poly_linked_list.py +++ b/test/orm/inheritance/poly_linked_list.py @@ -1,10 +1,10 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * -class PolymorphicCircularTest(testbase.ORMTest): +class PolymorphicCircularTest(ORMTest): keep_mappers = True def define_tables(self, metadata): global Table1, Table1B, Table2, Table3, Data diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index a3395924f3..3eb2e032f0 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -1,10 +1,10 @@ """tests basic polymorphic mapper loading/saving, minimal relations""" import testbase +import sets from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -import sets +from testlib import * class Person(object): @@ -35,7 +35,7 @@ class Company(object): def __repr__(self): return "Company %s" % self.name -class PolymorphTest(testbase.ORMTest): +class PolymorphTest(ORMTest): def define_tables(self, metadata): global companies, people, engineers, managers, boss diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/polymorph2.py index d18539aa10..a2f9c4a5f0 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/polymorph2.py @@ -1,7 +1,7 @@ +import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -import testbase +from testlib import * class AttrSettable(object): @@ -11,7 +11,7 @@ class AttrSettable(object): return self.__class__.__name__ + "(%s)" % (hex(id(self))) -class RelationTest1(testbase.ORMTest): +class RelationTest1(ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): global people, managers @@ -91,7 +91,7 @@ class RelationTest1(testbase.ORMTest): print p, m, m.employee assert m.employee is p -class RelationTest2(testbase.ORMTest): +class RelationTest2(ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): global people, managers, data @@ -186,7 +186,7 @@ class RelationTest2(testbase.ORMTest): if usedata: assert m.data.data == 'ms data' -class RelationTest3(testbase.ORMTest): +class RelationTest3(ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): global people, managers, data @@ -289,7 +289,7 @@ for jointype in ["join1", "join2", "join3", "join4"]: setattr(RelationTest3, func.__name__, func) -class RelationTest4(testbase.ORMTest): +class RelationTest4(ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars people = Table('people', metadata, @@ -405,7 +405,7 @@ class RelationTest4(testbase.ORMTest): c = s.join("employee").filter(Person.name=="E4")[0] assert c.car_id==car1.car_id -class RelationTest5(testbase.ORMTest): +class RelationTest5(ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars people = Table('people', metadata, @@ -465,7 +465,7 @@ class RelationTest5(testbase.ORMTest): assert carlist[0].manager is None assert carlist[1].manager.person_id == car2.manager.person_id -class RelationTest6(testbase.ORMTest): +class RelationTest6(ORMTest): """test self-referential relationships on a single joined-table inheritance mapper""" def define_tables(self, metadata): global people, managers, data @@ -508,7 +508,7 @@ class RelationTest6(testbase.ORMTest): m2 = sess.query(Manager).get(m2.person_id) assert m.colleague is m2 -class RelationTest7(testbase.ORMTest): +class RelationTest7(ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars, offroad_cars cars = Table('cars', metadata, @@ -607,7 +607,7 @@ class RelationTest7(testbase.ORMTest): for p in r: assert p.car_id == p.car.car_id -class GenerativeTest(testbase.AssertMixin): +class GenerativeTest(AssertMixin): def setUpAll(self): # cars---owned by--- people (abstract) --- has a --- status # | ^ ^ | @@ -733,7 +733,7 @@ class GenerativeTest(testbase.AssertMixin): r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id)) assert str(list(r)) == "[Engineer E4, field X, status Status dead]" -class MultiLevelTest(testbase.ORMTest): +class MultiLevelTest(ORMTest): def define_tables(self, metadata): global table_Employee, table_Engineer, table_Manager table_Employee = Table( 'Employee', metadata, @@ -810,7 +810,7 @@ class MultiLevelTest(testbase.ORMTest): assert set(session.query( Engineer).select()) == set([b,c]) assert session.query( Manager).select() == [c] -class ManyToManyPolyTest(testbase.ORMTest): +class ManyToManyPolyTest(ORMTest): def define_tables(self, metadata): global base_item_table, item_table, base_item_collection_table, collection_table base_item_table = Table( @@ -860,7 +860,7 @@ class ManyToManyPolyTest(testbase.ORMTest): class_mapper(BaseItem) -class CustomPKTest(testbase.ORMTest): +class CustomPKTest(ORMTest): def define_tables(self, metadata): global t1, t2 t1 = Table('t1', metadata, diff --git a/test/orm/inheritance/productspec.py b/test/orm/inheritance/productspec.py index aff89f5b77..2459cd36e1 100644 --- a/test/orm/inheritance/productspec.py +++ b/test/orm/inheritance/productspec.py @@ -1,11 +1,11 @@ import testbase +from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * -from datetime import datetime -class InheritTest(testbase.ORMTest): +class InheritTest(ORMTest): """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" def define_tables(self, metadata): global products_table, specification_table, documents_table diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/single.py index 1432887e00..68fe821af0 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/single.py @@ -1,10 +1,10 @@ +import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -import testbase +from testlib import * -class SingleInheritanceTest(testbase.AssertMixin): +class SingleInheritanceTest(AssertMixin): def setUpAll(self): metadata = MetaData(testbase.db) global employees_table diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py index e9d77e09c6..6684c62881 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/lazy_relations.py @@ -1,9 +1,9 @@ """basic tests of lazy loaded attributes""" +import testbase from sqlalchemy import * from sqlalchemy.orm import * -import testbase - +from testlib import * from fixtures import * from query import QueryTest diff --git a/test/orm/lazytest1.py b/test/orm/lazytest1.py index 83694d3961..b5296120b3 100644 --- a/test/orm/lazytest1.py +++ b/test/orm/lazytest1.py @@ -1,10 +1,7 @@ -from testbase import PersistTest, AssertMixin import testbase -import unittest, sys, os from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -import datetime +from testlib import * class LazyTest(AssertMixin): def setUpAll(self): diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py index 3cd680fd2a..8b310f86c5 100644 --- a/test/orm/manytomany.py +++ b/test/orm/manytomany.py @@ -1,8 +1,8 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -import string +from testlib import * + class Place(object): '''represents a place''' @@ -27,7 +27,7 @@ class Transition(object): def __repr__(self): return object.__repr__(self)+ " " + repr(self.inputs) + " " + repr(self.outputs) -class M2MTest(testbase.ORMTest): +class M2MTest(ORMTest): def define_tables(self, metadata): global place place = Table('place', metadata, @@ -112,7 +112,7 @@ class M2MTest(testbase.ORMTest): for p in l: pp = p.places - self.echo("Place " + str(p) +" places " + repr(pp)) + print "Place " + str(p) +" places " + repr(pp) [sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7] sess.flush() @@ -178,7 +178,7 @@ class M2MTest(testbase.ORMTest): self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) -class M2MTest2(testbase.ORMTest): +class M2MTest2(ORMTest): def define_tables(self, metadata): global studentTbl studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True)) @@ -245,7 +245,7 @@ class M2MTest2(testbase.ORMTest): sess.flush() assert enrolTbl.count().scalar() == 0 -class M2MTest3(testbase.ORMTest): +class M2MTest3(ORMTest): def define_tables(self, metadata): global c, c2a1, c2a2, b, a c = Table('c', metadata, diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 5fc5e15a5a..f9c3aac841 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1,14 +1,14 @@ """tests general mapper operations with an emphasis on selecting/loading""" -from testbase import PersistTest, AssertMixin, ORMTest import testbase -import unittest, sys, os from sqlalchemy import * from sqlalchemy.orm import * import sqlalchemy.exceptions as exceptions from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt -from tables import * -import tables +from testlib import * +from testlib.tables import * +import testlib.tables as tables + class MapperSuperTest(AssertMixin): def setUpAll(self): @@ -168,7 +168,7 @@ class MapperTest(MapperSuperTest): u = q2.selectfirst(users.c.user_id==8) def go(): s.refresh(u) - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) def testexpire(self): """test the expire function""" @@ -300,7 +300,7 @@ class MapperTest(MapperSuperTest): # l = create_session().query(User).select(order_by=None) - @testbase.unsupported('firebird') + @testing.unsupported('firebird') def testfunction(self): """test mapping to a SELECT statement that has functions in it.""" s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')], @@ -313,7 +313,7 @@ class MapperTest(MapperSuperTest): assert l[0].concat == l[0].user_id * 2 == 14 assert l[1].concat == l[1].user_id * 2 == 16 - @testbase.unsupported('firebird') + @testing.unsupported('firebird') def testcount(self): """test the count function on Query. @@ -390,7 +390,7 @@ class MapperTest(MapperSuperTest): def go(): u = sess.query(User).options(eagerload('adlist')).get_by(user_name='jack') self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1])) - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) def testextensionoptions(self): sess = create_session() @@ -429,7 +429,7 @@ class MapperTest(MapperSuperTest): def go(): self.assert_result(l, User, *user_address_result) - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) def testeageroptionswithlimit(self): sess = create_session() @@ -441,7 +441,7 @@ class MapperTest(MapperSuperTest): def go(): assert u.user_id == 8 assert len(u.addresses) == 3 - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) sess.clear() @@ -450,7 +450,7 @@ class MapperTest(MapperSuperTest): u = sess.query(User).get_by(user_id=8) assert u.user_id == 8 assert len(u.addresses) == 3 - assert "tbl_row_count" not in self.capture_sql(db, go) + assert "tbl_row_count" not in self.capture_sql(testbase.db, go) def testlazyoptionswithlimit(self): sess = create_session() @@ -462,7 +462,7 @@ class MapperTest(MapperSuperTest): def go(): assert u.user_id == 8 assert len(u.addresses) == 3 - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) def testeagerdegrade(self): """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available""" @@ -475,7 +475,7 @@ class MapperTest(MapperSuperTest): def go(): l = sess.query(usermapper).select() self.assert_result(l, User, *user_address_result) - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) sess.clear() @@ -486,7 +486,7 @@ class MapperTest(MapperSuperTest): r = users.select().execute() l = usermapper.instances(r, sess) self.assert_result(l, User, *user_address_result) - self.assert_sql_count(db, go, 4) + self.assert_sql_count(testbase.db, go, 4) clear_mappers() @@ -513,7 +513,7 @@ class MapperTest(MapperSuperTest): def go(): l = sess.query(usermapper).select() self.assert_result(l, User, *user_all_result) - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) sess.clear() @@ -523,7 +523,7 @@ class MapperTest(MapperSuperTest): r = users.select().execute() l = usermapper.instances(r, sess) self.assert_result(l, User, *user_all_result) - self.assert_sql_count(db, go, 7) + self.assert_sql_count(testbase.db, go, 7) def testlazyoptions(self): @@ -535,7 +535,7 @@ class MapperTest(MapperSuperTest): l = sess.query(User).options(lazyload('addresses')).select() def go(): self.assert_result(l, User, *user_address_result) - self.assert_sql_count(db, go, 3) + self.assert_sql_count(testbase.db, go, 3) def testlatecompile(self): """tests mappers compiling late in the game""" @@ -549,7 +549,7 @@ class MapperTest(MapperSuperTest): u = sess.query(User).select() def go(): print u[0].orders[1].items[0].keywords[1] - self.assert_sql_count(db, go, 3) + self.assert_sql_count(testbase.db, go, 3) def testdeepoptions(self): mapper(User, users, @@ -567,7 +567,7 @@ class MapperTest(MapperSuperTest): u = sess.query(User).select() def go(): print u[0].orders[1].items[0].keywords[1] - self.assert_sql_count(db, go, 3) + self.assert_sql_count(testbase.db, go, 3) sess.clear() @@ -578,7 +578,7 @@ class MapperTest(MapperSuperTest): def go(): print u[0].orders[1].items[0].keywords[1] print "-------MARK2----------" - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) sess.clear() @@ -588,7 +588,7 @@ class MapperTest(MapperSuperTest): def go(): print u[0].orders[1].items[0].keywords[1] print "-------MARK3----------" - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) print "-------MARK4----------" sess.clear() @@ -598,7 +598,7 @@ class MapperTest(MapperSuperTest): print "-------MARK5----------" q3 = sess.query(User).options(eagerload('orders.items.keywords')) u = q3.select() - self.assert_sql_count(db, go, 2) + self.assert_sql_count(testbase.db, go, 2) class DeferredTest(MapperSuperTest): @@ -619,8 +619,8 @@ class DeferredTest(MapperSuperTest): o2 = l[2] print o2.description - orderby = str(orders.default_order_by()[0].compile(bind=db)) - self.assert_sql(db, go, [ + orderby = str(orders.default_order_by()[0].compile(bind=testbase.db)) + self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) ]) @@ -682,8 +682,8 @@ class DeferredTest(MapperSuperTest): assert o2.opened == 1 assert o2.userident == 7 assert o2.description == 'order 3' - orderby = str(orders.default_order_by()[0].compile(db)) - self.assert_sql(db, go, [ + orderby = str(orders.default_order_by()[0].compile(testbase.db)) + self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) ]) @@ -695,7 +695,7 @@ class DeferredTest(MapperSuperTest): o2.description = 'order 3' def go(): sess.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) def testcommitsstate(self): """test that when deferred elements are loaded via a group, they get the proper CommittedState @@ -717,7 +717,7 @@ class DeferredTest(MapperSuperTest): def go(): # therefore the flush() shouldnt actually issue any SQL sess.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) def testoptions(self): """tests using options on a mapper to create deferred and undeferred columns""" @@ -729,8 +729,8 @@ class DeferredTest(MapperSuperTest): l = q2.select() print l[2].user_id - orderby = str(orders.default_order_by()[0].compile(db)) - self.assert_sql(db, go, [ + orderby = str(orders.default_order_by()[0].compile(testbase.db)) + self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) ]) @@ -739,7 +739,7 @@ class DeferredTest(MapperSuperTest): def go(): l = q3.select() print l[3].user_id - self.assert_sql(db, go, [ + self.assert_sql(testbase.db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ]) @@ -759,8 +759,8 @@ class DeferredTest(MapperSuperTest): assert o2.opened == 1 assert o2.userident == 7 assert o2.description == 'order 3' - orderby = str(orders.default_order_by()[0].compile(db)) - self.assert_sql(db, go, [ + orderby = str(orders.default_order_by()[0].compile(testbase.db)) + self.assert_sql(testbase.db, go, [ ("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen, orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}), ]) @@ -779,7 +779,7 @@ class DeferredTest(MapperSuperTest): item = l[0].orders[1].items[1] def go(): print item.item_name - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) self.assert_(item.item_name == 'item 4') sess.clear() q2 = q.options(undefer('orders.items.item_name')) @@ -787,7 +787,7 @@ class DeferredTest(MapperSuperTest): item = l[0].orders[1].items[1] def go(): print item.item_name - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) self.assert_(item.item_name == 'item 4') class CompositeTypesTest(ORMTest): diff --git a/test/orm/memusage.py b/test/orm/memusage.py index 5834974ec5..26da7c010d 100644 --- a/test/orm/memusage.py +++ b/test/orm/memusage.py @@ -1,23 +1,18 @@ -from sqlalchemy import * -from sqlalchemy.orm import * -from sqlalchemy.orm import mapperlib, session, unitofwork, attributes -Mapper = mapperlib.Mapper -import gc import testbase -from testbase import Table, Column -import tables +import gc +from sqlalchemy import MetaData, Integer, String, ForeignKey +from sqlalchemy.orm import mapper, relation, clear_mappers, create_session +from sqlalchemy.orm.mapper import Mapper +from testlib import * class A(object):pass class B(object):pass -class MapperCleanoutTest(testbase.AssertMixin): +class MapperCleanoutTest(AssertMixin): """test that clear_mappers() removes everything related to the class. does not include classes that use the assignmapper extension.""" - def setUp(self): - global engine - engine = testbase.db - + def test_mapper_cleanup(self): for x in range(0, 5): self.do_test() @@ -35,7 +30,7 @@ class MapperCleanoutTest(testbase.AssertMixin): assert True def do_test(self): - metadata = MetaData(engine) + metadata = MetaData(testbase.db) table1 = Table("mytable", metadata, Column('col1', Integer, primary_key=True), diff --git a/test/orm/merge.py b/test/orm/merge.py index 66e21ed8d1..3dd0a95a47 100644 --- a/test/orm/merge.py +++ b/test/orm/merge.py @@ -1,10 +1,9 @@ -from testbase import PersistTest, AssertMixin import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -from tables import * -import tables +from testlib import * +from testlib.tables import * +import testlib.tables as tables class MergeTest(AssertMixin): """tests session.merge() functionality""" @@ -166,5 +165,3 @@ class MergeTest(AssertMixin): if __name__ == "__main__": testbase.main() - - diff --git a/test/orm/onetoone.py b/test/orm/onetoone.py index ada4c98cdf..e41fa1d20f 100644 --- a/test/orm/onetoone.py +++ b/test/orm/onetoone.py @@ -2,7 +2,7 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext -from testbase import Table, Column +from testlib import * class Jack(object): def __repr__(self): @@ -24,7 +24,7 @@ class Port(object): self.name=name self.description = description -class O2OTest(testbase.AssertMixin): +class O2OTest(AssertMixin): def setUpAll(self): global jack, port, metadata, ctx metadata = MetaData(testbase.db) diff --git a/test/orm/query.py b/test/orm/query.py index a32a1439e6..9d516f3780 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1,10 +1,10 @@ import testbase +import operator from sqlalchemy import * from sqlalchemy import ansisql from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * from fixtures import * -import operator class Base(object): def __init__(self, **kwargs): @@ -40,7 +40,7 @@ class Base(object): else: return True -class QueryTest(testbase.ORMTest): +class QueryTest(ORMTest): keep_mappers = True keep_data = True @@ -703,7 +703,7 @@ class CustomJoinTest(QueryTest): assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all() -class SelfReferentialJoinTest(testbase.ORMTest): +class SelfReferentialJoinTest(ORMTest): def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 2d32c8ac37..9fca22b244 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -1,14 +1,12 @@ import testbase -import unittest, sys, datetime +import datetime from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm import collections from sqlalchemy.orm.collections import collection -from testbase import Table, Column +from testlib import * -db = testbase.db - -class RelationTest(testbase.PersistTest): +class RelationTest(PersistTest): """this is essentially an extension of the "dependency.py" topological sort test. in this test, a table is dependent on two other tables that are otherwise unrelated to each other. the dependency sort must insure that this childmost table is below both parent tables in the outcome @@ -17,10 +15,8 @@ class RelationTest(testbase.PersistTest): to subtle differences in program execution, this test case was exposing the bug whereas the simpler tests were not.""" def setUpAll(self): - global tbl_a - global tbl_b - global tbl_c - global tbl_d + global metadata, tbl_a, tbl_b, tbl_c, tbl_d + metadata = MetaData() tbl_a = Table("tbl_a", metadata, Column("id", Integer, primary_key=True), @@ -89,7 +85,7 @@ class RelationTest(testbase.PersistTest): conn.drop(tbl_a) def tearDownAll(self): - testbase.metadata.tables.clear() + metadata.drop_all(testbase.db) def testDeleteRootTable(self): session.flush() @@ -101,7 +97,7 @@ class RelationTest(testbase.PersistTest): session.delete(c) # fails session.flush() -class RelationTest2(testbase.PersistTest): +class RelationTest2(PersistTest): """this test tests a relationship on a column that is included in multiple foreign keys, as well as a self-referential relationship on a composite key where one column in the foreign key is 'joined to itself'.""" @@ -218,7 +214,7 @@ class RelationTest2(testbase.PersistTest): assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' -class RelationTest3(testbase.PersistTest): +class RelationTest3(PersistTest): def setUpAll(self): global jobs, pageversions, pages, metadata, Job, Page, PageVersion, PageComment import datetime @@ -352,7 +348,7 @@ class RelationTest3(testbase.PersistTest): s.delete(j) s.flush() -class RelationTest4(testbase.ORMTest): +class RelationTest4(ORMTest): """test syncrules on foreign keys that are also primary""" def define_tables(self, metadata): global tableA, tableB @@ -500,7 +496,7 @@ class RelationTest4(testbase.ORMTest): assert a1 not in sess assert b1 not in sess -class RelationTest5(testbase.ORMTest): +class RelationTest5(ORMTest): """test a map to a select that relates to a map to the table""" def define_tables(self, metadata): global items @@ -556,7 +552,7 @@ class RelationTest5(testbase.ORMTest): assert old.id == new.id -class TypeMatchTest(testbase.ORMTest): +class TypeMatchTest(ORMTest): """test errors raised when trying to add items whose type is not handled by a relation""" def define_tables(self, metadata): global a, b, c, d @@ -674,7 +670,7 @@ class TypeMatchTest(testbase.ORMTest): except exceptions.AssertionError, err: assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B) -class TypedAssociationTable(testbase.ORMTest): +class TypedAssociationTable(ORMTest): def define_tables(self, metadata): global t1, t2, t3 @@ -724,7 +720,7 @@ class TypedAssociationTable(testbase.ORMTest): assert t3.count().scalar() == 1 # TODO: move these tests to either attributes.py test or its own module -class CustomCollectionsTest(testbase.ORMTest): +class CustomCollectionsTest(ORMTest): def define_tables(self, metadata): global sometable, someothertable sometable = Table('sometable', metadata, @@ -1005,7 +1001,7 @@ class CustomCollectionsTest(testbase.ORMTest): o = list(p2.children) assert len(o) == 3 -class ViewOnlyTest(testbase.ORMTest): +class ViewOnlyTest(ORMTest): """test a view_only mapping where a third table is pulled into the primary join condition, using overlapping PK column names (should not produce "conflicting column" error)""" def define_tables(self, metadata): @@ -1056,7 +1052,7 @@ class ViewOnlyTest(testbase.ORMTest): assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id]) assert set([x.id for x in c1.t2_view]) == set([c2b.id]) -class ViewOnlyTest2(testbase.ORMTest): +class ViewOnlyTest2(ORMTest): """test a view_only mapping where a third table is pulled into the primary join condition, using non-overlapping PK column names (should not produce "mapper has no column X" error)""" def define_tables(self, metadata): diff --git a/test/orm/session.py b/test/orm/session.py index eca48f836d..4332796737 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -1,14 +1,9 @@ -from testbase import AssertMixin import testbase -import unittest, sys, datetime - -import tables -from tables import * - -db = testbase.db from sqlalchemy import * from sqlalchemy.orm import * - +from testlib import * +from testlib.tables import * +import testlib.tables as tables class SessionTest(AssertMixin): def setUpAll(self): @@ -78,7 +73,7 @@ class SessionTest(AssertMixin): # then see if expunge fails session.expunge(u) - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def test_transaction(self): class User(object):pass mapper(User, users) @@ -95,7 +90,7 @@ class SessionTest(AssertMixin): assert conn1.execute("select count(1) from users").scalar() == 1 assert testbase.db.connect().execute("select count(1) from users").scalar() == 1 - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def test_autoflush(self): class User(object):pass mapper(User, users) @@ -114,7 +109,7 @@ class SessionTest(AssertMixin): assert conn1.execute("select count(1) from users").scalar() == 1 assert testbase.db.connect().execute("select count(1) from users").scalar() == 1 - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def test_autoflush_unbound(self): class User(object):pass mapper(User, users) @@ -163,7 +158,7 @@ class SessionTest(AssertMixin): trans.rollback() # rolls back assert len(sess.query(User).select()) == 0 - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def test_external_nested_transaction(self): class User(object):pass mapper(User, users) @@ -187,7 +182,7 @@ class SessionTest(AssertMixin): conn.close() raise - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def test_twophase(self): # TODO: mock up a failure condition here # to ensure a rollback succeeds @@ -226,7 +221,7 @@ class SessionTest(AssertMixin): sess.rollback() # rolls back assert len(sess.query(User).select()) == 0 - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def test_nested_transaction(self): class User(object):pass mapper(User, users) @@ -248,7 +243,7 @@ class SessionTest(AssertMixin): sess.commit() assert len(sess.query(User).select()) == 1 - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def test_nested_autotrans(self): class User(object):pass mapper(User, users) diff --git a/test/orm/sessioncontext.py b/test/orm/sessioncontext.py index f6cd8f9f48..7a60b47c7e 100644 --- a/test/orm/sessioncontext.py +++ b/test/orm/sessioncontext.py @@ -1,12 +1,10 @@ -from testbase import PersistTest, AssertMixin -import unittest, sys, os -from sqlalchemy.ext.sessioncontext import SessionContext -from sqlalchemy.orm.session import object_session, Session +import testbase from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.orm.session import object_session, Session +from testlib import * -import testbase -from testbase import Table, Column metadata = MetaData() users = Table('users', metadata, diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index b7b5847943..eb39549eea 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -1,15 +1,14 @@ -from testbase import PersistTest, AssertMixin -from sqlalchemy import * -from sqlalchemy.orm import * import testbase -from testbase import Table, Column import pickleable +from sqlalchemy import * +from sqlalchemy.orm import * from sqlalchemy.orm.mapper import global_extensions from sqlalchemy.orm import util as ormutil from sqlalchemy.ext.sessioncontext import SessionContext import sqlalchemy.ext.assignmapper as assignmapper -from tables import * -import tables +from testlib import * +from testlib.tables import * +from testlib import tables """tests unitofwork operations""" @@ -28,6 +27,7 @@ class UnitOfWorkTest(AssertMixin): class HistoryTest(UnitOfWorkTest): def setUpAll(self): + tables.metadata.bind = testbase.db UnitOfWorkTest.setUpAll(self) users.create() addresses.create() @@ -63,7 +63,7 @@ class VersioningTest(UnitOfWorkTest): UnitOfWorkTest.setUpAll(self) ctx.current.clear() global version_table - version_table = Table('version_test', MetaData(db), + version_table = Table('version_test', MetaData(testbase.db), Column('id', Integer, Sequence('version_test_seq'), primary_key=True ), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False) @@ -255,9 +255,9 @@ class MutableTypesTest(UnitOfWorkTest): ctx.current.flush() def go(): ctx.current.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) f1.value = unicode('someothervalue') - self.assert_sql(db, lambda: ctx.current.flush(), [ + self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ ( "UPDATE mutabletest SET value=:value WHERE mutabletest.id = :mutabletest_id", {'mutabletest_id': f1.id, 'value': u'someothervalue'} @@ -265,7 +265,7 @@ class MutableTypesTest(UnitOfWorkTest): ]) f1.value = unicode('hi') f1.data.x = 9 - self.assert_sql(db, lambda: ctx.current.flush(), [ + self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ ( "UPDATE mutabletest SET data=:data, value=:value WHERE mutabletest.id = :mutabletest_id", {'mutabletest_id': f1.id, 'value': u'hi', 'data':f1.data} @@ -283,7 +283,7 @@ class MutableTypesTest(UnitOfWorkTest): def go(): ctx.current.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) ctx.current.clear() @@ -291,12 +291,12 @@ class MutableTypesTest(UnitOfWorkTest): def go(): ctx.current.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) f2.data.y = 19 def go(): ctx.current.flush() - self.assert_sql_count(db, go, 1) + self.assert_sql_count(testbase.db, go, 1) ctx.current.clear() f3 = ctx.current.query(Foo).get_by(id=f1.id) @@ -305,7 +305,7 @@ class MutableTypesTest(UnitOfWorkTest): def go(): ctx.current.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) def testunicode(self): """test that two equivalent unicode values dont get flagged as changed. @@ -322,14 +322,14 @@ class MutableTypesTest(UnitOfWorkTest): f1.value = u'hi' def go(): ctx.current.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) class PKTest(UnitOfWorkTest): def setUpAll(self): UnitOfWorkTest.setUpAll(self) global table, table2, table3, metadata - metadata = MetaData(db) + metadata = MetaData(testbase.db) table = Table( 'multipk', metadata, Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True), @@ -357,7 +357,7 @@ class PKTest(UnitOfWorkTest): # not support on sqlite since sqlite's auto-pk generation only works with # single column primary keys - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testprimarykey(self): class Entry(object): pass @@ -479,7 +479,7 @@ class PassiveDeletesTest(UnitOfWorkTest): metadata.drop_all() UnitOfWorkTest.tearDownAll(self) - @testbase.unsupported('sqlite') + @testing.unsupported('sqlite') def testbasic(self): class MyClass(object): pass @@ -516,6 +516,7 @@ class DefaultTest(UnitOfWorkTest): defaults back from the engine.""" def setUpAll(self): UnitOfWorkTest.setUpAll(self) + db = testbase.db use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite') if use_string_defaults: @@ -610,7 +611,7 @@ class OneToManyTest(UnitOfWorkTest): a2 = Address() a2.email_address = 'lala@test.org' u.addresses.append(a2) - self.echo( repr(u.addresses)) + print repr(u.addresses) ctx.current.flush() usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall() @@ -660,7 +661,7 @@ class OneToManyTest(UnitOfWorkTest): u2.user_name = 'user2modified' u1.addresses.append(a3) del u1.addresses[0] - self.assert_sql(db, lambda: ctx.current.flush(), + self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ ( "UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id", @@ -832,7 +833,7 @@ class SaveTest(UnitOfWorkTest): # assert the first one retreives the same from the identity map nu = ctx.current.get(m, u.user_id) - self.echo( "U: " + repr(u) + "NU: " + repr(nu)) + print "U: " + repr(u) + "NU: " + repr(nu) self.assert_(u is nu) # clear out the identity map, so next get forces a SELECT @@ -913,7 +914,7 @@ class SaveTest(UnitOfWorkTest): u.user_name = "" def go(): ctx.current.flush() - self.assert_sql_count(db, go, 0) + self.assert_sql_count(testbase.db, go, 0) def testmultitable(self): """tests a save of an object where each instance spans two tables. also tests @@ -1039,7 +1040,7 @@ class ManyToOneTest(UnitOfWorkTest): objects[2].email_address = 'imnew@foo.bar' objects[3].user = User() objects[3].user.user_name = 'imnewlyadded' - self.assert_sql(db, lambda: ctx.current.flush(), [ + self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ ( "INSERT INTO users (user_name) VALUES (:user_name)", {'user_name': 'imnewlyadded'} @@ -1209,7 +1210,7 @@ class ManyToManyTest(UnitOfWorkTest): k = Keyword() k.name = 'yellow' objects[5].keywords.append(k) - self.assert_sql(db, lambda:ctx.current.flush(), [ + self.assert_sql(testbase.db, lambda:ctx.current.flush(), [ { "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id": {'item_name': 'item4updated', 'items_item_id': objects[4].item_id} @@ -1238,7 +1239,7 @@ class ManyToManyTest(UnitOfWorkTest): objects[2].keywords.append(k) dkid = objects[5].keywords[1].keyword_id del objects[5].keywords[1] - self.assert_sql(db, lambda:ctx.current.flush(), [ + self.assert_sql(testbase.db, lambda:ctx.current.flush(), [ ( "DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id", [{'item_id': objects[5].item_id, 'keyword_id': dkid}] @@ -1423,7 +1424,7 @@ class SaveTest2(UnitOfWorkTest): ctx.current.clear() clear_mappers() global meta, users, addresses - meta = MetaData(db) + meta = MetaData(testbase.db) users = Table('users', meta, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(20)), @@ -1455,7 +1456,7 @@ class SaveTest2(UnitOfWorkTest): a.user = User() a.user.user_name = elem['user_name'] objects.append(a) - self.assert_sql(db, lambda: ctx.current.flush(), [ + self.assert_sql(testbase.db, lambda: ctx.current.flush(), [ ( "INSERT INTO users (user_name) VALUES (:user_name)", {'user_name': 'thesub'} @@ -1494,30 +1495,32 @@ class SaveTest2(UnitOfWorkTest): ] ) -class SaveTest3(UnitOfWorkTest): +class SaveTest3(UnitOfWorkTest): def setUpAll(self): + global st3_metadata, t1, t2, t3 + UnitOfWorkTest.setUpAll(self) - global metadata, t1, t2, t3 - metadata = testbase.metadata - t1 = Table('items', metadata, + + st3_metadata = MetaData(testbase.db) + t1 = Table('items', st3_metadata, Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True), Column('item_name', VARCHAR(50)), ) - t3 = Table('keywords', metadata, + t3 = Table('keywords', st3_metadata, Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True), Column('name', VARCHAR(50)), ) - t2 = Table('assoc', metadata, + t2 = Table('assoc', st3_metadata, Column('item_id', INT, ForeignKey("items")), Column('keyword_id', INT, ForeignKey("keywords")), Column('foo', Boolean, default=True) ) - metadata.create_all() + st3_metadata.create_all() def tearDownAll(self): - metadata.drop_all() + st3_metadata.drop_all() UnitOfWorkTest.tearDownAll(self) def setUp(self): diff --git a/test/perf/cascade_speed.py b/test/perf/cascade_speed.py index dd095ab9aa..34d046381f 100644 --- a/test/perf/cascade_speed.py +++ b/test/perf/cascade_speed.py @@ -1,7 +1,7 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column +from testlib import * from timeit import Timer import sys diff --git a/test/perf/masscreate.py b/test/perf/masscreate.py index e603e2c002..346a725e35 100644 --- a/test/perf/masscreate.py +++ b/test/perf/masscreate.py @@ -1,8 +1,7 @@ # times how long it takes to create 26000 objects -import sys -sys.path.insert(0, './lib/') +import testbase -from sqlalchemy.attributes import * +from sqlalchemy.orm.attributes import * import time import gc diff --git a/test/perf/masscreate2.py b/test/perf/masscreate2.py index 3a68f3612d..2e29a63272 100644 --- a/test/perf/masscreate2.py +++ b/test/perf/masscreate2.py @@ -1,11 +1,9 @@ -import sys -sys.path.insert(0, './lib/') - +import testbase import gc import random, string -from sqlalchemy.attributes import * +from sqlalchemy.orm.attributes import * # with this test, run top. make sure the Python process doenst grow in size arbitrarily. diff --git a/test/perf/masseagerload.py b/test/perf/masseagerload.py index c2c0933a59..f1c0f292b0 100644 --- a/test/perf/masseagerload.py +++ b/test/perf/masseagerload.py @@ -1,64 +1,54 @@ -from testbase import PersistTest, AssertMixin -import unittest, sys, os +import testbase +import hotshot, hotshot.stats from sqlalchemy import * from sqlalchemy.orm import * -from testbase import Table, Column -import StringIO -import testbase -import gc -import time -import hotshot -import hotshot.stats - -db = testbase.db +from testlib import * NUM = 500 DIVISOR = 50 -class LoadTest(AssertMixin): - def setUpAll(self): - global items, meta,subitems - meta = MetaData(db) - items = Table('items', meta, - Column('item_id', Integer, primary_key=True), - Column('value', String(100))) - subitems = Table('subitems', meta, - Column('sub_id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('items.item_id')), - Column('value', String(100))) - meta.create_all() - def tearDownAll(self): - meta.drop_all() - def setUp(self): - clear_mappers() +meta = MetaData(testbase.db) +items = Table('items', meta, + Column('item_id', Integer, primary_key=True), + Column('value', String(100))) +subitems = Table('subitems', meta, + Column('sub_id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('items.item_id')), + Column('value', String(100))) + +class Item(object):pass +class SubItem(object):pass +mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)}) +mapper(SubItem, subitems) + +def load(): + global l + l = [] + for x in range(1,NUM/DIVISOR + 1): + l.append({'item_id':x, 'value':'this is item #%d' % x}) + #print l + items.insert().execute(*l) + for x in range(1, NUM/DIVISOR + 1): l = [] - for x in range(1,NUM/DIVISOR + 1): - l.append({'item_id':x, 'value':'this is item #%d' % x}) + for y in range(1, DIVISOR + 1): + z = ((x-1) * DIVISOR) + y + l.append({'sub_id':z,'value':'this is item #%d' % z, 'parent_id':x}) #print l - items.insert().execute(*l) - for x in range(1, NUM/DIVISOR + 1): - l = [] - for y in range(1, DIVISOR + 1): - z = ((x-1) * DIVISOR) + y - l.append({'sub_id':z,'value':'this is iteim #%d' % z, 'parent_id':x}) - #print l - subitems.insert().execute(*l) - def testload(self): - class Item(object):pass - class SubItem(object):pass - mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)}) - mapper(SubItem, subitems) - sess = create_session() - prof = hotshot.Profile("masseagerload.prof") - prof.start() - query = sess.query(Item) - l = query.select() - print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems" - prof.stop() - prof.close() - stats = hotshot.stats.load("masseagerload.prof") - stats.sort_stats('time', 'calls') - stats.print_stats() - -if __name__ == "__main__": - testbase.main() + subitems.insert().execute(*l) + +@profiling.profiled('masseagerload', always=True) +def masseagerload(session): + query = session.query(Item) + l = query.select() + print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems" + +def all(): + meta.create_all() + try: + load() + masseagerload(create_session()) + finally: + meta.drop_all() + +if __name__ == '__main__': + all() diff --git a/test/perf/massload.py b/test/perf/massload.py index 02b847599d..92cf0fe920 100644 --- a/test/perf/massload.py +++ b/test/perf/massload.py @@ -1,14 +1,10 @@ -from testbase import PersistTest, AssertMixin -import unittest, sys, os -from sqlalchemy import * -import sqlalchemy.orm.attributes as attributes -from testbase import Table, Column -import StringIO import testbase -import gc import time - -db = testbase.db +#import gc +#import sqlalchemy.orm.attributes as attributes +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * NUM = 2500 @@ -21,7 +17,7 @@ for best results, dont run with sqlite :memory: database, and keep an eye on top class LoadTest(AssertMixin): def setUpAll(self): global items, meta - meta = MetaData(db) + meta = MetaData(testbase.db) items = Table('items', meta, Column('item_id', Integer, primary_key=True), Column('value', String(100))) @@ -29,8 +25,6 @@ class LoadTest(AssertMixin): def tearDownAll(self): items.drop() def setUp(self): - objectstore.clear() - clear_mappers() for x in range(1,NUM/500+1): l = [] for y in range(x*500-500 + 1, x*500 + 1): diff --git a/test/perf/masssave.py b/test/perf/masssave.py index 53f13119ed..dd03f39629 100644 --- a/test/perf/masssave.py +++ b/test/perf/masssave.py @@ -1,21 +1,16 @@ -from testbase import PersistTest, AssertMixin -import unittest, sys, os -from sqlalchemy import * -import sqlalchemy.attributes as attributes -from testbase import Table, Column -import StringIO import testbase -import gc -import sqlalchemy.orm.session import types -db = testbase.db +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * + NUM = 250000 class SaveTest(AssertMixin): def setUpAll(self): global items, metadata - metadata = MetaData(db) + metadata = MetaData(testbase.db) items = Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('value', String(100))) diff --git a/test/perf/ormsession.py b/test/perf/ormsession.py new file mode 100644 index 0000000000..e41ec0629c --- /dev/null +++ b/test/perf/ormsession.py @@ -0,0 +1,196 @@ +import testbase +import time +from datetime import datetime + +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * +from testlib.profiling import profiled + +class Item(object): + def __repr__(self): + return 'Item<#%s "%s">' % (self.id, self.name) +class SubItem(object): + def __repr__(self): + return 'SubItem<#%s "%s">' % (self.id, self.name) +class Customer(object): + def __repr__(self): + return 'Customer<#%s "%s">' % (self.id, self.name) +class Purchase(object): + def __repr__(self): + return 'Purchase<#%s "%s">' % (self.id, self.purchase_date) + +items, subitems, customers, purchases, purchaseitems = \ + None, None, None, None, None + +metadata = MetaData() + +@profiled('table') +def define_tables(): + global items, subitems, customers, purchases, purchaseitems + items = Table('items', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(100)), + test_needs_acid=True) + subitems = Table('subitems', metadata, + Column('id', Integer, primary_key=True), + Column('item_id', Integer, ForeignKey('items.id'), + nullable=False), + Column('name', String(100), PassiveDefault('no name')), + test_needs_acid=True) + customers = Table('customers', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(100)), + *[Column("col_%s" % chr(i), String(64), default=str(i)) + for i in range(97,117)], + **dict(test_needs_acid=True)) + purchases = Table('purchases', metadata, + Column('id', Integer, primary_key=True), + Column('customer_id', Integer, + ForeignKey('customers.id'), nullable=False), + Column('purchase_date', DateTime, + default=datetime.now), + test_needs_acid=True) + purchaseitems = Table('purchaseitems', metadata, + Column('purchase_id', Integer, + ForeignKey('purchases.id'), + nullable=False, primary_key=True), + Column('item_id', Integer, ForeignKey('items.id'), + nullable=False, primary_key=True), + test_needs_acid=True) + +@profiled('mapper') +def setup_mappers(): + mapper(Item, items, properties={ + 'subitems': relation(SubItem, backref='item', lazy=True) + }) + mapper(SubItem, subitems) + mapper(Customer, customers, properties={ + 'purchases': relation(Purchase, lazy=True, backref='customer') + }) + mapper(Purchase, purchases, properties={ + 'items': relation(Item, lazy=True, secondary=purchaseitems) + }) + +@profiled('inserts') +def insert_data(): + q_items = 1000 + q_sub_per_item = 10 + q_customers = 1000 + + con = testbase.db.connect() + + transaction = con.begin() + data, subdata = [], [] + for item_id in xrange(1, q_items + 1): + data.append({'name': "item number %s" % item_id}) + for subitem_id in xrange(1, (item_id % q_sub_per_item) + 1): + subdata.append({'item_id': item_id, + 'name': "subitem number %s" % subitem_id}) + if item_id % 100 == 0: + items.insert().execute(*data) + subitems.insert().execute(*subdata) + del data[:] + del subdata[:] + if data: + items.insert().execute(*data) + if subdata: + subitems.insert().execute(*subdata) + transaction.commit() + + transaction = con.begin() + data = [] + for customer_id in xrange(1, q_customers): + data.append({'name': "customer number %s" % customer_id}) + if customer_id % 100 == 0: + customers.insert().execute(*data) + del data[:] + if data: + customers.insert().execute(*data) + transaction.commit() + + transaction = con.begin() + data, subdata = [], [] + order_t = int(time.time()) - (5000 * 5 * 60) + current = xrange(1, q_customers) + step, purchase_id = 1, 0 + while current: + next = [] + for customer_id in current: + order_t += 300 + data.append({'customer_id': customer_id, + 'purchase_date': datetime.fromtimestamp(order_t)}) + purchase_id += 1 + for item_id in range(customer_id % 200, customer_id + 1, 200): + if item_id != 0: + subdata.append({'purchase_id': purchase_id, + 'item_id': item_id}) + if customer_id % 10 > step: + next.append(customer_id) + + if len(data) >= 100: + purchases.insert().execute(*data) + if subdata: + purchaseitems.insert().execute(*subdata) + del data[:] + del subdata[:] + step, current = step + 1, next + + if data: + purchases.insert().execute(*data) + if subdata: + purchaseitems.insert().execute(*subdata) + transaction.commit() + +@profiled('queries') +def run_queries(): + session = create_session() + # no explicit transaction here. + + # build a report of summarizing the last 50 purchases and + # the top 20 items from all purchases + + q = session.query(Purchase). \ + limit(50).order_by(desc(Purchase.purchase_date)). \ + options(eagerload('items'), eagerload('items.subitems'), + eagerload('customer')) + + report = [] + # "write" the report. pretend it's going to a web template or something, + # the point is to actually pull data through attributes and collections. + for purchase in q: + report.append(purchase.customer.name) + report.append(purchase.customer.col_a) + report.append(purchase.purchase_date) + for item in purchase.items: + report.append(item.name) + report.extend([s.name for s in item.subitems]) + + # pull a report of the top 20 items of all time + _item_id = purchaseitems.c.item_id + top_20_q = select([func.distinct(_item_id).label('id')], + group_by=[purchaseitems.c.purchase_id, _item_id], + order_by=[desc(func.count(_item_id)), _item_id], + limit=20) + q2 = session.query(Item).filter(Item.id.in_(top_20_q)) + + for num, item in enumerate(q2): + report.append("number %s: %s" % (num + 1, item.name)) + +def setup_db(): + metadata.drop_all() + metadata.create_all() +def cleanup_db(): + metadata.drop_all() + +@profiled('all') +def main(): + metadata.bind = testbase.db + define_tables() + setup_mappers() + setup_db() + insert_data() + run_queries() + cleanup_db() + +main() diff --git a/test/perf/poolload.py b/test/perf/poolload.py index d096f1c67f..1a2ff6978b 100644 --- a/test/perf/poolload.py +++ b/test/perf/poolload.py @@ -1,10 +1,11 @@ # load test of connection pool +import testbase from sqlalchemy import * import sqlalchemy.pool as pool import thread,time -db = create_engine('mysql://scott:tiger@127.0.0.1/test', pool_timeout=30, echo_pool=True) +db = create_engine(testbase.db.url, pool_timeout=30, echo_pool=True) metadata = MetaData(db) users_table = Table('users', metadata, @@ -18,7 +19,7 @@ users_table.insert().execute([{'user_name':'user#%d' % i, 'password':'pw#%d' % i def runfast(): while True: - c = db.connection_provider._pool.connect() + c = db.pool.connect() time.sleep(.5) c.close() # result = users_table.select(limit=100).execute() diff --git a/test/perf/threaded_compile.py b/test/perf/threaded_compile.py index 38fe145cd1..13ec31fd61 100644 --- a/test/perf/threaded_compile.py +++ b/test/perf/threaded_compile.py @@ -2,11 +2,12 @@ when additional mappers are created while the existing collection is being compiled.""" +import testbase from sqlalchemy import * from sqlalchemy.orm import * import thread, time from sqlalchemy.orm import mapperlib -from testbase import Table, Column +from testlib import * meta = MetaData('sqlite:///foo.db') diff --git a/test/perf/wsgi.py b/test/perf/wsgi.py index e40171f07f..d22eeb76a0 100644 --- a/test/perf/wsgi.py +++ b/test/perf/wsgi.py @@ -1,54 +1,55 @@ #!/usr/bin/python +"""Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop.""" +import testbase from sqlalchemy import * -import sqlalchemy.pool as pool +from sqlalchemy.orm import * import thread -from sqlalchemy import exceptions -from testbase import Table, Column +from testlib import * + +port = 8000 import logging logging.basicConfig() logging.getLogger('sqlalchemy.pool').setLevel(logging.INFO) threadids = set() -#meta = MetaData('postgres://scott:tiger@127.0.0.1/test') - -#meta = MetaData('mysql://scott:tiger@localhost/test', poolclass=pool.SingletonThreadPool) -meta = MetaData('mysql://scott:tiger@localhost/test') +meta = MetaData(testbase.db) foo = Table('foo', meta, Column('id', Integer, primary_key=True), Column('data', String(30))) - -meta.drop_all() -meta.create_all() - -data = [] -for x in range(1,500): - data.append({'id':x,'data':"this is x value %d" % x}) -foo.insert().execute(data) - class Foo(object): pass - mapper(Foo, foo) -root = './' -port = 8000 +def prep(): + meta.drop_all() + meta.create_all() + + data = [] + for x in range(1,500): + data.append({'id':x,'data':"this is x value %d" % x}) + foo.insert().execute(data) def serve(environ, start_response): + start_response("200 OK", [('Content-type', 'text/plain')]) sess = create_session() l = sess.query(Foo).select() - - start_response("200 OK", [('Content-type','text/plain')]) threadids.add(thread.get_ident()) - print "sending response on thread", thread.get_ident(), " total threads ", len(threadids) - return ["\n".join([x.data for x in l])] + + print ("sending response on thread", thread.get_ident(), + " total threads ", len(threadids)) + return [str("\n".join([x.data for x in l]))] if __name__ == '__main__': - from wsgiutils import wsgiServer - server = wsgiServer.WSGIServer (('localhost', port), {'/': serve}) - print "Server listening on port %d" % port - server.serve_forever() + from wsgiref import simple_server + try: + prep() + server = simple_server.make_server('localhost', port, serve) + print "Server listening on port %d" % port + server.serve_forever() + finally: + meta.drop_all() diff --git a/test/sql/alltests.py b/test/sql/alltests.py index ebb3fe34c6..a669a25f2d 100644 --- a/test/sql/alltests.py +++ b/test/sql/alltests.py @@ -32,7 +32,5 @@ def suite(): alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) return alltests - - if __name__ == '__main__': testbase.main(suite()) diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index bcf8849644..493545b228 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -1,10 +1,10 @@ -import sys import testbase +import sys from sqlalchemy import * -from testbase import Table, Column +from testlib import * -class CaseTest(testbase.PersistTest): +class CaseTest(PersistTest): def setUpAll(self): metadata = MetaData(testbase.db) diff --git a/test/sql/constraints.py b/test/sql/constraints.py index b5f1a17414..1c2bd1b57c 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -1,9 +1,8 @@ import testbase from sqlalchemy import * -from testbase import Table, Column -import sys +from testlib import * -class ConstraintTest(testbase.AssertMixin): +class ConstraintTest(AssertMixin): def setUp(self): global metadata @@ -53,7 +52,7 @@ class ConstraintTest(testbase.AssertMixin): ) metadata.create_all() - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_check_constraint(self): foo = Table('foo', metadata, Column('id', Integer, primary_key=True), diff --git a/test/sql/defaults.py b/test/sql/defaults.py index a9dd2f5ad2..5cbdc3e3fb 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -1,20 +1,17 @@ -from testbase import PersistTest -import sqlalchemy.util as util -import unittest, sys, os -import sqlalchemy.schema as schema import testbase from sqlalchemy import * +import sqlalchemy.util as util +import sqlalchemy.schema as schema from sqlalchemy.orm import mapper, create_session -from testbase import Table, Column -import sqlalchemy - -db = testbase.db +from testlib import * class DefaultTest(PersistTest): def setUpAll(self): global t, f, f2, ts, currenttime, metadata - metadata = MetaData(testbase.db) + + db = testbase.db + metadata = MetaData(db) x = {'x':50} def mydefault(): x['x'] += 1 @@ -80,7 +77,7 @@ class DefaultTest(PersistTest): t.delete().execute() def teststandalone(self): - c = db.engine.contextual_connect() + c = testbase.db.engine.contextual_connect() x = c.execute(t.c.col1.default) y = t.c.col2.default.execute() z = c.execute(t.c.col3.default) @@ -97,7 +94,7 @@ class DefaultTest(PersistTest): t.insert().execute() ctexec = currenttime.scalar() - self.echo("Currenttime "+ repr(ctexec)) + print "Currenttime "+ repr(ctexec) l = t.select().execute() self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)]) @@ -112,7 +109,7 @@ class DefaultTest(PersistTest): pk = r.last_inserted_ids()[0] t.update(t.c.col1==pk).execute(col4=None, col5=None) ctexec = currenttime.scalar() - self.echo("Currenttime "+ repr(ctexec)) + print "Currenttime "+ repr(ctexec) l = t.select(t.c.col1==pk).execute() l = l.fetchone() self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False)) @@ -127,7 +124,7 @@ class DefaultTest(PersistTest): l = l.fetchone() self.assert_(l['col3'] == 55) - @testbase.supported('postgres') + @testing.supported('postgres') def testpassiveoverride(self): """primarily for postgres, tests that when we get a primary key column back from reflecting a table which has a default value on it, we pre-execute @@ -155,7 +152,7 @@ class DefaultTest(PersistTest): testbase.db.execute("drop table speedy_users", None) class AutoIncrementTest(PersistTest): - @testbase.supported('postgres', 'mysql') + @testing.supported('postgres', 'mysql') def testnonautoincrement(self): meta = MetaData(testbase.db) nonai_table = Table("aitest", meta, @@ -219,7 +216,7 @@ class AutoIncrementTest(PersistTest): class SequenceTest(PersistTest): - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def setUpAll(self): global cartitems, sometable, metadata metadata = MetaData(testbase.db) @@ -236,7 +233,7 @@ class SequenceTest(PersistTest): metadata.create_all() - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def testseqnonpk(self): """test sequences fire off as defaults on non-pk columns""" sometable.insert().execute(name="somename") @@ -246,7 +243,7 @@ class SequenceTest(PersistTest): (2, "someother", 2), ] - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def testsequence(self): cartitems.insert().execute(description='hi') cartitems.insert().execute(description='there') @@ -255,7 +252,7 @@ class SequenceTest(PersistTest): cartitems.select().execute().fetchall() - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def test_implicit_sequence_exec(self): s = Sequence("my_sequence", metadata=MetaData(testbase.db)) s.create() @@ -265,7 +262,7 @@ class SequenceTest(PersistTest): finally: s.drop() - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def teststandalone_explicit(self): s = Sequence("my_sequence") s.create(bind=testbase.db) @@ -275,7 +272,7 @@ class SequenceTest(PersistTest): finally: s.drop(testbase.db) - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def test_checkfirst(self): s = Sequence("my_sequence") s.create(testbase.db, checkfirst=False) @@ -283,12 +280,12 @@ class SequenceTest(PersistTest): s.drop(testbase.db, checkfirst=False) s.drop(testbase.db, checkfirst=True) - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def teststandalone2(self): x = cartitems.c.cart_id.sequence.execute() self.assert_(1 <= x <= 4) - @testbase.supported('postgres', 'oracle') + @testing.supported('postgres', 'oracle') def tearDownAll(self): metadata.drop_all() diff --git a/test/sql/generative.py b/test/sql/generative.py index 82f3b175b1..357a66fcdf 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -1,9 +1,9 @@ import testbase from sql import select as selecttests - from sqlalchemy import * +from testlib import * -class TraversalTest(testbase.AssertMixin): +class TraversalTest(AssertMixin): """test ClauseVisitor's traversal, particularly its ability to copy and modify a ClauseElement in place.""" @@ -269,10 +269,7 @@ class SelectTest(selecttests.SQLTest): s = select([t2], t2.c.col1==t1.c.col1, correlate=False) s = s.correlate(t1).order_by(t2.c.col3) self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3") - - - - - + + if __name__ == '__main__': - testbase.main() \ No newline at end of file + testbase.main() diff --git a/test/sql/labels.py b/test/sql/labels.py index 6029c83088..5cca4160e5 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -1,11 +1,12 @@ import testbase from sqlalchemy import * -from testbase import Table, Column +from testlib import * + # TODO: either create a mock dialect with named paramstyle and a short identifier length, # or find a way to just use sqlite dialect and make those changes -class LabelTypeTest(testbase.PersistTest): +class LabelTypeTest(PersistTest): def test_type(self): m = MetaData() t = Table('sometable', m, @@ -14,7 +15,7 @@ class LabelTypeTest(testbase.PersistTest): assert isinstance(t.c.col1.label('hi').type, Integer) assert isinstance(select([t.c.col2], scalar=True).label('lala').type, Float) -class LongLabelsTest(testbase.PersistTest): +class LongLabelsTest(PersistTest): def setUpAll(self): global metadata, table1, maxlen metadata = MetaData(testbase.db) diff --git a/test/sql/query.py b/test/sql/query.py index 05b6d0419d..61734d5d92 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -1,14 +1,8 @@ -from testbase import PersistTest import testbase -import unittest, sys, datetime - -import sqlalchemy.databases.sqlite as sqllite - -import tables +import datetime from sqlalchemy import * -from sqlalchemy.engine import ResultProxy, RowProxy from sqlalchemy import exceptions -from testbase import Table, Column +from testlib import * class QueryTest(PersistTest): @@ -189,7 +183,7 @@ class QueryTest(PersistTest): r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r)) - @testbase.unsupported('mssql') + @testing.unsupported('mssql') def testselectlimitoffset(self): users.insert().execute(user_id=1, user_name='john') users.insert().execute(user_id=2, user_name='jack') @@ -203,7 +197,7 @@ class QueryTest(PersistTest): r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r==[(6, 'ralph'), (7, 'fido')]) - @testbase.supported('mssql') + @testing.supported('mssql') def testselectlimitoffset_mssql(self): try: r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall() @@ -211,7 +205,7 @@ class QueryTest(PersistTest): except exceptions.InvalidRequestError: pass - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_scalar_select(self): """test that scalar subqueries with labels get their type propigated to the result set.""" # mysql and/or mysqldb has a bug here, type isnt propigated for scalar subquery. @@ -346,7 +340,7 @@ class QueryTest(PersistTest): finally: meta.drop_all() - @testbase.supported('postgres') + @testing.supported('postgres') def test_functions_with_cols(self): # TODO: shouldnt this work on oracle too ? x = testbase.db.func.current_date().execute().scalar() @@ -379,7 +373,7 @@ class QueryTest(PersistTest): self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id']) self.assertEqual(r.values(), ['foo', 1]) - @testbase.unsupported('oracle', 'firebird') + @testing.unsupported('oracle', 'firebird') def test_column_accessor_shadow(self): meta = MetaData(testbase.db) shadowed = Table('test_shadowed', meta, @@ -409,7 +403,7 @@ class QueryTest(PersistTest): finally: shadowed.drop(checkfirst=True) - @testbase.supported('mssql') + @testing.supported('mssql') def test_fetchid_trigger(self): meta = MetaData(testbase.db) t1 = Table('t1', meta, @@ -435,7 +429,7 @@ class QueryTest(PersistTest): con.execute("""drop trigger paj""") meta.drop_all() - @testbase.supported('mssql') + @testing.supported('mssql') def test_insertid_schema(self): meta = MetaData(testbase.db) con = testbase.db.connect() @@ -448,7 +442,7 @@ class QueryTest(PersistTest): tbl.drop() con.execute('drop schema paj') - @testbase.supported('mssql') + @testing.supported('mssql') def test_insertid_reserved(self): meta = MetaData(testbase.db) table = Table( @@ -567,7 +561,7 @@ class CompoundTest(PersistTest): assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_intersect(self): i = intersect( select([t2.c.col3, t2.c.col4]), @@ -576,7 +570,7 @@ class CompoundTest(PersistTest): assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - @testbase.unsupported('mysql', 'oracle') + @testing.unsupported('mysql', 'oracle') def test_except_style1(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -585,7 +579,7 @@ class CompoundTest(PersistTest): ), select([t2.c.col3, t2.c.col4])) assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testbase.unsupported('mysql', 'oracle') + @testing.unsupported('mysql', 'oracle') def test_except_style2(self): e = except_(union( select([t1.c.col3, t1.c.col4]), @@ -595,7 +589,7 @@ class CompoundTest(PersistTest): assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] - @testbase.unsupported('sqlite', 'mysql', 'oracle') + @testing.unsupported('sqlite', 'mysql', 'oracle') def test_except_style3(self): # aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc e = except_( @@ -607,7 +601,7 @@ class CompoundTest(PersistTest): ) self.assertEquals(e.execute().fetchall(), [('ccc',)]) - @testbase.unsupported('sqlite', 'mysql', 'oracle') + @testing.unsupported('sqlite', 'mysql', 'oracle') def test_union_union_all(self): e = union_all( select([t1.c.col3]), @@ -618,7 +612,7 @@ class CompoundTest(PersistTest): ) self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)]) - @testbase.unsupported('mysql') + @testing.unsupported('mysql') def test_composite(self): u = intersect( select([t2.c.col3, t2.c.col4]), diff --git a/test/sql/quote.py b/test/sql/quote.py index 08cef8e5f9..2fdf9dba0c 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -1,7 +1,6 @@ -from testbase import PersistTest import testbase from sqlalchemy import * -from testbase import Table, Column +from testlib import * class QuoteTest(PersistTest): @@ -80,7 +79,7 @@ class QuoteTest(PersistTest): assert t1.c.UcCol.case_sensitive is False assert t2.c.normalcol.case_sensitive is False - @testbase.unsupported('oracle') + @testing.unsupported('oracle') def testlabels(self): """test the quoting of labels. diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index a74f3e0f84..e0da96a81d 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -1,9 +1,9 @@ -from sqlalchemy import * -from testbase import Table, Column import testbase +from sqlalchemy import * +from testlib import * -class FoundRowsTest(testbase.AssertMixin): +class FoundRowsTest(AssertMixin): """tests rowcount functionality""" def setUpAll(self): metadata = MetaData(testbase.db) diff --git a/test/sql/select.py b/test/sql/select.py index 3d5996df9d..18709f55d9 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -1,9 +1,8 @@ -from testbase import PersistTest import testbase +import re, operator from sqlalchemy import * from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql -from testbase import Table, Column -import unittest, re, operator +from testlib import * # the select test now tests almost completely with TableClause/ColumnClause objects, @@ -55,7 +54,7 @@ addresses = table('addresses', class SQLTest(PersistTest): def runtest(self, clause, result, dialect = None, params = None, checkparams = None): c = clause.compile(parameters=params, dialect=dialect) - self.echo("\nSQL String:\n" + str(c) + repr(c.get_params())) + print "\nSQL String:\n" + str(c) + repr(c.get_params()) cc = re.sub(r'\n', '', str(c)) self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'") if checkparams is not None: diff --git a/test/sql/selectable.py b/test/sql/selectable.py index bcf70bd206..dcc8550747 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -3,14 +3,10 @@ useable primary keys and foreign keys. Full relational algebra depends on every selectable unit behaving nicely with others..""" import testbase -import unittest, sys, datetime from sqlalchemy import * -from testbase import Table, Column - -db = testbase.db -metadata = MetaData(db) - +from testlib import * +metadata = MetaData() table = Table('table1', metadata, Column('col1', Integer, primary_key=True), Column('col2', String(20)), @@ -26,7 +22,7 @@ table2 = Table('table2', metadata, Column('coly', Integer), ) -class SelectableTest(testbase.AssertMixin): +class SelectableTest(AssertMixin): def testdistance(self): s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) @@ -172,7 +168,7 @@ class SelectableTest(testbase.AssertMixin): self.assert_(criterion.compare(j.onclause)) -class PrimaryKeyTest(testbase.AssertMixin): +class PrimaryKeyTest(AssertMixin): def test_join_pk_collapse_implicit(self): """test that redundant columns in a join get 'collapsed' into a minimal primary key, which is the root column along a chain of foreign key relationships.""" diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 59b9da3884..28e7db3a3e 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -1,16 +1,12 @@ -from testbase import PersistTest, AssertMixin import testbase import pickleable +import datetime, os from sqlalchemy import * -import string,datetime, re, sys, os import sqlalchemy.engine.url as url -import sqlalchemy.types from sqlalchemy.databases import mssql, oracle, mysql -from testbase import Table, Column +from testlib import * -db = testbase.db - class MyType(types.TypeEngine): def get_col_spec(self): return "VARCHAR(100)" @@ -108,7 +104,7 @@ class OverrideTest(PersistTest): def setUpAll(self): global users - users = Table('type_users', MetaData(db), + users = Table('type_users', MetaData(testbase.db), Column('user_id', Integer, primary_key = True), # totall custom type Column('goofy', MyType, nullable = False), @@ -139,6 +135,7 @@ class ColumnsTest(AssertMixin): 'float_column': 'float_column NUMERIC(25, 2)' } + db = testbase.db if not db.name=='sqlite' and not db.name=='oracle': expectedResults['float_column'] = 'float_column FLOAT(25)' @@ -158,7 +155,7 @@ class UnicodeTest(AssertMixin): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" def setUpAll(self): global unicode_table - metadata = MetaData(db) + metadata = MetaData(testbase.db) unicode_table = Table('unicode_table', metadata, Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True), Column('unicode_varchar', Unicode(250)), @@ -177,49 +174,49 @@ class UnicodeTest(AssertMixin): unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - self.echo(repr(x['unicode_varchar'])) - self.echo(repr(x['unicode_text'])) - self.echo(repr(x['plain_varchar'])) + print repr(x['unicode_varchar']) + print repr(x['unicode_text']) + print repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) if isinstance(x['plain_varchar'], unicode): # SQLLite and MSSQL return non-unicode data as unicode - self.assert_(db.name in ('sqlite', 'mssql')) + self.assert_(testbase.db.name in ('sqlite', 'mssql')) self.assert_(x['plain_varchar'] == unicodedata) - self.echo("it's %s!" % db.name) + print "it's %s!" % testbase.db.name else: self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata) def testengineparam(self): """tests engine-wide unicode conversion""" - prev_unicode = db.engine.dialect.convert_unicode + prev_unicode = testbase.db.engine.dialect.convert_unicode try: - db.engine.dialect.convert_unicode = True + testbase.db.engine.dialect.convert_unicode = True rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n' unicodedata = rawdata.decode('utf-8') unicode_table.insert().execute(unicode_varchar=unicodedata, unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - self.echo(repr(x['unicode_varchar'])) - self.echo(repr(x['unicode_text'])) - self.echo(repr(x['plain_varchar'])) + print repr(x['unicode_varchar']) + print repr(x['unicode_text']) + print repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) finally: - db.engine.dialect.convert_unicode = prev_unicode + testbase.db.engine.dialect.convert_unicode = prev_unicode - @testbase.unsupported('oracle') + @testing.unsupported('oracle') def testlength(self): """checks the database correctly understands the length of a unicode string""" teststr = u'aaa\x1234' - self.assert_(db.func.length(teststr).scalar() == len(teststr)) + self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr)) class BinaryTest(AssertMixin): def setUpAll(self): global binary_table - binary_table = Table('binary_table', MetaData(db), + binary_table = Table('binary_table', MetaData(testbase.db), Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True), Column('data', Binary), Column('data_slice', Binary(100)), @@ -270,6 +267,7 @@ class DateTest(AssertMixin): def setUpAll(self): global users_with_date, insert_data + db = testbase.db if db.engine.name == 'oracle': import sqlalchemy.databases.oracle as oracle insert_data = [ @@ -309,7 +307,8 @@ class DateTest(AssertMixin): collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)), Column('user_date', Date), Column('user_time', Time)] - users_with_date = Table('query_users_with_date', MetaData(db), *collist) + users_with_date = Table('query_users_with_date', + MetaData(testbase.db), *collist) users_with_date.create() insert_dicts = [dict(zip(fnames, d)) for d in insert_data] @@ -327,7 +326,7 @@ class DateTest(AssertMixin): def testtextdate(self): - x = db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall() + x = testbase.db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall() print repr(x) self.assert_(isinstance(x[0][0], datetime.datetime)) @@ -336,7 +335,11 @@ class DateTest(AssertMixin): #print repr(x) def testdate2(self): - t = Table('testdate', testbase.metadata, Column('id', Integer, Sequence('datetest_id_seq', optional=True), primary_key=True), + meta = MetaData(testbase.db) + t = Table('testdate', meta, + Column('id', Integer, + Sequence('datetest_id_seq', optional=True), + primary_key=True), Column('adate', Date), Column('adatetime', DateTime)) t.create() try: diff --git a/test/sql/unicode.py b/test/sql/unicode.py index cfd2354da9..f882c2a5f8 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -4,10 +4,10 @@ import testbase from sqlalchemy import * from sqlalchemy.orm import mapper, relation, create_session, eagerload -from testbase import Table, Column +from testlib import * -class UnicodeSchemaTest(testbase.PersistTest): +class UnicodeSchemaTest(PersistTest): def setUpAll(self): global unicode_bind, metadata, t1, t2 diff --git a/test/testbase.py b/test/testbase.py index c11ca1457e..1195db3400 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,504 +1,14 @@ -"""base import for all test cases. Patches in enhancements to unittest.TestCase, -instruments SQLAlchemy dialect/engine to track SQL statements for assertion purposes, -provides base test classes for common test scenarios.""" +"""First import for all test cases, sets sys.path and loads configuration.""" -import sys -import coverage +__all__ = 'db', -import os, unittest, StringIO, re, ConfigParser, optparse +import sys, os, logging sys.path.insert(0, os.path.join(os.getcwd(), 'lib')) +logging.basicConfig() -db = None -metadata = None -db_uri = None -echo = True -table_options = {} - -# redefine sys.stdout so all those print statements go to the echo func -local_stdout = sys.stdout -class Logger(object): - def write(self, msg): - if echo: - local_stdout.write(msg) - def flush(self): - pass - -def echo_text(text): - print text - -def parse_argv(): - # we are using the unittest main runner, so we are just popping out the - # arguments we need instead of using our own getopt type of thing - global db, db_uri, metadata - - DBTYPE = 'sqlite' - PROXY = False - - base_config = """ -[db] -sqlite=sqlite:///:memory: -sqlite_file=sqlite:///querytest.db -postgres=postgres://scott:tiger@127.0.0.1:5432/test -mysql=mysql://scott:tiger@127.0.0.1:3306/test -oracle=oracle://scott:tiger@127.0.0.1:1521 -oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 -mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test -firebird=firebird://sysdba:s@localhost/tmp/test.fdb -""" - config = ConfigParser.ConfigParser() - config.readfp(StringIO.StringIO(base_config)) - config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) - - parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") - parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)") - parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (%s)" % ', '.join(config.options('db'))) - parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool (asserts only one connection used)") - parser.add_option("--verbose", action="store_true", dest="verbose", help="enable stdout echoing/printing") - parser.add_option("--quiet", action="store_true", dest="quiet", help="suppress unittest output") - parser.add_option("--log-info", action="append", dest="log_info", help="turn on info logging for (multiple OK)") - parser.add_option("--log-debug", action="append", dest="log_debug", help="turn on debug logging for (multiple OK)") - parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)") - parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running") - parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)") - parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG") - parser.add_option("--require", action="append", dest="require", help="Require a particular driver or module version", default=[]) - parser.add_option("--mysql-engine", action="store", dest="mysql_engine", help="Use the specified MySQL storage engine for all tables, default is a db-default/InnoDB combo.", default=None) - parser.add_option("--table-option", action="append", dest="tableopts", help="Add a dialect-specific table option, key=value", default=[]) - - (options, args) = parser.parse_args() - sys.argv[1:] = args - - if options.dburi: - db_uri = param = options.dburi - DBTYPE = db_uri[:db_uri.index(':')] - elif options.db: - DBTYPE = param = options.db - - if options.require or (config.has_section('require') and - config.items('require')): - try: - import pkg_resources - except ImportError: - raise "setuptools is required for version requirements" - - cmdline = [] - for requirement in options.require: - pkg_resources.require(requirement) - cmdline.append(re.split('\s*(=)', requirement, 1)[0]) - - if config.has_section('require'): - for label, requirement in config.items('require'): - if not label == DBTYPE or label.startswith('%s.' % DBTYPE): - continue - seen = [c for c in cmdline if requirement.startswith(c)] - if seen: - continue - pkg_resources.require(requirement) - - opts = {} - if (None == db_uri): - if DBTYPE not in config.options('db'): - raise ("Could not create engine. specify --db <%s> to " - "test runner." % '|'.join(config.options('db'))) - - db_uri = config.get('db', DBTYPE) - - if not db_uri: - raise "Could not create engine. specify --db to test runner." - - global table_options - for spec in options.tableopts: - key, value = spec.split('=') - table_options[key] = value - - if options.mysql_engine: - table_options['mysql_engine'] = options.mysql_engine - - - global echo - echo = options.verbose and not options.quiet - - global quiet - quiet = options.quiet - - global with_coverage - with_coverage = options.coverage - if with_coverage: - coverage.erase() - coverage.start() - - from sqlalchemy import engine, schema - - if options.serverside: - opts['server_side_cursors'] = True - - if options.enginestrategy is not None: - opts['strategy'] = options.enginestrategy - if options.mockpool: - db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts) - else: - db = engine.create_engine(db_uri, **opts) - - # decorate the dialect's create_execution_context() method - # to produce a wrapper - create_context = db.dialect.create_execution_context - def create_exec_context(*args, **kwargs): - return ExecutionContextWrapper(create_context(*args, **kwargs)) - db.dialect.create_execution_context = create_exec_context - - if options.topological: - from sqlalchemy.orm import unitofwork - from sqlalchemy import topological - class RevQueueDepSort(topological.QueueDependencySorter): - def __init__(self, tuples, allitems): - self.tuples = list(tuples) - self.allitems = list(allitems) - self.tuples.reverse() - self.allitems.reverse() - topological.QueueDependencySorter = RevQueueDepSort - unitofwork.DependencySorter = RevQueueDepSort - - import logging - logging.basicConfig() - if options.log_info is not None: - for elem in options.log_info: - logging.getLogger(elem).setLevel(logging.INFO) - if options.log_debug is not None: - for elem in options.log_debug: - logging.getLogger(elem).setLevel(logging.DEBUG) - metadata = schema.MetaData(db) - -def unsupported(*dbs): - """a decorator that marks a test as unsupported by one or more database implementations""" - - def decorate(func): - name = db.name - for d in dbs: - if d == name: - def lala(self): - echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'") - lala.__name__ = func.__name__ - return lala - else: - return func - return decorate - -def supported(*dbs): - """a decorator that marks a test as supported by one or more database implementations""" - - def decorate(func): - name = db.name - for d in dbs: - if d == name: - return func - else: - def lala(self): - echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'") - lala.__name__ = func.__name__ - return lala - return decorate - -def Table(*args, **kw): - """A schema.Table wrapper/hook for dialect-specific tweaks.""" - - test_opts = dict([(k,kw.pop(k)) for k in kw.keys() - if k.startswith('test_')]) - - kw.update(table_options) - - if db.engine.name == 'mysql': - if 'mysql_engine' not in kw and 'mysql_type' not in kw: - if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: - kw['mysql_engine'] = 'InnoDB' - - return schema.Table(*args, **kw) - -def Column(*args, **kw): - """A schema.Column wrapper/hook for dialect-specific tweaks.""" - - # TODO: a Column that creates a Sequence automatically for PK columns, - # which would help Oracle tests - return schema.Column(*args, **kw) - -class PersistTest(unittest.TestCase): - - def __init__(self, *args, **params): - unittest.TestCase.__init__(self, *args, **params) - - def echo(self, text): - """DEPRECATED. use print """ - echo_text(text) - - - def setUpAll(self): - pass - def tearDownAll(self): - pass - - def shortDescription(self): - """overridden to not return docstrings""" - return None - -class AssertMixin(PersistTest): - """given a list-based structure of keys/properties which represent information within an object structure, and - a list of actual objects, asserts that the list of objects corresponds to the structure.""" - - def assert_result(self, result, class_, *objects): - result = list(result) - if echo: - print repr(result) - self.assert_list(result, class_, objects) - - def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), "result list is not the same size as test list, for class " + class_.__name__) - for i in range(0, len(list)): - self.assert_row(class_, result[i], list[i]) - - def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_)) - for key, value in desc.iteritems(): - if isinstance(value, tuple): - if isinstance(value[1], list): - self.assert_list(getattr(rowobj, key), value[0], value[1]) - else: - self.assert_row(value[0], getattr(rowobj, key), value[1]) - else: - self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value)) - - def assert_sql(self, db, callable_, list, with_sequences=None): - global testdata - testdata = TestData() - if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'): - testdata.set_assert_list(self, with_sequences) - else: - testdata.set_assert_list(self, list) - try: - callable_() - finally: - testdata.set_assert_list(None, None) - - def assert_sql_count(self, db, callable_, count): - global testdata - testdata = TestData() - try: - callable_() - finally: - self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count)) - - def capture_sql(self, db, callable_): - global testdata - testdata = TestData() - buffer = StringIO.StringIO() - testdata.buffer = buffer - try: - callable_() - return buffer.getvalue() - finally: - testdata.buffer = None - -class ORMTest(AssertMixin): - keep_mappers = False - keep_data = False - def setUpAll(self): - global _otest_metadata - _otest_metadata = MetaData(db) - self.define_tables(_otest_metadata) - _otest_metadata.create_all() - self.insert_data() - def define_tables(self, _otest_metadata): - raise NotImplementedError() - def insert_data(self): - pass - def get_metadata(self): - return _otest_metadata - def tearDownAll(self): - clear_mappers() - _otest_metadata.drop_all() - def tearDown(self): - if not self.keep_mappers: - clear_mappers() - if not self.keep_data: - for t in _otest_metadata.table_iterator(reverse=True): - t.delete().execute().close() - -class TestData(object): - """tracks SQL expressions as theyre executed via an instrumented ExecutionContext.""" - - def __init__(self): - self.set_assert_list(None, None) - self.sql_count = 0 - self.buffer = None - - def set_assert_list(self, unittest, list): - self.unittest = unittest - self.assert_list = list - if list is not None: - self.assert_list.reverse() - -testdata = TestData() - -class ExecutionContextWrapper(object): - """instruments the ExecutionContext created by the Engine so that SQL expressions - can be tracked.""" - - def __init__(self, ctx): - self.__dict__['ctx'] = ctx - def __getattr__(self, key): - return getattr(self.ctx, key) - def __setattr__(self, key, value): - setattr(self.ctx, key, value) - - def post_exec(self): - ctx = self.ctx - statement = unicode(ctx.compiled) - statement = re.sub(r'\n', '', ctx.statement) - if testdata.buffer is not None: - testdata.buffer.write(statement + "\n") - - if testdata.assert_list is not None: - assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement - item = testdata.assert_list[-1] - if not isinstance(item, dict): - item = testdata.assert_list.pop() - else: - # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in - # multiple orderings - if not item.has_key('_converted'): - for key in item.keys(): - ckey = self.convert_statement(key) - item[ckey] = item[key] - if ckey != key: - del item[key] - item['_converted'] = True - try: - entry = item.pop(statement) - if len(item) == 1: - testdata.assert_list.pop() - item = (statement, entry) - except KeyError: - assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement) - - (query, params) = item - if callable(params): - params = params(ctx) - if params is not None and isinstance(params, list) and len(params) == 1: - params = params[0] - - if isinstance(ctx.compiled_parameters, sql.ClauseParameters): - parameters = ctx.compiled_parameters.get_original_dict() - elif isinstance(ctx.compiled_parameters, list): - parameters = [p.get_original_dict() for p in ctx.compiled_parameters] - - query = self.convert_statement(query) - if db.engine.name == 'mssql' and statement.endswith('; select scope_identity()'): - statement = statement[:-25] - testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - testdata.sql_count += 1 - self.ctx.post_exec() - - def convert_statement(self, query): - paramstyle = self.ctx.dialect.paramstyle - if paramstyle == 'named': - pass - elif paramstyle =='pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - repl = None - if paramstyle=='qmark': - repl = "?" - elif paramstyle=='format': - repl = r"%s" - elif paramstyle=='numeric': - repl = None - query = re.sub(r':([\w_]+)', repl, query) - return query - -class TTestSuite(unittest.TestSuite): - """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality""" - def __init__(self, tests=()): - if len(tests) >0 and isinstance(tests[0], PersistTest): - self._initTest = tests[0] - else: - self._initTest = None - unittest.TestSuite.__init__(self, tests) - - def do_run(self, result): - """nice job unittest ! you switched __call__ and run() between py2.3 and 2.4 thereby - making straight subclassing impossible !""" - for test in self._tests: - if result.shouldStop: - break - test(result) - return result - - def run(self, result): - return self(result) - - def __call__(self, result): - try: - if self._initTest is not None: - self._initTest.setUpAll() - except: - result.addError(self._initTest, self.__exc_info()) - pass - try: - return self.do_run(result) - finally: - try: - if self._initTest is not None: - self._initTest.tearDownAll() - except: - result.addError(self._initTest, self.__exc_info()) - pass - - def __exc_info(self): - """Return a version of sys.exc_info() with the traceback frame - minimised; usually the top level of the traceback frame is not - needed. - ripped off out of unittest module since its double __ - """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) - -unittest.TestLoader.suiteClass = TTestSuite - -parse_argv() - -import sqlalchemy -from sqlalchemy import schema, MetaData, sql -from sqlalchemy.orm import clear_mappers - -def runTests(suite): - sys.stdout = Logger() - runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) - try: - return runner.run(suite) - finally: - if with_coverage: - global echo - echo=True - coverage.stop() - coverage.report(list(covered_files()), show_missing=False) - -def covered_files(): - for rec in os.walk(os.path.dirname(sqlalchemy.__file__)): - for x in rec[2]: - if x.endswith('.py'): - yield os.path.join(rec[0], x) - -def main(suite=None): - - if not suite: - if len(sys.argv[1:]): - suite =unittest.TestLoader().loadTestsFromNames(sys.argv[1:], __import__('__main__')) - else: - suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__')) - - result = runTests(suite) - sys.exit(not result.wasSuccessful()) +import testlib.config +testlib.config.configure() +from testlib.testing import main +db = testlib.config.db diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py new file mode 100644 index 0000000000..ff5c4c125e --- /dev/null +++ b/test/testlib/__init__.py @@ -0,0 +1,11 @@ +"""Enhance unittest and instrument SQLAlchemy classes for testing. + +Load after sqlalchemy imports to use instrumented stand-ins like Table. +""" + +import testlib.config +from testlib.schema import Table, Column +import testlib.testing as testing +from testlib.testing import PersistTest, AssertMixin, ORMTest +import testlib.profiling + diff --git a/test/testlib/config.py b/test/testlib/config.py new file mode 100644 index 0000000000..ca24a72560 --- /dev/null +++ b/test/testlib/config.py @@ -0,0 +1,255 @@ +import optparse, os, sys, ConfigParser, StringIO +logging, require = None, None + +__all__ = 'parser', 'configure', 'options', + +db, db_uri, db_type, db_label = None, None, None, None + +options = None +file_config = None + +base_config = """ +[db] +sqlite=sqlite:///:memory: +sqlite_file=sqlite:///querytest.db +postgres=postgres://scott:tiger@127.0.0.1:5432/test +mysql=mysql://scott:tiger@127.0.0.1:3306/test +oracle=oracle://scott:tiger@127.0.0.1:1521 +oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 +mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test +firebird=firebird://sysdba:s@localhost/tmp/test.fdb +""" + +parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") + +def configure(): + global options, config + global getopts_options, file_config + + file_config = ConfigParser.ConfigParser() + file_config.readfp(StringIO.StringIO(base_config)) + file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) + + # Opt parsing can fire immediate actions, like logging and coverage + (options, args) = parser.parse_args() + sys.argv[1:] = args + + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + return options, file_config + +def _log(option, opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + +def _start_coverage(option, opt_str, value, parser): + import sys, atexit, coverage + true_out = sys.stdout + + def _iter_covered_files(): + import sqlalchemy + for rec in os.walk(os.path.dirname(sqlalchemy.__file__)): + for x in rec[2]: + if x.endswith('.py'): + yield os.path.join(rec[0], x) + def _stop(): + coverage.stop() + true_out.write("\nPreparing coverage report...\n") + coverage.report(list(_iter_covered_files()), + show_missing=False, ignore_errors=False, + file=true_out) + atexit.register(_stop) + coverage.erase() + coverage.start() + +def _list_dbs(*args): + print "Available --db options (use --dburi to override)" + for macro in sorted(file_config.options('db')): + print "%20s\t%s" % (macro, file_config.get('db', macro)) + sys.exit(0) + +opt = parser.add_option +opt("--verbose", action="store_true", dest="verbose", + help="enable stdout echoing/printing") +opt("--quiet", action="store_true", dest="quiet", help="suppress output") +opt("--log-info", action="callback", callback=_log, + help="turn on info logging for (multiple OK)") +opt("--log-debug", action="callback", callback=_log, + help="turn on debug logging for (multiple OK)") +opt("--require", action="append", dest="require", default=[], + help="require a particular driver or module version (multiple OK)") +opt("--db", action="store", dest="db", default="sqlite", + help="Use prefab database uri") +opt('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") +opt("--dburi", action="store", dest="dburi", + help="Database uri (overrides --db)") +opt("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") +opt("--enginestrategy", action="store", dest="enginestrategy", default=None, + help="Engine strategy (plain or threadlocal, defaults toplain)") +opt("--reversetop", action="store_true", dest="reversetop", default=False, + help="Reverse the collection ordering for topological sorts (helps " + "reveal dependency issues)") +opt("--serverside", action="store_true", dest="serverside", + help="Turn on server side cursors for PG") +opt("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") +opt("--table-option", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") +opt("--coverage", action="callback", callback=_start_coverage, + help="Dump a full coverage report after running tests") +opt("--profile", action="append", dest="profile_targets", default=[], + help="Enable a named profile target (multiple OK.)") +opt("--profile-sort", action="store", dest="profile_sort", default=None, + help="Sort profile stats with this comma-separated sort order") +opt("--profile-limit", type="int", action="store", dest="profile_limit", + default=None, + help="Limit function count in profile stats") + +class _ordered_map(object): + def __init__(self): + self._keys = list() + self._data = dict() + + def __setitem__(self, key, value): + if key not in self._keys: + self._keys.append(key) + self._data[key] = value + + def __iter__(self): + for key in self._keys: + yield self._data[key] + +post_configure = _ordered_map() + +def _engine_uri(options, file_config): + global db_label, db_uri + db_label = 'sqlite' + if options.dburi: + db_uri = options.dburi + db_label = db_uri[:db_uri.index(':')] + elif options.db: + db_label = options.db + db_uri = None + + if db_uri is None: + if db_label not in file_config.options('db'): + raise RuntimeError( + "Unknown engine. Specify --dbs for known engines.") + db_uri = file_config.get('db', db_label) +post_configure['engine_uri'] = _engine_uri + +def _require(options, file_config): + if not(options.require or + (file_config.has_section('require') and + file_config.items('require'))): + return + + try: + import pkg_resources + except ImportError: + raise RuntimeError("setuptools is required for version requirements") + + cmdline = [] + for requirement in options.require: + pkg_resources.require(requirement) + cmdline.append(re.split('\s*(=)', requirement, 1)[0]) + + if file_config.has_section('require'): + for label, requirement in file_config.items('require'): + if not label == db_label or label.startswith('%s.' % db_label): + continue + seen = [c for c in cmdline if requirement.startswith(c)] + if seen: + continue + pkg_resources.require(requirement) +post_configure['require'] = _require + +def _create_testing_engine(options, file_config): + from sqlalchemy import engine + global db, db_type + engine_opts = {} + if options.serverside: + engine_opts['server_side_cursors'] = True + + if options.enginestrategy is not None: + engine_opts['strategy'] = options.enginestrategy + + if options.mockpool: + db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, + **engine_opts) + else: + db = engine.create_engine(db_uri, **engine_opts) + db_type = db.name + + # decorate the dialect's create_execution_context() method + # to produce a wrapper + from testlib.testing import ExecutionContextWrapper + + create_context = db.dialect.create_execution_context + def create_exec_context(*args, **kwargs): + return ExecutionContextWrapper(create_context(*args, **kwargs)) + db.dialect.create_execution_context = create_exec_context +post_configure['create_engine'] = _create_testing_engine + +def _set_table_options(options, file_config): + import testlib.schema + + table_options = testlib.schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine +post_configure['table_options'] = _set_table_options + +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm import unitofwork + from sqlalchemy import topological + class RevQueueDepSort(topological.QueueDependencySorter): + def __init__(self, tuples, allitems): + self.tuples = list(tuples) + self.allitems = list(allitems) + self.tuples.reverse() + self.allitems.reverse() + topological.QueueDependencySorter = RevQueueDepSort + unitofwork.DependencySorter = RevQueueDepSort +post_configure['topological'] = _reverse_topological + +def _set_profile_targets(options, file_config): + from testlib import profiling + + profile_config = profiling.profile_config + + for target in options.profile_targets: + profile_config['targets'].add(target) + + if options.profile_sort: + profile_config['sort'] = options.profile_sort.split(',') + + if options.profile_limit: + profile_config['limit'] = options.profile_limit + + if options.quiet: + profile_config['report'] = False + + # magic "all" target + if 'all' in profiling.all_targets: + targets = profile_config['targets'] + if 'all' in targets and len(targets) != 1: + targets.clear() + targets.add('all') +post_configure['profile_targets'] = _set_profile_targets diff --git a/test/coverage.py b/test/testlib/coverage.py similarity index 99% rename from test/coverage.py rename to test/testlib/coverage.py index 618f962fef..0203dbf7d5 100644 --- a/test/coverage.py +++ b/test/testlib/coverage.py @@ -1101,4 +1101,3 @@ if __name__ == '__main__': # DAMAGE. # # $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $ - diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py new file mode 100644 index 0000000000..51b8bb5d8d --- /dev/null +++ b/test/testlib/profiling.py @@ -0,0 +1,73 @@ +"""Profiling support for unit and performance tests.""" + +import time, hotshot, hotshot.stats +from testlib.config import parser, post_configure +import testlib.config + +__all__ = 'profiled', + +all_targets = set() +profile_config = { 'targets': set(), + 'report': True, + 'sort': ('time', 'calls'), + 'limit': None } + +def profiled(target, **target_opts): + """Optional function profiling. + + @profiled('label') + or + @profiled('label', report=True, sort=('calls',), limit=20) + + Enables profiling for a function when 'label' is targetted for + profiling. Report options can be supplied, and override the global + configuration and command-line options. + """ + + # manual or automatic namespacing by module would remove conflict issues + if target in all_targets: + print "Warning: redefining profile target '%s'" % target + all_targets.add(target) + + filename = "%s.prof" % target + + def decorator(fn): + def profiled(*args, **kw): + if (target not in profile_config['targets'] and + not target_opts.get('always', None)): + return fn(*args, **kw) + + prof = hotshot.Profile(filename) + began = time.time() + prof.start() + try: + result = fn(*args, **kw) + finally: + prof.stop() + ended = time.time() + prof.close() + + if not testlib.config.options.quiet: + print "Profiled target '%s', wall time: %.2f seconds" % ( + target, ended - began) + + report = target_opts.get('report', profile_config['report']) + if report: + sort_ = target_opts.get('sort', profile_config['sort']) + limit = target_opts.get('limit', profile_config['limit']) + print "Profile report for target '%s' (%s)" % ( + target, filename) + + stats = hotshot.stats.load(filename) + stats.sort_stats(*sort_) + if limit: + stats.print_stats(limit) + else: + stats.print_stats() + return result + try: + profiled.__name__ = fn.__name__ + except: + pass + return profiled + return decorator diff --git a/test/testlib/schema.py b/test/testlib/schema.py new file mode 100644 index 0000000000..a2fc912650 --- /dev/null +++ b/test/testlib/schema.py @@ -0,0 +1,28 @@ +import testbase +from sqlalchemy import schema + +__all__ = 'Table', 'Column', + +table_options = {} + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = dict([(k,kw.pop(k)) for k in kw.keys() + if k.startswith('test_')]) + + kw.update(table_options) + + if testbase.db.name == 'mysql': + if 'mysql_engine' not in kw and 'mysql_type' not in kw: + if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: + kw['mysql_engine'] = 'InnoDB' + + return schema.Table(*args, **kw) + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + # TODO: a Column that creates a Sequence automatically for PK columns, + # which would help Oracle tests + return schema.Column(*args, **kw) diff --git a/test/tables.py b/test/testlib/tables.py similarity index 96% rename from test/tables.py rename to test/testlib/tables.py index 6ae53a1dae..2dc9200f35 100644 --- a/test/tables.py +++ b/test/testlib/tables.py @@ -1,12 +1,8 @@ - -from sqlalchemy import * -import os import testbase -from testbase import Table, Column +from sqlalchemy import * +from testlib.schema import Table, Column -ECHO = testbase.echo -db = testbase.db -metadata = MetaData(db) +metadata = MetaData() users = Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), @@ -51,13 +47,19 @@ itemkeywords = Table('itemkeywords', metadata, ) def create(): + if not metadata.bind: + metadata.bind = testbase.db metadata.create_all() def drop(): + if not metadata.bind: + metadata.bind = testbase.db metadata.drop_all() def delete(): for t in metadata.table_iterator(reverse=True): t.delete().execute() def user_data(): + if not metadata.bind: + metadata.bind = testbase.db users.insert().execute( dict(user_id = 7, user_name = 'jack'), dict(user_id = 8, user_name = 'ed'), @@ -209,4 +211,4 @@ order_result = [ {'order_id' : 4, 'items':(Item, [])}, {'order_id' : 5, 'items':(Item, [])}, ] -#db.echo = True + diff --git a/test/testlib/testing.py b/test/testlib/testing.py new file mode 100644 index 0000000000..0361cfb682 --- /dev/null +++ b/test/testlib/testing.py @@ -0,0 +1,363 @@ +"""TestCase and TestSuite artifacts and testing decorators.""" + +# monkeypatches unittest.TestLoader.suiteClass at import time + +import unittest, re, sys, os +from cStringIO import StringIO +from sqlalchemy import MetaData, sql +from sqlalchemy.orm import clear_mappers +import testlib.config as config + +__all__ = 'PersistTest', 'AssertMixin', 'ORMTest' + +def unsupported(*dbs): + """Mark a test as unsupported by one or more database implementations""" + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name in dbs: + print "'%s' unsupported on DB implementation '%s'" % ( + fn_name, config.db.name) + return True + else: + return fn(*args, **kw) + try: + maybe.__name__ = fn_name + except: + pass + return maybe + return decorate + +def supported(*dbs): + """Mark a test as supported by one or more database implementations""" + + def decorate(fn): + fn_name = fn.__name__ + def maybe(*args, **kw): + if config.db.name in dbs: + return fn(*args, **kw) + else: + print "'%s' unsupported on DB implementation '%s'" % ( + fn_name, config.db.name) + return True + try: + maybe.__name__ = fn_name + except: + pass + return maybe + return decorate + +class TestData(object): + """Tracks SQL expressions as they are executed via an instrumented ExecutionContext.""" + + def __init__(self): + self.set_assert_list(None, None) + self.sql_count = 0 + self.buffer = None + + def set_assert_list(self, unittest, list): + self.unittest = unittest + self.assert_list = list + if list is not None: + self.assert_list.reverse() + +testdata = TestData() + + +class ExecutionContextWrapper(object): + """instruments the ExecutionContext created by the Engine so that SQL expressions + can be tracked.""" + + def __init__(self, ctx): + self.__dict__['ctx'] = ctx + def __getattr__(self, key): + return getattr(self.ctx, key) + def __setattr__(self, key, value): + setattr(self.ctx, key, value) + + def post_exec(self): + ctx = self.ctx + statement = unicode(ctx.compiled) + statement = re.sub(r'\n', '', ctx.statement) + if testdata.buffer is not None: + testdata.buffer.write(statement + "\n") + + if testdata.assert_list is not None: + assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement + item = testdata.assert_list[-1] + if not isinstance(item, dict): + item = testdata.assert_list.pop() + else: + # asserting a dictionary of statements->parameters + # this is to specify query assertions where the queries can be in + # multiple orderings + if not item.has_key('_converted'): + for key in item.keys(): + ckey = self.convert_statement(key) + item[ckey] = item[key] + if ckey != key: + del item[key] + item['_converted'] = True + try: + entry = item.pop(statement) + if len(item) == 1: + testdata.assert_list.pop() + item = (statement, entry) + except KeyError: + assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement) + + (query, params) = item + if callable(params): + params = params(ctx) + if params is not None and isinstance(params, list) and len(params) == 1: + params = params[0] + + if isinstance(ctx.compiled_parameters, sql.ClauseParameters): + parameters = ctx.compiled_parameters.get_original_dict() + elif isinstance(ctx.compiled_parameters, list): + parameters = [p.get_original_dict() for p in ctx.compiled_parameters] + + query = self.convert_statement(query) + if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'): + statement = statement[:-25] + testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) + testdata.sql_count += 1 + self.ctx.post_exec() + + def convert_statement(self, query): + paramstyle = self.ctx.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + return query + +class PersistTest(unittest.TestCase): + + def __init__(self, *args, **params): + unittest.TestCase.__init__(self, *args, **params) + + def setUpAll(self): + pass + + def tearDownAll(self): + pass + + def shortDescription(self): + """overridden to not return docstrings""" + return None + +class AssertMixin(PersistTest): + """given a list-based structure of keys/properties which represent information within an object structure, and + a list of actual objects, asserts that the list of objects corresponds to the structure.""" + + def assert_result(self, result, class_, *objects): + result = list(result) + print repr(result) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list): + self.assert_(len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + class_.__name__) + for i in range(0, len(list)): + self.assert_row(class_, result[i], list[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_(rowobj.__class__ is class_, + "item class is not " + repr(class_)) + for key, value in desc.iteritems(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_(getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" % ( + key, getattr(rowobj, key), value)) + + def assert_sql(self, db, callable_, list, with_sequences=None): + global testdata + testdata = TestData() + if with_sequences is not None and (config.db.name == 'postgres' or + config.db.name == 'oracle'): + testdata.set_assert_list(self, with_sequences) + else: + testdata.set_assert_list(self, list) + try: + callable_() + finally: + testdata.set_assert_list(None, None) + + def assert_sql_count(self, db, callable_, count): + global testdata + testdata = TestData() + try: + callable_() + finally: + self.assert_(testdata.sql_count == count, + "desired statement count %d does not match %d" % ( + count, testdata.sql_count)) + + def capture_sql(self, db, callable_): + global testdata + testdata = TestData() + buffer = StringIO() + testdata.buffer = buffer + try: + callable_() + return buffer.getvalue() + finally: + testdata.buffer = None + +_otest_metadata = None +class ORMTest(AssertMixin): + keep_mappers = False + keep_data = False + + def setUpAll(self): + global _otest_metadata + _otest_metadata = MetaData(config.db) + self.define_tables(_otest_metadata) + _otest_metadata.create_all() + self.insert_data() + + def define_tables(self, _otest_metadata): + raise NotImplementedError() + + def insert_data(self): + pass + + def get_metadata(self): + return _otest_metadata + + def tearDownAll(self): + clear_mappers() + _otest_metadata.drop_all() + + def tearDown(self): + if not self.keep_mappers: + clear_mappers() + if not self.keep_data: + for t in _otest_metadata.table_iterator(reverse=True): + t.delete().execute().close() + + +class TTestSuite(unittest.TestSuite): + """A TestSuite with once per TestCase setUpAll() and tearDownAll()""" + + def __init__(self, tests=()): + if len(tests) >0 and isinstance(tests[0], PersistTest): + self._initTest = tests[0] + else: + self._initTest = None + unittest.TestSuite.__init__(self, tests) + + def do_run(self, result): + # nice job unittest ! you switched __call__ and run() between py2.3 + # and 2.4 thereby making straight subclassing impossible ! + for test in self._tests: + if result.shouldStop: + break + test(result) + return result + + def run(self, result): + return self(result) + + def __call__(self, result): + try: + if self._initTest is not None: + self._initTest.setUpAll() + except: + result.addError(self._initTest, self.__exc_info()) + pass + try: + return self.do_run(result) + finally: + try: + if self._initTest is not None: + self._initTest.tearDownAll() + except: + result.addError(self._initTest, self.__exc_info()) + pass + + def __exc_info(self): + """Return a version of sys.exc_info() with the traceback frame + minimised; usually the top level of the traceback frame is not + needed. + ripped off out of unittest module since its double __ + """ + exctype, excvalue, tb = sys.exc_info() + if sys.platform[:4] == 'java': ## tracebacks look different in Jython + return (exctype, excvalue, tb) + return (exctype, excvalue, tb) + +unittest.TestLoader.suiteClass = TTestSuite + +def _iter_covered_files(): + import sqlalchemy + for rec in os.walk(os.path.dirname(sqlalchemy.__file__)): + for x in rec[2]: + if x.endswith('.py'): + yield os.path.join(rec[0], x) + +def cover(callable_, file_=None): + from testlib import coverage + coverage_client = coverage.the_coverage + coverage_client.get_ready() + coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]') + coverage_client.erase() + coverage_client.start() + try: + return callable_() + finally: + coverage_client.stop() + coverage_client.save() + coverage_client.report(list(_iter_covered_files()), + show_missing=False, ignore_errors=False, + file=file_) + +class DevNullWriter(object): + def write(self, msg): + pass + def flush(self): + pass + +def runTests(suite): + verbose = config.options.verbose + quiet = config.options.quiet + orig_stdout = sys.stdout + + try: + if not verbose or quiet: + sys.stdout = DevNullWriter() + runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) + return runner.run(suite) + finally: + if not verbose or quiet: + sys.stdout = orig_stdout + +def main(suite=None): + if not suite: + if len(sys.argv[1:]): + suite =unittest.TestLoader().loadTestsFromNames( + sys.argv[1:], __import__('__main__')) + else: + suite = unittest.TestLoader().loadTestsFromModule( + __import__('__main__')) + + result = runTests(suite) + sys.exit(not result.wasSuccessful()) diff --git a/test/zblog/tables.py b/test/zblog/tables.py index 0ca66fb722..3a5059972d 100644 --- a/test/zblog/tables.py +++ b/test/zblog/tables.py @@ -1,8 +1,10 @@ +"""application table metadata objects are described here.""" + from sqlalchemy import * -from testbase import Table, Column +from testlib import * + metadata = MetaData() -"""application table metadata objects are described here.""" users = Table('users', metadata, Column('user_id', Integer, primary_key=True), diff --git a/test/zblog/tests.py b/test/zblog/tests.py index 5bcad2da23..ad6876937d 100644 --- a/test/zblog/tests.py +++ b/test/zblog/tests.py @@ -1,21 +1,20 @@ -from testbase import AssertMixin import testbase -import unittest -db = testbase.db from sqlalchemy import * from sqlalchemy.orm import * - +from testlib import * from zblog import mappers, tables from zblog.user import * from zblog.blog import * + class ZBlogTest(AssertMixin): def create_tables(self): - tables.metadata.create_all(bind=db) + tables.metadata.drop_all(bind=testbase.db) + tables.metadata.create_all(bind=testbase.db) def drop_tables(self): - tables.metadata.drop_all(bind=db) + tables.metadata.drop_all(bind=testbase.db) def setUpAll(self): self.create_tables() @@ -32,7 +31,7 @@ class SavePostTest(ZBlogTest): super(SavePostTest, self).setUpAll() mappers.zblog_mappers() global blog_id, user_id - s = create_session(bind=db) + s = create_session(bind=testbase.db) user = User('zbloguser', "Zblog User", "hello", group=administrator) blog = Blog(owner=user) blog.name = "this is a blog" @@ -51,7 +50,7 @@ class SavePostTest(ZBlogTest): """test that a transient/pending instance has proper bi-directional behavior. this requires that lazy loaders do not fire off for a transient/pending instance.""" - s = create_session(bind=db) + s = create_session(bind=testbase.db) s.begin() try: @@ -67,7 +66,7 @@ class SavePostTest(ZBlogTest): def testoptimisticorphans(self): """test that instances in the session with un-loaded parents will not get marked as "orphans" and then deleted """ - s = create_session(bind=db) + s = create_session(bind=testbase.db) s.begin() try: @@ -97,4 +96,4 @@ class SavePostTest(ZBlogTest): if __name__ == "__main__": testbase.main() - \ No newline at end of file +