]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
Added tests for async functionality with imports and includes
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 14:55:54 +0000 (15:55 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 14:55:54 +0000 (15:55 +0100)
jinja2/asyncsupport.py
jinja2/compiler.py
jinja2/environment.py
tests/test_async.py

index 765c3137acad5ccd682803547b6d3358c92886ca..ca897a77ac0af3ad0f05f6c462703c229440cc00 100644 (file)
@@ -64,13 +64,6 @@ async def get_default_module_async(self):
     return rv
 
 
-@internalcode
-def get_default_module_impl(self):
-    if self.environment._async:
-        return self._get_default_module_async()
-    return self._get_default_module()
-
-
 def wrap_default_module(original_default_module):
     @internalcode
     def _get_default_module(self):
@@ -84,7 +77,7 @@ def wrap_default_module(original_default_module):
 async def make_module_async(self, vars=None, shared=False, locals=None):
     context = self.new_context(vars, shared, locals)
     body_stream = []
-    async for item in template.root_render_func(context):
+    async for item in self.root_render_func(context):
         body_stream.append(item)
     return TemplateModule(self, context, body_stream)
 
@@ -96,7 +89,6 @@ def patch_template():
     Template._get_default_module = wrap_default_module(
         Template._get_default_module)
     Template._get_default_module_async = get_default_module_async
-    Template._get_default_module_impl = get_default_module_impl
     Template.make_module_async = make_module_async
 
 
index cd5c1cec7df93e479155f207c5d7923acc55e92b..4afda9ad1b8abbb9cdb8daa038b76ac4c065c4be 100644 (file)
@@ -971,14 +971,18 @@ class CodeGenerator(NodeVisitor):
             self.writeline('else:')
             self.indent()
 
-        loop = self.environment._async and 'async for' or 'for'
         if node.with_context:
+            loop = self.environment._async and 'async for' or 'for'
             self.writeline('%s event in template.root_render_func('
                            'template.new_context(context.parent, True, '
                            'locals())):' % loop)
+        elif self.environment._async:
+            self.writeline('for event in (await '
+                           'template._get_default_module_async())'
+                           '._body_stream:')
         else:
-            self.writeline('%s event in template._get_default_module_impl()'
-                           '._body_stream:' % loop)
+            self.writeline('for event in template._get_default_module()'
+                           '._body_stream:')
 
         self.indent()
         self.simple_write('event', frame)
@@ -1002,8 +1006,10 @@ class CodeGenerator(NodeVisitor):
         if node.with_context:
             self.write('make_module%s(context.parent, True, locals())'
                        % (self.environment._async and '_async' or ''))
+        elif self.environment._async:
+            self.write('_get_default_module_async()')
         else:
-            self.write('_get_default_module_impl()')
+            self.write('_get_default_module()')
         if frame.toplevel and not node.target.startswith('_'):
             self.writeline('context.exported_vars.discard(%r)' % node.target)
         frame.assigned_names.add(node.target)
@@ -1018,8 +1024,10 @@ class CodeGenerator(NodeVisitor):
         if node.with_context:
             self.write('make_module%s(context.parent, True)'
                        % (self.environment._async and '_async' or ''))
+        elif self.environment._async:
+            self.write('_get_default_module_async()')
         else:
-            self.write('_get_default_module_impl()')
+            self.write('_get_default_module()')
 
         var_names = []
         discarded_names = []
index 1ebb4da2f14cd92601ca66365f498372a149ecd0..f91b01dcc83415ca1e6dc42383e3e59d57a8c6b5 100644 (file)
@@ -1042,10 +1042,6 @@ class Template(object):
         self._module = rv = self.make_module()
         return rv
 
-    # This is what the compiler dispatches to from generated code.  It
-    # might get swapped out by the async support
-    _get_default_module_impl = _get_default_module
-
     @property
     def module(self):
         """The template as module.  This is used for imports in the
index faaf2a2f891c42e0ea73032c6abb10c4b8e7e775..d593a2b2d3e7649a008f86dfb6f53ef3b2e14b4f 100644 (file)
@@ -1,13 +1,14 @@
 import pytest
 import asyncio
 
-from jinja2 import Template
+from jinja2 import Template, Environment, DictLoader
 from jinja2.utils import have_async_gen
+from jinja2.exceptions import TemplateNotFound, TemplatesNotFound
 
 
-def run(func):
+def run(coro):
     loop = asyncio.get_event_loop()
-    return loop.run_until_complete(func())
+    return loop.run_until_complete(coro)
 
 
 @pytest.mark.skipif(not have_async_gen, reason='No async generators')
@@ -17,7 +18,7 @@ def test_basic_async():
     async def func():
         return await t.render_async()
 
-    rv = run(func)
+    rv = run(func())
     assert rv == '[1][2][3]'
 
 
@@ -38,7 +39,7 @@ def test_await_on_calls():
             normal_func=normal_func
         )
 
-    rv = run(func)
+    rv = run(func())
     assert rv == '65'
 
 
@@ -72,7 +73,7 @@ def test_await_and_macros():
     async def func():
         return await t.render_async(async_func=async_func)
 
-    rv = run(func)
+    rv = run(func())
     assert rv == '[42][42]'
 
 
@@ -83,5 +84,139 @@ def test_async_blocks():
     async def func():
         return await t.render_async()
 
-    rv = run(func)
+    rv = run(func())
     assert rv == '<Test><Test>'
+
+
+@pytest.fixture
+def test_env_async():
+    env = Environment(loader=DictLoader(dict(
+        module='{% macro test() %}[{{ foo }}|{{ bar }}]{% endmacro %}',
+        header='[{{ foo }}|{{ 23 }}]',
+        o_printer='({{ o }})'
+    )), enable_async=True)
+    env.globals['bar'] = 23
+    return env
+
+
+@pytest.mark.imports
+class TestAsyncImports(object):
+
+    def test_context_imports(self, test_env_async):
+        t = test_env_async.from_string('{% import "module" as m %}{{ m.test() }}')
+        assert t.render(foo=42) == '[|23]'
+        t = test_env_async.from_string(
+            '{% import "module" as m without context %}{{ m.test() }}'
+        )
+        assert t.render(foo=42) == '[|23]'
+        t = test_env_async.from_string(
+            '{% import "module" as m with context %}{{ m.test() }}'
+        )
+        assert t.render(foo=42) == '[42|23]'
+        t = test_env_async.from_string('{% from "module" import test %}{{ test() }}')
+        assert t.render(foo=42) == '[|23]'
+        t = test_env_async.from_string(
+            '{% from "module" import test without context %}{{ test() }}'
+        )
+        assert t.render(foo=42) == '[|23]'
+        t = test_env_async.from_string(
+            '{% from "module" import test with context %}{{ test() }}'
+        )
+        assert t.render(foo=42) == '[42|23]'
+
+    def test_trailing_comma(self, test_env_async):
+        test_env_async.from_string('{% from "foo" import bar, baz with context %}')
+        test_env_async.from_string('{% from "foo" import bar, baz, with context %}')
+        test_env_async.from_string('{% from "foo" import bar, with context %}')
+        test_env_async.from_string('{% from "foo" import bar, with, context %}')
+        test_env_async.from_string('{% from "foo" import bar, with with context %}')
+
+    def test_exports(self, test_env_async):
+        m = run(test_env_async.from_string('''
+            {% macro toplevel() %}...{% endmacro %}
+            {% macro __private() %}...{% endmacro %}
+            {% set variable = 42 %}
+            {% for item in [1] %}
+                {% macro notthere() %}{% endmacro %}
+            {% endfor %}
+        ''')._get_default_module_async())
+        assert run(m.toplevel()) == '...'
+        assert not hasattr(m, '__missing')
+        assert m.variable == 42
+        assert not hasattr(m, 'notthere')
+
+
+@pytest.mark.imports
+@pytest.mark.includes
+class TestAsyncIncludes(object):
+
+    def test_context_include(self, test_env_async):
+        t = test_env_async.from_string('{% include "header" %}')
+        assert t.render(foo=42) == '[42|23]'
+        t = test_env_async.from_string('{% include "header" with context %}')
+        assert t.render(foo=42) == '[42|23]'
+        t = test_env_async.from_string('{% include "header" without context %}')
+        assert t.render(foo=42) == '[|23]'
+
+    def test_choice_includes(self, test_env_async):
+        t = test_env_async.from_string('{% include ["missing", "header"] %}')
+        assert t.render(foo=42) == '[42|23]'
+
+        t = test_env_async.from_string(
+            '{% include ["missing", "missing2"] ignore missing %}'
+        )
+        assert t.render(foo=42) == ''
+
+        t = test_env_async.from_string('{% include ["missing", "missing2"] %}')
+        pytest.raises(TemplateNotFound, t.render)
+        try:
+            t.render()
+        except TemplatesNotFound as e:
+            assert e.templates == ['missing', 'missing2']
+            assert e.name == 'missing2'
+        else:
+            assert False, 'thou shalt raise'
+
+        def test_includes(t, **ctx):
+            ctx['foo'] = 42
+            assert t.render(ctx) == '[42|23]'
+
+        t = test_env_async.from_string('{% include ["missing", "header"] %}')
+        test_includes(t)
+        t = test_env_async.from_string('{% include x %}')
+        test_includes(t, x=['missing', 'header'])
+        t = test_env_async.from_string('{% include [x, "header"] %}')
+        test_includes(t, x='missing')
+        t = test_env_async.from_string('{% include x %}')
+        test_includes(t, x='header')
+        t = test_env_async.from_string('{% include x %}')
+        test_includes(t, x='header')
+        t = test_env_async.from_string('{% include [x] %}')
+        test_includes(t, x='header')
+
+    def test_include_ignoring_missing(self, test_env_async):
+        t = test_env_async.from_string('{% include "missing" %}')
+        pytest.raises(TemplateNotFound, t.render)
+        for extra in '', 'with context', 'without context':
+            t = test_env_async.from_string('{% include "missing" ignore missing ' +
+                                     extra + ' %}')
+            assert t.render() == ''
+
+    def test_context_include_with_overrides(self, test_env_async):
+        env = Environment(loader=DictLoader(dict(
+            main="{% for item in [1, 2, 3] %}{% include 'item' %}{% endfor %}",
+            item="{{ item }}"
+        )))
+        assert env.get_template("main").render() == "123"
+
+    def test_unoptimized_scopes(self, test_env_async):
+        t = test_env_async.from_string("""
+            {% macro outer(o) %}
+            {% macro inner() %}
+            {% include "o_printer" %}
+            {% endmacro %}
+            {{ inner() }}
+            {% endmacro %}
+            {{ outer("FOO") }}
+        """)
+        assert t.render().strip() == '(FOO)'