]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Automatically await on function calls if necessary
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 12:18:20 +0000 (13:18 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 12:18:20 +0000 (13:18 +0100)
jinja2/asyncsupport.py
jinja2/compiler.py
tests/test_async.py

index eaa6ea94942bc1617f86ae3bc6f7c62ac8d21b18..534fb80f0bcc9ef92940531985e0ce3df67de564 100644 (file)
@@ -1,5 +1,6 @@
 import sys
 import asyncio
+import inspect
 
 from jinja2.utils import concat
 
@@ -41,3 +42,9 @@ def patch_template():
 
 def patch_all():
     patch_template()
+
+
+async def auto_await(value):
+    if inspect.isawaitable(value):
+        return await value
+    return value
index 1cc4aa8c1770d55760033138b6852ac7320bd163..a22904aed6df4b521c1a5be54c053482a682717f 100644 (file)
@@ -782,6 +782,9 @@ class CodeGenerator(NodeVisitor):
         if not unoptimize_before_dead_code:
             self.writeline('dummy = lambda *x: None')
 
+        if self.environment._async:
+            self.writeline('from jinja2.asyncsupport import auto_await')
+
         # if we want a deferred initialization we cannot move the
         # environment into a local name
         envenv = not self.defer_init and ', environment=environment' or ''
@@ -1625,6 +1628,8 @@ class CodeGenerator(NodeVisitor):
         self.write(')')
 
     def visit_Call(self, node, frame, forward_caller=False):
+        if self.environment._async:
+            self.write('await auto_await(')
         if self.environment.sandboxed:
             self.write('environment.call(context, ')
         else:
@@ -1633,6 +1638,8 @@ class CodeGenerator(NodeVisitor):
         extra_kwargs = forward_caller and {'caller': 'caller'} or None
         self.signature(node, frame, extra_kwargs)
         self.write(')')
+        if self.environment._async:
+            self.write(')')
 
     def visit_Keyword(self, node, frame):
         self.write(node.key + '=')
index a96eaa442f378e35691cd715ec2e92cba8ffdc38..c5d75ead810e45e31d7bc58b0f162533b72d63c4 100644 (file)
@@ -19,3 +19,24 @@ def test_basic_async():
 
     rv = run(func)
     assert rv == '[1][2][3]'
+
+
+@pytest.mark.skipif(not have_async_gen, reason='No async generators')
+def test_await_on_calls():
+    t = Template('{{ async_func() + normal_func() }}',
+                 enable_async=True)
+
+    async def async_func():
+        return 42
+
+    def normal_func():
+        return 23
+
+    async def func():
+        return await t.render_async(
+            async_func=async_func,
+            normal_func=normal_func
+        )
+
+    rv = run(func)
+    assert rv == '65'