]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
First pass on async support for filters
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 22:18:49 +0000 (23:18 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 22:18:51 +0000 (23:18 +0100)
jinja2/asyncfilters.py [new file with mode: 0644]
jinja2/asyncsupport.py
jinja2/compiler.py
tests/conftest.py
tests/test_asyncfilters.py [new file with mode: 0644]
tests/test_filters.py

diff --git a/jinja2/asyncfilters.py b/jinja2/asyncfilters.py
new file mode 100644 (file)
index 0000000..58a44a3
--- /dev/null
@@ -0,0 +1,34 @@
+from functools import wraps
+from jinja2.asyncsupport import auto_aiter
+from jinja2 import filters
+
+
+async def auto_to_seq(value):
+    seq = []
+    if hasattr(value, '__aiter__'):
+        async for item in value:
+            seq.append(item)
+    else:
+        for item in value:
+            seq.append(item)
+    return seq
+
+
+async def async_do_first(environment, seq):
+    try:
+        return await auto_aiter(seq).__anext__()
+    except StopAsyncIteration:
+        return environment.undefined('No first item, sequence was empty.')
+
+
+@wraps(filters.do_first)
+@filters.environmentfilter
+def do_first(environment, seq):
+    if environment.is_async:
+        return async_do_first(environment, seq)
+    return filters.do_first(environment, seq)
+
+
+ASYNC_FILTERS = {
+    'first': do_first,
+}
index 1cdda8332d2684537efd3762abe253108e34f9b5..33a1a071876f35db88bbddd51faf41df6d77753d 100644 (file)
@@ -139,12 +139,20 @@ def patch_template():
 
 def patch_runtime():
     from jinja2.runtime import BlockReference
-    BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__)
+    BlockReference.__call__ = wrap_block_reference_call(
+        BlockReference.__call__)
+
+
+def patch_filters():
+    from jinja2.filters import FILTERS
+    from jinja2.asyncfilters import ASYNC_FILTERS
+    FILTERS.update(ASYNC_FILTERS)
 
 
 def patch_all():
     patch_template()
     patch_runtime()
+    patch_filters()
 
 
 async def auto_await(value):
index c01d5b1aa2bf4802b51d27e188818757af2b9f89..056d60eaa021ae4839ee78514d6c612d418d20b9 100644 (file)
@@ -1615,6 +1615,8 @@ class CodeGenerator(NodeVisitor):
             self.visit(node.step, frame)
 
     def visit_Filter(self, node, frame):
+        if self.environment.is_async:
+            self.write('await auto_await(')
         self.write(self.filters[node.name] + '(')
         func = self.environment.filters.get(node.name)
         if func is None:
@@ -1640,6 +1642,8 @@ class CodeGenerator(NodeVisitor):
             self.write('concat(%s)' % frame.buffer)
         self.signature(node, frame)
         self.write(')')
+        if self.environment.is_async:
+            self.write(')')
 
     def visit_Test(self, node, frame):
         self.write(self.tests[node.name] + '(')
index f1ae10ff8fcec6ccbfbf3e56cc67f64b73405ca1..eaae2b0c263dd70284b7bda5a1a994037febb4b9 100644 (file)
@@ -21,7 +21,7 @@ from jinja2 import Environment
 
 
 def pytest_ignore_collect(path, config):
-    if path.basename == 'test_async.py' and not have_async_gen:
+    if 'async' in path.basename and not have_async_gen:
         return True
     return False
 
diff --git a/tests/test_asyncfilters.py b/tests/test_asyncfilters.py
new file mode 100644 (file)
index 0000000..75ee4a3
--- /dev/null
@@ -0,0 +1,22 @@
+import pytest
+from jinja2 import Environment
+
+
+@pytest.fixture
+def env_async():
+    return Environment(enable_async=True)
+
+
+def test_first(env_async):
+    tmpl = env_async.from_string('{{ foo|first }}')
+    out = tmpl.render(foo=list(range(10)))
+    assert out == '0'
+
+
+def test_first_aiter(env_async):
+    async def foo():
+        for x in range(10):
+            yield x
+    tmpl = env_async.from_string('{{ foo()|first }}')
+    out = tmpl.render(foo=foo)
+    assert out == '0'
index f8fdad8a001fe0ba6d6a63291fc5ed4d5ecdec5c..59999ee4c665afdc1988fbf20bc8ae3ea5e32beb 100644 (file)
@@ -14,7 +14,7 @@ from jinja2._compat import text_type, implements_to_string
 
 
 @pytest.mark.filter
-class TestFilter():
+class TestFilter(object):
 
     def test_filter_calling(self, env):
         rv = env.call_filter('sum', [1, 2, 3])