From: Armin Ronacher Date: Wed, 28 Dec 2016 19:06:34 +0000 (+0100) Subject: Added support for async loop context X-Git-Tag: 2.9~64 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d4e14fc4d20b19841c1074d3c28a3f059424354f;p=thirdparty%2Fjinja.git Added support for async loop context --- diff --git a/jinja2/asyncsupport.py b/jinja2/asyncsupport.py index 0a19b6ae..0c4fecf6 100644 --- a/jinja2/asyncsupport.py +++ b/jinja2/asyncsupport.py @@ -4,6 +4,7 @@ import inspect from jinja2.utils import concat, internalcode, concat, Markup from jinja2.environment import TemplateModule +from jinja2.runtime import LoopContextBase, _last_iteration async def concat_async(async_gen): @@ -144,3 +145,45 @@ async def auto_iter(iterable): return for item in iterable: yield item + + +class AsyncLoopContext(LoopContextBase): + + def __init__(self, async_iterator, iterable, after, recurse=None, depth0=0): + self._async_iterator = async_iterator + LoopContextBase.__init__(self, iterable, recurse, depth0) + self._after = after + + def __aiter__(self): + return AsyncLoopContextIterator(self) + + +class AsyncLoopContextIterator(object): + __slots__ = ('context',) + + def __init__(self, context): + self.context = context + + def __aiter__(self): + return self + + async def __anext__(self): + ctx = self.context + ctx.index0 += 1 + if ctx._after is _last_iteration: + raise StopAsyncIteration() + next_elem = ctx._after + try: + ctx._after = await ctx._async_iterator.__anext__() + except StopAsyncIteration: + ctx._after = _last_iteration + return next_elem, ctx + + +async def make_async_loop_context(iterable, recurse=None, depth0=0): + async_iterator = auto_iter(iterable) + try: + after = await async_iterator.__anext__() + except StopAsyncIteration: + after = _last_iteration + return AsyncLoopContext(async_iterator, iterable, after, recurse, depth0) diff --git a/jinja2/compiler.py b/jinja2/compiler.py index f2714c67..667fec52 100644 --- a/jinja2/compiler.py +++ b/jinja2/compiler.py @@ -783,7 +783,8 @@ class CodeGenerator(NodeVisitor): self.writeline('dummy = lambda *x: None') if self.environment._async: - self.writeline('from jinja2.asyncsupport import auto_await, auto_iter') + self.writeline('from jinja2.asyncsupport import auto_await, ' + 'auto_iter, make_async_loop_context') # if we want a deferred initialization we cannot move the # environment into a local name @@ -1132,7 +1133,13 @@ class CodeGenerator(NodeVisitor): self.writeline(self.environment._async and 'async for ' or 'for ', node) self.visit(node.target, loop_frame) - self.write(extended_loop and ', l_loop in LoopContext(' or ' in ') + if extended_loop: + if self.environment._async: + self.write(', l_loop in await make_async_loop_context(') + else: + self.write(', l_loop in LoopContext(') + else: + self.write(' in ') # if we have an extened loop and a node test, we filter in the # "outer frame". @@ -1158,10 +1165,10 @@ class CodeGenerator(NodeVisitor): elif node.recursive: self.write('reciter') else: - if self.environment._async: + if self.environment._async and not extended_loop: self.write('auto_iter(') self.visit(node.iter, loop_frame) - if self.environment._async: + if self.environment._async and not extended_loop: self.write(')') if node.recursive: diff --git a/jinja2/runtime.py b/jinja2/runtime.py index 685a12da..e0df9b71 100644 --- a/jinja2/runtime.py +++ b/jinja2/runtime.py @@ -280,13 +280,14 @@ class BlockReference(object): return rv -class LoopContext(object): +class LoopContextBase(object): """A loop context for dynamic iteration.""" + _after = _last_iteration + _length = None + def __init__(self, iterable, recurse=None, depth0=0): - self._iterator = iter(iterable) self._recurse = recurse - self._after = self._safe_next() self.index0 = -1 self.depth0 = depth0 @@ -315,15 +316,6 @@ class LoopContext(object): def __len__(self): return self.length - def __iter__(self): - return LoopContextIterator(self) - - def _safe_next(self): - try: - return next(self._iterator) - except StopIteration: - return _last_iteration - @internalcode def loop(self, iterable): if self._recurse is None: @@ -357,6 +349,23 @@ class LoopContext(object): ) +class LoopContext(LoopContextBase): + + def __init__(self, iterable, recurse=None, depth0=0): + self._iterator = iter(iterable) + LoopContextBase.__init__(self, iterable, recurse, depth0) + self._after = self._safe_next() + + def __iter__(self): + return LoopContextIterator(self) + + def _safe_next(self): + try: + return next(self._iterator) + except StopIteration: + return _last_iteration + + @implements_iterator class LoopContextIterator(object): """The iterator for a loop context.""" diff --git a/tests/test_async.py b/tests/test_async.py index 44463d10..fff732be 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -97,7 +97,7 @@ def test_async_generate(): @pytest.mark.skipif(not have_async_gen, reason='No async generators') -def test_async_iteration_in_tmeplates(): +def test_async_iteration_in_templates(): t = Template('{% for x in rng %}{{ x }}{% endfor %}', enable_async=True) async def async_iterator(): @@ -107,6 +107,17 @@ def test_async_iteration_in_tmeplates(): assert rv == ['1', '2', '3'] +@pytest.mark.skipif(not have_async_gen, reason='No async generators') +def test_async_iteration_in_templates_extended(): + t = Template('{% for x in rng %}{{ loop.index0 }}/{{ x }}{% endfor %}', + enable_async=True) + async def async_iterator(): + for item in [1, 2, 3]: + yield item + rv = list(t.generate(rng=async_iterator())) + assert rv == ['0/1', '1/2', '2/3'] + + @pytest.fixture def test_env_async(): env = Environment(loader=DictLoader(dict(