def visit_ContextReference(self, node, frame):
self.write('context')
+ def visit_DerivedContextReference(self, node, frame):
+ self.write(self.derive_context(frame))
+
def visit_Continue(self, node, frame):
self.writeline('continue', node)
class ExampleExtension(Extension):
tags = set(['test'])
ext_attr = 42
+ context_reference_node_cls = nodes.ContextReference
def parse(self, parser):
return nodes.Output([self.call_method('_dump', [
nodes.EnvironmentAttribute('sandboxed'),
self.attr('ext_attr'),
nodes.ImportedName(__name__ + '.importable_object'),
- nodes.ContextReference()
+ self.context_reference_node_cls()
])]).set_lineno(next(parser.stream).lineno)
def _dump(self, sandboxed, ext_attr, imported_object, context):
- return '%s|%s|%s|%s' % (
+ return '%s|%s|%s|%s|%s' % (
sandboxed,
ext_attr,
imported_object,
- context.blocks
+ context.blocks,
+ context.get('test_var')
)
+class DerivedExampleExtension(ExampleExtension):
+ context_reference_node_cls = nodes.DerivedContextReference
+
+
class PreprocessorExtension(Extension):
def preprocess(self, source, name, filename=None):
def test_extension_nodes(self):
env = Environment(extensions=[ExampleExtension])
tmpl = env.from_string('{% test %}')
- assert tmpl.render() == 'False|42|23|{}'
+ assert tmpl.render() == 'False|42|23|{}|None'
+
+ def test_contextreference_node_passes_context(self):
+ env = Environment(extensions=[ExampleExtension])
+ tmpl = env.from_string('{% set test_var="test_content" %}{% test %}')
+ assert tmpl.render() == 'False|42|23|{}|test_content'
+
+ def test_contextreference_node_can_pass_locals(self):
+ env = Environment(extensions=[DerivedExampleExtension])
+ tmpl = env.from_string(
+ '{% for test_var in ["test_content"] %}{% test %}{% endfor %}')
+ assert tmpl.render() == 'False|42|23|{}|test_content'
def test_identifier(self):
assert ExampleExtension.identifier == __name__ + '.ExampleExtension'