]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Support the same set of loop functions for async mode
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 20:49:00 +0000 (21:49 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 20:49:00 +0000 (21:49 +0100)
jinja2/asyncsupport.py
jinja2/compiler.py
jinja2/environment.py
jinja2/runtime.py
tests/test_async.py

index 72bbc510bd7966639318cf07d9e3f73d81d432ff..f109beea3037ac82faadb4d1725b89dd16f09e38 100644 (file)
@@ -153,7 +153,7 @@ async def auto_await(value):
     return value
 
 
-async def auto_iter(iterable):
+async def auto_aiter(iterable):
     if hasattr(iterable, '__aiter__'):
         async for item in iterable:
             yield item
@@ -164,10 +164,19 @@ async def auto_iter(iterable):
 
 class AsyncLoopContext(LoopContextBase):
 
-    def __init__(self, async_iterator, iterable, after, recurse=None, depth0=0):
+    def __init__(self, async_iterator, after, length, recurse=None,
+                 depth0=0):
+        LoopContextBase.__init__(self, recurse, depth0)
         self._async_iterator = async_iterator
-        LoopContextBase.__init__(self, iterable, recurse, depth0)
         self._after = after
+        self._length = length
+
+    @property
+    def length(self):
+        if self._length is None:
+            raise TypeError('Loop length for some iterators cannot be '
+                            'lazily calculated in async mode')
+        return self._length
 
     def __aiter__(self):
         return AsyncLoopContextIterator(self)
@@ -196,9 +205,25 @@ class AsyncLoopContextIterator(object):
 
 
 async def make_async_loop_context(iterable, recurse=None, depth0=0):
-    async_iterator = auto_iter(iterable)
+    # Length is more complicated and less efficient in async mode.  The
+    # reason for this is that we cannot know if length will be used
+    # upfront but because length is a property we cannot lazily execute it
+    # later.  This means that we need to buffer it up and measure :(
+    #
+    # We however only do this for actual iterators, not for async
+    # iterators as blocking here does not seem like the best idea in the
+    # world.
+    try:
+        length = len(iterable)
+    except (TypeError, AttributeError):
+        if not hasattr(iterable, '__aiter__'):
+            iterable = tuple(iterable)
+            length = len(iterable)
+        else:
+            length = None
+    async_iterator = auto_aiter(iterable)
     try:
         after = await async_iterator.__anext__()
     except StopAsyncIteration:
         after = _last_iteration
-    return AsyncLoopContext(async_iterator, iterable, after, recurse, depth0)
+    return AsyncLoopContext(async_iterator, after, length, recurse, depth0)
index ff60b9c75979adcd0da180a44d16595de9e2cbe6..40f145e0430da94e6fffc6e685cf3270167e3be2 100644 (file)
@@ -784,7 +784,7 @@ class CodeGenerator(NodeVisitor):
 
         if self.environment._async:
             self.writeline('from jinja2.asyncsupport import auto_await, '
-                           'auto_iter, make_async_loop_context')
+                           'auto_aiter, make_async_loop_context')
 
         # if we want a deferred initialization we cannot move the
         # environment into a local name
@@ -1144,19 +1144,19 @@ class CodeGenerator(NodeVisitor):
         # if we have an extened loop and a node test, we filter in the
         # "outer frame".
         if extended_loop and node.test is not None:
-            if self.environment._async:
-                self.fail('loop filters in async mode are unavailable if the '
-                          'loop uses the special "loop" variable or is '
-                          'recursive.', node.lineno)
             self.write('(')
             self.visit(node.target, loop_frame)
-            self.write(' for ')
+            self.write(self.environment._async and ' async for ' or ' for ')
             self.visit(node.target, loop_frame)
             self.write(' in ')
             if node.recursive:
                 self.write('reciter')
             else:
+                if self.environment._async:
+                    self.write('auto_aiter(')
                 self.visit(node.iter, loop_frame)
+                if self.environment._async:
+                    self.write(')')
             self.write(' if (')
             test_frame = loop_frame.copy()
             self.visit(node.test, test_frame)
@@ -1166,7 +1166,7 @@ class CodeGenerator(NodeVisitor):
             self.write('reciter')
         else:
             if self.environment._async and not extended_loop:
-                self.write('auto_iter(')
+                self.write('auto_aiter(')
             self.visit(node.iter, loop_frame)
             if self.environment._async and not extended_loop:
                 self.write(')')
@@ -1208,9 +1208,11 @@ class CodeGenerator(NodeVisitor):
             self.return_buffer_contents(loop_frame)
             self.outdent()
             self.start_write(frame, node)
+            if self.environment._async:
+                self.write('await ')
             self.write('loop(')
             if self.environment._async:
