]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Fixed self references in macros
authorArmin Ronacher <armin.ronacher@active-4.com>
Fri, 6 Jan 2017 19:57:30 +0000 (20:57 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Fri, 6 Jan 2017 19:57:30 +0000 (20:57 +0100)
jinja2/compiler.py
jinja2/idtracking.py
jinja2/runtime.py
tests/test_core_tags.py

index 5cb70b8d176d228c7d2d20227d9c56c68edef159..cc808e9d861f5874bd98643176f5642d2b6bfd08 100644 (file)
@@ -283,6 +283,9 @@ class CodeGenerator(NodeVisitor):
         # Tracks toplevel assignments
         self._assign_stack = []
 
+        # Tracks parameter definition blocks
+        self._param_def_block = []
+
     # -- Various compilation helpers
 
     def fail(self, msg, lineno):
@@ -508,20 +511,24 @@ class CodeGenerator(NodeVisitor):
         self.buffer(frame)
         self.enter_frame(frame)
 
+        self.push_parameter_definitions(frame)
         for idx, arg in enumerate(node.args):
-            self.writeline('if %s is missing:' % frame.symbols.ref(arg.name))
+            ref = frame.symbols.ref(arg.name)
+            self.writeline('if %s is missing:' % ref)
             self.indent()
             try:
                 default = node.defaults[idx - len(node.args)]
             except IndexError:
                 self.writeline('%s = undefined(%r, name=%r)' % (
-                    frame.symbols.ref(arg.name),
+                    ref,
                     'parameter %r was not provided' % arg.name,
                     arg.name))
             else:
-                self.writeline('%s = ' % frame.symbols.ref(arg.name))
+                self.writeline('%s = ' % ref)
                 self.visit(default, frame)
+            self.mark_parameter_stored(ref)
             self.outdent()
+        self.pop_parameter_definitions()
 
         self.blockvisit(node.body, frame)
         self.return_buffer_contents(frame, force_unescaped=True)
@@ -554,10 +561,72 @@ class CodeGenerator(NodeVisitor):
             in iteritems(frame.symbols.dump_stores()))
 
     def write_commons(self):
+        """Writes a common preamble that is used by root and block functions.
+        Primarily this sets up common local helpers and enforces a generator
+        through a dead branch.
+        """
         self.writeline('resolve = context.resolve_or_missing')
         self.writeline('undefined = environment.undefined')
         self.writeline('if 0: yield None')
 
+    def push_parameter_definitions(self, frame):
+        """Pushes all parameter targets from the given frame into a local
+        stack that permits tracking of yet to be assigned parameters.  In
+        particular this enables the optimization from `visit_Name` to skip
+        undefined expressions for parameters in macros as macros can reference
+        otherwise unbound parameters.
+        """
+        self._param_def_block.append(frame.symbols.dump_param_targets())
+
+    def pop_parameter_definitions(self):
+        """Pops the current parameter definitions set."""
+        self._param_def_block.pop()
+
+    def mark_parameter_stored(self, target):
+        """Marks a parameter in the current parameter definitions as stored.
+        This will skip the enforced undefined checks.
+        """
+        if self._param_def_block:
+            self._param_def_block[-1].discard(target)
+
+    def parameter_is_undeclared(self, target):
+        """Checks if a given target is an undeclared parameter."""
+        if not self._param_def_block:
+            return True
+        return target in self._param_def_block[-1]
+
+    def push_assign_tracking(self):
+        """Pushes a new layer for assignment tracking."""
+        self._assign_stack.append(set())
+
+    def pop_assign_tracking(self, frame):
+        """Pops the topmost level for assignment tracking and updates the
+        context variables if necessary.
+        """
+        vars = self._assign_stack.pop()
+        if not frame.toplevel or not vars:
+            return
+        public_names = [x for x in vars if x[:1] != '_']
+        if len(vars) == 1:
+            name = next(iter(vars))
+            ref = frame.symbols.ref(name)
+            self.writeline('context.vars[%r] = %s' % (name, ref))
+        else:
+            self.writeline('context.vars.update({')
+            for idx, name in enumerate(vars):
+                if idx:
+                    self.write(', ')
+                ref = frame.symbols.ref(name)
+                self.write('%r: %s' % (name, ref))
+            self.write('})')
+        if public_names:
+            if len(public_names) == 1:
+                self.writeline('context.exported_vars.add(%r)' %
+                               public_names[0])
+            else:
+                self.writeline('context.exported_vars.update((%s))' %
+                               ', '.join(imap(repr, public_names)))
+
     # -- Statement Visitors
 
     def visit_Template(self, node, frame=None):
