]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Reset variables for scoping
authorArmin Ronacher <armin.ronacher@active-4.com>
Mon, 2 Jan 2017 12:52:37 +0000 (13:52 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Tue, 3 Jan 2017 22:45:29 +0000 (23:45 +0100)
jinja2/compiler.py
jinja2/idtracking.py

index f5c1641b44ac8735eb7137f065f57713cf630f03..c11de4e6e5ba8f96e247a0b2f85a9760a62e3d6a 100644 (file)
@@ -97,6 +97,15 @@ def find_undeclared(nodes, names):
     return visitor.undeclared
 
 
+class MacroRef(object):
+
+    def __init__(self, node):
+        self.node = node
+        self.accesses_caller = False
+        self.accesses_kwargs = False
+        self.accesses_varargs = False
+
+
 class Frame(object):
     """Holds compile time information for us."""
 
@@ -435,6 +444,7 @@ class CodeGenerator(NodeVisitor):
                                (mapping[name], dependency, name))
 
     def enter_frame(self, frame):
+        undefs = []
         for target, (action, param) in iteritems(frame.symbols.loads):
             if action == VAR_LOAD_PARAMETER:
                 pass
@@ -444,9 +454,19 @@ class CodeGenerator(NodeVisitor):
             elif action == VAR_LOAD_ALIAS:
                 self.writeline('%s = %s' % (target, param))
             elif action == VAR_LOAD_UNDEFINED:
-                self.writeline('%s = missing' % target)
+                undefs.append(target)
             else:
                 raise NotImplementedError('unknown load instruction')
+        if undefs:
+            self.writeline('%s = missing' % ' = '.join(undefs))
+
+    def leave_frame(self, frame, with_python_scope=False):
+        if not with_python_scope:
+            undefs = []
+            for target, _ in iteritems(frame.symbols.loads):
+                undefs.append(target)
+            if undefs:
+                self.writeline('%s = missing' % ' = '.join(undefs))
 
     def func(self, name):
         if self.environment.is_async:
@@ -455,22 +475,44 @@ class CodeGenerator(NodeVisitor):
 
     def macro_body(self, node, frame, children=None):
         """Dump the function def of a macro or call block."""
-        frame = self.function_scoping(node, frame, children)
+        if children is None:
+            children = list(node.iter_child_nodes())
+        children = list(children)
+        frame = frame.inner()
+        macro_ref = MacroRef(node)
+
+        args = []
+        for arg in node.args:
+            args.append(frame.symbols.ref(arg))
+
+        undeclared = find_undeclared(children, ('caller', 'kwargs', 'varargs'))
+        if 'caller' in undeclared:
+            args.append(frame.symbols.declare_parameter('caller'))
+            macro_ref.accesses_caller = True
+        if 'kwargs' in undeclared:
+            args.append(frame.symbols.declare_parameter('kwargs'))
+            macro_ref.accesses_kwargs = True
+        if 'varargs' in undeclared:
+            args.append(frame.symbols.declare_parameter('varargs'))
+            macro_ref.accesses_varargs = True
+
         # macros are delayed, they never require output checks
         frame.require_output_check = False
-        args = frame.arguments
+        frame.symbols.analyze_node(node)
         self.writeline('%s(%s):' % (self.func('macro'), ', '.join(args)), node)
         self.indent()
         self.buffer(frame)
-        self.pull_locals(frame)
+        self.enter_frame(frame)
         self.blockvisit(node.body, frame)
         self.return_buffer_contents(frame)
+        self.leave_frame(frame, with_python_scope=True)
         self.outdent()
-        return frame
 
-    def macro_def(self, node, frame):
+        return frame, macro_ref
+
+    def macro_def(self, macro_ref, frame):
         """Dump the macro definition for the def created by macro_body."""
-        arg_tuple = ', '.join(repr(x.name) for x in node.args)
+        arg_tuple = ', '.join(repr(x.name) for x in macro_ref.node.args)
         name = getattr(node, 'name', None)
         if len(node.args) == 1:
             arg_tuple += ','
@@ -480,9 +522,9 @@ class CodeGenerator(NodeVisitor):
             self.visit(arg, frame)
             self.write(', ')
         self.write('), %r, %r, %r)' % (
-            bool(frame.accesses_kwargs),
-            bool(frame.accesses_varargs),
-            bool(frame.accesses_caller)
+            bool(macro_ref.accesses_kwargs),
+            bool(macro_ref.accesses_varargs),
+            bool(macro_ref.accesses_caller)
         ))
 
     def position(self, node):
@@ -557,6 +599,7 @@ class CodeGenerator(NodeVisitor):
         self.enter_frame(frame)
         self.pull_dependencies(node.body)
         self.blockvisit(node.body, frame)
+        self.leave_frame(frame)
         self.outdent()
 
         # make sure that the parent root is called.
@@ -593,6 +636,7 @@ class CodeGenerator(NodeVisitor):
             self.enter_frame(block_frame)
             self.pull_dependencies(block.body)
             self.blockvisit(block.body, block_frame)