-                self.write('auto_iter(')
+                self.write('auto_aiter(')
             self.visit(node.iter, frame)
             if self.environment._async:
                 self.write(')')
index efbed6a4a0af0f4bd3c313c1b9e0504093b755df..afe7686d5ca2286e70645c0db887be973a2aadd0 100644 (file)
@@ -992,6 +992,14 @@ class Template(object):
         return self.environment.handle_exception(exc_info, True)
 
     def render_async(self, *args, **kwargs):
+        """This works similar to :meth:`render` but returns a coroutine
+        that when awaited returns the entire rendered template string.  This
+        requires the async feature to be enabled.
+
+        Example usage::
+
+            await template.render_async(knights='that say nih; asynchronously')
+        """
         # see asyncsupport for the actual implementation
         raise NotImplementedError('This feature is not available for this '
                                   'version of Python')
@@ -1021,6 +1029,9 @@ class Template(object):
         yield self.environment.handle_exception(exc_info, True)
 
     def generate_async(self, *args, **kwargs):
+        """An async version of :meth:`generate`.  Works very similarly but
+        returns an async iterator instead.
+        """
         # see asyncsupport for the actual implementation
         raise NotImplementedError('This feature is not available for this '
                                   'version of Python')
@@ -1046,6 +1057,11 @@ class Template(object):
         return TemplateModule(self, self.new_context(vars, shared, locals))
 
     def make_module_async(self, vars=None, shared=False, locals=None):
+        """As template module creation can invoke template code for
+        asynchronous exections this method must be used instead of the
+        normal :meth:`make_module` one.  Likewise the module attribute
+        becomes unavailable in async mode.
+        """
         # see asyncsupport for the actual implementation
         raise NotImplementedError('This feature is not available for this '
                                   'version of Python')
@@ -1068,6 +1084,8 @@ class Template(object):
         '23'
         >>> t.module.foo() == u'42'
         True
+
+        This attribute is not available if async mode is enabled.
         """
         return self._get_default_module()
 
index e0df9b71952b423c0fbd05d19b25721c1285caf7..622a91b26cdde908242fa7c30cc18a25e5c5b9d3 100644 (file)
@@ -286,20 +286,11 @@ class LoopContextBase(object):
     _after = _last_iteration
     _length = None
 
-    def __init__(self, iterable, recurse=None, depth0=0):
+    def __init__(self, recurse=None, depth0=0):
         self._recurse = recurse
         self.index0 = -1
         self.depth0 = depth0
 
-        # try to get the length of the iterable early.  This must be done
-        # here because there are some broken iterators around where there
-        # __len__ is the number of iterations left (i'm looking at your
-        # listreverseiterator!).
-        try:
-            self._length = len(iterable)
-        except (TypeError, AttributeError):
-            self._length = None
-
     def cycle(self, *args):
         """Cycles among the arguments with the current loop index."""
         if not args:
@@ -328,19 +319,6 @@ class LoopContextBase(object):
     __call__ = loop
     del loop
 
-    @property
-    def length(self):
-        if self._length is None:
-            # if was not possible to get the length of the iterator when
-            # the loop context was created (ie: iterating over a generator)
-            # we have to convert the iterable into a sequence and use the
-            # length of that + the number of iterations so far.
-            iterable = tuple(self._iterator)
-            self._iterator = iter(iterable)
-            iterations_done = self.index0 + 2
-            self._length = len(iterable) + iterations_done
-        return self._length
-
     def __repr__(self):
         return '<%s %r/%r>' % (
             self.__class__.__name__,
@@ -352,10 +330,32 @@ class LoopContextBase(object):
 class LoopContext(LoopContextBase):
 
     def __init__(self, iterable, recurse=None, depth0=0):
+        LoopContextBase.__init__(self, recurse, depth0)
         self._iterator = iter(iterable)
-        LoopContextBase.__init__(self, iterable, recurse, depth0)
+
+        # try to get the length of the iterable early.  This must be done
+        # here because there are some broken iterators around where there
+        # __len__ is the number of iterations left (i'm looking at your
+        # listreverseiterator!).
+        try:
+            self._length = len(iterable)
+        except (TypeError, AttributeError):
+            self._length = None
         self._after = self._safe_next()
 
+    @property
+    def length(self):
+        if self._length is None:
+            # if was not possible to get the length of the iterator when
+            # the loop context was created (ie: iterating over a generator)
+            # we have to convert the iterable into a sequence and use the
+            # length of that + the number of iterations so far.
+            iterable = tuple(self._iterator)
+            self._iterator = iter(iterable)
+            iterations_done = self.index0 + 2
+            self._length = len(iterable) + iterations_done
+        return self._length
+
     def __iter__(self):
         return LoopContextIterator(self)
 
index fff732befa2faba010e11f4e37f141f1356bc0bb..94b7ac0563147cbc5127a492d7e48c7ac057c32f 100644 (file)
@@ -3,7 +3,8 @@ import asyncio
 
 from jinja2 import Template, Environment, DictLoader
 from jinja2.utils import have_async_gen
-from jinja2.exceptions import TemplateNotFound, TemplatesNotFound
+from jinja2.exceptions import TemplateNotFound, TemplatesNotFound, \
+     UndefinedError
 
 
 def run(coro):
@@ -252,3 +253,177 @@ class TestAsyncIncludes(object):
             {{ outer("FOO") }}
         """)
         assert t.render().strip() == '(FOO)'
