]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added testing.fails_on('db') failure-asserter.
authorJason Kirtland <jek@discorporate.us>
Mon, 5 Nov 2007 22:02:00 +0000 (22:02 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 5 Nov 2007 22:02:00 +0000 (22:02 +0000)
test/testlib/testing.py

index 941ddf2497b3553a86b632b6737f76cf5f49b250..6342ce898de2699daa28172f4b7eb06c89fec302 100644 (file)
@@ -23,7 +23,7 @@ _ops = { '<': operator.lt,
 
 def unsupported(*dbs):
     """Mark a test as unsupported by one or more database implementations"""
-    
+
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
@@ -40,9 +40,42 @@ def unsupported(*dbs):
         return maybe
     return decorate
 
+def fails_on(*dbs):
+    """Mark a test as expected to fail on one or more database implementations.
+
+    Unlike ``unsupported``, tests marked as ``fails_on`` will be run
+    for the named databases.  The test is expected to fail and the unit test
+    logic is inverted: if the test fails, a success is reported.  If the test
+    succeeds, a failure is reported.
+    """
+
+    def decorate(fn):
+        fn_name = fn.__name__
+        def maybe(*args, **kw):
+            if config.db.name not in dbs:
+                return fn(*args, **kw)
+            else:
+                try:
+                    fn(*args, **kw)
+                except Exception, ex:
+                    print ("'%s' failed as expected on DB implementation "
+                           "'%s': %s" % (
+                        fn_name, config.db.name, str(ex)))
+                    return True
+                else:
+                    raise AssertionError(
+                        "Unexpected success for '%s' on DB implementation '%s'" %
+                        (fn_name, config.db.name))
+        try:
+            maybe.__name__ = fn_name
+        except:
+            pass
+        return maybe
+    return decorate
+
 def supported(*dbs):
     """Mark a test as supported by one or more database implementations"""
-    
+
     def decorate(fn):
         fn_name = fn.__name__
         def maybe(*args, **kw):
@@ -134,12 +167,12 @@ def rowset(results):
 
 class TestData(object):
     """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
-    
+
     def __init__(self):
         self.set_assert_list(None, None)
         self.sql_count = 0
         self.buffer = None
-        
+
     def set_assert_list(self, unittest, list):
         self.unittest = unittest
         self.assert_list = list
@@ -152,7 +185,7 @@ testdata = TestData()
 class ExecutionContextWrapper(object):
     """instruments the ExecutionContext created by the Engine so that SQL expressions
     can be tracked."""
-    
+
     def __init__(self, ctx):
         global sql
         if sql is None:
@@ -163,7 +196,7 @@ class ExecutionContextWrapper(object):
         return getattr(self.ctx, key)
     def __setattr__(self, key, value):
         setattr(self.ctx, key, value)
-        
+
     def post_execution(self):
         ctx = self.ctx
         statement = unicode(ctx.compiled)
@@ -180,7 +213,7 @@ class ExecutionContextWrapper(object):
                 item = testdata.assert_list.pop()
             else:
                 # asserting a dictionary of statements->parameters
-                # this is to specify query assertions where the queries can be in 
+                # this is to specify query assertions where the queries can be in
                 # multiple orderings
                 if '_converted' not in item:
                     for key in item.keys():
@@ -202,14 +235,14 @@ class ExecutionContextWrapper(object):
                 params = params(ctx)
             if params is not None and not isinstance(params, list):
                 params = [params]
-            
+
             parameters = ctx.compiled_parameters
-                    
+
             query = self.convert_statement(query)
             testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         testdata.sql_count += 1
         self.ctx.post_execution()
-        
+
     def convert_statement(self, query):
         paramstyle = self.ctx.dialect.paramstyle
         if paramstyle == 'named':
@@ -247,12 +280,12 @@ class SQLCompileTest(PersistTest):
     def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None):
         if dialect is None:
             dialect = getattr(self, '__dialect__', None)
-        
+
         if params is None:
             keys = None
         else:
             keys = params.keys()
-                
+
         c = clause.compile(column_keys=keys, dialect=dialect)
 
         print "\nSQL String:\n" + str(c) + repr(c.params)
@@ -267,19 +300,19 @@ class SQLCompileTest(PersistTest):
 class AssertMixin(PersistTest):
     """given a list-based structure of keys/properties which represent information within an object structure, and
     a list of actual objects, asserts that the list of objects corresponds to the structure."""
-    
+
     def assert_result(self, result, class_, *objects):
         result = list(result)
         print repr(result)
         self.assert_list(result, class_, objects)
-        
+
     def assert_list(self, result, class_, list):
         self.assert_(len(result) == len(list),
                      "result list is not the same size as test list, " +
                      "for class " + class_.__name__)
         for i in range(0, len(list)):
             self.assert_row(class_, result[i], list[i])
-            
+
     def assert_row(self, class_, rowobj, desc):
         self.assert_(rowobj.__class__ is class_,
                      "item class is not " + repr(class_))
@@ -382,13 +415,13 @@ class ORMTest(AssertMixin):
     keep_mappers = False
     keep_data = False
     metadata = None
-    
+
     def setUpAll(self):
         global MetaData, _otest_metadata
 
         if MetaData is None:
             from sqlalchemy import MetaData
-        
+
         if self.metadata is None:
             _otest_metadata = MetaData(config.db)
         else: