]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- support __only_on__ and __backend__ at the same time
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Jul 2014 22:12:32 +0000 (18:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Jul 2014 22:12:32 +0000 (18:12 -0400)
lib/sqlalchemy/testing/plugin/plugin_base.py
test/dialect/mysql/test_types.py

index 061848e27106de5523af76ab6c867499a748489e..a068272421c39485fc80e04de5db10cbde5920bd 100644 (file)
@@ -318,7 +318,7 @@ def want_class(cls):
 
 def generate_sub_tests(cls, module):
     if getattr(cls, '__backend__', False):
-        for cfg in config.Config.all_configs():
+        for cfg in _possible_configs_for_cls(cls):
             name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver)
             subcls = type(
                         name,
@@ -370,8 +370,25 @@ def after_test(test):
     engines.testing_reaper._after_test_ctx()
     warnings.resetwarnings()
 
-def _do_skips(cls):
+def _possible_configs_for_cls(cls):
     all_configs = set(config.Config.all_configs())
+    if cls.__unsupported_on__:
+        spec = exclusions.db_spec(*cls.__unsupported_on__)
+        for config_obj in list(all_configs):
+            if spec(config_obj):
+                all_configs.remove(config_obj)
+
+    if getattr(cls, '__only_on__', None):
+        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+        for config_obj in list(all_configs):
+            if not spec(config_obj):
+                all_configs.remove(config_obj)
+
+
+    return all_configs
+
+def _do_skips(cls):
+    all_configs = _possible_configs_for_cls(cls)
     reasons = []
 
     if hasattr(cls, '__requires__'):
@@ -398,19 +415,6 @@ def _do_skips(cls):
         if all_configs.difference(non_preferred):
             all_configs.difference_update(non_preferred)
 
-    if cls.__unsupported_on__:
-        spec = exclusions.db_spec(*cls.__unsupported_on__)
-        for config_obj in list(all_configs):
-            if spec(config_obj):
-                all_configs.remove(config_obj)
-
-    if getattr(cls, '__only_on__', None):
-        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
-        for config_obj in list(all_configs):
-            if not spec(config_obj):
-                all_configs.remove(config_obj)
-
-
     if getattr(cls, '__skip_if__', False):
         for c in getattr(cls, '__skip_if__'):
             if c():
index f5901812e7e34f3bdb0d5da3781829c9253b0d52..ffb2240bb420bda3e3e703dc9b3a194a86d016d7 100644 (file)
@@ -15,6 +15,7 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     "Test MySQL column types"
 
     __dialect__ = mysql.dialect()
+    __only_on__ = 'mysql'
     __backend__ = True
 
     def test_numeric(self):
@@ -153,7 +154,6 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
                 res
             )
 
-    @testing.only_if('mysql')
     @testing.provide_metadata
     def test_precision_float_roundtrip(self):
         t = Table('t', self.metadata,
@@ -291,7 +291,6 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         ]:
             self.assert_compile(type_, expected)
 
-    @testing.only_if('mysql')
     @testing.exclude('mysql', '<', (5, 0, 5), 'a 5.0+ feature')
     @testing.fails_if(
             lambda: testing.against("mysql+oursql") and util.py3k,
@@ -350,7 +349,6 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         ]:
             self.assert_compile(type_, expected)
 
-    @testing.only_if('mysql')
     @testing.provide_metadata
     def test_boolean_roundtrip(self):
         bool_table = Table(
@@ -448,7 +446,6 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
 
             )
 
-    @testing.only_if('mysql')
     @testing.provide_metadata
     def test_timestamp_nullable(self):
         ts_table = Table('mysql_timestamp', self.metadata,
@@ -515,7 +512,6 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             datetime.time(8, 37, 35, 450)
         )
 
-    @testing.only_if('mysql')
     @testing.provide_metadata
     def test_time_roundtrip(self):
         t = Table('mysql_time', self.metadata,
@@ -525,7 +521,6 @@ class TypesTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         t.insert().values(t1=datetime.time(8, 37, 35)).execute()
         eq_(select([t.c.t1]).scalar(), datetime.time(8, 37, 35))
 
-    @testing.only_if('mysql')
     @testing.provide_metadata
     def test_year(self):
         """Exercise YEAR."""