]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Support Python 3.5 async/await native coroutines.
authorBen Darnell <ben@bendarnell.com>
Thu, 11 Jun 2015 04:05:36 +0000 (00:05 -0400)
committerBen Darnell <ben@bendarnell.com>
Thu, 11 Jun 2015 04:13:29 +0000 (00:13 -0400)
Requires changes to be included in 3.5b3.

tornado/concurrent.py
tornado/gen.py
tornado/test/gen_test.py
tornado/testing.py
tornado/web.py

index 479ca022ef399d36883837d92ee7b12f055ab57b..2d6d803d352869588c3a461b38b68c72b461528c 100644 (file)
@@ -26,6 +26,7 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import functools
 import platform
+import textwrap
 import traceback
 import sys
 
@@ -170,6 +171,14 @@ class Future(object):
 
         self._callbacks = []
 
+    # Implement the Python 3.5 Awaitable protocol if possible
+    # (we can't use return and yield together until py33).
+    if sys.version_info >= (3, 3):
+        exec(textwrap.dedent("""
+        def __await__(self):
+            return (yield self)
+        """))
+
     def cancel(self):
         """Cancel the operation, if possible.
 
index 9145768951dbe035044bf442241222d308b79b75..e0ce1dde67c83eb4f7fce17257cc0ad58eba7063 100644 (file)
@@ -80,8 +80,8 @@ import collections
 import functools
 import itertools
 import sys
+import textwrap
 import types
-import weakref
 
 from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
 from tornado.ioloop import IOLoop
@@ -98,6 +98,17 @@ except ImportError as e:
         singledispatch = None
 
 
+try:
+    from collections.abc import Generator as GeneratorType  # py35+
+except ImportError:
+    from types import GeneratorType
+
+try:
+    from inspect import isawaitable  # py35+
+except ImportError:
+    def isawaitable(x): return False
+
+
 class KeyReuseError(Exception):
     pass
 
@@ -202,6 +213,10 @@ def _make_coroutine_wrapper(func, replace_callback):
     argument, so we cannot simply implement ``@engine`` in terms of
     ``@coroutine``.
     """
+    # On Python 3.5, set the coroutine flag on our generator, to allow it
+    # to be used with 'await'.
+    if hasattr(types, 'coroutine'):
+        func = types.coroutine(func)
     @functools.wraps(func)
     def wrapper(*args, **kwargs):
         future = TracebackFuture()
@@ -219,7 +234,7 @@ def _make_coroutine_wrapper(func, replace_callback):
             future.set_exc_info(sys.exc_info())
             return future
         else:
-            if isinstance(result, types.GeneratorType):
+            if isinstance(result, GeneratorType):
                 # Inline the first iteration of Runner.run.  This lets us
                 # avoid the cost of creating a Runner when the coroutine
                 # never actually yields, which in turn allows us to
@@ -1001,6 +1016,16 @@ def _argument_adapter(callback):
             callback(None)
     return wrapper
 
