]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Correctly scope loop filters. Fixes #649
authorArmin Ronacher <armin.ronacher@active-4.com>
Mon, 9 Jan 2017 11:23:18 +0000 (12:23 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Mon, 9 Jan 2017 11:23:18 +0000 (12:23 +0100)
CHANGES
jinja2/compiler.py
jinja2/idtracking.py
tests/test_async.py
tests/test_regression.py

diff --git a/CHANGES b/CHANGES
index 8baa1236d0948b5af4f3faeaf4943bf07ec65052..84a698a3a029cbba702c41bfb97d3bf81c18f058 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,6 +8,7 @@ Version 2.9.4
 - Solved some warnings for string literals.  (#646)
 - Increment the bytecode cache version which was not done due to an
   oversight before.
+- Corrected bad code generation and scoping for filtered loops.  (#649)
 
 Version 2.9.3
 -------------
index cdfe38ebfd230191f107b74c0715c211cef0e847..2fde59a1ec6fd583b411e3b38725304bb2fed921 100644 (file)
@@ -993,6 +993,7 @@ class CodeGenerator(NodeVisitor):
 
     def visit_For(self, node, frame):
         loop_frame = frame.inner()
+        test_frame = frame.inner()
         else_frame = frame.inner()
 
         # try to figure out if we have an extended loop.  An extended loop
@@ -1005,11 +1006,32 @@ class CodeGenerator(NodeVisitor):
         loop_ref = None
         if extended_loop:
             loop_ref = loop_frame.symbols.declare_parameter('loop')
-        loop_frame.symbols.analyze_node(node)
 
+        loop_frame.symbols.analyze_node(node, for_branch='body')
         if node.else_:
             else_frame.symbols.analyze_node(node, for_branch='else')
 
+        if node.test:
+            loop_filter_func = self.temporary_identifier()
+            test_frame.symbols.analyze_node(node, for_branch='test')
+            self.writeline('%s(fiter):' % self.func(loop_filter_func), node.test)
+            self.indent()
+            self.enter_frame(test_frame)
+            self.writeline(self.environment.is_async and 'async for ' or 'for ')
+            self.visit(node.target, loop_frame)
+            self.write(' in ')
+            self.write(self.environment.is_async and 'auto_aiter(fiter)' or 'fiter')
+            self.write(':')
+            self.indent()
+            self.writeline('if ', node.test)
+            self.visit(node.test, test_frame)
+            self.write(':')
+            self.indent()
+            self.writeline('yield ')
+            self.visit(node.target, loop_frame)
+            self.outdent(3)
+            self.leave_frame(test_frame, with_python_scope=True)
+
         # if we don't have an recursive loop we have to find the shadowed
         # variables at that point.  Because loops can be nested but the loop
         # variable is a special one we have to enforce aliasing for it.
@@ -1043,27 +1065,9 @@ class CodeGenerator(NodeVisitor):
         else:
             self.write(' in ')
 
-        # if we have an extened loop and a node test, we filter in the
-        # "outer frame".
-        if extended_loop and node.test is not None:
-            self.write('(')
-            self.visit(node.target, loop_frame)
-            self.write(self.environment.is_async and ' async for ' or ' for ')
-            self.visit(node.target, loop_frame)
-            self.write(' in ')
-            if node.recursive:
-                self.write('reciter')
-            else:
-                if self.environment.is_async:
-                    self.write('auto_aiter(')
-                self.visit(node.iter, frame)
-                if self.environment.is_async:
-                    self.write(')')
-            self.write(' if (')
-            self.visit(node.test, loop_frame)
-            self.write('))')
-
-        elif node.recursive:
+        if node.test:
+            self.write('%s(' % loop_filter_func)
+        if node.recursive:
             self.write('reciter')
         else:
             if self.environment.is_async and not extended_loop:
@@ -1071,6 +1075,8 @@ class CodeGenerator(NodeVisitor):
             self.visit(node.iter, frame)
             if self.environment.is_async and not extended_loop:
                 self.write(')')
+        if node.test:
+            self.write(')')
 
         if node.recursive:
             self.write(', loop_render_func, depth):')
@@ -1080,15 +1086,6 @@ class CodeGenerator(NodeVisitor):
         self.indent()
         self.enter_frame(loop_frame)
 
-        # tests in not extended loops become a continue
-        if not extended_loop and node.test is not None:
-            self.writeline('if not ')
-            self.visit(node.test, loop_frame)
-            self.write(':')
-            self.indent()
-            self.writeline('continue')
-            self.outdent()
-
         self.blockvisit(node.body, loop_frame)
         if node.else_:
             self.writeline('%s = 0' % iteration_indicator)
index 433b92c8f453c0400ab920ac815e8347a1253a95..8479b72c2309c6bf8db7cee4c160f10c9a824fdf 100644 (file)
@@ -168,13 +168,16 @@ class RootVisitor(NodeVisitor):
             self.sym_visitor.visit(child)
 
     def visit_For(self, node, for_branch='body', **kwargs):
-        if node.test is not None:
-            self.sym_visitor.visit(node.test)
         if for_branch == 'body':
             self.sym_visitor.visit(node.target, store_as_param=True)
             branch = node.body
         elif for_branch == 'else':
             branch = node.else_
+        elif for_branch == 'test':
+            self.sym_visitor.visit(node.target, store_as_param=True)
+            if node.test is not None:
+                self.sym_visitor.visit(node.test)
+            return
         else:
             raise RuntimeError('Unknown for branch')
         for item in branch or ():
index 88590c7ed9a5fcb788b40c4ff9d0f9a20d654576..bd3ef50aab58c5b910c0cb2c7afc7c796ada4a8a 100644 (file)
@@ -415,3 +415,52 @@ class TestAsyncForLoop(object):
         tmpl = test_env_async.from_string('{% for a, b, c in [[1, 2, 3]] %}'
                                '{{ a }}|{{ b }}|{{ c }}{% endfor %}')
         assert tmpl.render() == '1|2|3'
+
+    def test_recursive_loop_filter(self, test_env_async):
+        t = test_env_async.from_string('''
+        <?xml version="1.0" encoding="UTF-8"?>
+        <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
+          {%- for page in [site.root] if page.url != this recursive %}
+          <url><loc>{{ page.url }}</loc></url>
+          {{- loop(page.children) }}
+          {%- endfor %}
+        </urlset>
+        ''')
+        sm  =t.render(this='/foo', site={'root': {
+            'url': '/',
+            'children': [
+                {'url': '/foo'},
+                {'url': '/bar'},
+            ]
+        }})
+        lines = [x.strip() for x in sm.splitlines() if x.strip()]
+        assert lines == [
+            '<?xml version="1.0" encoding="UTF-8"?>',
+            '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">',
+            '<url><loc>/</loc></url>',
+            '<url><loc>/bar</loc></url>',
+            '</urlset>',
+        ]
+
+    def test_nonrecursive_loop_filter(self, test_env_async):
+        t = test_env_async.from_string('''
+        <?xml version="1.0" encoding="UTF-8"?>
+        <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
+          {%- for page in items if page.url != this %}
+          <url><loc>{{ page.url }}</loc></url>
+          {%- endfor %}
+        </urlset>
+        ''')
+        sm  =t.render(this='/foo', items=[
+            {'url': '/'},
+            {'url': '/foo'},
+            {'url': '/bar'},
+        ])
+        lines = [x.strip() for x in sm.splitlines() if x.strip()]
+        assert lines == [
+            '<?xml version="1.0" encoding="UTF-8"?>',
+            '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">',
+            '<url><loc>/</loc></url>',
+            '<url><loc>/bar</loc></url>',
+            '</urlset>',
+        ]
index 1706c8b51001bde38a39c5908ef0b6035e3e21f4..83e78afbff68f9f7f2e88d4b460877ba612ae31e 100644 (file)
@@ -452,3 +452,30 @@ class TestBug(object):
         t = env.from_string('{% set x = 1 %}{% with x = 2 %}{% block y scoped %}'
                             '{{ x }}{% endblock %}{% endwith %}')
         assert t.render() == '2'
+
+    def test_recursive_loop_filter(self, env):
+        t = env.from_string('''
+        <?xml version="1.0" encoding="UTF-8"?>
+        <urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
+          {%- for page in [site.root] if page.url != this recursive %}
+          <url><loc>{{ page.url }}</loc></url>
+          {{- loop(page.children) }}
+          {%- endfor %}
+        </urlset>
+        ''')
+        sm  =t.render(this='/foo', site={'root': {
+            'url': '/',
+            'children': [
+                {'url': '/foo'},
+                {'url': '/bar'},
+            ]
+        }})
+        lines = [x.strip() for x in sm.splitlines() if x.strip()]
+        print(lines)
+        assert lines == [
+            '<?xml version="1.0" encoding="UTF-8"?>',
+            '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">',
+            '<url><loc>/</loc></url>',
+            '<url><loc>/bar</loc></url>',
+            '</urlset>',
+        ]