]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
full mapper test suite works with postgres
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Oct 2005 21:44:37 +0000 (21:44 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Oct 2005 21:44:37 +0000 (21:44 +0000)
lib/sqlalchemy/engine.py
lib/sqlalchemy/util.py
test/alltests.py
test/mapper.py
test/tables.py
test/testbase.py

index 085c201883b2dc70126df97eeaa4ab0f675a925b..8be019ea53bf738471d1fd093eccd65640ecdc57 100644 (file)
@@ -58,7 +58,7 @@ class SQLEngine(schema.SchemaEngine):
         (cargs, cparams) = self.connect_args()
         self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams)
         self.echo = echo
-        self.context = util.ThreadLocal()
+        self.context = util.ThreadLocal(raiseerror=False)
         self.tables = {}
         self.notes = {}
         self.logger = sys.stdout
@@ -168,6 +168,8 @@ class SQLEngine(schema.SchemaEngine):
             self.do_rollback(self.context.transaction)
             self.context.transaction = None
             self.context.tcount = None
+        else:
+            self.do_rollback(self.connection())
             
     def commit(self):
         if self.context.transaction is not None:
@@ -177,7 +179,9 @@ class SQLEngine(schema.SchemaEngine):
                 self.do_commit(self.context.transaction)
                 self.context.transaction = None
                 self.context.tcount = None
-
+        else:
+            self.do_commit(self.connection())
+            
     def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
         pass
 
index 07eb85846eab6d9ddb8a988f6cac9529f11a5a29..2498316c1d1eafe63757a91dfa8d4884a2b85ff4 100644 (file)
@@ -101,15 +101,19 @@ class OrderedDict(dict):
 
 class ThreadLocal(object):
     """an object in which attribute access occurs only within the context of the current thread"""
-    def __init__(self):
-        object.__setattr__(self, 'tdict', {})
-    def __getattribute__(self, key):
+    def __init__(self, raiseerror = True):
+        self.__dict__['_tdict'] = {}
+        self.__dict__['_raiseerror'] = raiseerror
+    def __getattr__(self, key):
         try:
-            return object.__getattribute__(self, 'tdict')["%d_%s" % (thread.get_ident(), key)]
+            return self._tdict["%d_%s" % (thread.get_ident(), key)]
         except KeyError:
-            raise AttributeError(key)
+            if self._raiseerror:
+                raise AttributeError(key)
+            else:
+                return None
     def __setattr__(self, key, value):
-        object.__getattribute__(self, 'tdict')["%d_%s" % (thread.get_ident(), key)] = value
+        self._tdict["%d_%s" % (thread.get_ident(), key)] = value
 
 class HashSet(object):
     """implements a Set."""
index 167bef9eafd0f2924456559394244e4434cd05a5..cee36d0f0f69e0348af138c92c8731cba6520cf9 100644 (file)
@@ -3,8 +3,12 @@ import testbase
 
 testbase.echo = False
 
-
-modules_to_test = ('attributes', 'historyarray', 'pool', 'engines', 'query', 'types', 'mapper', 'objectstore')
+def suite():
+    modules_to_test = ('attributes', 'historyarray', 'pool', 'engines', 'query', 'types', 'mapper', 'objectstore')
+    alltests = unittest.TestSuite()
+    for module in map(__import__, modules_to_test):
+        alltests.addTest(unittest.findTestCases(module))
+    return alltests
 
 if __name__ == '__main__':
     testbase.runTests(suite())
index fc3601879aec005daf4ab2dd9683535dc0cb8d76..001225f9f340d508909cf48f1596cc8717fd27b6 100644 (file)
@@ -6,68 +6,23 @@ import sqlalchemy.objectstore as objectstore
 
 
 from tables import *
