(cargs, cparams) = self.connect_args()
self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams)
self.echo = echo
- self.context = util.ThreadLocal()
+ self.context = util.ThreadLocal(raiseerror=False)
self.tables = {}
self.notes = {}
self.logger = sys.stdout
self.do_rollback(self.context.transaction)
self.context.transaction = None
self.context.tcount = None
+ else:
+ self.do_rollback(self.connection())
def commit(self):
if self.context.transaction is not None:
self.do_commit(self.context.transaction)
self.context.transaction = None
self.context.tcount = None
-
+ else:
+ self.do_commit(self.connection())
+
def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
pass
class ThreadLocal(object):
"""an object in which attribute access occurs only within the context of the current thread"""
- def __init__(self):
- object.__setattr__(self, 'tdict', {})
- def __getattribute__(self, key):
+ def __init__(self, raiseerror = True):
+ self.__dict__['_tdict'] = {}
+ self.__dict__['_raiseerror'] = raiseerror
+ def __getattr__(self, key):
try:
- return object.__getattribute__(self, 'tdict')["%d_%s" % (thread.get_ident(), key)]
+ return self._tdict["%d_%s" % (thread.get_ident(), key)]
except KeyError:
- raise AttributeError(key)
+ if self._raiseerror:
+ raise AttributeError(key)
+ else:
+ return None
def __setattr__(self, key, value):
- object.__getattribute__(self, 'tdict')["%d_%s" % (thread.get_ident(), key)] = value
+ self._tdict["%d_%s" % (thread.get_ident(), key)] = value
class HashSet(object):
"""implements a Set."""
testbase.echo = False
-
-modules_to_test = ('attributes', 'historyarray', 'pool', 'engines', 'query', 'types', 'mapper', 'objectstore')
+def suite():
+ modules_to_test = ('attributes', 'historyarray', 'pool', 'engines', 'query', 'types', 'mapper', 'objectstore')
+ alltests = unittest.TestSuite()
+ for module in map(__import__, modules_to_test):
+ alltests.addTest(unittest.findTestCases(module))
+ return alltests
if __name__ == '__main__':
testbase.runTests(suite())
from tables import *
-
-db.echo = testbase.echo
-
-itemkeywords.delete().execute()
-keywords.delete().execute()
-orderitems.delete().execute()
-orders.delete().execute()
-addresses.delete().execute()
-users.delete().execute()
-users.insert().execute(
- dict(user_id = 7, user_name = 'jack'),
- dict(user_id = 8, user_name = 'ed'),
- dict(user_id = 9, user_name = 'fred')
-)
-addresses.insert().execute(
- dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
- dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
- dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
-)
-orders.insert().execute(
- dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),
- dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0),
- dict(order_id = 3, user_id = 7, description = 'order 3', isopen=1),
- dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1),
- dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0)
-)
-orderitems.insert().execute(
- dict(item_id=1, order_id=2, item_name='item 1'),
- dict(item_id=3, order_id=3, item_name='item 3'),
- dict(item_id=2, order_id=2, item_name='item 2'),
- dict(item_id=5, order_id=3, item_name='item 5'),
- dict(item_id=4, order_id=3, item_name='item 4')
-)
-keywords.insert().execute(
- dict(keyword_id=1, name='blue'),
- dict(keyword_id=2, name='red'),
- dict(keyword_id=3, name='green'),
- dict(keyword_id=4, name='big'),
- dict(keyword_id=5, name='small'),
- dict(keyword_id=6, name='round'),
- dict(keyword_id=7, name='square')
-)
-itemkeywords.insert().execute(
- dict(keyword_id=2, item_id=1),
- dict(keyword_id=2, item_id=2),
- dict(keyword_id=4, item_id=1),
- dict(keyword_id=6, item_id=1),
- dict(keyword_id=7, item_id=2),
- dict(keyword_id=6, item_id=3),
- dict(keyword_id=3, item_id=3),
- dict(keyword_id=5, item_id=2),
- dict(keyword_id=4, item_id=3)
-)
-
-db.connection().commit()
-
-
-class MapperTest(AssertMixin):
-
+import tables
+
+class MapperSuperTest(AssertMixin):
+ def setUpAll(self):
+ db.echo = False
+ tables.create()
+ tables.data()
+ db.echo = testbase.echo
+ def tearDownAll(self):
+ db.echo = False
+ tables.drop()
+ db.echo = testbase.echo
def setUp(self):
objectstore.clear()
+
+class MapperTest(MapperSuperTest):
def testget(self):
m = mapper(User, users)
self.assert_(m.get(19) is None)
{'user_id' : 9, 'addresses' : (Address, [])}
)
-class PropertyTest(AssertMixin):
- def setUp(self):
- objectstore.clear()
-
+class PropertyTest(MapperSuperTest):
def testbasic(self):
"""tests that you can create mappers inline with class definitions"""
class _Address(object):
self.echo(repr(AddressUser.mapper.select(AddressUser.c.user_name == 'jack')))
-class LazyTest(AssertMixin):
- def setUp(self):
- objectstore.clear()
+class LazyTest(MapperSuperTest):
def testbasic(self):
"""tests a basic one-to-many lazy load"""
{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
)
-class EagerTest(AssertMixin):
-
- def setUp(self):
- objectstore.clear()
-
+class EagerTest(MapperSuperTest):
def testbasic(self):
"""tests a basic one-to-many eager load"""
l = m.select(order_by=[items.c.item_id, keywords.c.keyword_id])
self.assert_result(l, Item,
{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
- {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])},
- {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])},
+ {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])},
+ {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 6,'name':'round'}, {'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}])},
{'item_id' : 4, 'keywords' : (Keyword, [])},
{'item_id' : 5, 'keywords' : (Keyword, [])}
)
- l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id), order_by=[items.c.item_id, keywords.c.keyword_id])
+ l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
self.assert_result(l, Item,
{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
- {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
+ {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])},
)
def testoneandmany(self):
{'item_id':2, 'item_name':'item 2','keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])}
])},
{'order_id' : 3, 'items': (Item, [
- {'item_id':3, 'item_name':'item 3'},
+ {'item_id':3, 'item_name':'item 3', 'keywords' : (Keyword, [{'keyword_id' : 6, 'name':'round'}, {'keyword_id' : 3, 'name':'green'}, {'keyword_id' : 4, 'name':'big'}])},
{'item_id':4, 'item_name':'item 4'},
{'item_id':5, 'item_name':'item 5'}
])},
)
-if __name__ == "__main__":
- testbase.runTests()
+if __name__ == "__main__":
+ testbase.main()
__ALL__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'itemkeywords']
ECHO = testbase.echo
-DATA = True
-#CREATE = False
-CREATE = True
+
DBTYPE = 'sqlite_memory'
#DBTYPE = 'postgres'
#DBTYPE = 'sqlite_file'
Column('keyword_id', INT, ForeignKey("keywords"))
)
-if CREATE:
+def create():
users.create()
addresses.create()
orders.create()
orderitems.create()
keywords.create()
itemkeywords.create()
-
-
+
+def drop():
+ itemkeywords.drop()
+ keywords.drop()
+ orderitems.drop()
+ orders.drop()
+ addresses.drop()
+ users.drop()
+
+def data():
+ itemkeywords.delete().execute()
+ keywords.delete().execute()
+ orderitems.delete().execute()
+ orders.delete().execute()
+ addresses.delete().execute()
+ users.delete().execute()
+ users.insert().execute(
+ dict(user_id = 7, user_name = 'jack'),
+ dict(user_id = 8, user_name = 'ed'),
+ dict(user_id = 9, user_name = 'fred')
+ )
+ addresses.insert().execute(
+ dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
+ dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
+ dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
+ )
+ orders.insert().execute(
+ dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),
+ dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0),
+ dict(order_id = 3, user_id = 7, description = 'order 3', isopen=1),
+ dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1),
+ dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0)
+ )
+ orderitems.insert().execute(
+ dict(item_id=1, order_id=2, item_name='item 1'),
+ dict(item_id=3, order_id=3, item_name='item 3'),
+ dict(item_id=2, order_id=2, item_name='item 2'),
+ dict(item_id=5, order_id=3, item_name='item 5'),
+ dict(item_id=4, order_id=3, item_name='item 4')
+ )
+ keywords.insert().execute(
+ dict(keyword_id=1, name='blue'),
+ dict(keyword_id=2, name='red'),
+ dict(keyword_id=3, name='green'),
+ dict(keyword_id=4, name='big'),
+ dict(keyword_id=5, name='small'),
+ dict(keyword_id=6, name='round'),
+ dict(keyword_id=7, name='square')
+ )
+ itemkeywords.insert().execute(
+ dict(keyword_id=2, item_id=1),
+ dict(keyword_id=2, item_id=2),
+ dict(keyword_id=4, item_id=1),
+ dict(keyword_id=6, item_id=1),
+ dict(keyword_id=7, item_id=2),
+ dict(keyword_id=6, item_id=3),
+ dict(keyword_id=3, item_id=3),
+ dict(keyword_id=5, item_id=2),
+ dict(keyword_id=4, item_id=3)
+ )
+
+ db.commit()
+
class User(object):
def __init__(self):
self.user_id = None
import unittest
import StringIO
import sqlalchemy.engine as engine
-import re
+import re, sys
import sqlalchemy.databases.postgres as postgres
echo = True
def echo(self, text):
if echo:
print text
-
+ def setUpAll(self):
+ pass
+ def tearDownAll(self):
+ pass
class AssertMixin(PersistTest):
def assert_result(self, result, class_, *objects):
self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
return self.realexec(statement, parameters, **kwargs)
-
-def runTests(*modules):
- for m in modules:
- if m.__dict__.has_key('startUp'):
- m.startUp()
- s = suite(m)
- runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
- runner.run(s)
- if m.__dict__.has_key('tearDown'):
- m.tearDown()
+
+
+class TTestSuite(unittest.TestSuite):
+ def __init__(self, tests=()):
+ if len(tests) >0 and isinstance(tests[0], PersistTest):
+ self._initTest = tests[0]
+ else:
+ self._initTest = None
+ unittest.TestSuite.__init__(self, tests)
+
+ def run(self, result):
+ try:
+ if self._initTest is not None:
+ self._initTest.setUpAll()
+ except:
+ result.addError(self._initTest, self.__exc_info())
+ pass
+ try:
+ return unittest.TestSuite.run(self, result)
+ finally:
+ try:
+ if self._initTest is not None:
+ self._initTest.tearDownAll()
+ except:
+ result.addError(self._initTest, self.__exc_info())
+ pass
+
+ def __exc_info(self):
+ """Return a version of sys.exc_info() with the traceback frame
+ minimised; usually the top level of the traceback frame is not
+ needed.
+ ripped off out of unittest module since its double __
+ """
+ exctype, excvalue, tb = sys.exc_info()
+ if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+ return (exctype, excvalue, tb)
+ return (exctype, excvalue, tb)
+
+
+unittest.TestLoader.suiteClass = TTestSuite
+
+def runTests(suite):
+ runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
+ runner.run(suite)
-def suite(modules):
- alltests = unittest.TestSuite()
- for module in map(__import__, modules):
- alltests.addTest(unittest.findTestCases(module))
- return alltests
+def main():
+ unittest.main()
\ No newline at end of file