]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
don't prefetch next item in loop context
authorDavid Lord <davidism@gmail.com>
Thu, 7 Nov 2019 21:35:57 +0000 (13:35 -0800)
committerDavid Lord <davidism@gmail.com>
Thu, 7 Nov 2019 21:35:57 +0000 (13:35 -0800)
jinja2/asyncsupport.py
jinja2/compiler.py
jinja2/runtime.py

index 53ad192de26a3d969d6fe5d3621fee98ca2d8acd..7d457e3d1865caadc4e8380cea8b659de2c7564a 100644 (file)
@@ -9,14 +9,17 @@
     :copyright: (c) 2017 by the Jinja Team.
     :license: BSD, see LICENSE for more details.
 """
-import sys
 import asyncio
 import inspect
+import sys
 from functools import update_wrapper
 
-from jinja2.utils import concat, internalcode, Markup
 from jinja2.environment import TemplateModule
-from jinja2.runtime import LoopContextBase, _last_iteration
+from jinja2.runtime import LoopContext
+from jinja2.utils import concat
+from jinja2.utils import internalcode
+from jinja2.utils import Markup
+from jinja2.utils import missing
 
 
 async def concat_async(async_gen):
@@ -187,73 +190,80 @@ async def auto_aiter(iterable):
         yield item
 
 
-class AsyncLoopContext(LoopContextBase):
-
-    def __init__(self, async_iterator, undefined, after, length, recurse=None,
-                 depth0=0):
-        LoopContextBase.__init__(self, undefined, recurse, depth0)
-        self._async_iterator = async_iterator
-        self._after = after
-        self._length = length
+class AsyncLoopContext(LoopContext):
+    _to_iterator = staticmethod(auto_aiter)
 
     @property
-    def length(self):
-        if self._length is None:
-            raise TypeError('Loop length for some iterators cannot be '
-                            'lazily calculated in async mode')
+    async def length(self):
+        if self._length is not None:
+            return self._length
+
+        try:
+            self._length = len(self._iterable)
+        except TypeError:
+            iterable = [x async for x in self._iterator]
+            self._iterator = self._to_iterator(iterable)
+            self._length = len(iterable) + self.index + (self._after is not missing)
+
         return self._length
 
-    def __aiter__(self):
-        return AsyncLoopContextIterator(self)
+    @property
+    async def revindex0(self):
+        return await self.length - self.index
+
+    @property
+    async def revindex(self):
+        return await self.length - self.index0
+
+    async def _peek_next(self):
+        if self._after is not missing:
+            return self._after
+
+        try:
+            self._after = await self._iterator.__anext__()
+        except StopAsyncIteration:
+            self._after = missing
+
+        return self._after
 
+    @property
+    async def last(self):
+        return await self._peek_next() is missing
 
-class AsyncLoopContextIterator(object):
-    __slots__ = ('context',)
+    @property
+    async def nextitem(self):
+        rv = await self._peek_next()
 
-    def __init__(self, context):
-        self.context = context
+        if rv is missing:
+            return self._undefined("there is no next item")
+
+        return rv
 
     def __aiter__(self):
         return self
 
     async def __anext__(self):
-        ctx = self.context
-        ctx.index0 += 1
-        if ctx._after is _last_iteration:
-            raise StopAsyncIteration()
-        ctx._before = ctx._current
-        ctx._current = ctx._after
-        try:
-            ctx._after = await ctx._async_iterator.__anext__()
-        except StopAsyncIteration:
-            ctx._after = _last_iteration
-        return ctx._current, ctx
+        if self._after is not missing:
+            rv = self._after
+            self._after = missing
+        else:
+            rv = await self._iterator.__anext__()
+
+        self.index0 += 1
+        self._before = self._current
+        self._current = rv
+        return rv, self
 
 
 async def make_async_loop_context(iterable, undefined, recurse=None, depth0=0):
-    # Length is more complicated and less efficient in async mode.  The
-    # reason for this is that we cannot know if length will be used
-    # upfront but because length is a property we cannot lazily execute it
-    # later.  This means that we need to buffer it up and measure :(
-    #
-    # We however only do this for actual iterators, not for async
-    # iterators as blocking here does not seem like the best idea in the
-    # world.
-    try:
-        length = len(iterable)
-    except (TypeError, AttributeError):
-        if not hasattr(iterable, '__aiter__'):
-            iterable = tuple(iterable)
-            length = len(iterable)
-        else:
-            length = None
-    async_iterator = auto_aiter(iterable)
-    try:
-        after = await async_iterator.__anext__()
-    except StopAsyncIteration:
-        after = _last_iteration
-    return AsyncLoopContext(async_iterator, undefined, after, length, recurse,
-                            depth0)
+    import warnings
+    warnings.warn(
+        "This template must be recompiled with at least Jinja 2.11, or"
+        " it will fail in 3.0.",
+        DeprecationWarning,
+        stacklevel=2,
+    )
+    return AsyncLoopContext(iterable, undefined, recurse, depth0)
 
 
 patch_all()
index 488ef0a375d2dcc287cb6ecc0555e7aed1c6ecef..00b29b8ef18e3527ca50365471db8d04437c859d 100644 (file)
@@ -705,7 +705,7 @@ class CodeGenerator(NodeVisitor):
 
         if self.environment.is_async:
             self.writeline('from jinja2.asyncsupport import auto_await, '
-                           'auto_aiter, make_async_loop_context')
+                           'auto_aiter, AsyncLoopContext')
 
         # if we want a deferred initialization we cannot move the
         # environment into a local name
@@ -1095,7 +1095,7 @@ class CodeGenerator(NodeVisitor):
         self.visit(node.target, loop_frame)
         if extended_loop:
             if self.environment.is_async:
-                self.write(', %s in await make_async_loop_context(' % loop_ref)
+                self.write(', %s in AsyncLoopContext(' % loop_ref)
             else:
                 self.write(', %s in LoopContext(' % loop_ref)
         else:
index ff12dedad36c3e3dba1d1bd31785ddfba29c4ccf..135ff27b1485f2bf3595f032c3ce626bf8bbdf75 100644 (file)
@@ -343,134 +343,197 @@ class BlockReference(object):
         return rv
 
 
-class LoopContextBase(object):
-    """A loop context for dynamic iteration."""
+@implements_iterator
+class LoopContext:
+    """A wrapper iterable for dynamic ``for`` loops, with information
+    about the loop and iteration.
+    """
+
+    #: Current iteration of the loop, starting at 0.
+    index0 = -1
 
-    _before = _first_iteration
-    _current = _first_iteration
-    _after = _last_iteration
     _length = None
+    _after = missing
+    _current = missing
+    _before = missing
+    _last_changed_value = missing
 
-    def __init__(self, undefined, recurse=None, depth0=0):
+    def __init__(self, iterable, undefined, recurse=None, depth0=0):
+        """
+        :param iterable: Iterable to wrap.
+        :param undefined: :class:`Undefined` class to use for next and
+            previous items.
+        :param recurse: The function to render the loop body when the
+            loop is marked recursive.
+        :param depth0: Incremented when looping recursively.
+        """
+        self._iterable = iterable
+        self._iterator = self._to_iterator(iterable)
         self._undefined = undefined
         self._recurse = recurse
-        self.index0 = -1
+        #: How many levels deep a recursive loop currently is, starting at 0.
         self.depth0 = depth0
-        self._last_checked_value = missing
 
-    def cycle(self, *args):
-        """Cycles among the arguments with the current loop index."""
-        if not args:
-            raise TypeError('no items for cycling given')
-        return args[self.index0 % len(args)]
+    @staticmethod
+    def _to_iterator(iterable):
+        return iter(iterable)
 
-    def changed(self, *value):
-        """Checks whether the value has changed since the last call."""
-        if self._last_checked_value != value:
-            self._last_checked_value = value
-            return True
-        return False
+    @property
+    def length(self):
+        """Length of the iterable.
 