-
-db.echo = testbase.echo
-
-itemkeywords.delete().execute()
-keywords.delete().execute()
-orderitems.delete().execute()
-orders.delete().execute()
-addresses.delete().execute()
-users.delete().execute()
-users.insert().execute(
-    dict(user_id = 7, user_name = 'jack'),
-    dict(user_id = 8, user_name = 'ed'),
-    dict(user_id = 9, user_name = 'fred')
-)
-addresses.insert().execute(
-    dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
-    dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
-    dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
-)
-orders.insert().execute(
-    dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),
-    dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0),
-    dict(order_id = 3, user_id = 7, description = 'order 3', isopen=1),
-    dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1),
-    dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0)
-)
-orderitems.insert().execute(
-    dict(item_id=1, order_id=2, item_name='item 1'),
-    dict(item_id=3, order_id=3, item_name='item 3'),
-    dict(item_id=2, order_id=2, item_name='item 2'),
-    dict(item_id=5, order_id=3, item_name='item 5'),
-    dict(item_id=4, order_id=3, item_name='item 4')
-)
-keywords.insert().execute(
-    dict(keyword_id=1, name='blue'),
-    dict(keyword_id=2, name='red'),
-    dict(keyword_id=3, name='green'),
-    dict(keyword_id=4, name='big'),
-    dict(keyword_id=5, name='small'),
-    dict(keyword_id=6, name='round'),
-    dict(keyword_id=7, name='square')
-)
-itemkeywords.insert().execute(
-    dict(keyword_id=2, item_id=1),
-    dict(keyword_id=2, item_id=2),
-    dict(keyword_id=4, item_id=1),
-    dict(keyword_id=6, item_id=1),
-    dict(keyword_id=7, item_id=2),
-    dict(keyword_id=6, item_id=3),
-    dict(keyword_id=3, item_id=3),
-    dict(keyword_id=5, item_id=2),
-    dict(keyword_id=4, item_id=3)
-)
-
-db.connection().commit()
-
-
-class MapperTest(AssertMixin):
-    
+import tables
+
+class MapperSuperTest(AssertMixin):
+    def setUpAll(self):
+        db.echo = False
+        tables.create()
+        tables.data()
+        db.echo = testbase.echo
+    def tearDownAll(self):
+        db.echo = False
+        tables.drop()
+        db.echo = testbase.echo
     def setUp(self):
         objectstore.clear()
 
+    
+class MapperTest(MapperSuperTest):
     def testget(self):
         m = mapper(User, users)
         self.assert_(m.get(19) is None)
@@ -116,10 +71,7 @@ class MapperTest(AssertMixin):
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
-class PropertyTest(AssertMixin):
-    def setUp(self):
-        objectstore.clear()
-
+class PropertyTest(MapperSuperTest):
     def testbasic(self):
         """tests that you can create mappers inline with class definitions"""
         class _Address(object):
@@ -154,9 +106,7 @@ class PropertyTest(AssertMixin):
         
         self.echo(repr(AddressUser.mapper.select(AddressUser.c.user_name == 'jack')))
             
-class LazyTest(AssertMixin):
-    def setUp(self):
-        objectstore.clear()
+class LazyTest(MapperSuperTest):
 
     def testbasic(self):
         """tests a basic one-to-many lazy load"""
@@ -206,11 +156,7 @@ class LazyTest(AssertMixin):
             {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
         )
 
-class EagerTest(AssertMixin):
-    
-    def setUp(self):
-        objectstore.clear()
-
+class EagerTest(MapperSuperTest):
     def testbasic(self):
         """tests a basic one-to-many eager load"""
         
@@ -365,16 +311,16 @@ class EagerTest(AssertMixin):
         l = m.select(order_by=[items.c.item_id, keywords.c.keyword_id])
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])},
-            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])},
+            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])},
+            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 6,'name':'round'}, {'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}])},
             {'item_id' : 4, 'keywords' : (Keyword, [])},
             {'item_id' : 5, 'keywords' : (Keyword, [])}
         )
         
-        l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id), order_by=[items.c.item_id, keywords.c.keyword_id])
+        l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
+            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])},
         )
     
     def testoneandmany(self):
@@ -396,12 +342,12 @@ class EagerTest(AssertMixin):
                 {'item_id':2, 'item_name':'item 2','keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])}
                ])},
             {'order_id' : 3, 'items': (Item, [
-                {'item_id':3, 'item_name':'item 3'}, 
+                {'item_id':3, 'item_name':'item 3', 'keywords' : (Keyword, [{'keyword_id' : 6, 'name':'round'}, {'keyword_id' : 3, 'name':'green'}, {'keyword_id' : 4, 'name':'big'}])}, 
                 {'item_id':4, 'item_name':'item 4'}, 
                 {'item_id':5, 'item_name':'item 5'}
                ])},
         )
         
         
-if __name__ == "__main__":
-    testbase.runTests()
+if __name__ == "__main__":    
+    testbase.main()
index a5e8bdbeec7b5b502c9209a13abb5dcc6d730fb8..1dee54236b46002784b2fdb8a2e6ffd89198ef8d 100644 (file)
@@ -9,9 +9,7 @@ import testbase
 __ALL__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'itemkeywords']
 
 ECHO = testbase.echo
-DATA = True
-#CREATE = False
-CREATE = True
+
 DBTYPE = 'sqlite_memory'
 #DBTYPE = 'postgres'
 #DBTYPE = 'sqlite_file'
@@ -62,15 +60,76 @@ itemkeywords = Table('itemkeywords', db,
     Column('keyword_id', INT, ForeignKey("keywords"))
 )
 
-if CREATE:
+def create():
     users.create()
     addresses.create()
     orders.create()
     orderitems.create()
     keywords.create()
     itemkeywords.create()
