]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Basic async support for blocks
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 13:37:56 +0000 (14:37 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 13:37:56 +0000 (14:37 +0100)
jinja2/asyncsupport.py
jinja2/compiler.py
tests/test_async.py

index 534fb80f0bcc9ef92940531985e0ce3df67de564..e81e83cc05f85a31b8aa57b498acc0d8aebeea0a 100644 (file)
@@ -2,7 +2,16 @@ import sys
 import asyncio
 import inspect
 
-from jinja2.utils import concat
+from jinja2.utils import concat, internalcode, concat, Markup
+
+
+async def concat_async(async_gen):
+    rv = []
+    async def collect():
+        async for event in async_gen:
+            rv.append(event)
+    await collect()
+    return concat(rv)
 
 
 async def render_async(self, *args, **kwargs):
@@ -12,14 +21,9 @@ async def render_async(self, *args, **kwargs):
 
     vars = dict(*args, **kwargs)
     ctx = self.new_context(vars)
-    rv = []
-    async def collect():
-        async for event in self.root_render_func(ctx):
-            rv.append(event)
 
     try:
-        await collect()
-        return concat(rv)
+        return await concat_async(self.root_render_func(ctx))
     except Exception:
         exc_info = sys.exc_info()
     return self.environment.handle_exception(exc_info, True)
@@ -34,14 +38,37 @@ def wrap_render_func(original_render):
     return render
 
 
+def wrap_block_reference_call(original_call):
+    @internalcode
+    async def async_call(self):
+        rv = await concat_async(self._stack[self._depth](self._context))
+        if self._context.eval_ctx.autoescape:
+            rv = Markup(rv)
+        return rv
+
+    @internalcode
+    def __call__(self):
+        if not self._context.environment._async:
+            return original_call(self)
+        return async_call(self)
+
+    return __call__
+
+
 def patch_template():
     from jinja2 import Template
     Template.render_async = render_async
     Template.render = wrap_render_func(Template.render)
 
 
+def patch_runtime():
+    from jinja2.runtime import BlockReference
+    BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__)
+
+
 def patch_all():
     patch_template()
+    patch_runtime()
 
 
 async def auto_await(value):
index a22904aed6df4b521c1a5be54c053482a682717f..09ad42b912aea89e45c6e5215e955db654df5112 100644 (file)
@@ -887,8 +887,10 @@ class CodeGenerator(NodeVisitor):
                 self.indent()
                 level += 1
         context = node.scoped and 'context.derived(locals())' or 'context'
-        self.writeline('for event in context.blocks[%r][0](%s):' % (
-                       node.name, context), node)
+
+        loop = self.environment._async and 'async for' or 'for'
+        self.writeline('%s event in context.blocks[%r][0](%s):' % (
+                       loop, node.name, context), node)
         self.indent()
         self.simple_write('event', frame)
         self.outdent(level)
index 00dfffe13dbbfa2f9017b421af3ac802c2280103..fd88805b7f350c00ade2dcc291d88468501b2949 100644 (file)
@@ -55,3 +55,14 @@ def test_await_and_macros():
 
     rv = run(func)
     assert rv == '[42][42]'
+
+
+@pytest.mark.skipif(not have_async_gen, reason='No async generators')
+def test_async_blocks():
+    t = Template('{% block foo %}<Test>{% endblock %}{{ self.foo() }}',
+                 enable_async=True, autoescape=True)
+    async def func():
+        return await t.render_async()
+
+    rv = run(func)
+    assert rv == '<Test><Test>'