-    first = property(lambda x: x.index0 == 0)
-    last = property(lambda x: x._after is _last_iteration)
-    index = property(lambda x: x.index0 + 1)
-    revindex = property(lambda x: x.length - x.index0)
-    revindex0 = property(lambda x: x.length - x.index)
-    depth = property(lambda x: x.depth0 + 1)
+        If the iterable is a generator or otherwise does not have a
+        size, it is eagerly evaluated to get a size.
+        """
+        if self._length is not None:
+            return self._length
 
-    @property
-    def previtem(self):
-        if self._before is _first_iteration:
-            return self._undefined('there is no previous item')
-        return self._before
+        try:
+            self._length = len(self._iterable)
+        except TypeError:
+            iterable = list(self._iterator)
+            self._iterator = self._to_iterator(iterable)
+            self._length = len(iterable) + self.index + (self._after is not missing)
 
-    @property
-    def nextitem(self):
-        if self._after is _last_iteration:
-            return self._undefined('there is no next item')
-        return self._after
+        return self._length
 
     def __len__(self):
         return self.length
 
-    @internalcode
-    def loop(self, iterable):
-        if self._recurse is None:
-            raise TypeError('Tried to call non recursive loop.  Maybe you '
-                            "forgot the 'recursive' modifier.")
-        return self._recurse(iterable, self._recurse, self.depth0 + 1)
+    @property
+    def depth(self):
+        """How many levels deep a recursive loop currently is, starting at 1."""
+        return self.depth0 + 1
 
-    # a nifty trick to enhance the error message if someone tried to call
-    # the loop without or with too many arguments.
-    __call__ = loop
-    del loop
+    @property
+    def index(self):
+        """Current iteration of the loop, starting at 1."""
+        return self.index0 + 1
 
-    def __repr__(self):
-        return '<%s %r/%r>' % (
-            self.__class__.__name__,
-            self.index,
-            self.length
-        )
+    @property
+    def revindex0(self):
+        """Number of iterations from the end of the loop, ending at 0.
 