-
-
+    
+def drop():
+    itemkeywords.drop()
+    keywords.drop()
+    orderitems.drop()
+    orders.drop()
+    addresses.drop()
+    users.drop()
+
+def data():
+    itemkeywords.delete().execute()
+    keywords.delete().execute()
+    orderitems.delete().execute()
+    orders.delete().execute()
+    addresses.delete().execute()
+    users.delete().execute()
+    users.insert().execute(
+        dict(user_id = 7, user_name = 'jack'),
+        dict(user_id = 8, user_name = 'ed'),
+        dict(user_id = 9, user_name = 'fred')
+    )
+    addresses.insert().execute(
+        dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
+        dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
+        dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
+    )
+    orders.insert().execute(
+        dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),
+        dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0),
+        dict(order_id = 3, user_id = 7, description = 'order 3', isopen=1),
+        dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1),
+        dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0)
+    )
+    orderitems.insert().execute(
+        dict(item_id=1, order_id=2, item_name='item 1'),
+        dict(item_id=3, order_id=3, item_name='item 3'),
+        dict(item_id=2, order_id=2, item_name='item 2'),
+        dict(item_id=5, order_id=3, item_name='item 5'),
+        dict(item_id=4, order_id=3, item_name='item 4')
+    )
+    keywords.insert().execute(
+        dict(keyword_id=1, name='blue'),
+        dict(keyword_id=2, name='red'),
+        dict(keyword_id=3, name='green'),
+        dict(keyword_id=4, name='big'),
+        dict(keyword_id=5, name='small'),
+        dict(keyword_id=6, name='round'),
+        dict(keyword_id=7, name='square')
+    )
+    itemkeywords.insert().execute(
+        dict(keyword_id=2, item_id=1),
+        dict(keyword_id=2, item_id=2),
+        dict(keyword_id=4, item_id=1),
+        dict(keyword_id=6, item_id=1),
+        dict(keyword_id=7, item_id=2),
+        dict(keyword_id=6, item_id=3),
+        dict(keyword_id=3, item_id=3),
+        dict(keyword_id=5, item_id=2),
+        dict(keyword_id=4, item_id=3)
+    )
+
+    db.commit()
+    
 class User(object):
     def __init__(self):
         self.user_id = None
index c48c48ef243fff0608a7848eea18050c6fd22e26..2ddea27e882f5f1baadeea24aa392e2418a900c0 100644 (file)
@@ -1,7 +1,7 @@
 import unittest
 import StringIO
 import sqlalchemy.engine as engine
-import re
+import re, sys
 import sqlalchemy.databases.postgres as postgres
 
 echo = True
@@ -12,7 +12,10 @@ class PersistTest(unittest.TestCase):
     def echo(self, text):
         if echo:
             print text
-
+    def setUpAll(self):
+        pass
+    def tearDownAll(self):
+        pass
 
 class AssertMixin(PersistTest):
     def assert_result(self, result, class_, *objects):
@@ -70,19 +73,50 @@ class EngineAssert(object):
 
             self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         return self.realexec(statement, parameters, **kwargs)
-        
-def runTests(*modules):
-    for m in modules:
-        if m.__dict__.has_key('startUp'):
-            m.startUp()
-        s = suite(m)
-        runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
-        runner.run(s)
-        if m.__dict__.has_key('tearDown'):
-            m.tearDown()
+
+
+class TTestSuite(unittest.TestSuite):
+        def __init__(self, tests=()):
+            if len(tests) >0 and isinstance(tests[0], PersistTest):
+                self._initTest = tests[0]
+            else:
+                self._initTest = None
+            unittest.TestSuite.__init__(self, tests)
+
+        def run(self, result):
+            try:
+                if self._initTest is not None:
+                    self._initTest.setUpAll()
+            except:
+                result.addError(self._initTest, self.__exc_info())
+                pass
+            try:
+                return unittest.TestSuite.run(self, result)
+            finally:
+                try:
+                    if self._initTest is not None:
+                        self._initTest.tearDownAll()
+                except:
+                    result.addError(self._initTest, self.__exc_info())
+                    pass
+
+        def __exc_info(self):
+            """Return a version of sys.exc_info() with the traceback frame
+               minimised; usually the top level of the traceback frame is not
+               needed.
+               ripped off out of unittest module since its double __
+            """
+            exctype, excvalue, tb = sys.exc_info()
+            if sys.platform[:4] == 'java': ## tracebacks look different in Jython
+                return (exctype, excvalue, tb)
+            return (exctype, excvalue, tb)
+
+
+unittest.TestLoader.suiteClass = TTestSuite
+                    
+def runTests(suite):
+    runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
+    runner.run(suite)
     
-def suite(modules):
-    alltests = unittest.TestSuite()
-    for module in map(__import__, modules):
-        alltests.addTest(unittest.findTestCases(module))
-    return alltests
+def main():
+    unittest.main()
\ No newline at end of file