+            self.leave_frame(block_frame, with_python_scope=True)
             self.outdent()
 
         self.writeline('blocks = {%s}' % ', '.join('%r: block_%s' % (x, x)
@@ -799,7 +843,7 @@ class CodeGenerator(NodeVisitor):
                                'update((%s))' % ', '.join(imap(repr, discarded_names)))
 
     def visit_For(self, node, frame):
-        loop_frame = Frame(frame.eval_ctx, frame)
+        loop_frame = frame.inner()
 
         # try to figure out if we have an extended loop.  An extended loop
         # is necessary if the loop is in recursive mode if the special loop
@@ -916,6 +960,8 @@ class CodeGenerator(NodeVisitor):
             self.blockvisit(node.else_, loop_frame)
             self.outdent()
 
+        self.leave_frame(loop_frame)
+
         # if the node was recursive we have to return the buffer contents
         # and start the iteration code
         if node.recursive:
@@ -948,38 +994,36 @@ class CodeGenerator(NodeVisitor):
             self.outdent()
 
     def visit_Macro(self, node, frame):
-        macro_frame = self.macro_body(node, frame)
+        macro_frame, macro_ref = self.macro_body(node, frame)
         self.newline()
         if frame.toplevel:
             if not node.name.startswith('_'):
                 self.write('context.exported_vars.add(%r)' % node.name)
-            ref = frame.symbols.find_ref(node.name)
-            assert ref is not None, 'unknown reference for macro'
+            ref = frame.symbols.ref(node.name)
             self.writeline('context.vars[%r] = ' % ref)
         self.write('l_%s = ' % node.name)
-        self.macro_def(node, macro_frame)
+        self.macro_def(macro_ref)
         frame.assigned_names.add(node.name)
 
     def visit_CallBlock(self, node, frame):
         children = node.iter_child_nodes(exclude=('call',))
-        call_frame = self.macro_body(node, frame, children)
+        call_frame, macro_ref = self.macro_body(node, frame, children)
         self.writeline('caller = ')
-        self.macro_def(node, call_frame)
+        self.macro_def(macro_ref)
         self.start_write(frame, node)
         self.visit_Call(node.call, call_frame, forward_caller=True)
         self.end_write(frame)
 
     def visit_FilterBlock(self, node, frame):
         filter_frame = frame.inner()
-        filter_frame.inspect(node.iter_child_nodes())
-        aliases = self.push_scope(filter_frame)
-        self.pull_locals(filter_frame)
+        filter_frame.symbols.analyze_node(node)
+        self.enter_frame(filter_frame)
         self.buffer(filter_frame)
         self.blockvisit(node.body, filter_frame)
         self.start_write(frame, node)
         self.visit_Filter(node.filter, filter_frame)
         self.end_write(frame)
-        self.pop_scope(aliases, filter_frame)
+        self.leave_frame(filter_frame)
 
     def visit_ExprStmt(self, node, frame):
         self.newline(node)
@@ -1148,16 +1192,14 @@ class CodeGenerator(NodeVisitor):
                         if not x.startswith('_')]
         if len(frame.toplevel_assignments) == 1:
             name = next(iter(frame.toplevel_assignments))
-            ref = frame.symbols.find_ref(name)
-            assert ref is not None, 'missing ref in export'
+            ref = frame.symbols.ref(name)
             self.writeline('context.vars[%r] = %s' % (name, ref))
         else:
             self.writeline('context.vars.update({')
             for idx, name in enumerate(assignment_frame.toplevel_assignments):
                 if idx:
                     self.write(', ')
-                ref = frame.symbols.find_ref(name)
-                assert ref is not None, 'missing ref in export'
+                ref = frame.symbols.ref(name)
                 self.write('%r: %s' % (name, ref))
             self.write('})')
         if public_names:
@@ -1198,8 +1240,7 @@ class CodeGenerator(NodeVisitor):
     def visit_Name(self, node, frame):
         if node.ctx == 'store' and frame.toplevel:
             frame.toplevel_assignments.add(node.name)
-        ref = frame.symbols.find_ref(node.name)
-        assert ref is not None, 'compiler error: undefined ref (%r)' % node.name
+        ref = frame.symbols.ref(node.name)
         if node.ctx == 'load':
             self.write('(environment.undefined(name=%r) if %s is missing else %s)' %
                        (node.name, ref, ref))
index be96707f96fa2f60ffa4844afc70df1e09b39940..32fb6e978f7233287cccabd049c34f75ac3cd6ad 100644 (file)
@@ -51,6 +51,13 @@ class Symbols(object):
         if self.parent is not None:
             return self.parent.find_ref(name)
 
+    def ref(self, name):
+        rv = self.find_ref(name)
+        if rv is None:
+            raise AssertionError('Tried to resolve a name to a reference that '
+                                 'was unknown to the frame (%r)' % name)
+        return rv
+
     def copy(self):
         rv = object.__new__(self.__class__)
         rv.__dict__.update(self.__dict__)