+        Requires calculating :attr:`length`.
+        """
+        return self.length - self.index
 
-class LoopContext(LoopContextBase):
+    @property
+    def revindex(self):
+        """Number of iterations from the end of the loop, ending at 1.
 
-    def __init__(self, iterable, undefined, recurse=None, depth0=0):
-        LoopContextBase.__init__(self, undefined, recurse, depth0)
-        self._iterator = iter(iterable)
-        self._iterations_done_count = 0
-        self._length = None
-        self._after = self._safe_next()
+        Requires calculating :attr:`length`.
+        """
+        return self.length - self.index0
 
     @property
-    def length(self):
+    def first(self):
+        """Whether this is the first iteration of the loop."""
+        return self.index0 == 0
+
+    def _peek_next(self):
+        """Return the next element in the iterable, or :data:`missing`
+        if the iterable is exhausted. Only peeks one item ahead, caching
+        the result in :attr:`_last` for use in subsequent checks. The
+        cache is reset when :meth:`__next__` is called.
+        """
+        if self._after is not missing:
+            return self._after
+
+        self._after = next(self._iterator, missing)
+        return self._after
+
+    @property
+    def last(self):
+        """Whether this is the last iteration of the loop.
+
+        Causes the iterable to advance early. See
+        :func:`itertools.groupby` for issues this can cause.
+        The :func:`groupby` filter avoids that issue.
         """
-        Getting length of an iterator is a costly operation which requires extra memory
-        and traversing in linear time. So make it an on demand param that iterates from
-        the point onwards of the iterator and accounts for iterated elements.
+        return self._peek_next() is missing
+
+    @property
+    def previtem(self):
+        """The item in the previous iteration. Undefined during the
+        first iteration.
         """
-        if self._length is None:
-            # if was not possible to get the length of the iterator when
-            # the loop context was created (ie: iterating over a generator)
-            # we have to convert the iterable into a sequence and use the
-            # length of that + the number of iterations so far.
-            iterable = tuple(self._iterator)
-            self._iterator = iter(iterable)
-            self._length = len(iterable) + self._iterations_done_count
-        return self._length
+        if self.first:
+            return self._undefined("there is no previous item")
 
-    def __iter__(self):
-        return LoopContextIterator(self)
+        return self._before
 
-    def _safe_next(self):
-        try:
-            tmp = next(self._iterator)
-            self._iterations_done_count += 1
-            return tmp
-        except StopIteration:
-            return _last_iteration
+    @property
+    def nextitem(self):
+        """The item in the next iteration. Undefined during the last
+        iteration.
 
+        Causes the iterable to advance early. See
+        :func:`itertools.groupby` for issues this can cause.
+        The :func:`groupby` filter avoids that issue.
+        """
+        rv = self._peek_next()
 
-@implements_iterator
-class LoopContextIterator(object):
-    """The iterator for a loop context."""
-    __slots__ = ('context',)
+        if rv is missing:
+            return self._undefined("there is no next item")
 
-    def __init__(self, context):
-        self.context = context
+        return rv
+
+    def cycle(self, *args):
+        """Return a value from the given args, cycling through based on
+        the current :attr:`index0`.
+
+        :param args: One or more values to cycle through.
+        """
+        if not args:
+            raise TypeError("no items for cycling given")
+
+        return args[self.index0 % len(args)]
+
+    def changed(self, *value):
+        """Return ``True`` if previously called with a different value
+        (including when called for the first time).
+
+        :param value: One or more values to compare to the last call.
+        """
+        if self._last_changed_value != value:
+            self._last_changed_value = value
+            return True
+
+        return False
 
     def __iter__(self):
         return self
 
+    @internalcode
     def __next__(self):
-        ctx = self.context
-        ctx.index0 += 1
-        if ctx._after is _last_iteration:
-            raise StopIteration()
-        ctx._before = ctx._current
-        ctx._current = ctx._after
-        ctx._after = ctx._safe_next()
-        return ctx._current, ctx
+        if self._after is not missing:
+            rv = self._after
+            self._after = missing
+        else:
+            rv = next(self._iterator)
+
+        self.index0 += 1
+        self._before = self._current
+        self._current = rv
+        return rv, self
+
+    def __call__(self, iterable):
+        """When iterating over nested data, render the body of the loop
+        recursively with the given inner iterable data.
+
+        The loop must have the ``recursive`` marker for this to work.
+        """
+        if self._recurse is None:
+            raise TypeError(
+                "The loop must have the 'recursive' marker to be"
+                " called recursively."
+            )
+
+        return self._recurse(iterable, self._recurse, depth=self.depth)
+
+    def __repr__(self):
+        return "<%s %d/%d>" % (self.__class__.__name__, self.index, self.length)
 
 
 class Macro(object):