--- /dev/null
+from functools import wraps
+from jinja2.asyncsupport import auto_aiter
+from jinja2 import filters
+
+
+async def auto_to_seq(value):
+ seq = []
+ if hasattr(value, '__aiter__'):
+ async for item in value:
+ seq.append(item)
+ else:
+ for item in value:
+ seq.append(item)
+ return seq
+
+
+async def async_do_first(environment, seq):
+ try:
+ return await auto_aiter(seq).__anext__()
+ except StopAsyncIteration:
+ return environment.undefined('No first item, sequence was empty.')
+
+
+@wraps(filters.do_first)
+@filters.environmentfilter
+def do_first(environment, seq):
+ if environment.is_async:
+ return async_do_first(environment, seq)
+ return filters.do_first(environment, seq)
+
+
+ASYNC_FILTERS = {
+ 'first': do_first,
+}
def patch_runtime():
from jinja2.runtime import BlockReference
- BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__)
+ BlockReference.__call__ = wrap_block_reference_call(
+ BlockReference.__call__)
+
+
+def patch_filters():
+ from jinja2.filters import FILTERS
+ from jinja2.asyncfilters import ASYNC_FILTERS
+ FILTERS.update(ASYNC_FILTERS)
def patch_all():
patch_template()
patch_runtime()
+ patch_filters()
async def auto_await(value):
self.visit(node.step, frame)
def visit_Filter(self, node, frame):
+ if self.environment.is_async:
+ self.write('await auto_await(')
self.write(self.filters[node.name] + '(')
func = self.environment.filters.get(node.name)
if func is None:
self.write('concat(%s)' % frame.buffer)
self.signature(node, frame)
self.write(')')
+ if self.environment.is_async:
+ self.write(')')
def visit_Test(self, node, frame):
self.write(self.tests[node.name] + '(')
def pytest_ignore_collect(path, config):
- if path.basename == 'test_async.py' and not have_async_gen:
+ if 'async' in path.basename and not have_async_gen:
return True
return False
--- /dev/null
+import pytest
+from jinja2 import Environment
+
+
+@pytest.fixture
+def env_async():
+ return Environment(enable_async=True)
+
+
+def test_first(env_async):
+ tmpl = env_async.from_string('{{ foo|first }}')
+ out = tmpl.render(foo=list(range(10)))
+ assert out == '0'
+
+
+def test_first_aiter(env_async):
+ async def foo():
+ for x in range(10):
+ yield x
+ tmpl = env_async.from_string('{{ foo()|first }}')
+ out = tmpl.render(foo=foo)
+ assert out == '0'
@pytest.mark.filter
-class TestFilter():
+class TestFilter(object):
def test_filter_calling(self, env):
rv = env.call_filter('sum', [1, 2, 3])