]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Update gen_test for native coroutines. 1550/head pr1550/r1
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 12 Oct 2015 17:45:15 +0000 (13:45 -0400)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 12 Oct 2015 17:46:17 +0000 (13:46 -0400)
tornado/test/testing_test.py
tornado/testing.py

index ded2569b8875f1def3971d4ce541eb2fa83907c2..e00058ac343128afc1fe5e26d9c5e45d841c97df 100644 (file)
@@ -5,11 +5,11 @@ from __future__ import absolute_import, division, print_function, with_statement
 from tornado import gen, ioloop
 from tornado.log import app_log
 from tornado.testing import AsyncTestCase, gen_test, ExpectLog
-from tornado.test.util import unittest
-
+from tornado.test.util import unittest, skipBefore35, exec_test
 import contextlib
 import os
 import traceback
+import warnings
 
 
 @contextlib.contextmanager
@@ -86,6 +86,26 @@ class AsyncTestCaseWrapperTest(unittest.TestCase):
         self.assertEqual(len(result.errors), 1)
         self.assertIn("should be decorated", result.errors[0][1])
 
+    @skipBefore35
+    def test_undecorated_coroutine(self):
+        namespace = exec_test(globals(), locals(), """
+        class Test(AsyncTestCase):
+            async def test_coro(self):
+                pass
+        """)
+
+        test_class = namespace['Test']
+        test = test_class('test_coro')
+        result = unittest.TestResult()
+
+        # Silence "RuntimeWarning: coroutine 'test_coro' was never awaited".
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')
+            test.run(result)
+
+        self.assertEqual(len(result.errors), 1)
+        self.assertIn("should be decorated", result.errors[0][1])
+
     def test_undecorated_generator_with_skip(self):
         class Test(AsyncTestCase):
             @unittest.skip("don't run this")
@@ -228,5 +248,31 @@ class GenTest(AsyncTestCase):
         test_with_kwargs(self, test='test')
         self.finished = True
 
+    @skipBefore35
+    def test_native_coroutine(self):
+        namespace = exec_test(globals(), locals(), """
+        @gen_test
+        async def test(self):
+            self.finished = True
+        """)
+
+        namespace['test'](self)
+
+    @skipBefore35
+    def test_native_coroutine_timeout(self):
+        # Set a short timeout and exceed it.
+        namespace = exec_test(globals(), locals(), """
+        @gen_test(timeout=0.1)
+        async def test(self):
+            await gen.sleep(1)
+        """)
+
+        try:
+            namespace['test'](self)
+            self.fail("did not get expected exception")
+        except ioloop.TimeoutError:
+            self.finished = True
+
+
 if __name__ == '__main__':
     unittest.main()
index f5e9f153581502ca9cbcd9dddc8ea306291668b6..54d76fe40fecf451287e7c82e4069cb43f1604b0 100644 (file)
@@ -34,6 +34,7 @@ from tornado.log import gen_log, app_log
 from tornado.stack_context import ExceptionStackContext
 from tornado.util import raise_exc_info, basestring_type
 import functools
+import inspect
 import logging
 import os
 import re
@@ -51,6 +52,12 @@ try:
 except ImportError:
     from types import GeneratorType
 
+if sys.version_info >= (3, 5):
+    iscoroutine = inspect.iscoroutine
+    iscoroutinefunction = inspect.iscoroutinefunction
+else:
+    iscoroutine = iscoroutinefunction = lambda f: False
+
 # Tornado's own test suite requires the updated unittest module
 # (either py27+ or unittest2) so tornado.test.util enforces
 # this requirement, but for other users of tornado.testing we want
@@ -123,9 +130,9 @@ class _TestMethodWrapper(object):
 
     def __call__(self, *args, **kwargs):
         result = self.orig_method(*args, **kwargs)
-        if isinstance(result, GeneratorType):
-            raise TypeError("Generator test methods should be decorated with "
-                            "tornado.testing.gen_test")
+        if isinstance(result, GeneratorType) or iscoroutine(result):
+            raise TypeError("Generator and coroutine test methods should be"
+                            " decorated with tornado.testing.gen_test")
         elif result is not None:
             raise ValueError("Return value from test method ignored: %r" %
                              result)
@@ -499,13 +506,16 @@ def gen_test(func=None, timeout=None):
         @functools.wraps(f)
         def pre_coroutine(self, *args, **kwargs):
             result = f(self, *args, **kwargs)
-            if isinstance(result, GeneratorType):
+            if isinstance(result, GeneratorType) or iscoroutine(result):
                 self._test_generator = result
             else:
                 self._test_generator = None
             return result
 
-        coro = gen.coroutine(pre_coroutine)
+        if iscoroutinefunction(f):
+            coro = pre_coroutine
+        else:
+            coro = gen.coroutine(pre_coroutine)
 
         @functools.wraps(coro)
         def post_coroutine(self, *args, **kwargs):
@@ -515,8 +525,8 @@ def gen_test(func=None, timeout=None):
                     timeout=timeout)
             except TimeoutError as e:
                 # run_sync raises an error with an unhelpful traceback.
-                # If we throw it back into the generator the stack trace
-                # will be replaced by the point where the test is stopped.
+                # Throw it back into the generator or coroutine so the stack
+                # trace is replaced by the point where the test is stopped.
                 self._test_generator.throw(e)
                 # In case the test contains an overly broad except clause,
                 # we may get back here.  In this case re-raise the original