+
+
+@pytest.mark.skipif(not have_async_gen, reason='No async generators')
+@pytest.mark.core_tags
+@pytest.mark.for_loop
+class TestAsyncForLoop(object):
+
+    def test_simple(self, test_env_async):
+        tmpl = test_env_async.from_string('{% for item in seq %}{{ item }}{% endfor %}')
+        assert tmpl.render(seq=list(range(10))) == '0123456789'
+
+    def test_else(self, test_env_async):
+        tmpl = test_env_async.from_string(
+            '{% for item in seq %}XXX{% else %}...{% endfor %}')
+        assert tmpl.render() == '...'
+
+    def test_empty_blocks(self, test_env_async):
+        tmpl = test_env_async.from_string('<{% for item in seq %}{% else %}{% endfor %}>')
+        assert tmpl.render() == '<>'
+
+    def test_context_vars(self, test_env_async):
+        slist = [42, 24]
+        for seq in [slist, iter(slist), reversed(slist), (_ for _ in slist)]:
+            tmpl = test_env_async.from_string('''{% for item in seq -%}
+            {{ loop.index }}|{{ loop.index0 }}|{{ loop.revindex }}|{{
+                loop.revindex0 }}|{{ loop.first }}|{{ loop.last }}|{{
+               loop.length }}###{% endfor %}''')
+            one, two, _ = tmpl.render(seq=seq).split('###')
+            (one_index, one_index0, one_revindex, one_revindex0, one_first,
+             one_last, one_length) = one.split('|')
+            (two_index, two_index0, two_revindex, two_revindex0, two_first,
+             two_last, two_length) = two.split('|')
+
+            assert int(one_index) == 1 and int(two_index) == 2
+            assert int(one_index0) == 0 and int(two_index0) == 1
+            assert int(one_revindex) == 2 and int(two_revindex) == 1
+            assert int(one_revindex0) == 1 and int(two_revindex0) == 0
+            assert one_first == 'True' and two_first == 'False'
+            assert one_last == 'False' and two_last == 'True'
+            assert one_length == two_length == '2'
+
+    def test_cycling(self, test_env_async):
+        tmpl = test_env_async.from_string('''{% for item in seq %}{{
+            loop.cycle('<1>', '<2>') }}{% endfor %}{%
+            for item in seq %}{{ loop.cycle(*through) }}{% endfor %}''')
+        output = tmpl.render(seq=list(range(4)), through=('<1>', '<2>'))
+        assert output == '<1><2>' * 4
+
+    def test_scope(self, test_env_async):
+        tmpl = test_env_async.from_string('{% for item in seq %}{% endfor %}{{ item }}')
+        output = tmpl.render(seq=list(range(10)))
+        assert not output
+
+    def test_varlen(self, test_env_async):
+        def inner():
+            for item in range(5):
+                yield item
+        tmpl = test_env_async.from_string('{% for item in iter %}{{ item }}{% endfor %}')
+        output = tmpl.render(iter=inner())
+        assert output == '01234'
+
+    def test_noniter(self, test_env_async):
+        tmpl = test_env_async.from_string('{% for item in none %}...{% endfor %}')
+        pytest.raises(TypeError, tmpl.render)
+
+    def test_recursive(self, test_env_async):
+        tmpl = test_env_async.from_string('''{% for item in seq recursive -%}
+            [{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}]
+        {%- endfor %}''')
+        assert tmpl.render(seq=[
+            dict(a=1, b=[dict(a=1), dict(a=2)]),
+            dict(a=2, b=[dict(a=1), dict(a=2)]),
+            dict(a=3, b=[dict(a='a')])
+        ]) == '[1<[1][2]>][2<[1][2]>][3<[a]>]'
+
+    def test_recursive_depth0(self, test_env_async):
+        tmpl = test_env_async.from_string('''{% for item in seq recursive -%}
+            [{{ loop.depth0 }}:{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}]
+        {%- endfor %}''')
+        assert tmpl.render(seq=[
+            dict(a=1, b=[dict(a=1), dict(a=2)]),
+            dict(a=2, b=[dict(a=1), dict(a=2)]),
+            dict(a=3, b=[dict(a='a')])
+        ]) == '[0:1<[1:1][1:2]>][0:2<[1:1][1:2]>][0:3<[1:a]>]'
+
+    def test_recursive_depth(self, test_env_async):
+        tmpl = test_env_async.from_string('''{% for item in seq recursive -%}
+            [{{ loop.depth }}:{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}]
+        {%- endfor %}''')
+        assert tmpl.render(seq=[
+            dict(a=1, b=[dict(a=1), dict(a=2)]),
+            dict(a=2, b=[dict(a=1), dict(a=2)]),
+            dict(a=3, b=[dict(a='a')])
+        ]) == '[1:1<[2:1][2:2]>][1:2<[2:1][2:2]>][1:3<[2:a]>]'
+
+    def test_looploop(self, test_env_async):
+        tmpl = test_env_async.from_string('''{% for row in table %}
+            {%- set rowloop = loop -%}
+            {% for cell in row -%}
+                [{{ rowloop.index }}|{{ loop.index }}]
+            {%- endfor %}
+        {%- endfor %}''')
+        assert tmpl.render(table=['ab', 'cd']) == '[1|1][1|2][2|1][2|2]'
+
+    def test_reversed_bug(self, test_env_async):
+        tmpl = test_env_async.from_string('{% for i in items %}{{ i }}'
+                               '{% if not loop.last %}'
+                               ',{% endif %}{% endfor %}')
+        assert tmpl.render(items=reversed([3, 2, 1])) == '1,2,3'
+
+    def test_loop_errors(self, test_env_async):
+        tmpl = test_env_async.from_string('''{% for item in [1] if loop.index
+                                      == 0 %}...{% endfor %}''')
+        pytest.raises(UndefinedError, tmpl.render)
+        tmpl = test_env_async.from_string('''{% for item in [] %}...{% else
+            %}{{ loop }}{% endfor %}''')
+        assert tmpl.render() == ''
+
+    def test_loop_filter(self, test_env_async):
+        tmpl = test_env_async.from_string('{% for item in range(10) if item '
+                               'is even %}[{{ item }}]{% endfor %}')
+        assert tmpl.render() == '[0][2][4][6][8]'
+        tmpl = test_env_async.from_string('''
+            {%- for item in range(10) if item is even %}[{{
+                loop.index }}:{{ item }}]{% endfor %}''')
+        assert tmpl.render() == '[1:0][2:2][3:4][4:6][5:8]'
+
+    def test_scoped_special_var(self, test_env_async):
+        t = test_env_async.from_string(
+            '{% for s in seq %}[{{ loop.first }}{% for c in s %}'
+            '|{{ loop.first }}{% endfor %}]{% endfor %}')
+        assert t.render(seq=('ab', 'cd')) \
+            == '[True|True|False][False|True|False]'
+
+    def test_scoped_loop_var(self, test_env_async):
+        t = test_env_async.from_string('{% for x in seq %}{{ loop.first }}'
+                            '{% for y in seq %}{% endfor %}{% endfor %}')
+        assert t.render(seq='ab') == 'TrueFalse'
+        t = test_env_async.from_string('{% for x in seq %}{% for y in seq %}'
+                            '{{ loop.first }}{% endfor %}{% endfor %}')
+        assert t.render(seq='ab') == 'TrueFalseTrueFalse'
+
+    def test_recursive_empty_loop_iter(self, test_env_async):
+        t = test_env_async.from_string('''
+        {%- for item in foo recursive -%}{%- endfor -%}
+        ''')
+        assert t.render(dict(foo=[])) == ''
+
+    def test_call_in_loop(self, test_env_async):
+        t = test_env_async.from_string('''
+        {%- macro do_something() -%}
+            [{{ caller() }}]
+        {%- endmacro %}
+
+        {%- for i in [1, 2, 3] %}
+            {%- call do_something() -%}
+                {{ i }}
+            {%- endcall %}
+        {%- endfor -%}
+        ''')
+        assert t.render() == '[1][2][3]'
+
+    def test_scoping_bug(self, test_env_async):
+        t = test_env_async.from_string('''
+        {%- for item in foo %}...{{ item }}...{% endfor %}
+        {%- macro item(a) %}...{{ a }}...{% endmacro %}
+        {{- item(2) -}}
+        ''')
+        assert t.render(foo=(1,)) == '...1......2...'
+
+    def test_unpacking(self, test_env_async):
+        tmpl = test_env_async.from_string('{% for a, b, c in [[1, 2, 3]] %}'
+                               '{{ a }}|{{ b }}|{{ c }}{% endfor %}')
+        assert tmpl.render() == '1|2|3'