From d1b8f08470b2f1cbfdf53932893b6e1f1b9d519b Mon Sep 17 00:00:00 2001 From: Armin Ronacher Date: Wed, 28 Dec 2016 13:18:20 +0100 Subject: [PATCH] Automatically await on function calls if necessary --- jinja2/asyncsupport.py | 7 +++++++ jinja2/compiler.py | 7 +++++++ tests/test_async.py | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/jinja2/asyncsupport.py b/jinja2/asyncsupport.py index eaa6ea94..534fb80f 100644 --- a/jinja2/asyncsupport.py +++ b/jinja2/asyncsupport.py @@ -1,5 +1,6 @@ import sys import asyncio +import inspect from jinja2.utils import concat @@ -41,3 +42,9 @@ def patch_template(): def patch_all(): patch_template() + + +async def auto_await(value): + if inspect.isawaitable(value): + return await value + return value diff --git a/jinja2/compiler.py b/jinja2/compiler.py index 1cc4aa8c..a22904ae 100644 --- a/jinja2/compiler.py +++ b/jinja2/compiler.py @@ -782,6 +782,9 @@ class CodeGenerator(NodeVisitor): if not unoptimize_before_dead_code: self.writeline('dummy = lambda *x: None') + if self.environment._async: + self.writeline('from jinja2.asyncsupport import auto_await') + # if we want a deferred initialization we cannot move the # environment into a local name envenv = not self.defer_init and ', environment=environment' or '' @@ -1625,6 +1628,8 @@ class CodeGenerator(NodeVisitor): self.write(')') def visit_Call(self, node, frame, forward_caller=False): + if self.environment._async: + self.write('await auto_await(') if self.environment.sandboxed: self.write('environment.call(context, ') else: @@ -1633,6 +1638,8 @@ class CodeGenerator(NodeVisitor): extra_kwargs = forward_caller and {'caller': 'caller'} or None self.signature(node, frame, extra_kwargs) self.write(')') + if self.environment._async: + self.write(')') def visit_Keyword(self, node, frame): self.write(node.key + '=') diff --git a/tests/test_async.py b/tests/test_async.py index a96eaa44..c5d75ead 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -19,3 +19,24 @@ def test_basic_async(): rv = run(func) assert rv == '[1][2][3]' + + +@pytest.mark.skipif(not have_async_gen, reason='No async generators') +def test_await_on_calls(): + t = Template('{{ async_func() + normal_func() }}', + enable_async=True) + + async def async_func(): + return 42 + + def normal_func(): + return 23 + + async def func(): + return await t.render_async( + async_func=async_func, + normal_func=normal_func + ) + + rv = run(func) + assert rv == '65' -- 2.47.3