]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Wrap generate to support async mode
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 15:11:09 +0000 (16:11 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 15:11:09 +0000 (16:11 +0100)
jinja2/asyncsupport.py
tests/test_async.py

index ca897a77ac0af3ad0f05f6c462703c229440cc00..dc12be9c94a5bacbabfee6537280972a0b96531e 100644 (file)
@@ -15,6 +15,33 @@ async def concat_async(async_gen):
     return concat(rv)
 
 
+async def generate_async(self, *args, **kwargs):
+    vars = dict(*args, **kwargs)
+    try:
+        async for event in self.root_render_func(self.new_context(vars)):
+            yield event
+    except Exception:
+        exc_info = sys.exc_info()
+    else:
+        return
+    yield self.environment.handle_exception(exc_info, True)
+
+
+def wrap_generate_func(original_generate):
+    def _convert_generator(self, loop, args, kwargs):
+        async_gen = self.generate_async(*args, **kwargs)
+        try:
+            while 1:
+                yield loop.run_until_complete(async_gen.__anext__())
+        except StopAsyncIteration:
+            pass
+    def generate(self, *args, **kwargs):
+        if not self.environment._async:
+            return original_generate(self, *args, **kwargs)
+        return _convert_generator(self, asyncio.get_event_loop(), args, kwargs)
+    return generate
+
+
 async def render_async(self, *args, **kwargs):
     if not self.environment._async:
         raise RuntimeError('The environment was not created with async mode '
@@ -84,6 +111,8 @@ async def make_module_async(self, vars=None, shared=False, locals=None):
 
 def patch_template():
     from jinja2 import Template
+    Template.generate_async = generate_async
+    Template.generate = wrap_generate_func(Template.generate)
     Template.render_async = render_async
     Template.render = wrap_render_func(Template.render)
     Template._get_default_module = wrap_default_module(
index e21f7413bf5a5a0fa40f5a82c9a53a5ccb9b6db1..13af44ddba476d827abc8ccb9e550d37fb4ab329 100644 (file)
@@ -88,6 +88,14 @@ def test_async_blocks():
     assert rv == '<Test><Test>'
 
 
+@pytest.mark.skipif(not have_async_gen, reason='No async generators')
+def test_async_generate():
+    t = Template('{% for x in [1, 2, 3] %}{{ x }}{% endfor %}',
+                 enable_async=True)
+    rv = list(t.generate())
+    assert rv == ['1', '2', '3']
+
+
 @pytest.fixture
 def test_env_async():
     env = Environment(loader=DictLoader(dict(