From: Armin Ronacher Date: Fri, 6 Jan 2017 19:57:30 +0000 (+0100) Subject: Fixed self references in macros X-Git-Tag: 2.9~19 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cb4b551e12a4a3b190a5ee4ac9f019359653f99e;p=thirdparty%2Fjinja.git Fixed self references in macros --- diff --git a/jinja2/compiler.py b/jinja2/compiler.py index 5cb70b8d..cc808e9d 100644 --- a/jinja2/compiler.py +++ b/jinja2/compiler.py @@ -283,6 +283,9 @@ class CodeGenerator(NodeVisitor): # Tracks toplevel assignments self._assign_stack = [] + # Tracks parameter definition blocks + self._param_def_block = [] + # -- Various compilation helpers def fail(self, msg, lineno): @@ -508,20 +511,24 @@ class CodeGenerator(NodeVisitor): self.buffer(frame) self.enter_frame(frame) + self.push_parameter_definitions(frame) for idx, arg in enumerate(node.args): - self.writeline('if %s is missing:' % frame.symbols.ref(arg.name)) + ref = frame.symbols.ref(arg.name) + self.writeline('if %s is missing:' % ref) self.indent() try: default = node.defaults[idx - len(node.args)] except IndexError: self.writeline('%s = undefined(%r, name=%r)' % ( - frame.symbols.ref(arg.name), + ref, 'parameter %r was not provided' % arg.name, arg.name)) else: - self.writeline('%s = ' % frame.symbols.ref(arg.name)) + self.writeline('%s = ' % ref) self.visit(default, frame) + self.mark_parameter_stored(ref) self.outdent() + self.pop_parameter_definitions() self.blockvisit(node.body, frame) self.return_buffer_contents(frame, force_unescaped=True) @@ -554,10 +561,72 @@ class CodeGenerator(NodeVisitor): in iteritems(frame.symbols.dump_stores())) def write_commons(self): + """Writes a common preamble that is used by root and block functions. + Primarily this sets up common local helpers and enforces a generator + through a dead branch. + """ self.writeline('resolve = context.resolve_or_missing') self.writeline('undefined = environment.undefined') self.writeline('if 0: yield None') + def push_parameter_definitions(self, frame): + """Pushes all parameter targets from the given frame into a local + stack that permits tracking of yet to be assigned parameters. In + particular this enables the optimization from `visit_Name` to skip + undefined expressions for parameters in macros as macros can reference + otherwise unbound parameters. + """ + self._param_def_block.append(frame.symbols.dump_param_targets()) + + def pop_parameter_definitions(self): + """Pops the current parameter definitions set.""" + self._param_def_block.pop() + + def mark_parameter_stored(self, target): + """Marks a parameter in the current parameter definitions as stored. + This will skip the enforced undefined checks. + """ + if self._param_def_block: + self._param_def_block[-1].discard(target) + + def parameter_is_undeclared(self, target): + """Checks if a given target is an undeclared parameter.""" + if not self._param_def_block: + return True + return target in self._param_def_block[-1] + + def push_assign_tracking(self): + """Pushes a new layer for assignment tracking.""" + self._assign_stack.append(set()) + + def pop_assign_tracking(self, frame): + """Pops the topmost level for assignment tracking and updates the + context variables if necessary. + """ + vars = self._assign_stack.pop() + if not frame.toplevel or not vars: + return + public_names = [x for x in vars if x[:1] != '_'] + if len(vars) == 1: + name = next(iter(vars)) + ref = frame.symbols.ref(name) + self.writeline('context.vars[%r] = %s' % (name, ref)) + else: + self.writeline('context.vars.update({') + for idx, name in enumerate(vars): + if idx: + self.write(', ') + ref = frame.symbols.ref(name) + self.write('%r: %s' % (name, ref)) + self.write('})') + if public_names: + if len(public_names) == 1: + self.writeline('context.exported_vars.add(%r)' % + public_names[0]) + else: + self.writeline('context.exported_vars.update((%s))' % + ', '.join(imap(repr, public_names))) + # -- Statement Visitors def visit_Template(self, node, frame=None): @@ -1207,34 +1276,6 @@ class CodeGenerator(NodeVisitor): if outdent_later: self.outdent() - def push_assign_tracking(self): - self._assign_stack.append(set()) - - def pop_assign_tracking(self, frame): - vars = self._assign_stack.pop() - if not frame.toplevel or not vars: - return - public_names = [x for x in vars if x[:1] != '_'] - if len(vars) == 1: - name = next(iter(vars)) - ref = frame.symbols.ref(name) - self.writeline('context.vars[%r] = %s' % (name, ref)) - else: - self.writeline('context.vars.update({') - for idx, name in enumerate(vars): - if idx: - self.write(', ') - ref = frame.symbols.ref(name) - self.write('%r: %s' % (name, ref)) - self.write('})') - if public_names: - if len(public_names) == 1: - self.writeline('context.exported_vars.add(%r)' % - public_names[0]) - else: - self.writeline('context.exported_vars.update((%s))' % - ', '.join(imap(repr, public_names))) - def visit_Assign(self, node, frame): self.push_assign_tracking() self.newline(node) @@ -1273,7 +1314,8 @@ class CodeGenerator(NodeVisitor): # instruction indicates a parameter which are always defined. if node.ctx == 'load': load = frame.symbols.find_load(ref) - if load is None or load[0] != VAR_LOAD_PARAMETER: + if not (load is not None and load[0] == VAR_LOAD_PARAMETER and \ + not self.parameter_is_undeclared(ref)): self.write('(undefined(name=%r) if %s is missing else %s)' % (node.name, ref, ref)) return diff --git a/jinja2/idtracking.py b/jinja2/idtracking.py index f99c568d..87ef5107 100644 --- a/jinja2/idtracking.py +++ b/jinja2/idtracking.py @@ -123,6 +123,16 @@ class Symbols(object): node = node.parent return rv + def dump_param_targets(self): + rv = set() + node = self + while node is not None: + for target, (instr, _) in iteritems(self.loads): + if instr == VAR_LOAD_PARAMETER: + rv.add(target) + node = node.parent + return rv + class RootVisitor(NodeVisitor): diff --git a/jinja2/runtime.py b/jinja2/runtime.py index 4ee47ee6..43f25063 100644 --- a/jinja2/runtime.py +++ b/jinja2/runtime.py @@ -66,7 +66,9 @@ def new_context(environment, template_name, blocks, vars=None, # we don't want to modify the dict passed if shared: parent = dict(parent) - parent.update(locals or ()) + for key, value in iteritems(locals): + if value is not missing: + parent[key] = value return environment.context_class(environment, parent, template_name, blocks) diff --git a/tests/test_core_tags.py b/tests/test_core_tags.py index 058bf85e..7d49d8a9 100644 --- a/tests/test_core_tags.py +++ b/tests/test_core_tags.py @@ -202,7 +202,7 @@ class TestForLoop(object): @pytest.mark.core_tags @pytest.mark.if_condition -class TestIfCondition(): +class TestIfCondition(object): def test_simple(self, env): tmpl = env.from_string('''{% if true %}...{% endif %}''') @@ -237,7 +237,7 @@ class TestIfCondition(): @pytest.mark.core_tags @pytest.mark.macros -class TestMacros(): +class TestMacros(object): def test_simple(self, env_trim): tmpl = env_trim.from_string('''\ {% macro say_hello(name) %}Hello {{ name }}!{% endmacro %} @@ -324,10 +324,20 @@ class TestMacros(): '{{ foo(5) }}') assert tmpl.render() == '5|4|3|2|1' + def test_macro_defaults_self_ref(self, env): + tmpl = env.from_string(''' + {%- set x = 42 %} + {%- macro m(a, b=x, x=23) %}{{ a }}|{{ b }}|{{ x }}{% endmacro -%} + ''') + assert tmpl.module.m(1) == '1||23' + assert tmpl.module.m(1, 2) == '1|2|23' + assert tmpl.module.m(1, 2, 3) == '1|2|3' + assert tmpl.module.m(1, x=7) == '1|7|7' + @pytest.mark.core_tags @pytest.mark.set -class TestSet(): +class TestSet(object): def test_normal(self, env_trim): tmpl = env_trim.from_string('{% set foo = 1 %}{{ foo }}')