def visit_FilterBlock(self, node):
self.visit(node.filter)
+ def visit_AssignBlock(self, node):
+ """Stop visiting at block assigns."""
+
def visit_Scope(self, node):
"""Stop visiting at scopes."""
if outdent_later:
self.outdent()
- def visit_Assign(self, node, frame):
- self.newline(node)
+ def make_assignment_frame(self, frame):
# toplevel assignments however go into the local namespace and
# the current template's context. We create a copy of the frame
# here and add a set so that the Name visitor can add the assigned
# names here.
- if frame.toplevel:
- assignment_frame = frame.copy()
- assignment_frame.toplevel_assignments = set()
+ if not frame.toplevel:
+ return frame
+ assignment_frame = frame.copy()
+ assignment_frame.toplevel_assignments = set()
+ return assignment_frame
+
+ def export_assigned_vars(self, frame, assignment_frame):
+ if not frame.toplevel:
+ return
+ public_names = [x for x in assignment_frame.toplevel_assignments
+ if not x.startswith('_')]
+ if len(assignment_frame.toplevel_assignments) == 1:
+ name = next(iter(assignment_frame.toplevel_assignments))
+ self.writeline('context.vars[%r] = l_%s' % (name, name))
else:
- assignment_frame = frame
+ self.writeline('context.vars.update({')
+ for idx, name in enumerate(assignment_frame.toplevel_assignments):
+ if idx:
+ self.write(', ')
+ self.write('%r: l_%s' % (name, name))
+ self.write('})')
+ if public_names:
+ if len(public_names) == 1:
+ self.writeline('context.exported_vars.add(%r)' %
+ public_names[0])
+ else:
+ self.writeline('context.exported_vars.update((%s))' %
+ ', '.join(imap(repr, public_names)))
+
+ def visit_Assign(self, node, frame):
+ self.newline(node)
+ assignment_frame = self.make_assignment_frame(frame)
self.visit(node.target, assignment_frame)
self.write(' = ')
self.visit(node.node, frame)
-
- # make sure toplevel assignments are added to the context.
- if frame.toplevel:
- public_names = [x for x in assignment_frame.toplevel_assignments
- if not x.startswith('_')]
- if len(assignment_frame.toplevel_assignments) == 1:
- name = next(iter(assignment_frame.toplevel_assignments))
- self.writeline('context.vars[%r] = l_%s' % (name, name))
- else:
- self.writeline('context.vars.update({')
- for idx, name in enumerate(assignment_frame.toplevel_assignments):
- if idx:
- self.write(', ')
- self.write('%r: l_%s' % (name, name))
- self.write('})')
- if public_names:
- if len(public_names) == 1:
- self.writeline('context.exported_vars.add(%r)' %
- public_names[0])
- else:
- self.writeline('context.exported_vars.update((%s))' %
- ', '.join(imap(repr, public_names)))
+ self.export_assigned_vars(frame, assignment_frame)
+
+ def visit_AssignBlock(self, node, frame):
+ block_frame = frame.inner()
+ block_frame.inspect(node.body)
+ aliases = self.push_scope(block_frame)
+ self.pull_locals(block_frame)
+ self.buffer(block_frame)
+ self.blockvisit(node.body, block_frame)
+ self.pop_scope(aliases, block_frame)
+
+ assignment_frame = self.make_assignment_frame(frame)
+ self.newline(node)
+ self.visit(node.target, assignment_frame)
+ self.write(' = concat(%s)' % block_frame.buffer)
+ self.export_assigned_vars(frame, assignment_frame)
# -- Expression Visitors
"""Parse an assign statement."""
lineno = next(self.stream).lineno
target = self.parse_assign_target()
- self.stream.expect('assign')
- expr = self.parse_tuple()
- return nodes.Assign(target, expr, lineno=lineno)
+ if self.stream.skip_if('assign'):
+ expr = self.parse_tuple()
+ return nodes.Assign(target, expr, lineno=lineno)
+ body = self.parse_statements(('name:endset',),
+ drop_needle=True)
+ return nodes.AssignBlock(target, body, lineno=lineno)
def parse_for(self):
"""Parse a for loop."""
assert tmpl.render() == '5|4|3|2|1'
+class SetTestCase(JinjaTestCase):
+ env = Environment(trim_blocks=True)
+
+ def test_normal(self):
+ tmpl = self.env.from_string('{% set foo = 1 %}{{ foo }}')
+ assert tmpl.render() == '1'
+ assert tmpl.module.foo == 1
+
+ def test_block(self):
+ tmpl = self.env.from_string('{% set foo %}42{% endset %}{{ foo }}')
+ assert tmpl.render() == '42'
+ assert tmpl.module.foo == u'42'
+
+
def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ForLoopTestCase))
suite.addTest(unittest.makeSuite(IfConditionTestCase))
suite.addTest(unittest.makeSuite(MacrosTestCase))
+ suite.addTest(unittest.makeSuite(SetTestCase))
return suite