--- /dev/null
+from jinja2.visitor import NodeVisitor
+from jinja2._compat import iteritems
+
+
+VAR_LOAD_PARAMETER = 'param'
+VAR_LOAD_RESOLVE = 'resolve'
+VAR_LOAD_ALIAS = 'alias'
+VAR_LOAD_UNDEFINED = 'undefined'
+
+
+def find_symbols(nodes, parent_symbols=None):
+ sym = Symbols(parent=parent_symbols)
+ visitor = FrameSymbolVisitor(sym)
+ for node in nodes:
+ visitor.visit(node)
+ return sym
+
+
+def symbols_for_node(node, parent_symbols=None):
+ sym = Symbols(parent=parent_symbols)
+ sym.analyze_node(node)
+ return sym
+
+
+class Symbols(object):
+
+ def __init__(self, parent=None):
+ if parent is None:
+ self.level = 0
+ else:
+ self.level = parent.level + 1
+ self.parent = parent
+ self.refs = {}
+ self.loads = {}
+ self.stores = set()
+
+ def analyze_node(self, node):
+ visitor = RootVisitor(self)
+ visitor.visit(node)
+
+ def _define_ref(self, name, load=None):
+ ident = 'l_%d_%s' % (self.level, name)
+ self.refs[name] = ident
+ if load is not None:
+ self.loads[ident] = load
+ return ident
+
+ def find_ref(self, name):
+ if name in self.refs:
+ return self.refs[name]
+ if self.parent is not None:
+ return self.parent.find_ref(name)
+
+ def copy(self):
+ rv = object.__new__(self.__class__)
+ rv.__dict__.update(self.__dict__)
+ rv.refs = self.refs.copy()
+ rv.loads = self.loads.copy()
+ rv.stores = self.stores.copy()
+ return rv
+
+ def store(self, name):
+ # We already have that name locally, so we can just bail
+ if name not in self.refs:
+ self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
+ self.stores.add(name)
+
+ def declare_parameter(self, name):
+ self.stores.add(name)
+ return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
+
+ def load(self, name):
+ target = self.find_ref(name)
+ if target is None:
+ self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
+
+ def branch_update(self, branch_symbols):
+ stores = {}
+ for branch in branch_symbols:
+ for target in branch.stores:
+ if target in self.stores:
+ continue
+ stores[target] = stores.get(target, 0) + 1
+
+ for sym in branch_symbols:
+ self.refs.update(sym.refs)
+ self.loads.update(sym.loads)
+ self.stores.update(sym.stores)
+
+ for name, branch_count in iteritems(stores):
+ if branch_count == len(branch_symbols):
+ continue
+ target = self.find_ref(name)
+ assert target is not None, 'should not happen'
+
+ if self.parent is not None:
+ outer_target = self.parent.find_ref(name)
+ if outer_target is not None:
+ self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
+ continue
+ self.loads[target] = (VAR_LOAD_RESOLVE, name)
+
+ def dump_stores(self):
+ rv = {}
+ node = self
+ while node is not None:
+ for name in node.stores:
+ if name not in rv:
+ rv[name] = self.find_ref(name)
+ node = node.parent
+ return rv
+
+
+class RootVisitor(NodeVisitor):
+
+ def __init__(self, symbols):
+ self.sym_visitor = FrameSymbolVisitor(symbols)
+
+ def _simple_visit(self, node):
+ for child in node.iter_child_nodes():
+ self.sym_visitor.visit(child)
+
+ visit_Template = visit_Block = visit_Macro = visit_FilterBlock = \
+ visit_Scope = visit_If = visit_ScopedEvalContextModifier = \
+ _simple_visit
+
+ def visit_AssignBlock(self, node):
+ for child in self.body:
+ self.sym_visitor.visit(child)
+
+ def visit_CallBlock(self, node):
+ for child in node.iter_child_nodes(exclude=('call',)):
+ self.sym_visitor.visit(child)
+
+ def visit_For(self, node):
+ self.sym_visitor.visit(node.target, store_as_param=True)
+ for child in node.iter_child_nodes(exclude=('iter', 'target')):
+ self.sym_visitor.visit(child)
+
+ def generic_visit(self, node, *args, **kwargs):
+ raise NotImplementedError('Cannot find symbols for %r' %
+ node.__class__.__name__)
+
+
+class FrameSymbolVisitor(NodeVisitor):
+ """A visitor for `Frame.inspect`."""
+
+ def __init__(self, symbols):
+ self.symbols = symbols
+
+ def visit_Name(self, node, store_as_param=False, **kwargs):
+ """All assignments to names go through this function."""
+ if store_as_param or node.ctx == 'param':
+ self.symbols.declare_parameter(node.name)
+ elif node.ctx == 'store':
+ self.symbols.store(node.name)
+ elif node.ctx == 'load':
+ self.symbols.load(node.name)
+
+ def visit_If(self, node, **kwargs):
+ self.visit(node.test, **kwargs)
+
+ original_symbols = self.symbols
+
+ def inner_visit(nodes):
+ self.symbols = rv = original_symbols.copy()
+ for subnode in nodes:
+ self.visit(subnode, **kwargs)
+ self.symbols = original_symbols
+ return rv
+
+ body_symbols = inner_visit(node.body)
+ else_symbols = inner_visit(node.else_ or ())
+
+ self.symbols.branch_update([body_symbols, else_symbols])
+
+ def visit_Macro(self, node, **kwargs):
+ self.symbols.store(node.name)
+
+ def visit_Import(self, node, **kwargs):
+ self.generic_visit(node, **kwargs)
+ self.symbols.store(node.target)
+
+ def visit_FromImport(self, node, **kwargs):
+ self.generic_visit(node, **kwargs)
+ for name in node.names:
+ if isinstance(name, tuple):
+ self.symbols.store(name[1])
+ else:
+ self.symbols.store(name)
+
+ def visit_Assign(self, node, **kwargs):
+ """Visit assignments in the correct order."""
+ self.visit(node.node, **kwargs)
+ self.visit(node.target, **kwargs)
+
+ def visit_For(self, node, **kwargs):
+ """Visiting stops at for blocks. However the block sequence
+ is visited as part of the outer scope.
+ """
+ self.visit(node.iter, **kwargs)
+
+ def visit_CallBlock(self, node, **kwargs):
+ self.visit(node.call, **kwargs)
+
+ def visit_FilterBlock(self, node, **kwargs):
+ self.visit(node.filter, **kwargs)
+
+ def visit_AssignBlock(self, node, **kwargs):
+ """Stop visiting at block assigns."""
+
+ def visit_Scope(self, node, **kwargs):
+ """Stop visiting at scopes."""
+
+ def visit_Block(self, node, **kwargs):
+ """Stop visiting at blocks."""
--- /dev/null
+import pytest
+
+from jinja2 import nodes
+from jinja2.idtracking import symbols_for_node
+
+
+def test_basics():
+ for_loop = nodes.For(
+ nodes.Name('foo', 'store'),
+ nodes.Name('seq', 'load'),
+ [nodes.Output([nodes.Name('foo', 'load')])],
+ [], None, False)
+ tmpl = nodes.Template([
+ nodes.Assign(
+ nodes.Name('foo', 'store'),
+ nodes.Name('bar', 'load')),
+ for_loop])
+
+ sym = symbols_for_node(tmpl)
+ assert sym.refs == {
+ 'foo': 'l_0_foo',
+ 'bar': 'l_0_bar',
+ 'seq': 'l_0_seq',
+ }
+ assert sym.loads == {
+ 'l_0_foo': ('undefined', None),
+ 'l_0_bar': ('resolve', 'bar'),
+ 'l_0_seq': ('resolve', 'seq'),
+ }
+
+ sym = symbols_for_node(for_loop, sym)
+ assert sym.refs == {
+ 'foo': 'l_1_foo',
+ }
+ assert sym.loads == {
+ 'l_1_foo': ('undefined', None),
+ }
+
+
+def test_complex():
+ title_block = nodes.Block('title', [
+ nodes.Output([nodes.TemplateData(u'Page Title')])
+ ], False)
+
+ render_title_macro = nodes.Macro('render_title', [nodes.Name('title', 'param')], [], [
+ nodes.Output([
+ nodes.TemplateData(u'\n <div class="title">\n <h1>'),
+ nodes.Name('title', 'load'),
+ nodes.TemplateData(u'</h1>\n <p>'),
+ nodes.Name('subtitle', 'load'),
+ nodes.TemplateData(u'</p>\n ')]),
+ nodes.Assign(
+ nodes.Name('subtitle', 'store'), nodes.Const('something else')),
+ nodes.Output([
+ nodes.TemplateData(u'\n <p>'),
+ nodes.Name('subtitle', 'load'),
+ nodes.TemplateData(u'</p>\n </div>\n'),
+ nodes.If(
+ nodes.Name('something', 'load'), [
+ nodes.Assign(nodes.Name('title_upper', 'store'),
+ nodes.Filter(nodes.Name('title', 'load'),
+ 'upper', [], [], None, None)),
+ nodes.Output([
+ nodes.Name('title_upper', 'load'),
+ nodes.Call(nodes.Name('render_title', 'load'), [
+ nodes.Const('Aha')], [], None, None)])], [])])])
+
+ for_loop = nodes.For(
+ nodes.Name('item', 'store'),
+ nodes.Name('seq', 'load'), [
+ nodes.Output([
+ nodes.TemplateData(u'\n <li>'),
+ nodes.Name('item', 'load'),
+ nodes.TemplateData(u'</li>\n <span>')]),
+ nodes.Include(nodes.Const('helper.html'), True, False),
+ nodes.Output([
+ nodes.TemplateData(u'</span>\n ')])], [], None, False)
+
+ body_block = nodes.Block('body', [
+ nodes.Output([
+ nodes.TemplateData(u'\n '),
+ nodes.Call(nodes.Name('render_title', 'load'), [
+ nodes.Name('item', 'load')], [], None, None),
+ nodes.TemplateData(u'\n <ul>\n ')]),
+ for_loop,
+ nodes.Output([nodes.TemplateData(u'\n </ul>\n')])],
+ False)
+
+ tmpl = nodes.Template([
+ nodes.Extends(nodes.Const('layout.html')),
+ title_block,
+ render_title_macro,
+ body_block,
+ ])
+
+ tmpl_sym = symbols_for_node(tmpl)
+ assert tmpl_sym.refs == {
+ 'render_title': 'l_0_render_title',
+ }
+ assert tmpl_sym.loads == {
+ 'l_0_render_title': ('undefined', None),
+ }
+ assert tmpl_sym.stores == set(['render_title'])
+ assert tmpl_sym.dump_stores() == {
+ 'render_title': 'l_0_render_title',
+ }
+
+ macro_sym = symbols_for_node(render_title_macro, tmpl_sym)
+ assert macro_sym.refs == {
+ 'subtitle': 'l_1_subtitle',
+ 'something': 'l_1_something',
+ 'title': 'l_1_title',
+ 'title_upper': 'l_1_title_upper',
+ }
+ assert macro_sym.loads == {
+ 'l_1_subtitle': ('resolve', 'subtitle'),
+ 'l_1_something': ('resolve','something'),
+ 'l_1_title': ('param', None),
+ 'l_1_title_upper': ('resolve', 'title_upper'),
+ }
+ assert macro_sym.stores == set(['title', 'title_upper', 'subtitle'])
+ assert macro_sym.find_ref('render_title') == 'l_0_render_title'
+ assert macro_sym.dump_stores() == {
+ 'title': 'l_1_title',
+ 'title_upper': 'l_1_title_upper',
+ 'subtitle': 'l_1_subtitle',
+ 'render_title': 'l_0_render_title',
+ }
+
+ body_sym = symbols_for_node(body_block)
+ assert body_sym.refs == {
+ 'item': 'l_0_item',
+ 'seq': 'l_0_seq',
+ 'render_title': 'l_0_render_title',
+ }
+ assert body_sym.loads == {
+ 'l_0_item': ('resolve', 'item'),
+ 'l_0_seq': ('resolve', 'seq'),
+ 'l_0_render_title': ('resolve', 'render_title'),
+ }
+ assert body_sym.stores == set([])
+
+ for_sym = symbols_for_node(for_loop, body_sym)
+ assert for_sym.refs == {
+ 'item': 'l_1_item',
+ }
+ assert for_sym.loads == {
+ 'l_1_item': ('undefined', None),
+ }
+ assert for_sym.stores == set(['item'])
+ assert for_sym.dump_stores() == {
+ 'item': 'l_1_item',
+ }
+
+
+def test_if_branching_stores():
+ tmpl = nodes.Template([
+ nodes.If(nodes.Name('expression', 'load'), [
+ nodes.Assign(nodes.Name('variable', 'store'),
+ nodes.Const(42))], [])])
+
+ sym = symbols_for_node(tmpl)
+ assert sym.refs == {
+ 'variable': 'l_0_variable',
+ 'expression': 'l_0_expression'
+ }
+ assert sym.stores == set(['variable'])
+ assert sym.loads == {
+ 'l_0_variable': ('resolve', 'variable'),
+ 'l_0_expression': ('resolve', 'expression')
+ }
+ assert sym.dump_stores() == {
+ 'variable': 'l_0_variable',
+ }
+
+
+def test_if_branching_stores_undefined():
+ tmpl = nodes.Template([
+ nodes.Assign(nodes.Name('variable', 'store'), nodes.Const(23)),
+ nodes.If(nodes.Name('expression', 'load'), [
+ nodes.Assign(nodes.Name('variable', 'store'),
+ nodes.Const(42))], [])])
+
+ sym = symbols_for_node(tmpl)
+ assert sym.refs == {
+ 'variable': 'l_0_variable',
+ 'expression': 'l_0_expression'
+ }
+ assert sym.stores == set(['variable'])
+ assert sym.loads == {
+ 'l_0_variable': ('undefined', None),
+ 'l_0_expression': ('resolve', 'expression')
+ }
+ assert sym.dump_stores() == {
+ 'variable': 'l_0_variable',
+ }
+
+
+def test_if_branching_multi_scope():
+ for_loop = nodes.For(nodes.Name('item', 'store'), nodes.Name('seq', 'load'), [
+ nodes.If(nodes.Name('expression', 'load'), [
+ nodes.Assign(nodes.Name('x', 'store'), nodes.Const(42))], []),
+ nodes.Include(nodes.Const('helper.html'), True, False)
+ ], [], None, False)
+
+ tmpl = nodes.Template([
+ nodes.Assign(nodes.Name('x', 'store'), nodes.Const(23)),
+ for_loop
+ ])
+
+ tmpl_sym = symbols_for_node(tmpl)
+ for_sym = symbols_for_node(for_loop, tmpl_sym)
+ assert for_sym.stores == set(['item', 'x'])
+ assert for_sym.loads == {
+ 'l_1_x': ('alias', 'l_0_x'),
+ 'l_1_item': ('undefined', None),
+ 'l_1_expression': ('resolve', 'expression'),
+ }