+if sys.version_info >= (3, 3):
+    exec(textwrap.dedent("""
+    @coroutine
+    def _wrap_awaitable(x):
+        return (yield from x)
+    """))
+else:
+    def _wrap_awaitable(x):
+        raise NotImplementedError()
+
 
 def convert_yielded(yielded):
     """Convert a yielded object into a `.Future`.
@@ -1022,6 +1047,8 @@ def convert_yielded(yielded):
         return multi_future(yielded)
     elif is_future(yielded):
         return yielded
+    elif isawaitable(yielded):
+        return _wrap_awaitable(yielded)
     else:
         raise BadYieldError("yielded unknown object %r" % (yielded,))
 
index fdaa0ec804dd56c1efc3a9c46837bf453bb61ea4..7b47f13021cdf7763b3d0c37600d6f77226f3c69 100644 (file)
@@ -26,7 +26,8 @@ try:
 except ImportError:
     futures = None
 
-skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 not available')
+skipBefore33 = unittest.skipIf(sys.version_info < (3, 3), 'PEP 380 (yield from) not available')
+skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available')
 skipNotCPython = unittest.skipIf(platform.python_implementation() != 'CPython',
                                  'Not CPython implementation')
 
@@ -728,6 +729,23 @@ class GenCoroutineTest(AsyncTestCase):
         self.assertEqual(result, 42)
         self.finished = True
 
+    @skipBefore35
+    @gen_test
+    def test_async_await(self):
+        # This test verifies that an async function can await a
+        # yield-based gen.coroutine, and that a gen.coroutine
+        # (the test method itself) can yield an async function.
+        global_namespace = dict(globals(), **locals())
+        local_namespace = {}
+        exec(textwrap.dedent("""
+        async def f():
+            await gen.Task(self.io_loop.add_callback)
+            return 42
+        """), global_namespace, local_namespace)
+        result = yield local_namespace['f']()
+        self.assertEqual(result, 42)
+        self.finished = True
+
     @gen_test
     def test_sync_return_no_value(self):
         @gen.coroutine
@@ -1041,6 +1059,15 @@ class AsyncPrepareErrorHandler(RequestHandler):
         self.finish('ok')
 
 
+class NativeCoroutineHandler(RequestHandler):
+    if sys.version_info > (3, 5):
+        exec(textwrap.dedent("""
+        async def get(self):
+            await gen.Task(IOLoop.current().add_callback)
+            self.write("ok")
+        """))
+
+
 class GenWebTest(AsyncHTTPTestCase):
     def get_app(self):
         return Application([
@@ -1054,6 +1081,7 @@ class GenWebTest(AsyncHTTPTestCase):
             ('/yield_exception', GenYieldExceptionHandler),
             ('/undecorated_coroutine', UndecoratedCoroutinesHandler),
             ('/async_prepare_error', AsyncPrepareErrorHandler),
+            ('/native_coroutine', NativeCoroutineHandler),
         ])
 
     def test_sequence_handler(self):
@@ -1096,6 +1124,12 @@ class GenWebTest(AsyncHTTPTestCase):
         response = self.fetch('/async_prepare_error')
         self.assertEqual(response.code, 403)
 
+    @skipBefore35
+    def test_native_coroutine_handler(self):
+        response = self.fetch('/native_coroutine')
+        self.assertEqual(response.code, 200)
+        self.assertEqual(response.body, b'ok')
+
 
 class WithTimeoutTest(AsyncTestCase):
     @gen_test
index 93f0dbe14196569f467b015789140e88eae4a8fc..6dd1ac247daf441690c848002238c4366bf0c2e5 100644 (file)
@@ -47,6 +47,11 @@ try:
 except ImportError:
     from io import StringIO  # py3
 
+try:
+    from collections.abc import Generator as GeneratorType  # py35+
+except ImportError:
+    from types import GeneratorType
+
 # Tornado's own test suite requires the updated unittest module
 # (either py27+ or unittest2) so tornado.test.util enforces
 # this requirement, but for other users of tornado.testing we want
@@ -118,7 +123,7 @@ class _TestMethodWrapper(object):
 
     def __call__(self, *args, **kwargs):
         result = self.orig_method(*args, **kwargs)
-        if isinstance(result, types.GeneratorType):
+        if isinstance(result, GeneratorType):
             raise TypeError("Generator test methods should be decorated with "
                             "tornado.testing.gen_test")
         elif result is not None:
@@ -485,7 +490,7 @@ def gen_test(func=None, timeout=None):
         @functools.wraps(f)
         def pre_coroutine(self, *args, **kwargs):
             result = f(self, *args, **kwargs)
-            if isinstance(result, types.GeneratorType):
+            if isinstance(result, GeneratorType):
                 self._test_generator = result
             else:
                 self._test_generator = None
index 0a50f79357fb338f4f43088b6abafe79ae2a51cf..802811b3779e06eff6f23e5617a39fda602c7b7e 100644 (file)
@@ -1388,10 +1388,8 @@ class RequestHandler(object):
                 self.check_xsrf_cookie()
 
             result = self.prepare()
-            if is_future(result):
-                result = yield result
             if result is not None:
-                raise TypeError("Expected None, got %r" % result)
+                result = yield result
             if self._prepared_future is not None:
                 # Tell the Application we've finished with prepare()
                 # and are ready for the body to arrive.
@@ -1411,10 +1409,8 @@ class RequestHandler(object):
 
             method = getattr(self, self.request.method.lower())
             result = method(*self.path_args, **self.path_kwargs)
-            if is_future(result):
-                result = yield result
             if result is not None:
-                raise TypeError("Expected None, got %r" % result)
+                result = yield result
             if self._auto_finish and not self._finished:
                 self.finish()
         except Exception as e: