]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Implemented a block set tag.
authorArmin Ronacher <armin.ronacher@active-4.com>
Fri, 6 Jun 2014 18:56:05 +0000 (00:56 +0600)
committerArmin Ronacher <armin.ronacher@active-4.com>
Fri, 6 Jun 2014 18:56:05 +0000 (00:56 +0600)
CHANGES
jinja2/compiler.py
jinja2/nodes.py
jinja2/parser.py
jinja2/testsuite/core_tags.py

diff --git a/CHANGES b/CHANGES
index b5aee7937a259bb549156876fd9ea211d46b816f..d4f3c3bdf986dc182cf99ea659f4e622cd39c158 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -18,6 +18,7 @@ Version 2.8
   object that logs failures into a logger.
 - If unmarshalling of cached data fails the template will be
   reloaded now.
+- Implemented a block ``set`` tag.
 
 Version 2.7.3
 -------------
index 3f9fb97878fce4fe31f3b3fe19d6cf0be852d9de..96b32f7d6b4af673371c1a8dbd5298a6151cd6f3 100644 (file)
@@ -347,6 +347,9 @@ class FrameIdentifierVisitor(NodeVisitor):
     def visit_FilterBlock(self, node):
         self.visit(node.filter)
 
+    def visit_AssignBlock(self, node):
+        """Stop visiting at block assigns."""
+
     def visit_Scope(self, node):
         """Stop visiting at scopes."""
 
@@ -1332,42 +1335,62 @@ class CodeGenerator(NodeVisitor):
         if outdent_later:
             self.outdent()
 
-    def visit_Assign(self, node, frame):
-        self.newline(node)
+    def make_assignment_frame(self, frame):
         # toplevel assignments however go into the local namespace and
         # the current template's context.  We create a copy of the frame
         # here and add a set so that the Name visitor can add the assigned
         # names here.
-        if frame.toplevel:
-            assignment_frame = frame.copy()
-            assignment_frame.toplevel_assignments = set()
+        if not frame.toplevel:
+            return frame
+        assignment_frame = frame.copy()
+        assignment_frame.toplevel_assignments = set()
+        return assignment_frame
+
+    def export_assigned_vars(self, frame, assignment_frame):
+        if not frame.toplevel:
+            return
+        public_names = [x for x in assignment_frame.toplevel_assignments
+                        if not x.startswith('_')]
+        if len(assignment_frame.toplevel_assignments) == 1:
+            name = next(iter(assignment_frame.toplevel_assignments))
+            self.writeline('context.vars[%r] = l_%s' % (name, name))
         else:
-            assignment_frame = frame
+            self.writeline('context.vars.update({')
+            for idx, name in enumerate(assignment_frame.toplevel_assignments):
+                if idx:
+                    self.write(', ')
+                self.write('%r: l_%s' % (name, name))
+            self.write('})')
+        if public_names:
+            if len(public_names) == 1:
+                self.writeline('context.exported_vars.add(%r)' %
+                               public_names[0])
+            else:
+                self.writeline('context.exported_vars.update((%s))' %
+                               ', '.join(imap(repr, public_names)))
+
+    def visit_Assign(self, node, frame):
+        self.newline(node)
+        assignment_frame = self.make_assignment_frame(frame)
         self.visit(node.target, assignment_frame)
         self.write(' = ')
         self.visit(node.node, frame)
-
-        # make sure toplevel assignments are added to the context.
-        if frame.toplevel:
-            public_names = [x for x in assignment_frame.toplevel_assignments
-                            if not x.startswith('_')]
-            if len(assignment_frame.toplevel_assignments) == 1:
-                name = next(iter(assignment_frame.toplevel_assignments))
-                self.writeline('context.vars[%r] = l_%s' % (name, name))
-            else:
-                self.writeline('context.vars.update({')
-                for idx, name in enumerate(assignment_frame.toplevel_assignments):
-                    if idx:
-                        self.write(', ')
-                    self.write('%r: l_%s' % (name, name))
-                self.write('})')
-            if public_names:
-                if len(public_names) == 1:
-                    self.writeline('context.exported_vars.add(%r)' %
-                                   public_names[0])
-                else:
-                    self.writeline('context.exported_vars.update((%s))' %
-                                   ', '.join(imap(repr, public_names)))
+        self.export_assigned_vars(frame, assignment_frame)
+
+    def visit_AssignBlock(self, node, frame):
+        block_frame = frame.inner()
+        block_frame.inspect(node.body)
+        aliases = self.push_scope(block_frame)
+        self.pull_locals(block_frame)
+        self.buffer(block_frame)
+        self.blockvisit(node.body, block_frame)
+        self.pop_scope(aliases, block_frame)
+
+        assignment_frame = self.make_assignment_frame(frame)
+        self.newline(node)
+        self.visit(node.target, assignment_frame)
+        self.write(' = concat(%s)' % block_frame.buffer)
+        self.export_assigned_vars(frame, assignment_frame)
 
     # -- Expression Visitors
 
index cbb7a2fb2e6b6342edb4dbabc2a3021f5d372c06..d451da36986252eed2b1c99958afb01a06aa6e47 100644 (file)
@@ -347,6 +347,11 @@ class Assign(Stmt):
     fields = ('target', 'node')
 
 
+class AssignBlock(Stmt):
+    """Assigns a block to a target."""
+    fields = ('target', 'body')
+
+
 class Expr(Node):
     """Baseclass for all expressions."""
     abstract = True
index 43755a0f416322457b8165b8b8ed198e26f6ede0..2cf2bd2055d511b687a076ed7423326d8f02cc95 100644 (file)
@@ -168,9 +168,12 @@ class Parser(object):
         """Parse an assign statement."""
         lineno = next(self.stream).lineno
         target = self.parse_assign_target()
-        self.stream.expect('assign')
-        expr = self.parse_tuple()
-        return nodes.Assign(target, expr, lineno=lineno)
+        if self.stream.skip_if('assign'):
+            expr = self.parse_tuple()
+            return nodes.Assign(target, expr, lineno=lineno)
+        body = self.parse_statements(('name:endset',),
+                                     drop_needle=True)
+        return nodes.AssignBlock(target, body, lineno=lineno)
 
     def parse_for(self):
         """Parse a for loop."""
index d4c44c42217bc1aa1c7888d349b4dede89ee4504..ac0eb4cb2564e91e1176a58d8e46b9ca5863a67f 100644 (file)
@@ -299,9 +299,24 @@ class MacrosTestCase(JinjaTestCase):
         assert tmpl.render() == '5|4|3|2|1'
 
 
+class SetTestCase(JinjaTestCase):
+    env = Environment(trim_blocks=True)
+
+    def test_normal(self):
+        tmpl = self.env.from_string('{% set foo = 1 %}{{ foo }}')
+        assert tmpl.render() == '1'
+        assert tmpl.module.foo == 1
+
+    def test_block(self):
+        tmpl = self.env.from_string('{% set foo %}42{% endset %}{{ foo }}')
+        assert tmpl.render() == '42'
+        assert tmpl.module.foo == u'42'
+
+
 def suite():
     suite = unittest.TestSuite()
     suite.addTest(unittest.makeSuite(ForLoopTestCase))
     suite.addTest(unittest.makeSuite(IfConditionTestCase))
     suite.addTest(unittest.makeSuite(MacrosTestCase))
+    suite.addTest(unittest.makeSuite(SetTestCase))
     return suite