]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Use a separate scope for the loop else branch
authorArmin Ronacher <armin.ronacher@active-4.com>
Tue, 3 Jan 2017 00:50:05 +0000 (01:50 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Tue, 3 Jan 2017 22:45:30 +0000 (23:45 +0100)
jinja2/compiler.py
jinja2/idtracking.py
tests/test_core_tags.py

index a7fb325cb70e5f46438d4cdd7170203abbf545f0..673dae64b797d2178c963374dd0a1ba30e755d67 100644 (file)
@@ -841,10 +841,8 @@ class CodeGenerator(NodeVisitor):
                                'update((%s))' % ', '.join(imap(repr, discarded_names)))
 
     def visit_For(self, node, frame):
-        # TODO: this should really use two frames: one for the loop body
-        # and a separate one for the loop else block.  This also is needed
-        # because the loop variable must not be visible in the else block
         loop_frame = frame.inner()
+        else_frame = frame.inner()
 
         # try to figure out if we have an extended loop.  An extended loop
         # is necessary if the loop is in recursive mode if the special loop
@@ -858,6 +856,9 @@ class CodeGenerator(NodeVisitor):
             loop_ref = loop_frame.symbols.declare_parameter('loop')
         loop_frame.symbols.analyze_node(node)
 
+        if node.else_:
+            else_frame.symbols.analyze_node(node, for_branch='else')
+
         # if we don't have an recursive loop we have to find the shadowed
         # variables at that point.  Because loops can be nested but the loop
         # variable is a special one we have to enforce aliasing for it.
@@ -946,9 +947,10 @@ class CodeGenerator(NodeVisitor):
         if node.else_:
             self.writeline('if %s:' % iteration_indicator)
             self.indent()
-            self.enter_frame(loop_frame)
-            self.blockvisit(node.else_, loop_frame)
-            self.leave_frame(loop_frame)
+            print(else_frame.symbols.__dict__)
+            self.enter_frame(else_frame)
+            self.blockvisit(node.else_, else_frame)
+            self.leave_frame(else_frame)
             self.outdent()
 
         # if the node was recursive we have to return the buffer contents
index be6021119667f4a7fdff031cfb916b06e11a12ed..6cb58109357e84aed618d97c51eac7865d496c76 100644 (file)
@@ -34,9 +34,9 @@ class Symbols(object):
         self.loads = {}
         self.stores = set()
 
-    def analyze_node(self, node):
+    def analyze_node(self, node, **kwargs):
         visitor = RootVisitor(self)
-        visitor.visit(node)
+        visitor.visit(node, **kwargs)
 
     def _define_ref(self, name, load=None):
         ident = 'l_%d_%s' % (self.level, name)
@@ -123,7 +123,7 @@ class RootVisitor(NodeVisitor):
     def __init__(self, symbols):
         self.sym_visitor = FrameSymbolVisitor(symbols)
 
-    def _simple_visit(self, node):
+    def _simple_visit(self, node, **kwargs):
         for child in node.iter_child_nodes():
             self.sym_visitor.visit(child)
 
@@ -131,18 +131,26 @@ class RootVisitor(NodeVisitor):
         visit_Scope = visit_If = visit_ScopedEvalContextModifier = \
         _simple_visit
 
-    def visit_AssignBlock(self, node):
+    def visit_AssignBlock(self, node, **kwargs):
         for child in node.body:
             self.sym_visitor.visit(child)
 
-    def visit_CallBlock(self, node):
+    def visit_CallBlock(self, node, **kwargs):
         for child in node.iter_child_nodes(exclude=('call',)):
             self.sym_visitor.visit(child)
 
-    def visit_For(self, node):
-        self.sym_visitor.visit(node.target, store_as_param=True)
-        for child in node.iter_child_nodes(exclude=('iter', 'target')):
-            self.sym_visitor.visit(child)
+    def visit_For(self, node, for_branch='body', **kwargs):
+        if node.test is not None:
+            self.sym_visitor.visit(node.test)
+        if for_branch == 'body':
+            self.sym_visitor.visit(node.target, store_as_param=True)
+            branch = node.body
+        elif for_branch == 'else':
+            branch = node.else_
+        else:
+            raise RuntimeError('Unknown for branch')
+        for item in branch or ():
+            self.sym_visitor.visit(item)
 
     def generic_visit(self, node, *args, **kwargs):
         raise NotImplementedError('Cannot find symbols for %r' %
index 2ea7757e4862aa8d9f96f566909cbc988692e920..5391354f05fd74edba5e9c41cd75ed29f90c9073 100644 (file)
@@ -20,7 +20,7 @@ def env_trim():
 
 @pytest.mark.core_tags
 @pytest.mark.for_loop
-class TestForLoop():
+class TestForLoop(object):
 
     def test_simple(self, env):
         tmpl = env.from_string('{% for item in seq %}{{ item }}{% endfor %}')
@@ -31,6 +31,11 @@ class TestForLoop():
             '{% for item in seq %}XXX{% else %}...{% endfor %}')
         assert tmpl.render() == '...'
 
+    def test_else_scoping_item(self, env):
+        tmpl = env.from_string(
+            '{% for item in [] %}{% else %}{{ item }}{% endfor %}')
+        assert tmpl.render(item=42) == '42'
+
     def test_empty_blocks(self, env):
         tmpl = env.from_string('<{% for item in seq %}{% else %}{% endfor %}>')
         assert tmpl.render() == '<>'