From fe12e56166ba6da0466fb36c2bf499005f2746d7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 22 Oct 2005 21:44:37 +0000 Subject: [PATCH] full mapper test suite works with postgres --- lib/sqlalchemy/engine.py | 8 ++- lib/sqlalchemy/util.py | 16 +++--- test/alltests.py | 8 ++- test/mapper.py | 102 +++++++++------------------------------ test/tables.py | 71 ++++++++++++++++++++++++--- test/testbase.py | 68 +++++++++++++++++++------- 6 files changed, 162 insertions(+), 111 deletions(-) diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 085c201883..8be019ea53 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -58,7 +58,7 @@ class SQLEngine(schema.SchemaEngine): (cargs, cparams) = self.connect_args() self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams) self.echo = echo - self.context = util.ThreadLocal() + self.context = util.ThreadLocal(raiseerror=False) self.tables = {} self.notes = {} self.logger = sys.stdout @@ -168,6 +168,8 @@ class SQLEngine(schema.SchemaEngine): self.do_rollback(self.context.transaction) self.context.transaction = None self.context.tcount = None + else: + self.do_rollback(self.connection()) def commit(self): if self.context.transaction is not None: @@ -177,7 +179,9 @@ class SQLEngine(schema.SchemaEngine): self.do_commit(self.context.transaction) self.context.transaction = None self.context.tcount = None - + else: + self.do_commit(self.connection()) + def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs): pass diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 07eb85846e..2498316c1d 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -101,15 +101,19 @@ class OrderedDict(dict): class ThreadLocal(object): """an object in which attribute access occurs only within the context of the current thread""" - def __init__(self): - object.__setattr__(self, 'tdict', {}) - def __getattribute__(self, key): + def __init__(self, raiseerror = True): + self.__dict__['_tdict'] = {} + self.__dict__['_raiseerror'] = raiseerror + def __getattr__(self, key): try: - return object.__getattribute__(self, 'tdict')["%d_%s" % (thread.get_ident(), key)] + return self._tdict["%d_%s" % (thread.get_ident(), key)] except KeyError: - raise AttributeError(key) + if self._raiseerror: + raise AttributeError(key) + else: + return None def __setattr__(self, key, value): - object.__getattribute__(self, 'tdict')["%d_%s" % (thread.get_ident(), key)] = value + self._tdict["%d_%s" % (thread.get_ident(), key)] = value class HashSet(object): """implements a Set.""" diff --git a/test/alltests.py b/test/alltests.py index 167bef9eaf..cee36d0f0f 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -3,8 +3,12 @@ import testbase testbase.echo = False - -modules_to_test = ('attributes', 'historyarray', 'pool', 'engines', 'query', 'types', 'mapper', 'objectstore') +def suite(): + modules_to_test = ('attributes', 'historyarray', 'pool', 'engines', 'query', 'types', 'mapper', 'objectstore') + alltests = unittest.TestSuite() + for module in map(__import__, modules_to_test): + alltests.addTest(unittest.findTestCases(module)) + return alltests if __name__ == '__main__': testbase.runTests(suite()) diff --git a/test/mapper.py b/test/mapper.py index fc3601879a..001225f9f3 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -6,68 +6,23 @@ import sqlalchemy.objectstore as objectstore from tables import * - -db.echo = testbase.echo - -itemkeywords.delete().execute() -keywords.delete().execute() -orderitems.delete().execute() -orders.delete().execute() -addresses.delete().execute() -users.delete().execute() -users.insert().execute( - dict(user_id = 7, user_name = 'jack'), - dict(user_id = 8, user_name = 'ed'), - dict(user_id = 9, user_name = 'fred') -) -addresses.insert().execute( - dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"), - dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"), - dict(address_id = 3, user_id = 8, email_address = "ed@lala.com") -) -orders.insert().execute( - dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0), - dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0), - dict(order_id = 3, user_id = 7, description = 'order 3', isopen=1), - dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1), - dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0) -) -orderitems.insert().execute( - dict(item_id=1, order_id=2, item_name='item 1'), - dict(item_id=3, order_id=3, item_name='item 3'), - dict(item_id=2, order_id=2, item_name='item 2'), - dict(item_id=5, order_id=3, item_name='item 5'), - dict(item_id=4, order_id=3, item_name='item 4') -) -keywords.insert().execute( - dict(keyword_id=1, name='blue'), - dict(keyword_id=2, name='red'), - dict(keyword_id=3, name='green'), - dict(keyword_id=4, name='big'), - dict(keyword_id=5, name='small'), - dict(keyword_id=6, name='round'), - dict(keyword_id=7, name='square') -) -itemkeywords.insert().execute( - dict(keyword_id=2, item_id=1), - dict(keyword_id=2, item_id=2), - dict(keyword_id=4, item_id=1), - dict(keyword_id=6, item_id=1), - dict(keyword_id=7, item_id=2), - dict(keyword_id=6, item_id=3), - dict(keyword_id=3, item_id=3), - dict(keyword_id=5, item_id=2), - dict(keyword_id=4, item_id=3) -) - -db.connection().commit() - - -class MapperTest(AssertMixin): - +import tables + +class MapperSuperTest(AssertMixin): + def setUpAll(self): + db.echo = False + tables.create() + tables.data() + db.echo = testbase.echo + def tearDownAll(self): + db.echo = False + tables.drop() + db.echo = testbase.echo def setUp(self): objectstore.clear() + +class MapperTest(MapperSuperTest): def testget(self): m = mapper(User, users) self.assert_(m.get(19) is None) @@ -116,10 +71,7 @@ class MapperTest(AssertMixin): {'user_id' : 9, 'addresses' : (Address, [])} ) -class PropertyTest(AssertMixin): - def setUp(self): - objectstore.clear() - +class PropertyTest(MapperSuperTest): def testbasic(self): """tests that you can create mappers inline with class definitions""" class _Address(object): @@ -154,9 +106,7 @@ class PropertyTest(AssertMixin): self.echo(repr(AddressUser.mapper.select(AddressUser.c.user_name == 'jack'))) -class LazyTest(AssertMixin): - def setUp(self): - objectstore.clear() +class LazyTest(MapperSuperTest): def testbasic(self): """tests a basic one-to-many lazy load""" @@ -206,11 +156,7 @@ class LazyTest(AssertMixin): {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, ) -class EagerTest(AssertMixin): - - def setUp(self): - objectstore.clear() - +class EagerTest(MapperSuperTest): def testbasic(self): """tests a basic one-to-many eager load""" @@ -365,16 +311,16 @@ class EagerTest(AssertMixin): l = m.select(order_by=[items.c.item_id, keywords.c.keyword_id]) self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, - {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])}, - {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])}, + {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])}, + {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 6,'name':'round'}, {'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}])}, {'item_id' : 4, 'keywords' : (Keyword, [])}, {'item_id' : 5, 'keywords' : (Keyword, [])} ) - l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id), order_by=[items.c.item_id, keywords.c.keyword_id]) + l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id)) self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, - {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, + {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])}, ) def testoneandmany(self): @@ -396,12 +342,12 @@ class EagerTest(AssertMixin): {'item_id':2, 'item_name':'item 2','keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])} ])}, {'order_id' : 3, 'items': (Item, [ - {'item_id':3, 'item_name':'item 3'}, + {'item_id':3, 'item_name':'item 3', 'keywords' : (Keyword, [{'keyword_id' : 6, 'name':'round'}, {'keyword_id' : 3, 'name':'green'}, {'keyword_id' : 4, 'name':'big'}])}, {'item_id':4, 'item_name':'item 4'}, {'item_id':5, 'item_name':'item 5'} ])}, ) -if __name__ == "__main__": - testbase.runTests() +if __name__ == "__main__": + testbase.main() diff --git a/test/tables.py b/test/tables.py index a5e8bdbeec..1dee54236b 100644 --- a/test/tables.py +++ b/test/tables.py @@ -9,9 +9,7 @@ import testbase __ALL__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'itemkeywords'] ECHO = testbase.echo -DATA = True -#CREATE = False -CREATE = True + DBTYPE = 'sqlite_memory' #DBTYPE = 'postgres' #DBTYPE = 'sqlite_file' @@ -62,15 +60,76 @@ itemkeywords = Table('itemkeywords', db, Column('keyword_id', INT, ForeignKey("keywords")) ) -if CREATE: +def create(): users.create() addresses.create() orders.create() orderitems.create() keywords.create() itemkeywords.create() - - + +def drop(): + itemkeywords.drop() + keywords.drop() + orderitems.drop() + orders.drop() + addresses.drop() + users.drop() + +def data(): + itemkeywords.delete().execute() + keywords.delete().execute() + orderitems.delete().execute() + orders.delete().execute() + addresses.delete().execute() + users.delete().execute() + users.insert().execute( + dict(user_id = 7, user_name = 'jack'), + dict(user_id = 8, user_name = 'ed'), + dict(user_id = 9, user_name = 'fred') + ) + addresses.insert().execute( + dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"), + dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"), + dict(address_id = 3, user_id = 8, email_address = "ed@lala.com") + ) + orders.insert().execute( + dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0), + dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0), + dict(order_id = 3, user_id = 7, description = 'order 3', isopen=1), + dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1), + dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0) + ) + orderitems.insert().execute( + dict(item_id=1, order_id=2, item_name='item 1'), + dict(item_id=3, order_id=3, item_name='item 3'), + dict(item_id=2, order_id=2, item_name='item 2'), + dict(item_id=5, order_id=3, item_name='item 5'), + dict(item_id=4, order_id=3, item_name='item 4') + ) + keywords.insert().execute( + dict(keyword_id=1, name='blue'), + dict(keyword_id=2, name='red'), + dict(keyword_id=3, name='green'), + dict(keyword_id=4, name='big'), + dict(keyword_id=5, name='small'), + dict(keyword_id=6, name='round'), + dict(keyword_id=7, name='square') + ) + itemkeywords.insert().execute( + dict(keyword_id=2, item_id=1), + dict(keyword_id=2, item_id=2), + dict(keyword_id=4, item_id=1), + dict(keyword_id=6, item_id=1), + dict(keyword_id=7, item_id=2), + dict(keyword_id=6, item_id=3), + dict(keyword_id=3, item_id=3), + dict(keyword_id=5, item_id=2), + dict(keyword_id=4, item_id=3) + ) + + db.commit() + class User(object): def __init__(self): self.user_id = None diff --git a/test/testbase.py b/test/testbase.py index c48c48ef24..2ddea27e88 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,7 +1,7 @@ import unittest import StringIO import sqlalchemy.engine as engine -import re +import re, sys import sqlalchemy.databases.postgres as postgres echo = True @@ -12,7 +12,10 @@ class PersistTest(unittest.TestCase): def echo(self, text): if echo: print text - + def setUpAll(self): + pass + def tearDownAll(self): + pass class AssertMixin(PersistTest): def assert_result(self, result, class_, *objects): @@ -70,19 +73,50 @@ class EngineAssert(object): self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) return self.realexec(statement, parameters, **kwargs) - -def runTests(*modules): - for m in modules: - if m.__dict__.has_key('startUp'): - m.startUp() - s = suite(m) - runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) - runner.run(s) - if m.__dict__.has_key('tearDown'): - m.tearDown() + + +class TTestSuite(unittest.TestSuite): + def __init__(self, tests=()): + if len(tests) >0 and isinstance(tests[0], PersistTest): + self._initTest = tests[0] + else: + self._initTest = None + unittest.TestSuite.__init__(self, tests) + + def run(self, result): + try: + if self._initTest is not None: + self._initTest.setUpAll() + except: + result.addError(self._initTest, self.__exc_info()) + pass + try: + return unittest.TestSuite.run(self, result) + finally: + try: + if self._initTest is not None: + self._initTest.tearDownAll() + except: + result.addError(self._initTest, self.__exc_info()) + pass + + def __exc_info(self): + """Return a version of sys.exc_info() with the traceback frame + minimised; usually the top level of the traceback frame is not + needed. + ripped off out of unittest module since its double __ + """ + exctype, excvalue, tb = sys.exc_info() + if sys.platform[:4] == 'java': ## tracebacks look different in Jython + return (exctype, excvalue, tb) + return (exctype, excvalue, tb) + + +unittest.TestLoader.suiteClass = TTestSuite + +def runTests(suite): + runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) + runner.run(suite) -def suite(modules): - alltests = unittest.TestSuite() - for module in map(__import__, modules): - alltests.addTest(unittest.findTestCases(module)) - return alltests +def main(): + unittest.main() \ No newline at end of file -- 2.47.2