-from testbase import PersistTest
+import testbase
import sqlalchemy.topological as topological
-import unittest, sys, os
from sqlalchemy import util
+from testlib import *
+
# TODO: need assertion conditions in this suite
if __name__ == "__main__":
- unittest.main()
+ testbase.main()
import testbase
from sqlalchemy import util
+from testlib import *
-class OrderedDictTest(testbase.PersistTest):
+
+class OrderedDictTest(PersistTest):
def test_odict(self):
o = util.OrderedDict()
o['a'] = 1
-from testbase import PersistTest, AssertMixin
import testbase
from sqlalchemy import *
from sqlalchemy.databases import mysql
-from testbase import Table, Column
-import sys, StringIO
+from testlib import *
-db = testbase.db
class TypesTest(AssertMixin):
"Test MySQL column types"
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_numeric(self):
"Exercise type specification and options for numeric types."
'SMALLINT(4) UNSIGNED ZEROFILL'),
]
- table_args = ['test_mysql_numeric', MetaData(db)]
+ table_args = ['test_mysql_numeric', MetaData(testbase.db)]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw)))
numeric_table = Table(*table_args)
- gen = db.dialect.schemagenerator(db, None, None)
+ gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
for col in numeric_table.c:
index = int(col.name[1:])
raise
numeric_table.drop()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_charset(self):
"""Exercise CHARACTER SET and COLLATE-related options on string-type
columns."""
'''ENUM('foo','bar') UNICODE''')
]
- table_args = ['test_mysql_charset', MetaData(db)]
+ table_args = ['test_mysql_charset', MetaData(testbase.db)]
for index, spec in enumerate(columns):
type_, args, kw, res = spec
table_args.append(Column('c%s' % index, type_(*args, **kw)))
charset_table = Table(*table_args)
- gen = db.dialect.schemagenerator(db, None, None)
+ gen = testbase.db.dialect.schemagenerator(testbase.db, None, None)
for col in charset_table.c:
index = int(col.name[1:])
raise
charset_table.drop()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_enum(self):
"Exercise the ENUM type"
-
- enum_table = Table('mysql_enum', MetaData(db),
+
+ db = testbase.db
+ enum_table = Table('mysql_enum', MetaData(testbase.db),
Column('e1', mysql.MSEnum('"a"', "'b'")),
Column('e2', mysql.MSEnum('"a"', "'b'"), nullable=False),
Column('e3', mysql.MSEnum('"a"', "'b'", strict=True)),
# This is known to fail with MySQLDB 1.2.2 beta versions
# which return these as sets.Set(['a']), sets.Set(['b'])
# (even on Pythons with __builtin__.set)
- if db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
- db.dialect.dbapi.version_info >= (1, 2, 2):
+ if testbase.db.dialect.dbapi.version_info < (1, 2, 2, 'beta', 3) and \
+ testbase.db.dialect.dbapi.version_info >= (1, 2, 2):
# these mysqldb seem to always uses 'sets', even on later pythons
import sets
def convert(value):
self.assertEqual(res, expected)
enum_table.drop()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def test_type_reflection(self):
# FIXME: older versions need their own test
- if db.dialect.get_version_info(db) < (5, 0):
+ if testbase.db.dialect.get_version_info(testbase.db) < (5, 0):
return
# (ask_for, roundtripped_as_if_different)
columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
- m = MetaData(db)
+ m = MetaData(testbase.db)
t_table = Table('mysql_types', m, *columns)
m.drop_all()
m.create_all()
- m2 = MetaData(db)
+ m2 = MetaData(testbase.db)
rt = Table('mysql_types', m2, autoload=True)
#print
-from testbase import AssertMixin
import testbase
+import datetime
from sqlalchemy import *
from sqlalchemy.databases import postgres
-import datetime
+from testlib import *
-db = testbase.db
class DomainReflectionTest(AssertMixin):
"Test PostgreSQL domains"
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def setUpAll(self):
- con = db.connect()
+ con = testbase.db.connect()
con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def tearDownAll(self):
- con = db.connect()
+ con = testbase.db.connect()
con.execute('DROP TABLE testtable')
con.execute('DROP TABLE alt_schema.testtable')
con.execute('DROP TABLE crosschema')
con.execute('DROP DOMAIN testdomain')
con.execute('DROP DOMAIN alt_schema.testdomain')
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_table_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True)
self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_domain_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True)
self.assertEquals(str(table.columns.answer.default.arg), '42', "Reflected default value didn't equal expected value")
self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.")
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_table_is_reflected_alt_schema(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_domain_is_reflected(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
self.assertEquals(str(table.columns.answer.default.arg), '0', "Reflected default value didn't equal expected value")
self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_crosschema_domain_is_reflected(self):
metadata = MetaData(db)
table = Table('crosschema', metadata, autoload=True)
self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
class MiscTest(AssertMixin):
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_date_reflection(self):
m1 = MetaData(testbase.db)
t1 = Table('pgdate', m1,
finally:
m1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_pg_weirdchar_reflection(self):
meta1 = MetaData(testbase.db)
subject = Table("subject", meta1,
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_checksfor_sequence(self):
meta1 = MetaData(testbase.db)
t = Table('mytable', meta1,
finally:
t.drop(checkfirst=True)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_reflection(self):
"""note: this test requires that the 'alt_schema' schema be separate and accessible by the test user"""
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_reflection_2(self):
meta1 = MetaData(testbase.db)
subject = Table("subject", meta1,
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_schema_reflection_3(self):
meta1 = MetaData(testbase.db)
subject = Table("subject", meta1,
finally:
meta1.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_preexecute_passivedefault(self):
"""test that when we get a primary key column back
from reflecting a table which has a default value on it, we pre-execute
if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime
that doesnt have one. this test illustrates two ways to have datetime types with and without timezone
info. """
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def setUpAll(self):
global tztable, notztable, metadata
metadata = MetaData(testbase.db)
Column("name", String(20)),
)
metadata.create_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def tearDownAll(self):
metadata.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_with_timezone(self):
# get a date with a tzinfo
somedate = testbase.db.connect().scalar(func.current_timestamp().select())
x = c.last_updated_params()
print x['date'] == somedate
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_without_timezone(self):
# get a date without a tzinfo
somedate = datetime.datetime(2005, 10,20, 11, 52, 00)
print x['date'] == somedate
class ArrayTest(AssertMixin):
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def setUpAll(self):
global metadata, arrtable
metadata = MetaData(testbase.db)
Column('strarr', postgres.PGArray(String), nullable=False)
)
metadata.create_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def tearDownAll(self):
metadata.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_reflect_array_column(self):
metadata2 = MetaData(testbase.db)
tbl = Table('arrtable', metadata2, autoload=True)
self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_insert_array(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
results = arrtable.select().execute().fetchall()
self.assertEquals(results[0]['strarr'], ['abc','def'])
arrtable.delete().execute()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_array_where(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
self.assertEquals(results[0]['intarr'], [1,2,3])
arrtable.delete().execute()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_array_concat(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
including the deprecated versions of these arguments"""
import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
-class BindTest(testbase.PersistTest):
+class BindTest(PersistTest):
def test_create_drop_explicit(self):
metadata = MetaData()
table = Table('test_table', metadata,
assert False
except exceptions.InvalidRequestError, e:
assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
-
-
finally:
if isinstance(bind, engine.Connection):
bind.close()
if __name__ == '__main__':
- testbase.main()
\ No newline at end of file
+ testbase.main()
-
import testbase
-import unittest, sys, datetime
-import tables
-db = testbase.db
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
-class ExecuteTest(testbase.PersistTest):
+class ExecuteTest(PersistTest):
def setUpAll(self):
global users, metadata
metadata = MetaData(testbase.db)
def tearDownAll(self):
metadata.drop_all()
- @testbase.supported('sqlite')
+ @testing.supported('sqlite')
def test_raw_qmark(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
conn.execute("delete from users")
- @testbase.supported('mysql', 'postgres')
+ @testing.supported('mysql', 'postgres')
def test_raw_sprintf(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"])
# pyformat is supported for mysql, but skipping because a few driver
# versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_raw_python(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
conn.execute("delete from users")
- @testbase.supported('sqlite')
+ @testing.supported('sqlite')
def test_raw_named(self):
for conn in (testbase.db, testbase.db.connect()):
conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
import testbase
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
-class MetaDataTest(testbase.PersistTest):
+class MetaDataTest(PersistTest):
def test_metadata_connect(self):
metadata = MetaData()
t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
-from testbase import PersistTest
import testbase
-import sqlalchemy.engine.url as url
from sqlalchemy import *
-import unittest
+import sqlalchemy.engine.url as url
+from testlib import *
class ParseConnectTest(PersistTest):
import testbase
-from testbase import PersistTest
-import unittest, sys, os, time
-import threading, thread
-
+import threading, thread, time
import sqlalchemy.pool as pool
import sqlalchemy.exceptions as exceptions
+from testlib import *
+
mcid = 1
class MockDBAPI(object):
connection2 = manager.connect('foo.db')
connection3 = manager.connect('bar.db')
- self.echo( "connection " + repr(connection))
+ print "connection " + repr(connection)
self.assert_(connection.cursor() is not None)
self.assert_(connection is connection2)
self.assert_(connection2 is not connection3)
connection = manager.connect('foo.db')
connection2 = manager.connect('foo.db')
- self.echo( "connection " + repr(connection))
+ print "connection " + repr(connection)
self.assert_(connection.cursor() is not None)
self.assert_(connection is not connection2)
def status(pool):
tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
- self.echo( "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup)
+ print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
return tup
c1 = p.connect()
import testbase
+import sys, weakref
from sqlalchemy import create_engine, exceptions
-import gc, weakref, sys
+from testlib import *
+
class MockDisconnect(Exception):
pass
def close(self):
pass
-class ReconnectTest(testbase.PersistTest):
+class ReconnectTest(PersistTest):
def test_reconnect(self):
"""test that an 'is_disconnect' condition will invalidate the connection, and additionally
dispose the previous connection pool and recreate."""
assert len(dbapi.connections) == 1
if __name__ == '__main__':
- testbase.main()
\ No newline at end of file
+ testbase.main()
-from testbase import PersistTest
import testbase
-import pickle
-import sqlalchemy.ansisql as ansisql
+import pickle, StringIO
from sqlalchemy import *
+import sqlalchemy.ansisql as ansisql
from sqlalchemy.exceptions import NoSuchTableError
import sqlalchemy.databases.mysql as mysql
-from testbase import Table, Column
-import unittest, re, StringIO
+from testlib import *
+
class ReflectionTest(PersistTest):
def testbasic(self):
finally:
meta.drop_all()
- @testbase.supported('mysql')
+ @testing.supported('mysql')
def testmysqltypes(self):
meta1 = MetaData(testbase.db)
table = Table(
finally:
testbase.db.execute("drop table book")
- @testbase.supported('sqlite')
+ @testing.supported('sqlite')
def test_goofy_sqlite(self):
"""test autoload of table where quotes were used with all the colnames. quirky in sqlite."""
testbase.db.execute("""CREATE TABLE "django_content_type" (
finally:
table.drop()
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def testidentity(self):
meta = MetaData(testbase.db)
table = Table(
class SchemaTest(PersistTest):
# this test should really be in the sql tests somewhere, not engine
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testiteration(self):
metadata = MetaData()
table1 = Table('table1', metadata,
assert buf.index("CREATE TABLE someschema.table1") > -1
assert buf.index("CREATE TABLE someschema.table2") > -1
- @testbase.supported('mysql','postgres')
+ @testing.supported('mysql','postgres')
def testcreate(self):
engine = testbase.db
schema = engine.dialect.get_default_schema_name(engine)
-
import testbase
-import unittest, sys, datetime, random, time, threading
-import tables
-db = testbase.db
+import sys, time, threading
+
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
+
-class TransactionTest(testbase.PersistTest):
+class TransactionTest(PersistTest):
def setUpAll(self):
global users, metadata
metadata = MetaData()
assert len(result.fetchall()) == 0
connection.close()
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testnestedsubtransactionrollback(self):
connection = testbase.db.connect()
transaction = connection.begin()
)
connection.close()
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testnestedsubtransactioncommit(self):
connection = testbase.db.connect()
transaction = connection.begin()
)
connection.close()
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testrollbacktosubtransaction(self):
connection = testbase.db.connect()
transaction = connection.begin()
)
connection.close()
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def testtwophasetransaction(self):
connection = testbase.db.connect()
)
connection.close()
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def testmixedtransaction(self):
connection = testbase.db.connect()
)
connection.close()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def testtwophaserecover(self):
# MySQL recovery doesn't currently seem to work correctly
# Prepared transactions disappear when connections are closed and even
)
connection2.close()
-class AutoRollbackTest(testbase.PersistTest):
+class AutoRollbackTest(PersistTest):
def setUpAll(self):
global metadata
metadata = MetaData()
def tearDownAll(self):
metadata.drop_all(testbase.db)
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testrollback_deadlock(self):
"""test that returning connections to the pool clears any object locks."""
conn1 = testbase.db.connect()
users.drop(conn2)
conn2.close()
-class TLTransactionTest(testbase.PersistTest):
+class TLTransactionTest(PersistTest):
def setUpAll(self):
global users, metadata, tlengine
- tlengine = create_engine(testbase.db_uri, strategy='threadlocal')
+ tlengine = create_engine(testbase.db.url, strategy='threadlocal')
metadata = MetaData()
users = Table('query_users', metadata,
Column('user_id', INT, primary_key = True),
finally:
external_connection.close()
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testnesting(self):
"""tests nesting of tranacstions"""
external_connection = tlengine.connect()
c2.close()
assert c1.connection.connection is not None
-class ForUpdateTest(testbase.PersistTest):
+class ForUpdateTest(PersistTest):
def setUpAll(self):
global counters, metadata
metadata = MetaData()
counters.drop(testbase.db)
def increment(self, count, errors, update_style=True, delay=0.005):
- con = db.connect()
+ con = testbase.db.connect()
sel = counters.select(for_update=update_style,
whereclause=counters.c.counter_id==1)
con.close()
- @testbase.supported('mysql', 'oracle', 'postgres')
+ @testing.supported('mysql', 'oracle', 'postgres')
def testqueued_update(self):
"""Test SELECT FOR UPDATE with concurrent modifications.
def overlap(self, ids, errors, update_style):
sel = counters.select(for_update=update_style,
whereclause=counters.c.counter_id.in_(*ids))
- con = db.connect()
+ con = testbase.db.connect()
trans = con.begin()
try:
rows = con.execute(sel).fetchall()
return errors
- @testbase.supported('mysql', 'oracle', 'postgres')
+ @testing.supported('mysql', 'oracle', 'postgres')
def testqueued_select(self):
"""Simple SELECT FOR UPDATE conflict test"""
sys.stderr.write("Failure: %s\n" % e)
self.assert_(len(errors) == 0)
- @testbase.supported('oracle', 'postgres')
+ @testing.supported('oracle', 'postgres')
def testnowait_select(self):
"""Simple SELECT FOR UPDATE NOWAIT conflict test"""
import testbase
+from datetime import datetime
+
from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
from sqlalchemy import and_, or_, exceptions
from sqlalchemy import ForeignKey, String, Integer, DateTime, Table, Column
from sqlalchemy.orm import clear_mappers, backref, create_session, class_mapper
-from datetime import datetime
-import sqlalchemy
-
import sqlalchemy.ext.activemapper as activemapper
+import sqlalchemy
+from testlib import *
-class testcase(testbase.PersistTest):
+class testcase(PersistTest):
def setUpAll(self):
clear_mappers()
objectstore.clear()
self.assertEquals(len(person.addresses), 2)
self.assertEquals(person.addresses[0].postal_code, '30338')
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_update(self):
p1 = self.create_person_one()
objectstore.flush()
self.assertEquals(Person.query.count(), 2)
-class testmanytomany(testbase.PersistTest):
+class testmanytomany(PersistTest):
def setUpAll(self):
clear_mappers()
objectstore.clear()
foo1.bazrel.append(baz1)
assert (foo1.bazrel == [baz1])
-class testselfreferential(testbase.PersistTest):
+class testselfreferential(PersistTest):
def setUpAll(self):
clear_mappers()
objectstore.clear()
-from testbase import PersistTest, AssertMixin
import testbase
from sqlalchemy import *
from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
-
from sqlalchemy.ext.assignmapper import assign_mapper
from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
+from testlib import *
+
class AssignMapperTest(PersistTest):
def setUpAll(self):
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest
import testbase
+
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm.collections import collection
from sqlalchemy.ext.associationproxy import *
-from testbase import Table, Column
+from testlib import *
-db = testbase.db
class DictCollection(dict):
@collection.appender
def setUp(self):
collection_class = self.collection_class
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
parents_table = Table('Parent', metadata,
Column('id', Integer, primary_key=True),
class ScalarTest(PersistTest):
def test_scalar_proxy(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
parents_table = Table('Parent', metadata,
Column('id', Integer, primary_key=True),
class LazyLoadTest(PersistTest):
def setUp(self):
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
parents_table = Table('Parent', metadata,
Column('id', Integer, primary_key=True),
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest, sys, os
import testbase
+
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.orderinglist import *
-from testbase import Table, Column
+from testlib import *
-db = testbase.db
metadata = None
# order in whole steps
global metadata, slides_table, bullets_table, Slide, Bullet
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
slides_table = Table('test_Slides', metadata,
Column('id', Integer, primary_key=True),
Column('name', String))
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
-class AssociationTest(testbase.PersistTest):
+class AssociationTest(PersistTest):
def setUpAll(self):
global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation
metadata = MetaData(testbase.db)
sess.flush()
self.assert_(item_keywords.count().scalar() == 0)
-class AssociationTest2(testbase.PersistTest):
+class AssociationTest2(PersistTest):
def setUpAll(self):
global table_originals, table_people, table_isauthor, metadata, Originals, People, IsAuthor
metadata = MetaData(testbase.db)
"""eager loading unittests derived from mailing list-reported problems and trac tickets."""
-from testbase import PersistTest, AssertMixin, ORMTest
import testbase
+import random, datetime
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
-import random, datetime
+from testlib import *
class EagerTest(AssertMixin):
def setUpAll(self):
obj = session.query(Left).get_by(tag='tag1')
print obj.middle.right[0]
-class EagerTest3(testbase.ORMTest):
+class EagerTest3(ORMTest):
"""test eager loading combined with nested SELECT statements, functions, and aggregates"""
def define_tables(self, metadata):
global datas, foo, stats
# algorithms and there are repeated 'somedata' values in the list)
assert verify_result == arb_result
-class EagerTest4(testbase.ORMTest):
+class EagerTest4(ORMTest):
def define_tables(self, metadata):
global departments, employees
departments = Table('departments', metadata,
assert q.count() == 2
assert q[0] is d2
-class EagerTest5(testbase.ORMTest):
+class EagerTest5(ORMTest):
"""test the construction of AliasedClauses for the same eager load property but different
parent mappers, due to inheritance"""
def define_tables(self, metadata):
i = ctx.current.query(Invoice).get(invoice_id)
- self.echo(repr(c))
- self.echo(repr(i.company))
+ print repr(c)
+ print repr(i.company)
self.assert_(repr(c) == repr(i.company))
def testtwo(self):
ctx.current.clear()
a = ctx.current.query(Company).get(company_id)
- self.echo(repr(a))
+ print repr(a)
# set up an invoice
i1 = Invoice()
ctx.current.clear()
c = ctx.current.query(Company).get(company_id)
- self.echo(repr(c))
+ print repr(c)
ctx.current.clear()
assert repr(i.company) == repr(c), repr(i.company) + " does not match " + repr(c)
-class EagerTest8(testbase.ORMTest):
+class EagerTest8(ORMTest):
def define_tables(self, metadata):
global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t
-from testbase import PersistTest
-import sqlalchemy.util as util
+import testbase
+import pickle
import sqlalchemy.orm.attributes as attributes
from sqlalchemy.orm.collections import collection
from sqlalchemy import exceptions
-import unittest, sys, os
-import pickle
-import testbase
+from testlib import *
class MyTest(object):pass
class MyTest2(object):pass
-import testbase, tables
-import unittest, sys, datetime
+import testbase
-from sqlalchemy.ext.sessioncontext import SessionContext
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from sqlalchemy.ext.sessioncontext import SessionContext
+from testlib import *
+import testlib.tables as tables
-class O2MCascadeTest(testbase.AssertMixin):
+class O2MCascadeTest(AssertMixin):
def tearDown(self):
tables.delete()
sess = create_session()
l = sess.query(tables.User).select()
for u in l:
- self.echo( repr(u.orders))
+ print repr(u.orders)
self.assert_result(l, data[0], *data[1:])
ids = (l[0].user_id, l[2].user_id)
self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(*ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0)
-class M2OCascadeTest(testbase.AssertMixin):
+class M2OCascadeTest(AssertMixin):
def tearDown(self):
ctx.current.clear()
for t in metadata.table_iterator(reverse=True):
-class M2MCascadeTest(testbase.AssertMixin):
+class M2MCascadeTest(AssertMixin):
def setUpAll(self):
global metadata, a, b, atob
metadata = MetaData(testbase.db)
assert b.count().scalar() == 0
assert a.count().scalar() == 0
-class UnsavedOrphansTest(testbase.ORMTest):
+class UnsavedOrphansTest(ORMTest):
"""tests regarding pending entities that are orphans"""
def define_tables(self, metadata):
assert a.address_id is None, "Error: address should not be persistent"
-class UnsavedOrphansTest2(testbase.ORMTest):
+class UnsavedOrphansTest2(ORMTest):
"""same test as UnsavedOrphans only three levels deep"""
def define_tables(self, meta):
assert item.id is None
assert attr.id is None
-class DoubleParentOrphanTest(testbase.AssertMixin):
+class DoubleParentOrphanTest(AssertMixin):
"""test orphan detection for an entity with two parent relations"""
def setUpAll(self):
assert False
except exceptions.FlushError, e:
assert True
-
-
+
+
if __name__ == "__main__":
testbase.main()
from sqlalchemy.orm.collections import collection
from sqlalchemy import util
from operator import and_
-
+from testlib import *
class Canary(interfaces.AttributeExtension):
def __init__(self):
return Entity(a or str(_id), b or 'value %s' % _id, c)
-class CollectionsTest(testbase.PersistTest):
+class CollectionsTest(PersistTest):
def _test_adapter(self, typecallable, creator=entity_maker,
to_set=None):
class Foo(object):
obj.attr[0] = e3
self.assert_(e3 in canary.data)
-class DictHelpersTest(testbase.ORMTest):
+class DictHelpersTest(ORMTest):
def define_tables(self, metadata):
global parents, children, Parent, Child
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
-class CompileTest(testbase.AssertMixin):
+class CompileTest(AssertMixin):
"""test various mapper compilation scenarios"""
def tearDownAll(self):
clear_mappers()
-from testbase import PersistTest, AssertMixin, ORMTest
-import unittest, sys, os
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-import StringIO
-import testbase
-from testbase import Table, Column
-
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
"""test cyclical mapper relationships. Many of the assertions are provided
via running with postgres, which is strict about foreign keys.
sess.save(b)
sess.save(p)
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
(
"INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)",
{'favorite_ball_id': None, 'data':'some data'}
)
])
sess.delete(p)
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
# heres the post update (which is a pre-update with deletes)
(
"UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id",
sess = create_session()
[sess.save(x) for x in [b,p,b2,b3,b4]]
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
(
"INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
{'person_id':None, 'data':'some data'}
])
sess.delete(p)
- self.assert_sql(db, lambda: sess.flush(), [
+ self.assert_sql(testbase.db, lambda: sess.flush(), [
(
"UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
lambda ctx:{'person_id': None, 'ball_id': b.id}
remove_child(root, cats)
# pre-trigger lazy loader on 'cats' to make the test easier
cats.children
- self.assert_sql(db, lambda: session.flush(), [
+ self.assert_sql(testbase.db, lambda: session.flush(), [
(
"UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id",
lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}
"""basic tests of eager loaded attributes"""
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-import testbase
-
+from testlib import *
from fixtures import *
from query import QueryTest
l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id)
assert fixtures.user_address_result[1:2] == l.all()
-class SelfReferentialEagerTest(testbase.ORMTest):
+class SelfReferentialEagerTest(ORMTest):
def define_tables(self, metadata):
global nodes
nodes = Table('nodes', metadata,
-from testbase import PersistTest, AssertMixin
-import unittest
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
-
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
class EntityTest(AssertMixin):
"""tests mappers that are constructed based on "entity names", which allows the same class
+import testbase
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
_recursion_stack = util.Set()
class Base(object):
-from testbase import PersistTest, AssertMixin, ORMTest
import testbase
-import tables
-
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy import exceptions
-from testbase import Table, Column
+from testlib import *
+import testlib.tables as tables
# TODO: these are more tests that should be updated to be part of test/orm/query.py
assert res.order_by([Foo.c.bar])[0].bar == 5
assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
- @testbase.unsupported('mssql')
+ @testing.unsupported('mssql')
def test_slice(self):
sess = create_session()
query = sess.query(Foo)
assert list(query[-5:]) == orig[-5:]
assert query[10:20][5] == orig[10:20][5]
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_slice_mssql(self):
sess = create_session()
query = sess.query(Foo)
assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_aggregate_1(self):
# this one fails in mysql as the result comes back as a string
query = create_session().query(Foo)
assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
- @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
+ @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_2(self):
query = create_session().query(Foo)
assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
- @testbase.supported('postgres', 'mysql', 'firebird', 'mssql')
+ @testing.supported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_2_int(self):
query = create_session().query(Foo)
assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
- @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
+ @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_3(self):
query = create_session().query(Foo)
assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-
from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE
-import testbase
+from testlib import *
def produce_test(parent, child, direction):
"""produce a testcase for A->B->C inheritance with a self-referential
relationship between two of the classes, using either one-to-many or
many-to-one."""
- class ABCTest(testbase.ORMTest):
+ class ABCTest(ORMTest):
def define_tables(self, meta):
global ta, tb, tc
ta = ["a", meta]
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
-class O2MTest(testbase.ORMTest):
+class O2MTest(ORMTest):
"""deals with inheritance and one-to-many relationships"""
def define_tables(self, metadata):
global foo, bar, blub
sess.clear()
l = sess.query(Blub).select()
result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
- self.echo(result)
+ print result
self.assert_(compare == result)
self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
-class AddPropTest(testbase.ORMTest):
+class AddPropTest(ORMTest):
"""testing that construction of inheriting mappers works regardless of when extra properties
are added to the superclass mapper"""
def define_tables(self, metadata):
p.contenttype = ContentType()
# TODO: assertion ??
-class EagerLazyTest(testbase.ORMTest):
+class EagerLazyTest(ORMTest):
"""tests eager load/lazy load of child items off inheritance mappers, tests that
LazyLoader constructs the right query condition."""
def define_tables(self, metadata):
self.assert_(len(q.selectfirst().eager) == 1)
-class FlushTest(testbase.ORMTest):
+class FlushTest(ORMTest):
"""test dependency sorting among inheriting mappers"""
def define_tables(self, metadata):
global users, roles, user_roles, admins
sess.flush()
assert user_roles.count().scalar() == 1
-class DistinctPKTest(testbase.ORMTest):
+class DistinctPKTest(ORMTest):
"""test the construction of mapper.primary_key when an inheriting relationship
joins on a column other than primary key column."""
keep_data = True
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-import testbase
-from testbase import Table, Column
+from testlib import *
-class ConcreteTest1(testbase.ORMTest):
+class ConcreteTest1(ORMTest):
def define_tables(self, metadata):
global managers_table, engineers_table
managers_table = Table('managers', metadata,
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
class BaseObject(object):
pass
-class MagazineTest(testbase.ORMTest):
+class MagazineTest(ORMTest):
def define_tables(self, metadata):
global publication_table, issue_table, location_table, location_name_table, magazine_table, \
page_table, magazine_page_table, classified_page_table, page_size_table
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
-class InheritTest(testbase.ORMTest):
+class InheritTest(ORMTest):
"""deals with inheritance and many-to-many relationships"""
def define_tables(self, metadata):
global principals
sess.flush()
# TODO: put an assertion
-class InheritTest2(testbase.ORMTest):
+class InheritTest2(ORMTest):
"""deals with inheritance and many-to-many relationships"""
def define_tables(self, metadata):
global foo, bar, foo_bar
{'id':b.id, 'data':'barfoo', 'foos':(Foo, [{'id':f1.id,'data':'subfoo1'}, {'id':f2.id,'data':'subfoo2'}])},
)
-class InheritTest3(testbase.ORMTest):
+class InheritTest3(ORMTest):
"""deals with inheritance and many-to-many relationships"""
def define_tables(self, metadata):
global foo, bar, blub, bar_foo, blub_bar, blub_foo
compare = repr(b) + repr(b.foos)
sess.clear()
l = sess.query(Bar).select()
- self.echo(repr(l[0]) + repr(l[0].foos))
+ print repr(l[0]) + repr(l[0].foos)
self.assert_(repr(l[0]) + repr(l[0].foos) == compare)
def testadvanced(self):
sess.clear()
l = sess.query(Blub).select()
- self.echo(l)
+ print l
self.assert_(repr(l[0]) == compare)
sess.clear()
x = sess.query(Blub).get_by(id=blubid)
- self.echo(x)
+ print x
self.assert_(repr(x) == compare)
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
-class PolymorphicCircularTest(testbase.ORMTest):
+class PolymorphicCircularTest(ORMTest):
keep_mappers = True
def define_tables(self, metadata):
global Table1, Table1B, Table2, Table3, Data
"""tests basic polymorphic mapper loading/saving, minimal relations"""
import testbase
+import sets
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-import sets
+from testlib import *
class Person(object):
def __repr__(self):
return "Company %s" % self.name
-class PolymorphTest(testbase.ORMTest):
+class PolymorphTest(ORMTest):
def define_tables(self, metadata):
global companies, people, engineers, managers, boss
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-import testbase
+from testlib import *
class AttrSettable(object):
return self.__class__.__name__ + "(%s)" % (hex(id(self)))
-class RelationTest1(testbase.ORMTest):
+class RelationTest1(ORMTest):
"""test self-referential relationships on polymorphic mappers"""
def define_tables(self, metadata):
global people, managers
print p, m, m.employee
assert m.employee is p
-class RelationTest2(testbase.ORMTest):
+class RelationTest2(ORMTest):
"""test self-referential relationships on polymorphic mappers"""
def define_tables(self, metadata):
global people, managers, data
if usedata:
assert m.data.data == 'ms data'
-class RelationTest3(testbase.ORMTest):
+class RelationTest3(ORMTest):
"""test self-referential relationships on polymorphic mappers"""
def define_tables(self, metadata):
global people, managers, data
setattr(RelationTest3, func.__name__, func)
-class RelationTest4(testbase.ORMTest):
+class RelationTest4(ORMTest):
def define_tables(self, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
c = s.join("employee").filter(Person.name=="E4")[0]
assert c.car_id==car1.car_id
-class RelationTest5(testbase.ORMTest):
+class RelationTest5(ORMTest):
def define_tables(self, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
assert carlist[0].manager is None
assert carlist[1].manager.person_id == car2.manager.person_id
-class RelationTest6(testbase.ORMTest):
+class RelationTest6(ORMTest):
"""test self-referential relationships on a single joined-table inheritance mapper"""
def define_tables(self, metadata):
global people, managers, data
m2 = sess.query(Manager).get(m2.person_id)
assert m.colleague is m2
-class RelationTest7(testbase.ORMTest):
+class RelationTest7(ORMTest):
def define_tables(self, metadata):
global people, engineers, managers, cars, offroad_cars
cars = Table('cars', metadata,
for p in r:
assert p.car_id == p.car.car_id
-class GenerativeTest(testbase.AssertMixin):
+class GenerativeTest(AssertMixin):
def setUpAll(self):
# cars---owned by--- people (abstract) --- has a --- status
# | ^ ^ |
r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id))
assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
-class MultiLevelTest(testbase.ORMTest):
+class MultiLevelTest(ORMTest):
def define_tables(self, metadata):
global table_Employee, table_Engineer, table_Manager
table_Employee = Table( 'Employee', metadata,
assert set(session.query( Engineer).select()) == set([b,c])
assert session.query( Manager).select() == [c]
-class ManyToManyPolyTest(testbase.ORMTest):
+class ManyToManyPolyTest(ORMTest):
def define_tables(self, metadata):
global base_item_table, item_table, base_item_collection_table, collection_table
base_item_table = Table(
class_mapper(BaseItem)
-class CustomPKTest(testbase.ORMTest):
+class CustomPKTest(ORMTest):
def define_tables(self, metadata):
global t1, t2
t1 = Table('t1', metadata,
import testbase
+from datetime import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
-from datetime import datetime
-class InheritTest(testbase.ORMTest):
+class InheritTest(ORMTest):
"""tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
def define_tables(self, metadata):
global products_table, specification_table, documents_table
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-import testbase
+from testlib import *
-class SingleInheritanceTest(testbase.AssertMixin):
+class SingleInheritanceTest(AssertMixin):
def setUpAll(self):
metadata = MetaData(testbase.db)
global employees_table
"""basic tests of lazy loaded attributes"""
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-import testbase
-
+from testlib import *
from fixtures import *
from query import QueryTest
-from testbase import PersistTest, AssertMixin
import testbase
-import unittest, sys, os
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-import datetime
+from testlib import *
class LazyTest(AssertMixin):
def setUpAll(self):
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-import string
+from testlib import *
+
class Place(object):
'''represents a place'''
def __repr__(self):
return object.__repr__(self)+ " " + repr(self.inputs) + " " + repr(self.outputs)
-class M2MTest(testbase.ORMTest):
+class M2MTest(ORMTest):
def define_tables(self, metadata):
global place
place = Table('place', metadata,
for p in l:
pp = p.places
- self.echo("Place " + str(p) +" places " + repr(pp))
+ print "Place " + str(p) +" places " + repr(pp)
[sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7]
sess.flush()
self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
-class M2MTest2(testbase.ORMTest):
+class M2MTest2(ORMTest):
def define_tables(self, metadata):
global studentTbl
studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True))
sess.flush()
assert enrolTbl.count().scalar() == 0
-class M2MTest3(testbase.ORMTest):
+class M2MTest3(ORMTest):
def define_tables(self, metadata):
global c, c2a1, c2a2, b, a
c = Table('c', metadata,
"""tests general mapper operations with an emphasis on selecting/loading"""
-from testbase import PersistTest, AssertMixin, ORMTest
import testbase
-import unittest, sys, os
from sqlalchemy import *
from sqlalchemy.orm import *
import sqlalchemy.exceptions as exceptions
from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
+
class MapperSuperTest(AssertMixin):
def setUpAll(self):
u = q2.selectfirst(users.c.user_id==8)
def go():
s.refresh(u)
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
def testexpire(self):
"""test the expire function"""
# l = create_session().query(User).select(order_by=None)
- @testbase.unsupported('firebird')
+ @testing.unsupported('firebird')
def testfunction(self):
"""test mapping to a SELECT statement that has functions in it."""
s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
assert l[0].concat == l[0].user_id * 2 == 14
assert l[1].concat == l[1].user_id * 2 == 16
- @testbase.unsupported('firebird')
+ @testing.unsupported('firebird')
def testcount(self):
"""test the count function on Query.
def go():
u = sess.query(User).options(eagerload('adlist')).get_by(user_name='jack')
self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
def testextensionoptions(self):
sess = create_session()
def go():
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testeageroptionswithlimit(self):
sess = create_session()
def go():
assert u.user_id == 8
assert len(u.addresses) == 3
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
sess.clear()
u = sess.query(User).get_by(user_id=8)
assert u.user_id == 8
assert len(u.addresses) == 3
- assert "tbl_row_count" not in self.capture_sql(db, go)
+ assert "tbl_row_count" not in self.capture_sql(testbase.db, go)
def testlazyoptionswithlimit(self):
sess = create_session()
def go():
assert u.user_id == 8
assert len(u.addresses) == 3
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
def testeagerdegrade(self):
"""tests that an eager relation automatically degrades to a lazy relation if eager columns are not available"""
def go():
l = sess.query(usermapper).select()
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
sess.clear()
r = users.select().execute()
l = usermapper.instances(r, sess)
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 4)
+ self.assert_sql_count(testbase.db, go, 4)
clear_mappers()
def go():
l = sess.query(usermapper).select()
self.assert_result(l, User, *user_all_result)
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
sess.clear()
r = users.select().execute()
l = usermapper.instances(r, sess)
self.assert_result(l, User, *user_all_result)
- self.assert_sql_count(db, go, 7)
+ self.assert_sql_count(testbase.db, go, 7)
def testlazyoptions(self):
l = sess.query(User).options(lazyload('addresses')).select()
def go():
self.assert_result(l, User, *user_address_result)
- self.assert_sql_count(db, go, 3)
+ self.assert_sql_count(testbase.db, go, 3)
def testlatecompile(self):
"""tests mappers compiling late in the game"""
u = sess.query(User).select()
def go():
print u[0].orders[1].items[0].keywords[1]
- self.assert_sql_count(db, go, 3)
+ self.assert_sql_count(testbase.db, go, 3)
def testdeepoptions(self):
mapper(User, users,
u = sess.query(User).select()
def go():
print u[0].orders[1].items[0].keywords[1]
- self.assert_sql_count(db, go, 3)
+ self.assert_sql_count(testbase.db, go, 3)
sess.clear()
def go():
print u[0].orders[1].items[0].keywords[1]
print "-------MARK2----------"
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
sess.clear()
def go():
print u[0].orders[1].items[0].keywords[1]
print "-------MARK3----------"
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
print "-------MARK4----------"
sess.clear()
print "-------MARK5----------"
q3 = sess.query(User).options(eagerload('orders.items.keywords'))
u = q3.select()
- self.assert_sql_count(db, go, 2)
+ self.assert_sql_count(testbase.db, go, 2)
class DeferredTest(MapperSuperTest):
o2 = l[2]
print o2.description
- orderby = str(orders.default_order_by()[0].compile(bind=db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(bind=testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
assert o2.opened == 1
assert o2.userident == 7
assert o2.description == 'order 3'
- orderby = str(orders.default_order_by()[0].compile(db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
o2.description = 'order 3'
def go():
sess.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testcommitsstate(self):
"""test that when deferred elements are loaded via a group, they get the proper CommittedState
def go():
# therefore the flush() shouldnt actually issue any SQL
sess.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testoptions(self):
"""tests using options on a mapper to create deferred and undeferred columns"""
l = q2.select()
print l[2].user_id
- orderby = str(orders.default_order_by()[0].compile(db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
("SELECT orders.user_id AS orders_user_id FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
])
def go():
l = q3.select()
print l[3].user_id
- self.assert_sql(db, go, [
+ self.assert_sql(testbase.db, go, [
("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
])
assert o2.opened == 1
assert o2.userident == 7
assert o2.description == 'order 3'
- orderby = str(orders.default_order_by()[0].compile(db))
- self.assert_sql(db, go, [
+ orderby = str(orders.default_order_by()[0].compile(testbase.db))
+ self.assert_sql(testbase.db, go, [
("SELECT orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen, orders.order_id AS orders_order_id FROM orders ORDER BY %s" % orderby, {}),
])
item = l[0].orders[1].items[1]
def go():
print item.item_name
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
self.assert_(item.item_name == 'item 4')
sess.clear()
q2 = q.options(undefer('orders.items.item_name'))
item = l[0].orders[1].items[1]
def go():
print item.item_name
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
self.assert_(item.item_name == 'item 4')
class CompositeTypesTest(ORMTest):
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from sqlalchemy.orm import mapperlib, session, unitofwork, attributes
-Mapper = mapperlib.Mapper
-import gc
import testbase
-from testbase import Table, Column
-import tables
+import gc
+from sqlalchemy import MetaData, Integer, String, ForeignKey
+from sqlalchemy.orm import mapper, relation, clear_mappers, create_session
+from sqlalchemy.orm.mapper import Mapper
+from testlib import *
class A(object):pass
class B(object):pass
-class MapperCleanoutTest(testbase.AssertMixin):
+class MapperCleanoutTest(AssertMixin):
"""test that clear_mappers() removes everything related to the class.
does not include classes that use the assignmapper extension."""
- def setUp(self):
- global engine
- engine = testbase.db
-
+
def test_mapper_cleanup(self):
for x in range(0, 5):
self.do_test()
assert True
def do_test(self):
- metadata = MetaData(engine)
+ metadata = MetaData(testbase.db)
table1 = Table("mytable", metadata,
Column('col1', Integer, primary_key=True),
-from testbase import PersistTest, AssertMixin
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
class MergeTest(AssertMixin):
"""tests session.merge() functionality"""
if __name__ == "__main__":
testbase.main()
-
-
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.sessioncontext import SessionContext
-from testbase import Table, Column
+from testlib import *
class Jack(object):
def __repr__(self):
self.name=name
self.description = description
-class O2OTest(testbase.AssertMixin):
+class O2OTest(AssertMixin):
def setUpAll(self):
global jack, port, metadata, ctx
metadata = MetaData(testbase.db)
import testbase
+import operator
from sqlalchemy import *
from sqlalchemy import ansisql
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
from fixtures import *
-import operator
class Base(object):
def __init__(self, **kwargs):
else:
return True
-class QueryTest(testbase.ORMTest):
+class QueryTest(ORMTest):
keep_mappers = True
keep_data = True
assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all()
-class SelfReferentialJoinTest(testbase.ORMTest):
+class SelfReferentialJoinTest(ORMTest):
def define_tables(self, metadata):
global nodes
nodes = Table('nodes', metadata,
import testbase
-import unittest, sys, datetime
+import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm import collections
from sqlalchemy.orm.collections import collection
-from testbase import Table, Column
+from testlib import *
-db = testbase.db
-
-class RelationTest(testbase.PersistTest):
+class RelationTest(PersistTest):
"""this is essentially an extension of the "dependency.py" topological sort test.
in this test, a table is dependent on two other tables that are otherwise unrelated to each other.
the dependency sort must insure that this childmost table is below both parent tables in the outcome
to subtle differences in program execution, this test case was exposing the bug whereas the simpler tests
were not."""
def setUpAll(self):
- global tbl_a
- global tbl_b
- global tbl_c
- global tbl_d
+ global metadata, tbl_a, tbl_b, tbl_c, tbl_d
+
metadata = MetaData()
tbl_a = Table("tbl_a", metadata,
Column("id", Integer, primary_key=True),
conn.drop(tbl_a)
def tearDownAll(self):
- testbase.metadata.tables.clear()
+ metadata.drop_all(testbase.db)
def testDeleteRootTable(self):
session.flush()
session.delete(c) # fails
session.flush()
-class RelationTest2(testbase.PersistTest):
+class RelationTest2(PersistTest):
"""this test tests a relationship on a column that is included in multiple foreign keys,
as well as a self-referential relationship on a composite key where one column in the foreign key
is 'joined to itself'."""
assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
-class RelationTest3(testbase.PersistTest):
+class RelationTest3(PersistTest):
def setUpAll(self):
global jobs, pageversions, pages, metadata, Job, Page, PageVersion, PageComment
import datetime
s.delete(j)
s.flush()
-class RelationTest4(testbase.ORMTest):
+class RelationTest4(ORMTest):
"""test syncrules on foreign keys that are also primary"""
def define_tables(self, metadata):
global tableA, tableB
assert a1 not in sess
assert b1 not in sess
-class RelationTest5(testbase.ORMTest):
+class RelationTest5(ORMTest):
"""test a map to a select that relates to a map to the table"""
def define_tables(self, metadata):
global items
assert old.id == new.id
-class TypeMatchTest(testbase.ORMTest):
+class TypeMatchTest(ORMTest):
"""test errors raised when trying to add items whose type is not handled by a relation"""
def define_tables(self, metadata):
global a, b, c, d
except exceptions.AssertionError, err:
assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
-class TypedAssociationTable(testbase.ORMTest):
+class TypedAssociationTable(ORMTest):
def define_tables(self, metadata):
global t1, t2, t3
assert t3.count().scalar() == 1
# TODO: move these tests to either attributes.py test or its own module
-class CustomCollectionsTest(testbase.ORMTest):
+class CustomCollectionsTest(ORMTest):
def define_tables(self, metadata):
global sometable, someothertable
sometable = Table('sometable', metadata,
o = list(p2.children)
assert len(o) == 3
-class ViewOnlyTest(testbase.ORMTest):
+class ViewOnlyTest(ORMTest):
"""test a view_only mapping where a third table is pulled into the primary join condition,
using overlapping PK column names (should not produce "conflicting column" error)"""
def define_tables(self, metadata):
assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id])
assert set([x.id for x in c1.t2_view]) == set([c2b.id])
-class ViewOnlyTest2(testbase.ORMTest):
+class ViewOnlyTest2(ORMTest):
"""test a view_only mapping where a third table is pulled into the primary join condition,
using non-overlapping PK column names (should not produce "mapper has no column X" error)"""
def define_tables(self, metadata):
-from testbase import AssertMixin
import testbase
-import unittest, sys, datetime
-
-import tables
-from tables import *
-
-db = testbase.db
from sqlalchemy import *
from sqlalchemy.orm import *
-
+from testlib import *
+from testlib.tables import *
+import testlib.tables as tables
class SessionTest(AssertMixin):
def setUpAll(self):
# then see if expunge fails
session.expunge(u)
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def test_transaction(self):
class User(object):pass
mapper(User, users)
assert conn1.execute("select count(1) from users").scalar() == 1
assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def test_autoflush(self):
class User(object):pass
mapper(User, users)
assert conn1.execute("select count(1) from users").scalar() == 1
assert testbase.db.connect().execute("select count(1) from users").scalar() == 1
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def test_autoflush_unbound(self):
class User(object):pass
mapper(User, users)
trans.rollback() # rolls back
assert len(sess.query(User).select()) == 0
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def test_external_nested_transaction(self):
class User(object):pass
mapper(User, users)
conn.close()
raise
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def test_twophase(self):
# TODO: mock up a failure condition here
# to ensure a rollback succeeds
sess.rollback() # rolls back
assert len(sess.query(User).select()) == 0
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def test_nested_transaction(self):
class User(object):pass
mapper(User, users)
sess.commit()
assert len(sess.query(User).select()) == 1
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def test_nested_autotrans(self):
class User(object):pass
mapper(User, users)
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.orm.session import object_session, Session
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
+from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.session import object_session, Session
+from testlib import *
-import testbase
-from testbase import Table, Column
metadata = MetaData()
users = Table('users', metadata,
-from testbase import PersistTest, AssertMixin
-from sqlalchemy import *
-from sqlalchemy.orm import *
import testbase
-from testbase import Table, Column
import pickleable
+from sqlalchemy import *
+from sqlalchemy.orm import *
from sqlalchemy.orm.mapper import global_extensions
from sqlalchemy.orm import util as ormutil
from sqlalchemy.ext.sessioncontext import SessionContext
import sqlalchemy.ext.assignmapper as assignmapper
-from tables import *
-import tables
+from testlib import *
+from testlib.tables import *
+from testlib import tables
"""tests unitofwork operations"""
class HistoryTest(UnitOfWorkTest):
def setUpAll(self):
+ tables.metadata.bind = testbase.db
UnitOfWorkTest.setUpAll(self)
users.create()
addresses.create()
UnitOfWorkTest.setUpAll(self)
ctx.current.clear()
global version_table
- version_table = Table('version_test', MetaData(db),
+ version_table = Table('version_test', MetaData(testbase.db),
Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
Column('version_id', Integer, nullable=False),
Column('value', String(40), nullable=False)
ctx.current.flush()
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
f1.value = unicode('someothervalue')
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"UPDATE mutabletest SET value=:value WHERE mutabletest.id = :mutabletest_id",
{'mutabletest_id': f1.id, 'value': u'someothervalue'}
])
f1.value = unicode('hi')
f1.data.x = 9
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"UPDATE mutabletest SET data=:data, value=:value WHERE mutabletest.id = :mutabletest_id",
{'mutabletest_id': f1.id, 'value': u'hi', 'data':f1.data}
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
ctx.current.clear()
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
f2.data.y = 19
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 1)
+ self.assert_sql_count(testbase.db, go, 1)
ctx.current.clear()
f3 = ctx.current.query(Foo).get_by(id=f1.id)
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testunicode(self):
"""test that two equivalent unicode values dont get flagged as changed.
f1.value = u'hi'
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
class PKTest(UnitOfWorkTest):
def setUpAll(self):
UnitOfWorkTest.setUpAll(self)
global table, table2, table3, metadata
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
table = Table(
'multipk', metadata,
Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True),
# not support on sqlite since sqlite's auto-pk generation only works with
# single column primary keys
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testprimarykey(self):
class Entry(object):
pass
metadata.drop_all()
UnitOfWorkTest.tearDownAll(self)
- @testbase.unsupported('sqlite')
+ @testing.unsupported('sqlite')
def testbasic(self):
class MyClass(object):
pass
defaults back from the engine."""
def setUpAll(self):
UnitOfWorkTest.setUpAll(self)
+ db = testbase.db
use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
if use_string_defaults:
a2 = Address()
a2.email_address = 'lala@test.org'
u.addresses.append(a2)
- self.echo( repr(u.addresses))
+ print repr(u.addresses)
ctx.current.flush()
usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall()
u2.user_name = 'user2modified'
u1.addresses.append(a3)
del u1.addresses[0]
- self.assert_sql(db, lambda: ctx.current.flush(),
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(),
[
(
"UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id",
# assert the first one retreives the same from the identity map
nu = ctx.current.get(m, u.user_id)
- self.echo( "U: " + repr(u) + "NU: " + repr(nu))
+ print "U: " + repr(u) + "NU: " + repr(nu)
self.assert_(u is nu)
# clear out the identity map, so next get forces a SELECT
u.user_name = ""
def go():
ctx.current.flush()
- self.assert_sql_count(db, go, 0)
+ self.assert_sql_count(testbase.db, go, 0)
def testmultitable(self):
"""tests a save of an object where each instance spans two tables. also tests
objects[2].email_address = 'imnew@foo.bar'
objects[3].user = User()
objects[3].user.user_name = 'imnewlyadded'
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'imnewlyadded'}
k = Keyword()
k.name = 'yellow'
objects[5].keywords.append(k)
- self.assert_sql(db, lambda:ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda:ctx.current.flush(), [
{
"UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}
objects[2].keywords.append(k)
dkid = objects[5].keywords[1].keyword_id
del objects[5].keywords[1]
- self.assert_sql(db, lambda:ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda:ctx.current.flush(), [
(
"DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id",
[{'item_id': objects[5].item_id, 'keyword_id': dkid}]
ctx.current.clear()
clear_mappers()
global meta, users, addresses
- meta = MetaData(db)
+ meta = MetaData(testbase.db)
users = Table('users', meta,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
Column('user_name', String(20)),
a.user = User()
a.user.user_name = elem['user_name']
objects.append(a)
- self.assert_sql(db, lambda: ctx.current.flush(), [
+ self.assert_sql(testbase.db, lambda: ctx.current.flush(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'thesub'}
]
)
-class SaveTest3(UnitOfWorkTest):
+class SaveTest3(UnitOfWorkTest):
def setUpAll(self):
+ global st3_metadata, t1, t2, t3
+
UnitOfWorkTest.setUpAll(self)
- global metadata, t1, t2, t3
- metadata = testbase.metadata
- t1 = Table('items', metadata,
+
+ st3_metadata = MetaData(testbase.db)
+ t1 = Table('items', st3_metadata,
Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
Column('item_name', VARCHAR(50)),
)
- t3 = Table('keywords', metadata,
+ t3 = Table('keywords', st3_metadata,
Column('keyword_id', Integer, Sequence('keyword_id_seq', optional=True), primary_key = True),
Column('name', VARCHAR(50)),
)
- t2 = Table('assoc', metadata,
+ t2 = Table('assoc', st3_metadata,
Column('item_id', INT, ForeignKey("items")),
Column('keyword_id', INT, ForeignKey("keywords")),
Column('foo', Boolean, default=True)
)
- metadata.create_all()
+ st3_metadata.create_all()
def tearDownAll(self):
- metadata.drop_all()
+ st3_metadata.drop_all()
UnitOfWorkTest.tearDownAll(self)
def setUp(self):
import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
+from testlib import *
from timeit import Timer
import sys
# times how long it takes to create 26000 objects
-import sys
-sys.path.insert(0, './lib/')
+import testbase
-from sqlalchemy.attributes import *
+from sqlalchemy.orm.attributes import *
import time
import gc
-import sys
-sys.path.insert(0, './lib/')
-
+import testbase
import gc
import random, string
-from sqlalchemy.attributes import *
+from sqlalchemy.orm.attributes import *
# with this test, run top. make sure the Python process doenst grow in size arbitrarily.
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
+import testbase
+import hotshot, hotshot.stats
from sqlalchemy import *
from sqlalchemy.orm import *
-from testbase import Table, Column
-import StringIO
-import testbase
-import gc
-import time
-import hotshot
-import hotshot.stats
-
-db = testbase.db
+from testlib import *
NUM = 500
DIVISOR = 50
-class LoadTest(AssertMixin):
- def setUpAll(self):
- global items, meta,subitems
- meta = MetaData(db)
- items = Table('items', meta,
- Column('item_id', Integer, primary_key=True),
- Column('value', String(100)))
- subitems = Table('subitems', meta,
- Column('sub_id', Integer, primary_key=True),
- Column('parent_id', Integer, ForeignKey('items.item_id')),
- Column('value', String(100)))
- meta.create_all()
- def tearDownAll(self):
- meta.drop_all()
- def setUp(self):
- clear_mappers()
+meta = MetaData(testbase.db)
+items = Table('items', meta,
+ Column('item_id', Integer, primary_key=True),
+ Column('value', String(100)))
+subitems = Table('subitems', meta,
+ Column('sub_id', Integer, primary_key=True),
+ Column('parent_id', Integer, ForeignKey('items.item_id')),
+ Column('value', String(100)))
+
+class Item(object):pass
+class SubItem(object):pass
+mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)})
+mapper(SubItem, subitems)
+
+def load():
+ global l
+ l = []
+ for x in range(1,NUM/DIVISOR + 1):
+ l.append({'item_id':x, 'value':'this is item #%d' % x})
+ #print l
+ items.insert().execute(*l)
+ for x in range(1, NUM/DIVISOR + 1):
l = []
- for x in range(1,NUM/DIVISOR + 1):
- l.append({'item_id':x, 'value':'this is item #%d' % x})
+ for y in range(1, DIVISOR + 1):
+ z = ((x-1) * DIVISOR) + y
+ l.append({'sub_id':z,'value':'this is item #%d' % z, 'parent_id':x})
#print l
- items.insert().execute(*l)
- for x in range(1, NUM/DIVISOR + 1):
- l = []
- for y in range(1, DIVISOR + 1):
- z = ((x-1) * DIVISOR) + y
- l.append({'sub_id':z,'value':'this is iteim #%d' % z, 'parent_id':x})
- #print l
- subitems.insert().execute(*l)
- def testload(self):
- class Item(object):pass
- class SubItem(object):pass
- mapper(Item, items, properties={'subs':relation(SubItem, lazy=False)})
- mapper(SubItem, subitems)
- sess = create_session()
- prof = hotshot.Profile("masseagerload.prof")
- prof.start()
- query = sess.query(Item)
- l = query.select()
- print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
- prof.stop()
- prof.close()
- stats = hotshot.stats.load("masseagerload.prof")
- stats.sort_stats('time', 'calls')
- stats.print_stats()
-
-if __name__ == "__main__":
- testbase.main()
+ subitems.insert().execute(*l)
+
+@profiling.profiled('masseagerload', always=True)
+def masseagerload(session):
+ query = session.query(Item)
+ l = query.select()
+ print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
+
+def all():
+ meta.create_all()
+ try:
+ load()
+ masseagerload(create_session())
+ finally:
+ meta.drop_all()
+
+if __name__ == '__main__':
+ all()
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import sqlalchemy.orm.attributes as attributes
-from testbase import Table, Column
-import StringIO
import testbase
-import gc
import time
-
-db = testbase.db
+#import gc
+#import sqlalchemy.orm.attributes as attributes
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
NUM = 2500
class LoadTest(AssertMixin):
def setUpAll(self):
global items, meta
- meta = MetaData(db)
+ meta = MetaData(testbase.db)
items = Table('items', meta,
Column('item_id', Integer, primary_key=True),
Column('value', String(100)))
def tearDownAll(self):
items.drop()
def setUp(self):
- objectstore.clear()
- clear_mappers()
for x in range(1,NUM/500+1):
l = []
for y in range(x*500-500 + 1, x*500 + 1):
-from testbase import PersistTest, AssertMixin
-import unittest, sys, os
-from sqlalchemy import *
-import sqlalchemy.attributes as attributes
-from testbase import Table, Column
-import StringIO
import testbase
-import gc
-import sqlalchemy.orm.session
import types
-db = testbase.db
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
NUM = 250000
class SaveTest(AssertMixin):
def setUpAll(self):
global items, metadata
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
items = Table('items', metadata,
Column('item_id', Integer, primary_key=True),
Column('value', String(100)))
--- /dev/null
+import testbase
+import time
+from datetime import datetime
+
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.profiling import profiled
+
+class Item(object):
+ def __repr__(self):
+ return 'Item<#%s "%s">' % (self.id, self.name)
+class SubItem(object):
+ def __repr__(self):
+ return 'SubItem<#%s "%s">' % (self.id, self.name)
+class Customer(object):
+ def __repr__(self):
+ return 'Customer<#%s "%s">' % (self.id, self.name)
+class Purchase(object):
+ def __repr__(self):
+ return 'Purchase<#%s "%s">' % (self.id, self.purchase_date)
+
+items, subitems, customers, purchases, purchaseitems = \
+ None, None, None, None, None
+
+metadata = MetaData()
+
+@profiled('table')
+def define_tables():
+ global items, subitems, customers, purchases, purchaseitems
+ items = Table('items', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(100)),
+ test_needs_acid=True)
+ subitems = Table('subitems', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('item_id', Integer, ForeignKey('items.id'),
+ nullable=False),
+ Column('name', String(100), PassiveDefault('no name')),
+ test_needs_acid=True)
+ customers = Table('customers', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(100)),
+ *[Column("col_%s" % chr(i), String(64), default=str(i))
+ for i in range(97,117)],
+ **dict(test_needs_acid=True))
+ purchases = Table('purchases', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('customer_id', Integer,
+ ForeignKey('customers.id'), nullable=False),
+ Column('purchase_date', DateTime,
+ default=datetime.now),
+ test_needs_acid=True)
+ purchaseitems = Table('purchaseitems', metadata,
+ Column('purchase_id', Integer,
+ ForeignKey('purchases.id'),
+ nullable=False, primary_key=True),
+ Column('item_id', Integer, ForeignKey('items.id'),
+ nullable=False, primary_key=True),
+ test_needs_acid=True)
+
+@profiled('mapper')
+def setup_mappers():
+ mapper(Item, items, properties={
+ 'subitems': relation(SubItem, backref='item', lazy=True)
+ })
+ mapper(SubItem, subitems)
+ mapper(Customer, customers, properties={
+ 'purchases': relation(Purchase, lazy=True, backref='customer')
+ })
+ mapper(Purchase, purchases, properties={
+ 'items': relation(Item, lazy=True, secondary=purchaseitems)
+ })
+
+@profiled('inserts')
+def insert_data():
+ q_items = 1000
+ q_sub_per_item = 10
+ q_customers = 1000
+
+ con = testbase.db.connect()
+
+ transaction = con.begin()
+ data, subdata = [], []
+ for item_id in xrange(1, q_items + 1):
+ data.append({'name': "item number %s" % item_id})
+ for subitem_id in xrange(1, (item_id % q_sub_per_item) + 1):
+ subdata.append({'item_id': item_id,
+ 'name': "subitem number %s" % subitem_id})
+ if item_id % 100 == 0:
+ items.insert().execute(*data)
+ subitems.insert().execute(*subdata)
+ del data[:]
+ del subdata[:]
+ if data:
+ items.insert().execute(*data)
+ if subdata:
+ subitems.insert().execute(*subdata)
+ transaction.commit()
+
+ transaction = con.begin()
+ data = []
+ for customer_id in xrange(1, q_customers):
+ data.append({'name': "customer number %s" % customer_id})
+ if customer_id % 100 == 0:
+ customers.insert().execute(*data)
+ del data[:]
+ if data:
+ customers.insert().execute(*data)
+ transaction.commit()
+
+ transaction = con.begin()
+ data, subdata = [], []
+ order_t = int(time.time()) - (5000 * 5 * 60)
+ current = xrange(1, q_customers)
+ step, purchase_id = 1, 0
+ while current:
+ next = []
+ for customer_id in current:
+ order_t += 300
+ data.append({'customer_id': customer_id,
+ 'purchase_date': datetime.fromtimestamp(order_t)})
+ purchase_id += 1
+ for item_id in range(customer_id % 200, customer_id + 1, 200):
+ if item_id != 0:
+ subdata.append({'purchase_id': purchase_id,
+ 'item_id': item_id})
+ if customer_id % 10 > step:
+ next.append(customer_id)
+
+ if len(data) >= 100:
+ purchases.insert().execute(*data)
+ if subdata:
+ purchaseitems.insert().execute(*subdata)
+ del data[:]
+ del subdata[:]
+ step, current = step + 1, next
+
+ if data:
+ purchases.insert().execute(*data)
+ if subdata:
+ purchaseitems.insert().execute(*subdata)
+ transaction.commit()
+
+@profiled('queries')
+def run_queries():
+ session = create_session()
+ # no explicit transaction here.
+
+ # build a report of summarizing the last 50 purchases and
+ # the top 20 items from all purchases
+
+ q = session.query(Purchase). \
+ limit(50).order_by(desc(Purchase.purchase_date)). \
+ options(eagerload('items'), eagerload('items.subitems'),
+ eagerload('customer'))
+
+ report = []
+ # "write" the report. pretend it's going to a web template or something,
+ # the point is to actually pull data through attributes and collections.
+ for purchase in q:
+ report.append(purchase.customer.name)
+ report.append(purchase.customer.col_a)
+ report.append(purchase.purchase_date)
+ for item in purchase.items:
+ report.append(item.name)
+ report.extend([s.name for s in item.subitems])
+
+ # pull a report of the top 20 items of all time
+ _item_id = purchaseitems.c.item_id
+ top_20_q = select([func.distinct(_item_id).label('id')],
+ group_by=[purchaseitems.c.purchase_id, _item_id],
+ order_by=[desc(func.count(_item_id)), _item_id],
+ limit=20)
+ q2 = session.query(Item).filter(Item.id.in_(top_20_q))
+
+ for num, item in enumerate(q2):
+ report.append("number %s: %s" % (num + 1, item.name))
+
+def setup_db():
+ metadata.drop_all()
+ metadata.create_all()
+def cleanup_db():
+ metadata.drop_all()
+
+@profiled('all')
+def main():
+ metadata.bind = testbase.db
+ define_tables()
+ setup_mappers()
+ setup_db()
+ insert_data()
+ run_queries()
+ cleanup_db()
+
+main()
# load test of connection pool
+import testbase
from sqlalchemy import *
import sqlalchemy.pool as pool
import thread,time
-db = create_engine('mysql://scott:tiger@127.0.0.1/test', pool_timeout=30, echo_pool=True)
+db = create_engine(testbase.db.url, pool_timeout=30, echo_pool=True)
metadata = MetaData(db)
users_table = Table('users', metadata,
def runfast():
while True:
- c = db.connection_provider._pool.connect()
+ c = db.pool.connect()
time.sleep(.5)
c.close()
# result = users_table.select(limit=100).execute()
when additional mappers are created while the existing
collection is being compiled."""
+import testbase
from sqlalchemy import *
from sqlalchemy.orm import *
import thread, time
from sqlalchemy.orm import mapperlib
-from testbase import Table, Column
+from testlib import *
meta = MetaData('sqlite:///foo.db')
#!/usr/bin/python
+"""Uses ``wsgiref``, standard in Python 2.5 and also in the cheeseshop."""
+import testbase
from sqlalchemy import *
-import sqlalchemy.pool as pool
+from sqlalchemy.orm import *
import thread
-from sqlalchemy import exceptions
-from testbase import Table, Column
+from testlib import *
+
+port = 8000
import logging
logging.basicConfig()
logging.getLogger('sqlalchemy.pool').setLevel(logging.INFO)
threadids = set()
-#meta = MetaData('postgres://scott:tiger@127.0.0.1/test')
-
-#meta = MetaData('mysql://scott:tiger@localhost/test', poolclass=pool.SingletonThreadPool)
-meta = MetaData('mysql://scott:tiger@localhost/test')
+meta = MetaData(testbase.db)
foo = Table('foo', meta,
Column('id', Integer, primary_key=True),
Column('data', String(30)))
-
-meta.drop_all()
-meta.create_all()
-
-data = []
-for x in range(1,500):
- data.append({'id':x,'data':"this is x value %d" % x})
-foo.insert().execute(data)
-
class Foo(object):
pass
-
mapper(Foo, foo)
-root = './'
-port = 8000
+def prep():
+ meta.drop_all()
+ meta.create_all()
+
+ data = []
+ for x in range(1,500):
+ data.append({'id':x,'data':"this is x value %d" % x})
+ foo.insert().execute(data)
def serve(environ, start_response):
+ start_response("200 OK", [('Content-type', 'text/plain')])
sess = create_session()
l = sess.query(Foo).select()
-
- start_response("200 OK", [('Content-type','text/plain')])
threadids.add(thread.get_ident())
- print "sending response on thread", thread.get_ident(), " total threads ", len(threadids)
- return ["\n".join([x.data for x in l])]
+
+ print ("sending response on thread", thread.get_ident(),
+ " total threads ", len(threadids))
+ return [str("\n".join([x.data for x in l]))]
if __name__ == '__main__':
- from wsgiutils import wsgiServer
- server = wsgiServer.WSGIServer (('localhost', port), {'/': serve})
- print "Server listening on port %d" % port
- server.serve_forever()
+ from wsgiref import simple_server
+ try:
+ prep()
+ server = simple_server.make_server('localhost', port, serve)
+ print "Server listening on port %d" % port
+ server.serve_forever()
+ finally:
+ meta.drop_all()
alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
return alltests
-
-
if __name__ == '__main__':
testbase.main(suite())
-import sys
import testbase
+import sys
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
-class CaseTest(testbase.PersistTest):
+class CaseTest(PersistTest):
def setUpAll(self):
metadata = MetaData(testbase.db)
import testbase
from sqlalchemy import *
-from testbase import Table, Column
-import sys
+from testlib import *
-class ConstraintTest(testbase.AssertMixin):
+class ConstraintTest(AssertMixin):
def setUp(self):
global metadata
)
metadata.create_all()
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_check_constraint(self):
foo = Table('foo', metadata,
Column('id', Integer, primary_key=True),
-from testbase import PersistTest
-import sqlalchemy.util as util
-import unittest, sys, os
-import sqlalchemy.schema as schema
import testbase
from sqlalchemy import *
+import sqlalchemy.util as util
+import sqlalchemy.schema as schema
from sqlalchemy.orm import mapper, create_session
-from testbase import Table, Column
-import sqlalchemy
-
-db = testbase.db
+from testlib import *
class DefaultTest(PersistTest):
def setUpAll(self):
global t, f, f2, ts, currenttime, metadata
- metadata = MetaData(testbase.db)
+
+ db = testbase.db
+ metadata = MetaData(db)
x = {'x':50}
def mydefault():
x['x'] += 1
t.delete().execute()
def teststandalone(self):
- c = db.engine.contextual_connect()
+ c = testbase.db.engine.contextual_connect()
x = c.execute(t.c.col1.default)
y = t.c.col2.default.execute()
z = c.execute(t.c.col3.default)
t.insert().execute()
ctexec = currenttime.scalar()
- self.echo("Currenttime "+ repr(ctexec))
+ print "Currenttime "+ repr(ctexec)
l = t.select().execute()
self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False), (52, 'imthedefault', f, ts, ts, ctexec, True, False), (53, 'imthedefault', f, ts, ts, ctexec, True, False)])
pk = r.last_inserted_ids()[0]
t.update(t.c.col1==pk).execute(col4=None, col5=None)
ctexec = currenttime.scalar()
- self.echo("Currenttime "+ repr(ctexec))
+ print "Currenttime "+ repr(ctexec)
l = t.select(t.c.col1==pk).execute()
l = l.fetchone()
self.assert_(l == (pk, 'im the update', f2, None, None, ctexec, True, False))
l = l.fetchone()
self.assert_(l['col3'] == 55)
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def testpassiveoverride(self):
"""primarily for postgres, tests that when we get a primary key column back
from reflecting a table which has a default value on it, we pre-execute
testbase.db.execute("drop table speedy_users", None)
class AutoIncrementTest(PersistTest):
- @testbase.supported('postgres', 'mysql')
+ @testing.supported('postgres', 'mysql')
def testnonautoincrement(self):
meta = MetaData(testbase.db)
nonai_table = Table("aitest", meta,
class SequenceTest(PersistTest):
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def setUpAll(self):
global cartitems, sometable, metadata
metadata = MetaData(testbase.db)
metadata.create_all()
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def testseqnonpk(self):
"""test sequences fire off as defaults on non-pk columns"""
sometable.insert().execute(name="somename")
(2, "someother", 2),
]
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def testsequence(self):
cartitems.insert().execute(description='hi')
cartitems.insert().execute(description='there')
cartitems.select().execute().fetchall()
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def test_implicit_sequence_exec(self):
s = Sequence("my_sequence", metadata=MetaData(testbase.db))
s.create()
finally:
s.drop()
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def teststandalone_explicit(self):
s = Sequence("my_sequence")
s.create(bind=testbase.db)
finally:
s.drop(testbase.db)
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def test_checkfirst(self):
s = Sequence("my_sequence")
s.create(testbase.db, checkfirst=False)
s.drop(testbase.db, checkfirst=False)
s.drop(testbase.db, checkfirst=True)
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def teststandalone2(self):
x = cartitems.c.cart_id.sequence.execute()
self.assert_(1 <= x <= 4)
- @testbase.supported('postgres', 'oracle')
+ @testing.supported('postgres', 'oracle')
def tearDownAll(self):
metadata.drop_all()
import testbase
from sql import select as selecttests
-
from sqlalchemy import *
+from testlib import *
-class TraversalTest(testbase.AssertMixin):
+class TraversalTest(AssertMixin):
"""test ClauseVisitor's traversal, particularly its ability to copy and modify
a ClauseElement in place."""
s = select([t2], t2.c.col1==t1.c.col1, correlate=False)
s = s.correlate(t1).order_by(t2.c.col3)
self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3")
-
-
-
-
-
+
+
if __name__ == '__main__':
- testbase.main()
\ No newline at end of file
+ testbase.main()
import testbase
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
+
# TODO: either create a mock dialect with named paramstyle and a short identifier length,
# or find a way to just use sqlite dialect and make those changes
-class LabelTypeTest(testbase.PersistTest):
+class LabelTypeTest(PersistTest):
def test_type(self):
m = MetaData()
t = Table('sometable', m,
assert isinstance(t.c.col1.label('hi').type, Integer)
assert isinstance(select([t.c.col2], scalar=True).label('lala').type, Float)
-class LongLabelsTest(testbase.PersistTest):
+class LongLabelsTest(PersistTest):
def setUpAll(self):
global metadata, table1, maxlen
metadata = MetaData(testbase.db)
-from testbase import PersistTest
import testbase
-import unittest, sys, datetime
-
-import sqlalchemy.databases.sqlite as sqllite
-
-import tables
+import datetime
from sqlalchemy import *
-from sqlalchemy.engine import ResultProxy, RowProxy
from sqlalchemy import exceptions
-from testbase import Table, Column
+from testlib import *
class QueryTest(PersistTest):
r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall()
self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r))
- @testbase.unsupported('mssql')
+ @testing.unsupported('mssql')
def testselectlimitoffset(self):
users.insert().execute(user_id=1, user_name='john')
users.insert().execute(user_id=2, user_name='jack')
r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall()
self.assert_(r==[(6, 'ralph'), (7, 'fido')])
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def testselectlimitoffset_mssql(self):
try:
r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
except exceptions.InvalidRequestError:
pass
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_scalar_select(self):
"""test that scalar subqueries with labels get their type propigated to the result set."""
# mysql and/or mysqldb has a bug here, type isnt propigated for scalar subquery.
finally:
meta.drop_all()
- @testbase.supported('postgres')
+ @testing.supported('postgres')
def test_functions_with_cols(self):
# TODO: shouldnt this work on oracle too ?
x = testbase.db.func.current_date().execute().scalar()
self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
self.assertEqual(r.values(), ['foo', 1])
- @testbase.unsupported('oracle', 'firebird')
+ @testing.unsupported('oracle', 'firebird')
def test_column_accessor_shadow(self):
meta = MetaData(testbase.db)
shadowed = Table('test_shadowed', meta,
finally:
shadowed.drop(checkfirst=True)
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_fetchid_trigger(self):
meta = MetaData(testbase.db)
t1 = Table('t1', meta,
con.execute("""drop trigger paj""")
meta.drop_all()
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_insertid_schema(self):
meta = MetaData(testbase.db)
con = testbase.db.connect()
tbl.drop()
con.execute('drop schema paj')
- @testbase.supported('mssql')
+ @testing.supported('mssql')
def test_insertid_reserved(self):
meta = MetaData(testbase.db)
table = Table(
assert u.execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
assert u.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_intersect(self):
i = intersect(
select([t2.c.col3, t2.c.col4]),
assert i.execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
assert i.alias('bar').select().execute().fetchall() == [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
- @testbase.unsupported('mysql', 'oracle')
+ @testing.unsupported('mysql', 'oracle')
def test_except_style1(self):
e = except_(union(
select([t1.c.col3, t1.c.col4]),
), select([t2.c.col3, t2.c.col4]))
assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
- @testbase.unsupported('mysql', 'oracle')
+ @testing.unsupported('mysql', 'oracle')
def test_except_style2(self):
e = except_(union(
select([t1.c.col3, t1.c.col4]),
assert e.execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
- @testbase.unsupported('sqlite', 'mysql', 'oracle')
+ @testing.unsupported('sqlite', 'mysql', 'oracle')
def test_except_style3(self):
# aaa, bbb, ccc - (aaa, bbb, ccc - (ccc)) = ccc
e = except_(
)
self.assertEquals(e.execute().fetchall(), [('ccc',)])
- @testbase.unsupported('sqlite', 'mysql', 'oracle')
+ @testing.unsupported('sqlite', 'mysql', 'oracle')
def test_union_union_all(self):
e = union_all(
select([t1.c.col3]),
)
self.assertEquals(e.execute().fetchall(), [('aaa',),('bbb',),('ccc',),('aaa',),('bbb',),('ccc',)])
- @testbase.unsupported('mysql')
+ @testing.unsupported('mysql')
def test_composite(self):
u = intersect(
select([t2.c.col3, t2.c.col4]),
-from testbase import PersistTest
import testbase
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
class QuoteTest(PersistTest):
assert t1.c.UcCol.case_sensitive is False
assert t2.c.normalcol.case_sensitive is False
- @testbase.unsupported('oracle')
+ @testing.unsupported('oracle')
def testlabels(self):
"""test the quoting of labels.
-from sqlalchemy import *
-from testbase import Table, Column
import testbase
+from sqlalchemy import *
+from testlib import *
-class FoundRowsTest(testbase.AssertMixin):
+class FoundRowsTest(AssertMixin):
"""tests rowcount functionality"""
def setUpAll(self):
metadata = MetaData(testbase.db)
-from testbase import PersistTest
import testbase
+import re, operator
from sqlalchemy import *
from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
-from testbase import Table, Column
-import unittest, re, operator
+from testlib import *
# the select test now tests almost completely with TableClause/ColumnClause objects,
class SQLTest(PersistTest):
def runtest(self, clause, result, dialect = None, params = None, checkparams = None):
c = clause.compile(parameters=params, dialect=dialect)
- self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
+ print "\nSQL String:\n" + str(c) + repr(c.get_params())
cc = re.sub(r'\n', '', str(c))
self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'")
if checkparams is not None:
every selectable unit behaving nicely with others.."""\r
\r
import testbase\r
-import unittest, sys, datetime\r
from sqlalchemy import *\r
-from testbase import Table, Column\r
-\r
-db = testbase.db\r
-metadata = MetaData(db)\r
-\r
+from testlib import *\r
\r
+metadata = MetaData()\r
table = Table('table1', metadata, \r
Column('col1', Integer, primary_key=True),\r
Column('col2', String(20)),\r
Column('coly', Integer),\r
)\r
\r
-class SelectableTest(testbase.AssertMixin):\r
+class SelectableTest(AssertMixin):\r
def testdistance(self):\r
s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])\r
\r
self.assert_(criterion.compare(j.onclause))\r
\r
\r
-class PrimaryKeyTest(testbase.AssertMixin):\r
+class PrimaryKeyTest(AssertMixin):\r
def test_join_pk_collapse_implicit(self):\r
"""test that redundant columns in a join get 'collapsed' into a minimal primary key, \r
which is the root column along a chain of foreign key relationships."""\r
-from testbase import PersistTest, AssertMixin
import testbase
import pickleable
+import datetime, os
from sqlalchemy import *
-import string,datetime, re, sys, os
import sqlalchemy.engine.url as url
-import sqlalchemy.types
from sqlalchemy.databases import mssql, oracle, mysql
-from testbase import Table, Column
+from testlib import *
-db = testbase.db
-
class MyType(types.TypeEngine):
def get_col_spec(self):
return "VARCHAR(100)"
def setUpAll(self):
global users
- users = Table('type_users', MetaData(db),
+ users = Table('type_users', MetaData(testbase.db),
Column('user_id', Integer, primary_key = True),
# totall custom type
Column('goofy', MyType, nullable = False),
'float_column': 'float_column NUMERIC(25, 2)'
}
+ db = testbase.db
if not db.name=='sqlite' and not db.name=='oracle':
expectedResults['float_column'] = 'float_column FLOAT(25)'
"""tests the Unicode type. also tests the TypeDecorator with instances in the types package."""
def setUpAll(self):
global unicode_table
- metadata = MetaData(db)
+ metadata = MetaData(testbase.db)
unicode_table = Table('unicode_table', metadata,
Column('id', Integer, Sequence('uni_id_seq', optional=True), primary_key=True),
Column('unicode_varchar', Unicode(250)),
unicode_text=unicodedata,
plain_varchar=rawdata)
x = unicode_table.select().execute().fetchone()
- self.echo(repr(x['unicode_varchar']))
- self.echo(repr(x['unicode_text']))
- self.echo(repr(x['plain_varchar']))
+ print repr(x['unicode_varchar'])
+ print repr(x['unicode_text'])
+ print repr(x['plain_varchar'])
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
if isinstance(x['plain_varchar'], unicode):
# SQLLite and MSSQL return non-unicode data as unicode
- self.assert_(db.name in ('sqlite', 'mssql'))
+ self.assert_(testbase.db.name in ('sqlite', 'mssql'))
self.assert_(x['plain_varchar'] == unicodedata)
- self.echo("it's %s!" % db.name)
+ print "it's %s!" % testbase.db.name
else:
self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
def testengineparam(self):
"""tests engine-wide unicode conversion"""
- prev_unicode = db.engine.dialect.convert_unicode
+ prev_unicode = testbase.db.engine.dialect.convert_unicode
try:
- db.engine.dialect.convert_unicode = True
+ testbase.db.engine.dialect.convert_unicode = True
rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
unicodedata = rawdata.decode('utf-8')
unicode_table.insert().execute(unicode_varchar=unicodedata,
unicode_text=unicodedata,
plain_varchar=rawdata)
x = unicode_table.select().execute().fetchone()
- self.echo(repr(x['unicode_varchar']))
- self.echo(repr(x['unicode_text']))
- self.echo(repr(x['plain_varchar']))
+ print repr(x['unicode_varchar'])
+ print repr(x['unicode_text'])
+ print repr(x['plain_varchar'])
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata)
finally:
- db.engine.dialect.convert_unicode = prev_unicode
+ testbase.db.engine.dialect.convert_unicode = prev_unicode
- @testbase.unsupported('oracle')
+ @testing.unsupported('oracle')
def testlength(self):
"""checks the database correctly understands the length of a unicode string"""
teststr = u'aaa\x1234'
- self.assert_(db.func.length(teststr).scalar() == len(teststr))
+ self.assert_(testbase.db.func.length(teststr).scalar() == len(teststr))
class BinaryTest(AssertMixin):
def setUpAll(self):
global binary_table
- binary_table = Table('binary_table', MetaData(db),
+ binary_table = Table('binary_table', MetaData(testbase.db),
Column('primary_id', Integer, Sequence('binary_id_seq', optional=True), primary_key=True),
Column('data', Binary),
Column('data_slice', Binary(100)),
def setUpAll(self):
global users_with_date, insert_data
+ db = testbase.db
if db.engine.name == 'oracle':
import sqlalchemy.databases.oracle as oracle
insert_data = [
collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime(timezone=False)),
Column('user_date', Date), Column('user_time', Time)]
- users_with_date = Table('query_users_with_date', MetaData(db), *collist)
+ users_with_date = Table('query_users_with_date',
+ MetaData(testbase.db), *collist)
users_with_date.create()
insert_dicts = [dict(zip(fnames, d)) for d in insert_data]
def testtextdate(self):
- x = db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall()
+ x = testbase.db.text("select user_datetime from query_users_with_date", typemap={'user_datetime':DateTime}).execute().fetchall()
print repr(x)
self.assert_(isinstance(x[0][0], datetime.datetime))
#print repr(x)
def testdate2(self):
- t = Table('testdate', testbase.metadata, Column('id', Integer, Sequence('datetest_id_seq', optional=True), primary_key=True),
+ meta = MetaData(testbase.db)
+ t = Table('testdate', meta,
+ Column('id', Integer,
+ Sequence('datetest_id_seq', optional=True),
+ primary_key=True),
Column('adate', Date), Column('adatetime', DateTime))
t.create()
try:
import testbase
from sqlalchemy import *
from sqlalchemy.orm import mapper, relation, create_session, eagerload
-from testbase import Table, Column
+from testlib import *
-class UnicodeSchemaTest(testbase.PersistTest):
+class UnicodeSchemaTest(PersistTest):
def setUpAll(self):
global unicode_bind, metadata, t1, t2
-"""base import for all test cases. Patches in enhancements to unittest.TestCase,
-instruments SQLAlchemy dialect/engine to track SQL statements for assertion purposes,
-provides base test classes for common test scenarios."""
+"""First import for all test cases, sets sys.path and loads configuration."""
-import sys
-import coverage
+__all__ = 'db',
-import os, unittest, StringIO, re, ConfigParser, optparse
+import sys, os, logging
sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
+logging.basicConfig()
-db = None
-metadata = None
-db_uri = None
-echo = True
-table_options = {}
-
-# redefine sys.stdout so all those print statements go to the echo func
-local_stdout = sys.stdout
-class Logger(object):
- def write(self, msg):
- if echo:
- local_stdout.write(msg)
- def flush(self):
- pass
-
-def echo_text(text):
- print text
-
-def parse_argv():
- # we are using the unittest main runner, so we are just popping out the
- # arguments we need instead of using our own getopt type of thing
- global db, db_uri, metadata
-
- DBTYPE = 'sqlite'
- PROXY = False
-
- base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:s@localhost/tmp/test.fdb
-"""
- config = ConfigParser.ConfigParser()
- config.readfp(StringIO.StringIO(base_config))
- config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-
- parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
- parser.add_option("--dburi", action="store", dest="dburi", help="database uri (overrides --db)")
- parser.add_option("--db", action="store", dest="db", default="sqlite", help="prefab database uri (%s)" % ', '.join(config.options('db')))
- parser.add_option("--mockpool", action="store_true", dest="mockpool", help="use mock pool (asserts only one connection used)")
- parser.add_option("--verbose", action="store_true", dest="verbose", help="enable stdout echoing/printing")
- parser.add_option("--quiet", action="store_true", dest="quiet", help="suppress unittest output")
- parser.add_option("--log-info", action="append", dest="log_info", help="turn on info logging for <LOG> (multiple OK)")
- parser.add_option("--log-debug", action="append", dest="log_debug", help="turn on debug logging for <LOG> (multiple OK)")
- parser.add_option("--enginestrategy", action="store", default=None, dest="enginestrategy", help="engine strategy (plain or threadlocal, defaults to plain)")
- parser.add_option("--coverage", action="store_true", dest="coverage", help="Dump a full coverage report after running")
- parser.add_option("--reversetop", action="store_true", dest="topological", help="Reverse the collection ordering for topological sorts (helps reveal dependency issues)")
- parser.add_option("--serverside", action="store_true", dest="serverside", help="Turn on server side cursors for PG")
- parser.add_option("--require", action="append", dest="require", help="Require a particular driver or module version", default=[])
- parser.add_option("--mysql-engine", action="store", dest="mysql_engine", help="Use the specified MySQL storage engine for all tables, default is a db-default/InnoDB combo.", default=None)
- parser.add_option("--table-option", action="append", dest="tableopts", help="Add a dialect-specific table option, key=value", default=[])
-
- (options, args) = parser.parse_args()
- sys.argv[1:] = args
-
- if options.dburi:
- db_uri = param = options.dburi
- DBTYPE = db_uri[:db_uri.index(':')]
- elif options.db:
- DBTYPE = param = options.db
-
- if options.require or (config.has_section('require') and
- config.items('require')):
- try:
- import pkg_resources
- except ImportError:
- raise "setuptools is required for version requirements"
-
- cmdline = []
- for requirement in options.require:
- pkg_resources.require(requirement)
- cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
-
- if config.has_section('require'):
- for label, requirement in config.items('require'):
- if not label == DBTYPE or label.startswith('%s.' % DBTYPE):
- continue
- seen = [c for c in cmdline if requirement.startswith(c)]
- if seen:
- continue
- pkg_resources.require(requirement)
-
- opts = {}
- if (None == db_uri):
- if DBTYPE not in config.options('db'):
- raise ("Could not create engine. specify --db <%s> to "
- "test runner." % '|'.join(config.options('db')))
-
- db_uri = config.get('db', DBTYPE)
-
- if not db_uri:
- raise "Could not create engine. specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql|firebird> to test runner."
-
- global table_options
- for spec in options.tableopts:
- key, value = spec.split('=')
- table_options[key] = value
-
- if options.mysql_engine:
- table_options['mysql_engine'] = options.mysql_engine
-
-
- global echo
- echo = options.verbose and not options.quiet
-
- global quiet
- quiet = options.quiet
-
- global with_coverage
- with_coverage = options.coverage
- if with_coverage:
- coverage.erase()
- coverage.start()
-
- from sqlalchemy import engine, schema
-
- if options.serverside:
- opts['server_side_cursors'] = True
-
- if options.enginestrategy is not None:
- opts['strategy'] = options.enginestrategy
- if options.mockpool:
- db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts)
- else:
- db = engine.create_engine(db_uri, **opts)
-
- # decorate the dialect's create_execution_context() method
- # to produce a wrapper
- create_context = db.dialect.create_execution_context
- def create_exec_context(*args, **kwargs):
- return ExecutionContextWrapper(create_context(*args, **kwargs))
- db.dialect.create_execution_context = create_exec_context
-
- if options.topological:
- from sqlalchemy.orm import unitofwork
- from sqlalchemy import topological
- class RevQueueDepSort(topological.QueueDependencySorter):
- def __init__(self, tuples, allitems):
- self.tuples = list(tuples)
- self.allitems = list(allitems)
- self.tuples.reverse()
- self.allitems.reverse()
- topological.QueueDependencySorter = RevQueueDepSort
- unitofwork.DependencySorter = RevQueueDepSort
-
- import logging
- logging.basicConfig()
- if options.log_info is not None:
- for elem in options.log_info:
- logging.getLogger(elem).setLevel(logging.INFO)
- if options.log_debug is not None:
- for elem in options.log_debug:
- logging.getLogger(elem).setLevel(logging.DEBUG)
- metadata = schema.MetaData(db)
-
-def unsupported(*dbs):
- """a decorator that marks a test as unsupported by one or more database implementations"""
-
- def decorate(func):
- name = db.name
- for d in dbs:
- if d == name:
- def lala(self):
- echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'")
- lala.__name__ = func.__name__
- return lala
- else:
- return func
- return decorate
-
-def supported(*dbs):
- """a decorator that marks a test as supported by one or more database implementations"""
-
- def decorate(func):
- name = db.name
- for d in dbs:
- if d == name:
- return func
- else:
- def lala(self):
- echo_text("'" + func.__name__ + "' unsupported on DB implementation '" + name + "'")
- lala.__name__ = func.__name__
- return lala
- return decorate
-
-def Table(*args, **kw):
- """A schema.Table wrapper/hook for dialect-specific tweaks."""
-
- test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
- if k.startswith('test_')])
-
- kw.update(table_options)
-
- if db.engine.name == 'mysql':
- if 'mysql_engine' not in kw and 'mysql_type' not in kw:
- if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
- kw['mysql_engine'] = 'InnoDB'
-
- return schema.Table(*args, **kw)
-
-def Column(*args, **kw):
- """A schema.Column wrapper/hook for dialect-specific tweaks."""
-
- # TODO: a Column that creates a Sequence automatically for PK columns,
- # which would help Oracle tests
- return schema.Column(*args, **kw)
-
-class PersistTest(unittest.TestCase):
-
- def __init__(self, *args, **params):
- unittest.TestCase.__init__(self, *args, **params)
-
- def echo(self, text):
- """DEPRECATED. use print <statement>"""
- echo_text(text)
-
-
- def setUpAll(self):
- pass
- def tearDownAll(self):
- pass
-
- def shortDescription(self):
- """overridden to not return docstrings"""
- return None
-
-class AssertMixin(PersistTest):
- """given a list-based structure of keys/properties which represent information within an object structure, and
- a list of actual objects, asserts that the list of objects corresponds to the structure."""
-
- def assert_result(self, result, class_, *objects):
- result = list(result)
- if echo:
- print repr(result)
- self.assert_list(result, class_, objects)
-
- def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list), "result list is not the same size as test list, for class " + class_.__name__)
- for i in range(0, len(list)):
- self.assert_row(class_, result[i], list[i])
-
- def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_))
- for key, value in desc.iteritems():
- if isinstance(value, tuple):
- if isinstance(value[1], list):
- self.assert_list(getattr(rowobj, key), value[0], value[1])
- else:
- self.assert_row(value[0], getattr(rowobj, key), value[1])
- else:
- self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
-
- def assert_sql(self, db, callable_, list, with_sequences=None):
- global testdata
- testdata = TestData()
- if with_sequences is not None and (db.engine.name == 'postgres' or db.engine.name == 'oracle'):
- testdata.set_assert_list(self, with_sequences)
- else:
- testdata.set_assert_list(self, list)
- try:
- callable_()
- finally:
- testdata.set_assert_list(None, None)
-
- def assert_sql_count(self, db, callable_, count):
- global testdata
- testdata = TestData()
- try:
- callable_()
- finally:
- self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % (count, testdata.sql_count))
-
- def capture_sql(self, db, callable_):
- global testdata
- testdata = TestData()
- buffer = StringIO.StringIO()
- testdata.buffer = buffer
- try:
- callable_()
- return buffer.getvalue()
- finally:
- testdata.buffer = None
-
-class ORMTest(AssertMixin):
- keep_mappers = False
- keep_data = False
- def setUpAll(self):
- global _otest_metadata
- _otest_metadata = MetaData(db)
- self.define_tables(_otest_metadata)
- _otest_metadata.create_all()
- self.insert_data()
- def define_tables(self, _otest_metadata):
- raise NotImplementedError()
- def insert_data(self):
- pass
- def get_metadata(self):
- return _otest_metadata
- def tearDownAll(self):
- clear_mappers()
- _otest_metadata.drop_all()
- def tearDown(self):
- if not self.keep_mappers:
- clear_mappers()
- if not self.keep_data:
- for t in _otest_metadata.table_iterator(reverse=True):
- t.delete().execute().close()
-
-class TestData(object):
- """tracks SQL expressions as theyre executed via an instrumented ExecutionContext."""
-
- def __init__(self):
- self.set_assert_list(None, None)
- self.sql_count = 0
- self.buffer = None
-
- def set_assert_list(self, unittest, list):
- self.unittest = unittest
- self.assert_list = list
- if list is not None:
- self.assert_list.reverse()
-
-testdata = TestData()
-
-class ExecutionContextWrapper(object):
- """instruments the ExecutionContext created by the Engine so that SQL expressions
- can be tracked."""
-
- def __init__(self, ctx):
- self.__dict__['ctx'] = ctx
- def __getattr__(self, key):
- return getattr(self.ctx, key)
- def __setattr__(self, key, value):
- setattr(self.ctx, key, value)
-
- def post_exec(self):
- ctx = self.ctx
- statement = unicode(ctx.compiled)
- statement = re.sub(r'\n', '', ctx.statement)
- if testdata.buffer is not None:
- testdata.buffer.write(statement + "\n")
-
- if testdata.assert_list is not None:
- assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
- item = testdata.assert_list[-1]
- if not isinstance(item, dict):
- item = testdata.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if not item.has_key('_converted'):
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- testdata.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
-
- (query, params) = item
- if callable(params):
- params = params(ctx)
- if params is not None and isinstance(params, list) and len(params) == 1:
- params = params[0]
-
- if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
- parameters = ctx.compiled_parameters.get_original_dict()
- elif isinstance(ctx.compiled_parameters, list):
- parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
-
- query = self.convert_statement(query)
- if db.engine.name == 'mssql' and statement.endswith('; select scope_identity()'):
- statement = statement[:-25]
- testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- testdata.sql_count += 1
- self.ctx.post_exec()
-
- def convert_statement(self, query):
- paramstyle = self.ctx.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle =='pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle=='qmark':
- repl = "?"
- elif paramstyle=='format':
- repl = r"%s"
- elif paramstyle=='numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
- return query
-
-class TTestSuite(unittest.TestSuite):
- """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality"""
- def __init__(self, tests=()):
- if len(tests) >0 and isinstance(tests[0], PersistTest):
- self._initTest = tests[0]
- else:
- self._initTest = None
- unittest.TestSuite.__init__(self, tests)
-
- def do_run(self, result):
- """nice job unittest ! you switched __call__ and run() between py2.3 and 2.4 thereby
- making straight subclassing impossible !"""
- for test in self._tests:
- if result.shouldStop:
- break
- test(result)
- return result
-
- def run(self, result):
- return self(result)
-
- def __call__(self, result):
- try:
- if self._initTest is not None:
- self._initTest.setUpAll()
- except:
- result.addError(self._initTest, self.__exc_info())
- pass
- try:
- return self.do_run(result)
- finally:
- try:
- if self._initTest is not None:
- self._initTest.tearDownAll()
- except:
- result.addError(self._initTest, self.__exc_info())
- pass
-
- def __exc_info(self):
- """Return a version of sys.exc_info() with the traceback frame
- minimised; usually the top level of the traceback frame is not
- needed.
- ripped off out of unittest module since its double __
- """
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
-
-unittest.TestLoader.suiteClass = TTestSuite
-
-parse_argv()
-
-import sqlalchemy
-from sqlalchemy import schema, MetaData, sql
-from sqlalchemy.orm import clear_mappers
-
-def runTests(suite):
- sys.stdout = Logger()
- runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
- try:
- return runner.run(suite)
- finally:
- if with_coverage:
- global echo
- echo=True
- coverage.stop()
- coverage.report(list(covered_files()), show_missing=False)
-
-def covered_files():
- for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
- for x in rec[2]:
- if x.endswith('.py'):
- yield os.path.join(rec[0], x)
-
-def main(suite=None):
-
- if not suite:
- if len(sys.argv[1:]):
- suite =unittest.TestLoader().loadTestsFromNames(sys.argv[1:], __import__('__main__'))
- else:
- suite = unittest.TestLoader().loadTestsFromModule(__import__('__main__'))
-
- result = runTests(suite)
- sys.exit(not result.wasSuccessful())
+import testlib.config
+testlib.config.configure()
+from testlib.testing import main
+db = testlib.config.db
--- /dev/null
+"""Enhance unittest and instrument SQLAlchemy classes for testing.
+
+Load after sqlalchemy imports to use instrumented stand-ins like Table.
+"""
+
+import testlib.config
+from testlib.schema import Table, Column
+import testlib.testing as testing
+from testlib.testing import PersistTest, AssertMixin, ORMTest
+import testlib.profiling
+
--- /dev/null
+import optparse, os, sys, ConfigParser, StringIO
+logging, require = None, None
+
+__all__ = 'parser', 'configure', 'options',
+
+db, db_uri, db_type, db_label = None, None, None, None
+
+options = None
+file_config = None
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:s@localhost/tmp/test.fdb
+"""
+
+parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
+
+def configure():
+ global options, config
+ global getopts_options, file_config
+
+ file_config = ConfigParser.ConfigParser()
+ file_config.readfp(StringIO.StringIO(base_config))
+ file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+
+ # Opt parsing can fire immediate actions, like logging and coverage
+ (options, args) = parser.parse_args()
+ sys.argv[1:] = args
+
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ return options, file_config
+
+def _log(option, opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+ logging.basicConfig()
+
+ if opt_str.endswith('-info'):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith('-debug'):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+def _start_coverage(option, opt_str, value, parser):
+ import sys, atexit, coverage
+ true_out = sys.stdout
+
+ def _iter_covered_files():
+ import sqlalchemy
+ for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
+ for x in rec[2]:
+ if x.endswith('.py'):
+ yield os.path.join(rec[0], x)
+ def _stop():
+ coverage.stop()
+ true_out.write("\nPreparing coverage report...\n")
+ coverage.report(list(_iter_covered_files()),
+ show_missing=False, ignore_errors=False,
+ file=true_out)
+ atexit.register(_stop)
+ coverage.erase()
+ coverage.start()
+
+def _list_dbs(*args):
+ print "Available --db options (use --dburi to override)"
+ for macro in sorted(file_config.options('db')):
+ print "%20s\t%s" % (macro, file_config.get('db', macro))
+ sys.exit(0)
+
+opt = parser.add_option
+opt("--verbose", action="store_true", dest="verbose",
+ help="enable stdout echoing/printing")
+opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+opt("--log-info", action="callback", callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)")
+opt("--log-debug", action="callback", callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)")
+opt("--require", action="append", dest="require", default=[],
+ help="require a particular driver or module version (multiple OK)")
+opt("--db", action="store", dest="db", default="sqlite",
+ help="Use prefab database uri")
+opt('--dbs', action='callback', callback=_list_dbs,
+ help="List available prefab dbs")
+opt("--dburi", action="store", dest="dburi",
+ help="Database uri (overrides --db)")
+opt("--mockpool", action="store_true", dest="mockpool",
+ help="Use mock pool (asserts only one connection used)")
+opt("--enginestrategy", action="store", dest="enginestrategy", default=None,
+ help="Engine strategy (plain or threadlocal, defaults toplain)")
+opt("--reversetop", action="store_true", dest="reversetop", default=False,
+ help="Reverse the collection ordering for topological sorts (helps "
+ "reveal dependency issues)")
+opt("--serverside", action="store_true", dest="serverside",
+ help="Turn on server side cursors for PG")
+opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+ help="Use the specified MySQL storage engine for all tables, default is "
+ "a db-default/InnoDB combo.")
+opt("--table-option", action="append", dest="tableopts", default=[],
+ help="Add a dialect-specific table option, key=value")
+opt("--coverage", action="callback", callback=_start_coverage,
+ help="Dump a full coverage report after running tests")
+opt("--profile", action="append", dest="profile_targets", default=[],
+ help="Enable a named profile target (multiple OK.)")
+opt("--profile-sort", action="store", dest="profile_sort", default=None,
+ help="Sort profile stats with this comma-separated sort order")
+opt("--profile-limit", type="int", action="store", dest="profile_limit",
+ default=None,
+ help="Limit function count in profile stats")
+
+class _ordered_map(object):
+ def __init__(self):
+ self._keys = list()
+ self._data = dict()
+
+ def __setitem__(self, key, value):
+ if key not in self._keys:
+ self._keys.append(key)
+ self._data[key] = value
+
+ def __iter__(self):
+ for key in self._keys:
+ yield self._data[key]
+
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+ global db_label, db_uri
+ db_label = 'sqlite'
+ if options.dburi:
+ db_uri = options.dburi
+ db_label = db_uri[:db_uri.index(':')]
+ elif options.db:
+ db_label = options.db
+ db_uri = None
+
+ if db_uri is None:
+ if db_label not in file_config.options('db'):
+ raise RuntimeError(
+ "Unknown engine. Specify --dbs for known engines.")
+ db_uri = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+ if not(options.require or
+ (file_config.has_section('require') and
+ file_config.items('require'))):
+ return
+
+ try:
+ import pkg_resources
+ except ImportError:
+ raise RuntimeError("setuptools is required for version requirements")
+
+ cmdline = []
+ for requirement in options.require:
+ pkg_resources.require(requirement)
+ cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+ if file_config.has_section('require'):
+ for label, requirement in file_config.items('require'):
+ if not label == db_label or label.startswith('%s.' % db_label):
+ continue
+ seen = [c for c in cmdline if requirement.startswith(c)]
+ if seen:
+ continue
+ pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _create_testing_engine(options, file_config):
+ from sqlalchemy import engine
+ global db, db_type
+ engine_opts = {}
+ if options.serverside:
+ engine_opts['server_side_cursors'] = True
+
+ if options.enginestrategy is not None:
+ engine_opts['strategy'] = options.enginestrategy
+
+ if options.mockpool:
+ db = engine.create_engine(db_uri, poolclass=pool.AssertionPool,
+ **engine_opts)
+ else:
+ db = engine.create_engine(db_uri, **engine_opts)
+ db_type = db.name
+
+ # decorate the dialect's create_execution_context() method
+ # to produce a wrapper
+ from testlib.testing import ExecutionContextWrapper
+
+ create_context = db.dialect.create_execution_context
+ def create_exec_context(*args, **kwargs):
+ return ExecutionContextWrapper(create_context(*args, **kwargs))
+ db.dialect.create_execution_context = create_exec_context
+post_configure['create_engine'] = _create_testing_engine
+
+def _set_table_options(options, file_config):
+ import testlib.schema
+
+ table_options = testlib.schema.table_options
+ for spec in options.tableopts:
+ key, value = spec.split('=')
+ table_options[key] = value
+
+ if options.mysql_engine:
+ table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm import unitofwork
+ from sqlalchemy import topological
+ class RevQueueDepSort(topological.QueueDependencySorter):
+ def __init__(self, tuples, allitems):
+ self.tuples = list(tuples)
+ self.allitems = list(allitems)
+ self.tuples.reverse()
+ self.allitems.reverse()
+ topological.QueueDependencySorter = RevQueueDepSort
+ unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+
+def _set_profile_targets(options, file_config):
+ from testlib import profiling
+
+ profile_config = profiling.profile_config
+
+ for target in options.profile_targets:
+ profile_config['targets'].add(target)
+
+ if options.profile_sort:
+ profile_config['sort'] = options.profile_sort.split(',')
+
+ if options.profile_limit:
+ profile_config['limit'] = options.profile_limit
+
+ if options.quiet:
+ profile_config['report'] = False
+
+ # magic "all" target
+ if 'all' in profiling.all_targets:
+ targets = profile_config['targets']
+ if 'all' in targets and len(targets) != 1:
+ targets.clear()
+ targets.add('all')
+post_configure['profile_targets'] = _set_profile_targets
# DAMAGE.
#
# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
-
--- /dev/null
+"""Profiling support for unit and performance tests."""
+
+import time, hotshot, hotshot.stats
+from testlib.config import parser, post_configure
+import testlib.config
+
+__all__ = 'profiled',
+
+all_targets = set()
+profile_config = { 'targets': set(),
+ 'report': True,
+ 'sort': ('time', 'calls'),
+ 'limit': None }
+
+def profiled(target, **target_opts):
+ """Optional function profiling.
+
+ @profiled('label')
+ or
+ @profiled('label', report=True, sort=('calls',), limit=20)
+
+ Enables profiling for a function when 'label' is targetted for
+ profiling. Report options can be supplied, and override the global
+ configuration and command-line options.
+ """
+
+ # manual or automatic namespacing by module would remove conflict issues
+ if target in all_targets:
+ print "Warning: redefining profile target '%s'" % target
+ all_targets.add(target)
+
+ filename = "%s.prof" % target
+
+ def decorator(fn):
+ def profiled(*args, **kw):
+ if (target not in profile_config['targets'] and
+ not target_opts.get('always', None)):
+ return fn(*args, **kw)
+
+ prof = hotshot.Profile(filename)
+ began = time.time()
+ prof.start()
+ try:
+ result = fn(*args, **kw)
+ finally:
+ prof.stop()
+ ended = time.time()
+ prof.close()
+
+ if not testlib.config.options.quiet:
+ print "Profiled target '%s', wall time: %.2f seconds" % (
+ target, ended - began)
+
+ report = target_opts.get('report', profile_config['report'])
+ if report:
+ sort_ = target_opts.get('sort', profile_config['sort'])
+ limit = target_opts.get('limit', profile_config['limit'])
+ print "Profile report for target '%s' (%s)" % (
+ target, filename)
+
+ stats = hotshot.stats.load(filename)
+ stats.sort_stats(*sort_)
+ if limit:
+ stats.print_stats(limit)
+ else:
+ stats.print_stats()
+ return result
+ try:
+ profiled.__name__ = fn.__name__
+ except:
+ pass
+ return profiled
+ return decorator
--- /dev/null
+import testbase
+from sqlalchemy import schema
+
+__all__ = 'Table', 'Column',
+
+table_options = {}
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+ if k.startswith('test_')])
+
+ kw.update(table_options)
+
+ if testbase.db.name == 'mysql':
+ if 'mysql_engine' not in kw and 'mysql_type' not in kw:
+ if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
+ kw['mysql_engine'] = 'InnoDB'
+
+ return schema.Table(*args, **kw)
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ # TODO: a Column that creates a Sequence automatically for PK columns,
+ # which would help Oracle tests
+ return schema.Column(*args, **kw)
-
-from sqlalchemy import *
-import os
import testbase
-from testbase import Table, Column
+from sqlalchemy import *
+from testlib.schema import Table, Column
-ECHO = testbase.echo
-db = testbase.db
-metadata = MetaData(db)
+metadata = MetaData()
users = Table('users', metadata,
Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
)
def create():
+ if not metadata.bind:
+ metadata.bind = testbase.db
metadata.create_all()
def drop():
+ if not metadata.bind:
+ metadata.bind = testbase.db
metadata.drop_all()
def delete():
for t in metadata.table_iterator(reverse=True):
t.delete().execute()
def user_data():
+ if not metadata.bind:
+ metadata.bind = testbase.db
users.insert().execute(
dict(user_id = 7, user_name = 'jack'),
dict(user_id = 8, user_name = 'ed'),
{'order_id' : 4, 'items':(Item, [])},
{'order_id' : 5, 'items':(Item, [])},
]
-#db.echo = True
+
--- /dev/null
+"""TestCase and TestSuite artifacts and testing decorators."""
+
+# monkeypatches unittest.TestLoader.suiteClass at import time
+
+import unittest, re, sys, os
+from cStringIO import StringIO
+from sqlalchemy import MetaData, sql
+from sqlalchemy.orm import clear_mappers
+import testlib.config as config
+
+__all__ = 'PersistTest', 'AssertMixin', 'ORMTest'
+
+def unsupported(*dbs):
+ """Mark a test as unsupported by one or more database implementations"""
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name in dbs:
+ print "'%s' unsupported on DB implementation '%s'" % (
+ fn_name, config.db.name)
+ return True
+ else:
+ return fn(*args, **kw)
+ try:
+ maybe.__name__ = fn_name
+ except:
+ pass
+ return maybe
+ return decorate
+
+def supported(*dbs):
+ """Mark a test as supported by one or more database implementations"""
+
+ def decorate(fn):
+ fn_name = fn.__name__
+ def maybe(*args, **kw):
+ if config.db.name in dbs:
+ return fn(*args, **kw)
+ else:
+ print "'%s' unsupported on DB implementation '%s'" % (
+ fn_name, config.db.name)
+ return True
+ try:
+ maybe.__name__ = fn_name
+ except:
+ pass
+ return maybe
+ return decorate
+
+class TestData(object):
+ """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
+
+ def __init__(self):
+ self.set_assert_list(None, None)
+ self.sql_count = 0
+ self.buffer = None
+
+ def set_assert_list(self, unittest, list):
+ self.unittest = unittest
+ self.assert_list = list
+ if list is not None:
+ self.assert_list.reverse()
+
+testdata = TestData()
+
+
+class ExecutionContextWrapper(object):
+ """instruments the ExecutionContext created by the Engine so that SQL expressions
+ can be tracked."""
+
+ def __init__(self, ctx):
+ self.__dict__['ctx'] = ctx
+ def __getattr__(self, key):
+ return getattr(self.ctx, key)
+ def __setattr__(self, key, value):
+ setattr(self.ctx, key, value)
+
+ def post_exec(self):
+ ctx = self.ctx
+ statement = unicode(ctx.compiled)
+ statement = re.sub(r'\n', '', ctx.statement)
+ if testdata.buffer is not None:
+ testdata.buffer.write(statement + "\n")
+
+ if testdata.assert_list is not None:
+ assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
+ item = testdata.assert_list[-1]
+ if not isinstance(item, dict):
+ item = testdata.assert_list.pop()
+ else:
+ # asserting a dictionary of statements->parameters
+ # this is to specify query assertions where the queries can be in
+ # multiple orderings
+ if not item.has_key('_converted'):
+ for key in item.keys():
+ ckey = self.convert_statement(key)
+ item[ckey] = item[key]
+ if ckey != key:
+ del item[key]
+ item['_converted'] = True
+ try:
+ entry = item.pop(statement)
+ if len(item) == 1:
+ testdata.assert_list.pop()
+ item = (statement, entry)
+ except KeyError:
+ assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
+
+ (query, params) = item
+ if callable(params):
+ params = params(ctx)
+ if params is not None and isinstance(params, list) and len(params) == 1:
+ params = params[0]
+
+ if isinstance(ctx.compiled_parameters, sql.ClauseParameters):
+ parameters = ctx.compiled_parameters.get_original_dict()
+ elif isinstance(ctx.compiled_parameters, list):
+ parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
+
+ query = self.convert_statement(query)
+ if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'):
+ statement = statement[:-25]
+ testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
+ testdata.sql_count += 1
+ self.ctx.post_exec()
+
+ def convert_statement(self, query):
+ paramstyle = self.ctx.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+ return query
+
+class PersistTest(unittest.TestCase):
+
+ def __init__(self, *args, **params):
+ unittest.TestCase.__init__(self, *args, **params)
+
+ def setUpAll(self):
+ pass
+
+ def tearDownAll(self):
+ pass
+
+ def shortDescription(self):
+ """overridden to not return docstrings"""
+ return None
+
+class AssertMixin(PersistTest):
+ """given a list-based structure of keys/properties which represent information within an object structure, and
+ a list of actual objects, asserts that the list of objects corresponds to the structure."""
+
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print repr(result)
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list):
+ self.assert_(len(result) == len(list),
+ "result list is not the same size as test list, " +
+ "for class " + class_.__name__)
+ for i in range(0, len(list)):
+ self.assert_row(class_, result[i], list[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(rowobj.__class__ is class_,
+ "item class is not " + repr(class_))
+ for key, value in desc.iteritems():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s" % (
+ key, getattr(rowobj, key), value))
+
+ def assert_sql(self, db, callable_, list, with_sequences=None):
+ global testdata
+ testdata = TestData()
+ if with_sequences is not None and (config.db.name == 'postgres' or
+ config.db.name == 'oracle'):
+ testdata.set_assert_list(self, with_sequences)
+ else:
+ testdata.set_assert_list(self, list)
+ try:
+ callable_()
+ finally:
+ testdata.set_assert_list(None, None)
+
+ def assert_sql_count(self, db, callable_, count):
+ global testdata
+ testdata = TestData()
+ try:
+ callable_()
+ finally:
+ self.assert_(testdata.sql_count == count,
+ "desired statement count %d does not match %d" % (
+ count, testdata.sql_count))
+
+ def capture_sql(self, db, callable_):
+ global testdata
+ testdata = TestData()
+ buffer = StringIO()
+ testdata.buffer = buffer
+ try:
+ callable_()
+ return buffer.getvalue()
+ finally:
+ testdata.buffer = None
+
+_otest_metadata = None
+class ORMTest(AssertMixin):
+ keep_mappers = False
+ keep_data = False
+
+ def setUpAll(self):
+ global _otest_metadata
+ _otest_metadata = MetaData(config.db)
+ self.define_tables(_otest_metadata)
+ _otest_metadata.create_all()
+ self.insert_data()
+
+ def define_tables(self, _otest_metadata):
+ raise NotImplementedError()
+
+ def insert_data(self):
+ pass
+
+ def get_metadata(self):
+ return _otest_metadata
+
+ def tearDownAll(self):
+ clear_mappers()
+ _otest_metadata.drop_all()
+
+ def tearDown(self):
+ if not self.keep_mappers:
+ clear_mappers()
+ if not self.keep_data:
+ for t in _otest_metadata.table_iterator(reverse=True):
+ t.delete().execute().close()
+
+
+class TTestSuite(unittest.TestSuite):
+ """A TestSuite with once per TestCase setUpAll() and tearDownAll()"""
+
+ def __init__(self, tests=()):
+ if len(tests) >0 and isinstance(tests[0], PersistTest):
+ self._initTest = tests[0]
+ else:
+ self._initTest = None
+ unittest.TestSuite.__init__(self, tests)
+
+ def do_run(self, result):
+ # nice job unittest ! you switched __call__ and run() between py2.3
+ # and 2.4 thereby making straight subclassing impossible !
+ for test in self._tests:
+ if result.shouldStop:
+ break
+ test(result)
+ return result
+
+ def run(self, result):
+ return self(result)
+
+ def __call__(self, result):
+ try:
+ if self._initTest is not None:
+ self._initTest.setUpAll()
+ except:
+ result.addError(self._initTest, self.__exc_info())
+ pass
+ try:
+ return self.do_run(result)
+ finally:
+ try:
+ if self._initTest is not None:
+ self._initTest.tearDownAll()
+ except:
+ result.addError(self._initTest, self.__exc_info())
+ pass
+
+ def __exc_info(self):
+ """Return a version of sys.exc_info() with the traceback frame
+ minimised; usually the top level of the traceback frame is not
+ needed.
+ ripped off out of unittest module since its double __
+ """
+ exctype, excvalue, tb = sys.exc_info()
+ if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+ return (exctype, excvalue, tb)
+ return (exctype, excvalue, tb)
+
+unittest.TestLoader.suiteClass = TTestSuite
+
+def _iter_covered_files():
+ import sqlalchemy
+ for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
+ for x in rec[2]:
+ if x.endswith('.py'):
+ yield os.path.join(rec[0], x)
+
+def cover(callable_, file_=None):
+ from testlib import coverage
+ coverage_client = coverage.the_coverage
+ coverage_client.get_ready()
+ coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
+ coverage_client.erase()
+ coverage_client.start()
+ try:
+ return callable_()
+ finally:
+ coverage_client.stop()
+ coverage_client.save()
+ coverage_client.report(list(_iter_covered_files()),
+ show_missing=False, ignore_errors=False,
+ file=file_)
+
+class DevNullWriter(object):
+ def write(self, msg):
+ pass
+ def flush(self):
+ pass
+
+def runTests(suite):
+ verbose = config.options.verbose
+ quiet = config.options.quiet
+ orig_stdout = sys.stdout
+
+ try:
+ if not verbose or quiet:
+ sys.stdout = DevNullWriter()
+ runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
+ return runner.run(suite)
+ finally:
+ if not verbose or quiet:
+ sys.stdout = orig_stdout
+
+def main(suite=None):
+ if not suite:
+ if len(sys.argv[1:]):
+ suite =unittest.TestLoader().loadTestsFromNames(
+ sys.argv[1:], __import__('__main__'))
+ else:
+ suite = unittest.TestLoader().loadTestsFromModule(
+ __import__('__main__'))
+
+ result = runTests(suite)
+ sys.exit(not result.wasSuccessful())
+"""application table metadata objects are described here."""
+
from sqlalchemy import *
-from testbase import Table, Column
+from testlib import *
+
metadata = MetaData()
-"""application table metadata objects are described here."""
users = Table('users', metadata,
Column('user_id', Integer, primary_key=True),
-from testbase import AssertMixin
import testbase
-import unittest
-db = testbase.db
from sqlalchemy import *
from sqlalchemy.orm import *
-
+from testlib import *
from zblog import mappers, tables
from zblog.user import *
from zblog.blog import *
+
class ZBlogTest(AssertMixin):
def create_tables(self):
- tables.metadata.create_all(bind=db)
+ tables.metadata.drop_all(bind=testbase.db)
+ tables.metadata.create_all(bind=testbase.db)
def drop_tables(self):
- tables.metadata.drop_all(bind=db)
+ tables.metadata.drop_all(bind=testbase.db)
def setUpAll(self):
self.create_tables()
super(SavePostTest, self).setUpAll()
mappers.zblog_mappers()
global blog_id, user_id
- s = create_session(bind=db)
+ s = create_session(bind=testbase.db)
user = User('zbloguser', "Zblog User", "hello", group=administrator)
blog = Blog(owner=user)
blog.name = "this is a blog"
"""test that a transient/pending instance has proper bi-directional behavior.
this requires that lazy loaders do not fire off for a transient/pending instance."""
- s = create_session(bind=db)
+ s = create_session(bind=testbase.db)
s.begin()
try:
def testoptimisticorphans(self):
"""test that instances in the session with un-loaded parents will not
get marked as "orphans" and then deleted """
- s = create_session(bind=db)
+ s = create_session(bind=testbase.db)
s.begin()
try:
if __name__ == "__main__":
testbase.main()
-
\ No newline at end of file
+