]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Added new-style id tracking code
authorArmin Ronacher <armin.ronacher@active-4.com>
Mon, 2 Jan 2017 11:09:45 +0000 (12:09 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Tue, 3 Jan 2017 22:45:29 +0000 (23:45 +0100)
jinja2/idtracking.py [new file with mode: 0644]
tests/test_idtracking.py [new file with mode: 0644]

diff --git a/jinja2/idtracking.py b/jinja2/idtracking.py
new file mode 100644 (file)
index 0000000..be96707
--- /dev/null
@@ -0,0 +1,216 @@
+from jinja2.visitor import NodeVisitor
+from jinja2._compat import iteritems
+
+
+VAR_LOAD_PARAMETER = 'param'
+VAR_LOAD_RESOLVE = 'resolve'
+VAR_LOAD_ALIAS = 'alias'
+VAR_LOAD_UNDEFINED = 'undefined'
+
+
+def find_symbols(nodes, parent_symbols=None):
+    sym = Symbols(parent=parent_symbols)
+    visitor = FrameSymbolVisitor(sym)
+    for node in nodes:
+        visitor.visit(node)
+    return sym
+
+
+def symbols_for_node(node, parent_symbols=None):
+    sym = Symbols(parent=parent_symbols)
+    sym.analyze_node(node)
+    return sym
+
+
+class Symbols(object):
+
+    def __init__(self, parent=None):
+        if parent is None:
+            self.level = 0
+        else:
+            self.level = parent.level + 1
+        self.parent = parent
+        self.refs = {}
+        self.loads = {}
+        self.stores = set()
+
+    def analyze_node(self, node):
+        visitor = RootVisitor(self)
+        visitor.visit(node)
+
+    def _define_ref(self, name, load=None):
+        ident = 'l_%d_%s' % (self.level, name)
+        self.refs[name] = ident
+        if load is not None:
+            self.loads[ident] = load
+        return ident
+
+    def find_ref(self, name):
+        if name in self.refs:
+            return self.refs[name]
+        if self.parent is not None:
+            return self.parent.find_ref(name)
+
+    def copy(self):
+        rv = object.__new__(self.__class__)
+        rv.__dict__.update(self.__dict__)
+        rv.refs = self.refs.copy()
+        rv.loads = self.loads.copy()
+        rv.stores = self.stores.copy()
+        return rv
+
+    def store(self, name):
+        # We already have that name locally, so we can just bail
+        if name not in self.refs:
+            self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
+        self.stores.add(name)
+
+    def declare_parameter(self, name):
+        self.stores.add(name)
+        return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
+
+    def load(self, name):
+        target = self.find_ref(name)
+        if target is None:
+            self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
+
+    def branch_update(self, branch_symbols):
+        stores = {}
+        for branch in branch_symbols:
+            for target in branch.stores:
+                if target in self.stores:
+                    continue
+                stores[target] = stores.get(target, 0) + 1
+
+        for sym in branch_symbols:
+            self.refs.update(sym.refs)
+            self.loads.update(sym.loads)
+            self.stores.update(sym.stores)
+
+        for name, branch_count in iteritems(stores):
+            if branch_count == len(branch_symbols):
+                continue
+            target = self.find_ref(name)
+            assert target is not None, 'should not happen'
+
+            if self.parent is not None:
+                outer_target = self.parent.find_ref(name)
+                if outer_target is not None:
+                    self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
+                    continue
+            self.loads[target] = (VAR_LOAD_RESOLVE, name)
+
+    def dump_stores(self):
+        rv = {}
+        node = self
+        while node is not None:
+            for name in node.stores:
+                if name not in rv:
+                    rv[name] = self.find_ref(name)
+            node = node.parent
+        return rv
+
+
+class RootVisitor(NodeVisitor):
+
+    def __init__(self, symbols):
+        self.sym_visitor = FrameSymbolVisitor(symbols)
+
+    def _simple_visit(self, node):
+        for child in node.iter_child_nodes():
+            self.sym_visitor.visit(child)
+
+    visit_Template = visit_Block = visit_Macro = visit_FilterBlock = \
+        visit_Scope = visit_If = visit_ScopedEvalContextModifier = \
+        _simple_visit
+
+    def visit_AssignBlock(self, node):
+        for child in self.body:
+            self.sym_visitor.visit(child)
+
+    def visit_CallBlock(self, node):
+        for child in node.iter_child_nodes(exclude=('call',)):
+            self.sym_visitor.visit(child)
+
+    def visit_For(self, node):
+        self.sym_visitor.visit(node.target, store_as_param=True)
+        for child in node.iter_child_nodes(exclude=('iter', 'target')):
+            self.sym_visitor.visit(child)
+
+    def generic_visit(self, node, *args, **kwargs):
+        raise NotImplementedError('Cannot find symbols for %r' %
+                                  node.__class__.__name__)
+
+
+class FrameSymbolVisitor(NodeVisitor):
+    """A visitor for `Frame.inspect`."""
+
+    def __init__(self, symbols):
+        self.symbols = symbols
+
+    def visit_Name(self, node, store_as_param=False, **kwargs):
+        """All assignments to names go through this function."""
+        if store_as_param or node.ctx == 'param':
+            self.symbols.declare_parameter(node.name)
+        elif node.ctx == 'store':
+            self.symbols.store(node.name)
+        elif node.ctx == 'load':
+            self.symbols.load(node.name)
+
+    def visit_If(self, node, **kwargs):
+        self.visit(node.test, **kwargs)
+
+        original_symbols = self.symbols
+
+        def inner_visit(nodes):
+            self.symbols = rv = original_symbols.copy()
+            for subnode in nodes:
+                self.visit(subnode, **kwargs)
+            self.symbols = original_symbols
+            return rv
+
+        body_symbols = inner_visit(node.body)
+        else_symbols = inner_visit(node.else_ or ())
+
+        self.symbols.branch_update([body_symbols, else_symbols])
+
+    def visit_Macro(self, node, **kwargs):
+        self.symbols.store(node.name)
+
+    def visit_Import(self, node, **kwargs):
+        self.generic_visit(node, **kwargs)
+        self.symbols.store(node.target)
+
+    def visit_FromImport(self, node, **kwargs):
+        self.generic_visit(node, **kwargs)
+        for name in node.names:
+            if isinstance(name, tuple):
+                self.symbols.store(name[1])
+            else:
+                self.symbols.store(name)
+
+    def visit_Assign(self, node, **kwargs):
+        """Visit assignments in the correct order."""
+        self.visit(node.node, **kwargs)
+        self.visit(node.target, **kwargs)
+
+    def visit_For(self, node, **kwargs):
+        """Visiting stops at for blocks.  However the block sequence
+        is visited as part of the outer scope.
+        """
+        self.visit(node.iter, **kwargs)
+
+    def visit_CallBlock(self, node, **kwargs):
+        self.visit(node.call, **kwargs)
+
+    def visit_FilterBlock(self, node, **kwargs):
+        self.visit(node.filter, **kwargs)
+
+    def visit_AssignBlock(self, node, **kwargs):
+        """Stop visiting at block assigns."""
+
+    def visit_Scope(self, node, **kwargs):
+        """Stop visiting at scopes."""
+
+    def visit_Block(self, node, **kwargs):
+        """Stop visiting at blocks."""
diff --git a/tests/test_idtracking.py b/tests/test_idtracking.py
new file mode 100644 (file)
index 0000000..758b4a8
--- /dev/null
@@ -0,0 +1,218 @@
+import pytest
+
+from jinja2 import nodes
+from jinja2.idtracking import symbols_for_node
+
+
+def test_basics():
+    for_loop = nodes.For(
+        nodes.Name('foo', 'store'),
+        nodes.Name('seq', 'load'),
+        [nodes.Output([nodes.Name('foo', 'load')])],
+        [], None, False)
+    tmpl = nodes.Template([
+        nodes.Assign(
+            nodes.Name('foo', 'store'),
+            nodes.Name('bar', 'load')),
+        for_loop])
+
+    sym = symbols_for_node(tmpl)
+    assert sym.refs == {
+        'foo': 'l_0_foo',
+        'bar': 'l_0_bar',
+        'seq': 'l_0_seq',
+    }
+    assert sym.loads == {
+        'l_0_foo': ('undefined', None),
+        'l_0_bar': ('resolve', 'bar'),
+        'l_0_seq': ('resolve', 'seq'),
+    }
+
+    sym = symbols_for_node(for_loop, sym)
+    assert sym.refs == {
+        'foo': 'l_1_foo',
+    }
+    assert sym.loads == {
+        'l_1_foo': ('undefined', None),
+    }
+
+
+def test_complex():
+    title_block = nodes.Block('title', [
+        nodes.Output([nodes.TemplateData(u'Page Title')])
+    ], False)
+
+    render_title_macro = nodes.Macro('render_title', [nodes.Name('title', 'param')], [], [
+        nodes.Output([
+            nodes.TemplateData(u'\n  <div class="title">\n    <h1>'),
+            nodes.Name('title', 'load'),
+            nodes.TemplateData(u'</h1>\n    <p>'),
+            nodes.Name('subtitle', 'load'),
+            nodes.TemplateData(u'</p>\n    ')]),
+        nodes.Assign(
+            nodes.Name('subtitle', 'store'), nodes.Const('something else')),
+        nodes.Output([
+            nodes.TemplateData(u'\n    <p>'),
+            nodes.Name('subtitle', 'load'),
+            nodes.TemplateData(u'</p>\n  </div>\n'),
+            nodes.If(
+                nodes.Name('something', 'load'), [
+                    nodes.Assign(nodes.Name('title_upper', 'store'),
+                                 nodes.Filter(nodes.Name('title', 'load'),
+                                              'upper', [], [], None, None)),
+                    nodes.Output([
+                        nodes.Name('title_upper', 'load'),
+                        nodes.Call(nodes.Name('render_title', 'load'), [
+                            nodes.Const('Aha')], [], None, None)])], [])])])
+
+    for_loop = nodes.For(
+        nodes.Name('item', 'store'),
+        nodes.Name('seq', 'load'), [
+            nodes.Output([
+                nodes.TemplateData(u'\n    <li>'),
+                nodes.Name('item', 'load'),
+                nodes.TemplateData(u'</li>\n    <span>')]),
+            nodes.Include(nodes.Const('helper.html'), True, False),
+            nodes.Output([
+                nodes.TemplateData(u'</span>\n  ')])], [], None, False)
+
+    body_block = nodes.Block('body', [
+        nodes.Output([
+            nodes.TemplateData(u'\n  '),
+            nodes.Call(nodes.Name('render_title', 'load'), [
+                nodes.Name('item', 'load')], [], None, None),
+            nodes.TemplateData(u'\n  <ul>\n  ')]),
+        for_loop,
+        nodes.Output([nodes.TemplateData(u'\n  </ul>\n')])],
+        False)
+
+    tmpl = nodes.Template([
+        nodes.Extends(nodes.Const('layout.html')),
+        title_block,
+        render_title_macro,
+        body_block,
+    ])
+
+    tmpl_sym = symbols_for_node(tmpl)
+    assert tmpl_sym.refs == {
+        'render_title': 'l_0_render_title',
+    }
+    assert tmpl_sym.loads == {
+        'l_0_render_title': ('undefined', None),
+    }
+    assert tmpl_sym.stores == set(['render_title'])
+    assert tmpl_sym.dump_stores() == {
+        'render_title': 'l_0_render_title',
+    }
+
+    macro_sym = symbols_for_node(render_title_macro, tmpl_sym)
+    assert macro_sym.refs == {
+        'subtitle': 'l_1_subtitle',
+        'something': 'l_1_something',
+        'title': 'l_1_title',
+        'title_upper': 'l_1_title_upper',
+    }
+    assert macro_sym.loads == {
+        'l_1_subtitle': ('resolve', 'subtitle'),
+        'l_1_something': ('resolve','something'),
+        'l_1_title': ('param', None),
+        'l_1_title_upper': ('resolve', 'title_upper'),
+    }
+    assert macro_sym.stores == set(['title', 'title_upper', 'subtitle'])
+    assert macro_sym.find_ref('render_title') == 'l_0_render_title'
+    assert macro_sym.dump_stores() == {
+        'title': 'l_1_title',
+        'title_upper': 'l_1_title_upper',
+        'subtitle': 'l_1_subtitle',
+        'render_title': 'l_0_render_title',
+    }
+
+    body_sym = symbols_for_node(body_block)
+    assert body_sym.refs == {
+        'item': 'l_0_item',
+        'seq': 'l_0_seq',
+        'render_title': 'l_0_render_title',
+    }
+    assert body_sym.loads == {
+        'l_0_item': ('resolve', 'item'),
+        'l_0_seq': ('resolve', 'seq'),
+        'l_0_render_title': ('resolve', 'render_title'),
+    }
+    assert body_sym.stores == set([])
+
+    for_sym = symbols_for_node(for_loop, body_sym)
+    assert for_sym.refs == {
+        'item': 'l_1_item',
+    }
+    assert for_sym.loads == {
+        'l_1_item': ('undefined', None),
+    }
+    assert for_sym.stores == set(['item'])
+    assert for_sym.dump_stores() == {
+        'item': 'l_1_item',
+    }
+
+
+def test_if_branching_stores():
+    tmpl = nodes.Template([
+        nodes.If(nodes.Name('expression', 'load'), [
+            nodes.Assign(nodes.Name('variable', 'store'),
+                         nodes.Const(42))], [])])
+
+    sym = symbols_for_node(tmpl)
+    assert sym.refs == {
+        'variable': 'l_0_variable',
+        'expression': 'l_0_expression'
+    }
+    assert sym.stores == set(['variable'])
+    assert sym.loads == {
+        'l_0_variable': ('resolve', 'variable'),
+        'l_0_expression': ('resolve', 'expression')
+    }
+    assert sym.dump_stores() == {
+        'variable': 'l_0_variable',
+    }
+
+
+def test_if_branching_stores_undefined():
+    tmpl = nodes.Template([
+        nodes.Assign(nodes.Name('variable', 'store'), nodes.Const(23)),
+        nodes.If(nodes.Name('expression', 'load'), [
+            nodes.Assign(nodes.Name('variable', 'store'),
+                         nodes.Const(42))], [])])
+
+    sym = symbols_for_node(tmpl)
+    assert sym.refs == {
+        'variable': 'l_0_variable',
+        'expression': 'l_0_expression'
+    }
+    assert sym.stores == set(['variable'])
+    assert sym.loads == {
+        'l_0_variable': ('undefined', None),
+        'l_0_expression': ('resolve', 'expression')
+    }
+    assert sym.dump_stores() == {
+        'variable': 'l_0_variable',
+    }
+
+
+def test_if_branching_multi_scope():
+    for_loop = nodes.For(nodes.Name('item', 'store'), nodes.Name('seq', 'load'), [
+        nodes.If(nodes.Name('expression', 'load'), [
+            nodes.Assign(nodes.Name('x', 'store'), nodes.Const(42))], []),
+        nodes.Include(nodes.Const('helper.html'), True, False)
+    ], [], None, False)
+
+    tmpl = nodes.Template([
+        nodes.Assign(nodes.Name('x', 'store'), nodes.Const(23)),
+        for_loop
+    ])
+
+    tmpl_sym = symbols_for_node(tmpl)
+    for_sym = symbols_for_node(for_loop, tmpl_sym)
+    assert for_sym.stores == set(['item', 'x'])
+    assert for_sym.loads == {
+        'l_1_x': ('alias', 'l_0_x'),
+        'l_1_item': ('undefined', None),
+        'l_1_expression': ('resolve', 'expression'),
+    }