@@ -1207,34 +1276,6 @@ class CodeGenerator(NodeVisitor):
         if outdent_later:
             self.outdent()
 
-    def push_assign_tracking(self):
-        self._assign_stack.append(set())
-
-    def pop_assign_tracking(self, frame):
-        vars = self._assign_stack.pop()
-        if not frame.toplevel or not vars:
-            return
-        public_names = [x for x in vars if x[:1] != '_']
-        if len(vars) == 1:
-            name = next(iter(vars))
-            ref = frame.symbols.ref(name)
-            self.writeline('context.vars[%r] = %s' % (name, ref))
-        else:
-            self.writeline('context.vars.update({')
-            for idx, name in enumerate(vars):
-                if idx:
-                    self.write(', ')
-                ref = frame.symbols.ref(name)
-                self.write('%r: %s' % (name, ref))
-            self.write('})')
-        if public_names:
-            if len(public_names) == 1:
-                self.writeline('context.exported_vars.add(%r)' %
-                               public_names[0])
-            else:
-                self.writeline('context.exported_vars.update((%s))' %
-                               ', '.join(imap(repr, public_names)))
-
     def visit_Assign(self, node, frame):
         self.push_assign_tracking()
         self.newline(node)
@@ -1273,7 +1314,8 @@ class CodeGenerator(NodeVisitor):
         # instruction indicates a parameter which are always defined.
         if node.ctx == 'load':
             load = frame.symbols.find_load(ref)
-            if load is None or load[0] != VAR_LOAD_PARAMETER:
+            if not (load is not None and load[0] == VAR_LOAD_PARAMETER and \
+                    not self.parameter_is_undeclared(ref)):
                 self.write('(undefined(name=%r) if %s is missing else %s)' %
                            (node.name, ref, ref))
                 return
index f99c568dc69262b16457502ca45cc9fffbb98526..87ef5107ffdbc72cada152c86b72ad59fb75103c 100644 (file)
@@ -123,6 +123,16 @@ class Symbols(object):
             node = node.parent
         return rv
 
+    def dump_param_targets(self):
+        rv = set()
+        node = self
+        while node is not None:
+            for target, (instr, _) in iteritems(self.loads):
+                if instr == VAR_LOAD_PARAMETER:
+                    rv.add(target)
+            node = node.parent
+        return rv
+
 
 class RootVisitor(NodeVisitor):
 
index 4ee47ee64829f186ace34fd277d8733286c4538d..43f25063bab129dc81baf0830bb330daf8d98365 100644 (file)
@@ -66,7 +66,9 @@ def new_context(environment, template_name, blocks, vars=None,
         # we don't want to modify the dict passed
         if shared:
             parent = dict(parent)
-        parent.update(locals or ())
+        for key, value in iteritems(locals):
+            if value is not missing:
+                parent[key] = value
     return environment.context_class(environment, parent, template_name,
                                      blocks)
 
index 058bf85e68cb005c0643d05880589ee1f6fd8504..7d49d8a9f9365098882d0e23e5cafa24092506ba 100644 (file)
@@ -202,7 +202,7 @@ class TestForLoop(object):
 
 @pytest.mark.core_tags
 @pytest.mark.if_condition
-class TestIfCondition():
+class TestIfCondition(object):
 
     def test_simple(self, env):
         tmpl = env.from_string('''{% if true %}...{% endif %}''')
@@ -237,7 +237,7 @@ class TestIfCondition():
 
 @pytest.mark.core_tags
 @pytest.mark.macros
-class TestMacros():
+class TestMacros(object):
     def test_simple(self, env_trim):
         tmpl = env_trim.from_string('''\
 {% macro say_hello(name) %}Hello {{ name }}!{% endmacro %}
@@ -324,10 +324,20 @@ class TestMacros():
                                     '{{ foo(5) }}')
         assert tmpl.render() == '5|4|3|2|1'
 
+    def test_macro_defaults_self_ref(self, env):
+        tmpl = env.from_string('''
+            {%- set x = 42 %}
+            {%- macro m(a, b=x, x=23) %}{{ a }}|{{ b }}|{{ x }}{% endmacro -%}
+        ''')
+        assert tmpl.module.m(1) == '1||23'
+        assert tmpl.module.m(1, 2) == '1|2|23'
+        assert tmpl.module.m(1, 2, 3) == '1|2|3'
+        assert tmpl.module.m(1, x=7) == '1|7|7'
+
 
 @pytest.mark.core_tags
 @pytest.mark.set
-class TestSet():
+class TestSet(object):
 
     def test_normal(self, env_trim):
         tmpl = env_trim.from_string('{% set foo = 1 %}{{ foo }}')