]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Added support for async loop context
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 19:06:34 +0000 (20:06 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 19:06:34 +0000 (20:06 +0100)
jinja2/asyncsupport.py
jinja2/compiler.py
jinja2/runtime.py
tests/test_async.py

index 0a19b6ae5d5bc546efd20caaaa15197a7d64cda3..0c4fecf6a4af3d5f60706afa147a260bea795652 100644 (file)
@@ -4,6 +4,7 @@ import inspect
 
 from jinja2.utils import concat, internalcode, concat, Markup
 from jinja2.environment import TemplateModule
+from jinja2.runtime import LoopContextBase, _last_iteration
 
 
 async def concat_async(async_gen):
@@ -144,3 +145,45 @@ async def auto_iter(iterable):
         return
     for item in iterable:
         yield item
+
+
+class AsyncLoopContext(LoopContextBase):
+
+    def __init__(self, async_iterator, iterable, after, recurse=None, depth0=0):
+        self._async_iterator = async_iterator
+        LoopContextBase.__init__(self, iterable, recurse, depth0)
+        self._after = after
+
+    def __aiter__(self):
+        return AsyncLoopContextIterator(self)
+
+
+class AsyncLoopContextIterator(object):
+    __slots__ = ('context',)
+
+    def __init__(self, context):
+        self.context = context
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        ctx = self.context
+        ctx.index0 += 1
+        if ctx._after is _last_iteration:
+            raise StopAsyncIteration()
+        next_elem = ctx._after
+        try:
+            ctx._after = await ctx._async_iterator.__anext__()
+        except StopAsyncIteration:
+            ctx._after = _last_iteration
+        return next_elem, ctx
+
+
+async def make_async_loop_context(iterable, recurse=None, depth0=0):
+    async_iterator = auto_iter(iterable)
+    try:
+        after = await async_iterator.__anext__()
+    except StopAsyncIteration:
+        after = _last_iteration
+    return AsyncLoopContext(async_iterator, iterable, after, recurse, depth0)
index f2714c67af5c743975cc53b88d8fc6b6aa2773f3..667fec52df225154a62ee5b0db8a9a6e88fff372 100644 (file)
@@ -783,7 +783,8 @@ class CodeGenerator(NodeVisitor):
             self.writeline('dummy = lambda *x: None')
 
         if self.environment._async:
-            self.writeline('from jinja2.asyncsupport import auto_await, auto_iter')
+            self.writeline('from jinja2.asyncsupport import auto_await, '
+                           'auto_iter, make_async_loop_context')
 
         # if we want a deferred initialization we cannot move the
         # environment into a local name
@@ -1132,7 +1133,13 @@ class CodeGenerator(NodeVisitor):
 
         self.writeline(self.environment._async and 'async for ' or 'for ', node)
         self.visit(node.target, loop_frame)
-        self.write(extended_loop and ', l_loop in LoopContext(' or ' in ')
+        if extended_loop:
+            if self.environment._async:
+                self.write(', l_loop in await make_async_loop_context(')
+            else:
+                self.write(', l_loop in LoopContext(')
+        else:
+            self.write(' in ')
 
         # if we have an extened loop and a node test, we filter in the
         # "outer frame".
@@ -1158,10 +1165,10 @@ class CodeGenerator(NodeVisitor):
         elif node.recursive:
             self.write('reciter')
         else:
-            if self.environment._async:
+            if self.environment._async and not extended_loop:
                 self.write('auto_iter(')
             self.visit(node.iter, loop_frame)
-            if self.environment._async:
+            if self.environment._async and not extended_loop:
                 self.write(')')
 
         if node.recursive:
index 685a12da068c4808f2dfcec64d68e581f47e26fe..e0df9b71952b423c0fbd05d19b25721c1285caf7 100644 (file)
@@ -280,13 +280,14 @@ class BlockReference(object):
         return rv
 
 
-class LoopContext(object):
+class LoopContextBase(object):
     """A loop context for dynamic iteration."""
 
+    _after = _last_iteration
+    _length = None
+
     def __init__(self, iterable, recurse=None, depth0=0):
-        self._iterator = iter(iterable)
         self._recurse = recurse
-        self._after = self._safe_next()
         self.index0 = -1
         self.depth0 = depth0
 
@@ -315,15 +316,6 @@ class LoopContext(object):
     def __len__(self):
         return self.length
 
-    def __iter__(self):
-        return LoopContextIterator(self)
-
-    def _safe_next(self):
-        try:
-            return next(self._iterator)
-        except StopIteration:
-            return _last_iteration
-
     @internalcode
     def loop(self, iterable):
         if self._recurse is None:
@@ -357,6 +349,23 @@ class LoopContext(object):
         )
 
 
+class LoopContext(LoopContextBase):
+
+    def __init__(self, iterable, recurse=None, depth0=0):
+        self._iterator = iter(iterable)
+        LoopContextBase.__init__(self, iterable, recurse, depth0)
+        self._after = self._safe_next()
+
+    def __iter__(self):
+        return LoopContextIterator(self)
+
+    def _safe_next(self):
+        try:
+            return next(self._iterator)
+        except StopIteration:
+            return _last_iteration
+
+
 @implements_iterator
 class LoopContextIterator(object):
     """The iterator for a loop context."""
index 44463d105a0461bfb9669633540ea3fb134de81e..fff732befa2faba010e11f4e37f141f1356bc0bb 100644 (file)
@@ -97,7 +97,7 @@ def test_async_generate():
 
 
 @pytest.mark.skipif(not have_async_gen, reason='No async generators')
-def test_async_iteration_in_tmeplates():
+def test_async_iteration_in_templates():
     t = Template('{% for x in rng %}{{ x }}{% endfor %}',
                  enable_async=True)
     async def async_iterator():
@@ -107,6 +107,17 @@ def test_async_iteration_in_tmeplates():
     assert rv == ['1', '2', '3']
 
 
+@pytest.mark.skipif(not have_async_gen, reason='No async generators')
+def test_async_iteration_in_templates_extended():
+    t = Template('{% for x in rng %}{{ loop.index0 }}/{{ x }}{% endfor %}',
+                 enable_async=True)
+    async def async_iterator():
+        for item in [1, 2, 3]:
+            yield item
+    rv = list(t.generate(rng=async_iterator()))
+    assert rv == ['0/1', '1/2', '2/3']
+
+
 @pytest.fixture
 def test_env_async():
     env = Environment(loader=DictLoader(dict(