]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
async support doesn't require patching 1392/head
authorDavid Lord <davidism@gmail.com>
Sat, 10 Apr 2021 23:12:25 +0000 (16:12 -0700)
committerDavid Lord <davidism@gmail.com>
Sat, 10 Apr 2021 23:12:25 +0000 (16:12 -0700)
14 files changed:
CHANGES.rst
docs/intro.rst
src/jinja2/async_utils.py [new file with mode: 0644]
src/jinja2/asyncfilters.py [deleted file]
src/jinja2/asyncsupport.py [deleted file]
src/jinja2/compiler.py
src/jinja2/environment.py
src/jinja2/filters.py
src/jinja2/nativetypes.py
src/jinja2/runtime.py
src/jinja2/utils.py
tests/conftest.py
tests/test_async.py
tests/test_async_filters.py [moved from tests/test_asyncfilters.py with 99% similarity]

index 1c58908c0ff643d4e91948285429bcc49ce67434..e822841f52a4da39fadb74ab196c767931b0ace0 100644 (file)
@@ -70,6 +70,10 @@ Unreleased
     -   ``pass_environment`` replaces ``environmentfunction`` and
         ``environmentfilter``.
 
+-   Async support no longer requires Jinja to patch itself. It must
+    still be enabled with ``Environment(enable_async=True)``.
+    :issue:`1390`
+
 
 Version 2.11.3
 --------------
index 25c2b580dd663df1e477bd7f9bed5bd45a336bf4..56446a202d7d4197837c40ff9473ff88364ed3c8 100644 (file)
@@ -12,8 +12,8 @@ It includes:
 -   HTML templates can use autoescaping to prevent XSS from untrusted
     user input.
 -   A sandboxed environment can safely render untrusted templates.
--   AsyncIO support for generating templates and calling async
-    functions.
+-   Async support for generating templates that automatically handle
+    sync and async functions without extra syntax.
 -   I18N support with Babel.
 -   Templates are compiled to optimized Python code just-in-time and
     cached, or can be compiled ahead-of-time.
diff --git a/src/jinja2/async_utils.py b/src/jinja2/async_utils.py
new file mode 100644 (file)
index 0000000..cb011b2
--- /dev/null
@@ -0,0 +1,76 @@
+import inspect
+import typing as t
+from functools import wraps
+
+from .utils import _PassArg
+from .utils import pass_eval_context
+
+if t.TYPE_CHECKING:
+    V = t.TypeVar("V")
+
+
+def async_variant(normal_func):
+    def decorator(async_func):
+        pass_arg = _PassArg.from_obj(normal_func)
+        need_eval_context = pass_arg is None
+
+        if pass_arg is _PassArg.environment:
+
+            def is_async(args):
+                return args[0].is_async
+
+        else:
+
+            def is_async(args):
+                return args[0].environment.is_async
+
+        @wraps(normal_func)
+        def wrapper(*args, **kwargs):
+            b = is_async(args)
+
+            if need_eval_context:
+                args = args[1:]
+
+            if b:
+                return async_func(*args, **kwargs)
+
+            return normal_func(*args, **kwargs)
+
+        if need_eval_context:
+            wrapper = pass_eval_context(wrapper)
+
+        wrapper.jinja_async_variant = True
+        return wrapper
+
+    return decorator
+
+
+async def auto_await(value):
+    if inspect.isawaitable(value):
+        return await value
+
+    return value
+
+
+async def auto_aiter(iterable):
+    if hasattr(iterable, "__aiter__"):
+        async for item in iterable:
+            yield item
+    else:
+        for item in iterable:
+            yield item
+
+
+async def auto_to_list(
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+) -> "t.List[V]":
+    seq = []
+
+    if hasattr(value, "__aiter__"):
+        async for item in t.cast(t.AsyncIterable, value):
+            seq.append(item)
+    else:
+        for item in t.cast(t.Iterable, value):
+            seq.append(item)
+
+    return seq
diff --git a/src/jinja2/asyncfilters.py b/src/jinja2/asyncfilters.py
deleted file mode 100644 (file)
index 00cae01..0000000
+++ /dev/null
@@ -1,261 +0,0 @@
-import typing
-import typing as t
-import warnings
-from functools import wraps
-from itertools import groupby
-
-from . import filters
-from .asyncsupport import auto_aiter
-from .asyncsupport import auto_await
-from .utils import _PassArg
-from .utils import pass_eval_context
-
-if t.TYPE_CHECKING:
-    from .environment import Environment
-    from .nodes import EvalContext
-    from .runtime import Context
-    from .runtime import Undefined
-
-    V = t.TypeVar("V")
-
-
-async def auto_to_seq(
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-) -> "t.List[V]":
-    seq = []
-
-    if hasattr(value, "__aiter__"):
-        async for item in t.cast(t.AsyncIterable, value):
-            seq.append(item)
-    else:
-        for item in t.cast(t.Iterable, value):
-            seq.append(item)
-
-    return seq
-
-
-async def async_select_or_reject(
-    context: "Context",
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    args: t.Tuple,
-    kwargs: t.Dict[str, t.Any],
-    modfunc: t.Callable[[t.Any], t.Any],
-    lookup_attr: bool,
-) -> "t.AsyncIterator[V]":
-    if value:
-        func = filters.prepare_select_or_reject(
-            context, args, kwargs, modfunc, lookup_attr
-        )
-
-        async for item in auto_aiter(value):
-            if func(item):
-                yield item
-
-
-def dual_filter(normal_func, async_func):
-    pass_arg = _PassArg.from_obj(normal_func)
-    wrapper_has_eval_context = False
-
-    if pass_arg is _PassArg.environment:
-        wrapper_has_eval_context = False
-
-        def is_async(args):
-            return args[0].is_async
-
-    else:
-        wrapper_has_eval_context = pass_arg is None
-
-        def is_async(args):
-            return args[0].environment.is_async
-
-    @wraps(normal_func)
-    def wrapper(*args, **kwargs):
-        b = is_async(args)
-
-        if wrapper_has_eval_context:
-            args = args[1:]
-
-        if b:
-            return async_func(*args, **kwargs)
-
-        return normal_func(*args, **kwargs)
-
-    if wrapper_has_eval_context:
-        wrapper = pass_eval_context(wrapper)
-
-    wrapper.jinja_async_variant = True
-    return wrapper
-
-
-def async_variant(original):
-    def decorator(f):
-        return dual_filter(original, f)
-
-    return decorator
-
-
-def asyncfiltervariant(original):
-    warnings.warn(
-        "'asyncfiltervariant' is renamed to 'async_variant', the old"
-        " name will be removed in Jinja 3.1.",
-        DeprecationWarning,
-        stacklevel=2,
-    )
-    return async_variant(original)
-
-
-@async_variant(filters.do_first)
-async def do_first(
-    environment: "Environment", seq: "t.Union[t.AsyncIterable[V], t.Iterable[V]]"
-) -> "t.Union[V, Undefined]":
-    try:
-        return t.cast("V", await auto_aiter(seq).__anext__())
-    except StopAsyncIteration:
-        return environment.undefined("No first item, sequence was empty.")
-
-
-@async_variant(filters.do_groupby)
-async def do_groupby(
-    environment: "Environment",
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    attribute: t.Union[str, int],
-    default: t.Optional[t.Any] = None,
-) -> "t.List[t.Tuple[t.Any, t.List[V]]]":
-    expr = filters.make_attrgetter(environment, attribute, default=default)
-    return [
-        filters._GroupTuple(key, await auto_to_seq(values))
-        for key, values in groupby(sorted(await auto_to_seq(value), key=expr), expr)
-    ]
-
-
-@async_variant(filters.do_join)
-async def do_join(
-    eval_ctx: "EvalContext",
-    value: t.Union[t.AsyncIterable, t.Iterable],
-    d: str = "",
-    attribute: t.Optional[t.Union[str, int]] = None,
-) -> str:
-    return filters.do_join(eval_ctx, await auto_to_seq(value), d, attribute)
-
-
-@async_variant(filters.do_list)
-async def do_list(value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]") -> "t.List[V]":
-    return await auto_to_seq(value)
-
-
-@async_variant(filters.do_reject)
-async def do_reject(
-    context: "Context",
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    *args: t.Any,
-    **kwargs: t.Any,
-) -> "t.AsyncIterator[V]":
-    return async_select_or_reject(context, value, args, kwargs, lambda x: not x, False)
-
-
-@async_variant(filters.do_rejectattr)
-async def do_rejectattr(
-    context: "Context",
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    *args: t.Any,
-    **kwargs: t.Any,
-) -> "t.AsyncIterator[V]":
-    return async_select_or_reject(context, value, args, kwargs, lambda x: not x, True)
-
-
-@async_variant(filters.do_select)
-async def do_select(
-    context: "Context",
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    *args: t.Any,
-    **kwargs: t.Any,
-) -> "t.AsyncIterator[V]":
-    return async_select_or_reject(context, value, args, kwargs, lambda x: x, False)
-
-
-@async_variant(filters.do_selectattr)
-async def do_selectattr(
-    context: "Context",
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    *args: t.Any,
-    **kwargs: t.Any,
-) -> "t.AsyncIterator[V]":
-    return async_select_or_reject(context, value, args, kwargs, lambda x: x, True)
-
-
-@typing.overload
-def do_map(
-    context: "Context",
-    value: t.Union[t.AsyncIterable, t.Iterable],
-    name: str,
-    *args: t.Any,
-    **kwargs: t.Any,
-) -> t.Iterable:
-    ...
-
-
-@typing.overload
-def do_map(
-    context: "Context",
-    value: t.Union[t.AsyncIterable, t.Iterable],
-    *,
-    attribute: str = ...,
-    default: t.Optional[t.Any] = None,
-) -> t.Iterable:
-    ...
-
-
-@async_variant(filters.do_map)
-async def do_map(context, value, *args, **kwargs):
-    if value:
-        func = filters.prepare_map(context, args, kwargs)
-
-        async for item in auto_aiter(value):
-            yield await auto_await(func(item))
-
-
-@async_variant(filters.do_sum)
-async def do_sum(
-    environment: "Environment",
-    iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    attribute: t.Optional[t.Union[str, int]] = None,
-    start: "V" = 0,  # type: ignore
-) -> "V":
-    rv = start
-
-    if attribute is not None:
-        func = filters.make_attrgetter(environment, attribute)
-    else:
-
-        def func(x):
-            return x
-
-    async for item in auto_aiter(iterable):
-        rv += func(item)
-
-    return rv
-
-
-@async_variant(filters.do_slice)
-async def do_slice(
-    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
-    slices: int,
-    fill_with: t.Optional[t.Any] = None,
-) -> "t.Iterator[t.List[V]]":
-    return filters.do_slice(await auto_to_seq(value), slices, fill_with)
-
-
-ASYNC_FILTERS = {
-    "first": do_first,
-    "groupby": do_groupby,
-    "join": do_join,
-    "list": do_list,
-    # we intentionally do not support do_last because it may not be safe in async
-    "reject": do_reject,
-    "rejectattr": do_rejectattr,
-    "map": do_map,
-    "select": do_select,
-    "selectattr": do_selectattr,
-    "sum": do_sum,
-    "slice": do_slice,
-}
diff --git a/src/jinja2/asyncsupport.py b/src/jinja2/asyncsupport.py
deleted file mode 100644 (file)
index e46a85a..0000000
+++ /dev/null
@@ -1,249 +0,0 @@
-"""The code for async support. Importing this patches Jinja."""
-import asyncio
-import inspect
-from functools import update_wrapper
-
-from markupsafe import Markup
-
-from .environment import TemplateModule
-from .runtime import LoopContext
-from .utils import concat
-from .utils import internalcode
-from .utils import missing
-
-
-async def concat_async(async_gen):
-    rv = []
-
-    async def collect():
-        async for event in async_gen:
-            rv.append(event)
-
-    await collect()
-    return concat(rv)
-
-
-async def generate_async(self, *args, **kwargs):
-    vars = dict(*args, **kwargs)
-    try:
-        async for event in self.root_render_func(self.new_context(vars)):
-            yield event
-    except Exception:
-        yield self.environment.handle_exception()
-
-
-def wrap_generate_func(original_generate):
-    def _convert_generator(self, loop, args, kwargs):
-        async_gen = self.generate_async(*args, **kwargs)
-        try:
-            while 1:
-                yield loop.run_until_complete(async_gen.__anext__())
-        except StopAsyncIteration:
-            pass
-
-    def generate(self, *args, **kwargs):
-        if not self.environment.is_async:
-            return original_generate(self, *args, **kwargs)
-        return _convert_generator(self, asyncio.get_event_loop(), args, kwargs)
-
-    return update_wrapper(generate, original_generate)
-
-
-async def render_async(self, *args, **kwargs):
-    if not self.environment.is_async:
-        raise RuntimeError("The environment was not created with async mode enabled.")
-
-    vars = dict(*args, **kwargs)
-    ctx = self.new_context(vars)
-
-    try:
-        return await concat_async(self.root_render_func(ctx))
-    except Exception:
-        return self.environment.handle_exception()
-
-
-def wrap_render_func(original_render):
-    def render(self, *args, **kwargs):
-        if not self.environment.is_async:
-            return original_render(self, *args, **kwargs)
-        loop = asyncio.get_event_loop()
-        return loop.run_until_complete(self.render_async(*args, **kwargs))
-
-    return update_wrapper(render, original_render)
-
-
-def wrap_block_reference_call(original_call):
-    @internalcode
-    async def async_call(self):
-        rv = await concat_async(self._stack[self._depth](self._context))
-        if self._context.eval_ctx.autoescape:
-            rv = Markup(rv)
-        return rv
-
-    @internalcode
-    def __call__(self):
-        if not self._context.environment.is_async:
-            return original_call(self)
-        return async_call(self)
-
-    return update_wrapper(__call__, original_call)
-
-
-def wrap_macro_invoke(original_invoke):
-    @internalcode
-    async def async_invoke(self, arguments, autoescape):
-        rv = await self._func(*arguments)
-        if autoescape:
-            rv = Markup(rv)
-        return rv
-
-    @internalcode
-    def _invoke(self, arguments, autoescape):
-        if not self._environment.is_async:
-            return original_invoke(self, arguments, autoescape)
-        return async_invoke(self, arguments, autoescape)
-
-    return update_wrapper(_invoke, original_invoke)
-
-
-@internalcode
-async def get_default_module_async(self):
-    if self._module is not None:
-        return self._module
-    self._module = rv = await self.make_module_async()
-    return rv
-
-
-def wrap_default_module(original_default_module):
-    @internalcode
-    def _get_default_module(self, ctx=None):
-        if self.environment.is_async:
-            raise RuntimeError("Template module attribute is unavailable in async mode")
-        return original_default_module(self, ctx)
-
-    return _get_default_module
-
-
-async def make_module_async(self, vars=None, shared=False, locals=None):
-    context = self.new_context(vars, shared, locals)
-    body_stream = []
-    async for item in self.root_render_func(context):
-        body_stream.append(item)
-    return TemplateModule(self, context, body_stream)
-
-
-def patch_template():
-    from . import Template
-
-    Template.generate = wrap_generate_func(Template.generate)
-    Template.generate_async = update_wrapper(generate_async, Template.generate_async)
-    Template.render_async = update_wrapper(render_async, Template.render_async)
-    Template.render = wrap_render_func(Template.render)
-    Template._get_default_module = wrap_default_module(Template._get_default_module)
-    Template._get_default_module_async = get_default_module_async
-    Template.make_module_async = update_wrapper(
-        make_module_async, Template.make_module_async
-    )
-
-
-def patch_runtime():
-    from .runtime import BlockReference, Macro
-
-    BlockReference.__call__ = wrap_block_reference_call(BlockReference.__call__)
-    Macro._invoke = wrap_macro_invoke(Macro._invoke)
-
-
-def patch_filters():
-    from .filters import FILTERS
-    from .asyncfilters import ASYNC_FILTERS
-
-    FILTERS.update(ASYNC_FILTERS)
-
-
-def patch_all():
-    patch_template()
-    patch_runtime()
-    patch_filters()
-
-
-async def auto_await(value):
-    if inspect.isawaitable(value):
-        return await value
-    return value
-
-
-async def auto_aiter(iterable):
-    if hasattr(iterable, "__aiter__"):
-        async for item in iterable:
-            yield item
-        return
-    for item in iterable:
-        yield item
-
-
-class AsyncLoopContext(LoopContext):
-    _to_iterator = staticmethod(auto_aiter)
-
-    @property
-    async def length(self):
-        if self._length is not None:
-            return self._length
-
-        try:
-            self._length = len(self._iterable)
-        except TypeError:
-            iterable = [x async for x in self._iterator]
-            self._iterator = self._to_iterator(iterable)
-            self._length = len(iterable) + self.index + (self._after is not missing)
-
-        return self._length
-
-    @property
-    async def revindex0(self):
-        return await self.length - self.index
-
-    @property
-    async def revindex(self):
-        return await self.length - self.index0
-
-    async def _peek_next(self):
-        if self._after is not missing:
-            return self._after
-
-        try:
-            self._after = await self._iterator.__anext__()
-        except StopAsyncIteration:
-            self._after = missing
-
-        return self._after
-
-    @property
-    async def last(self):
-        return await self._peek_next() is missing
-
-    @property
-    async def nextitem(self):
-        rv = await self._peek_next()
-
-        if rv is missing:
-            return self._undefined("there is no next item")
-
-        return rv
-
-    def __aiter__(self):
-        return self
-
-    async def __anext__(self):
-        if self._after is not missing:
-            rv = self._after
-            self._after = missing
-        else:
-            rv = await self._iterator.__anext__()
-
-        self.index0 += 1
-        self._before = self._current
-        self._current = rv
-        return rv, self
-
-
-patch_all()
index 1d73f7d43ec7f2e06221c9bf92f1dbf64f8bc264..b15fb670d2ff0dfecf96ad6e650ade2b5d95fc80 100644 (file)
@@ -727,16 +727,15 @@ class CodeGenerator(NodeVisitor):
         assert frame is None, "no root frame allowed"
         eval_ctx = EvalContext(self.environment, self.name)
 
-        from .runtime import exported
-
-        self.writeline("from __future__ import generator_stop")  # Python < 3.7
-        self.writeline("from jinja2.runtime import " + ", ".join(exported))
+        from .runtime import exported, async_exported
 
         if self.environment.is_async:
-            self.writeline(
-                "from jinja2.asyncsupport import auto_await, "
-                "auto_aiter, AsyncLoopContext"
-            )
+            exported_names = sorted(exported + async_exported)
+        else:
+            exported_names = sorted(exported)
+
+        self.writeline("from __future__ import generator_stop")  # Python < 3.7
+        self.writeline("from jinja2.runtime import " + ", ".join(exported_names))
 
         # if we want a deferred initialization we cannot move the
         # environment into a local name
index 2a64a0ab19389e77cab1110915c89eea8e3da73b..ae687385abdb6786dbde6302247be0212247d036 100644 (file)
@@ -45,7 +45,6 @@ from .runtime import Undefined
 from .utils import _PassArg
 from .utils import concat
 from .utils import consume
-from .utils import have_async_gen
 from .utils import import_string
 from .utils import internalcode
 from .utils import LRUCache
@@ -342,12 +341,7 @@ class Environment:
         # load extensions
         self.extensions = load_extensions(self, extensions)
 
-        self.enable_async = enable_async
-        self.is_async = self.enable_async and have_async_gen
-        if self.is_async:
-            # runs patch_all() to enable async support
-            from . import asyncsupport  # noqa: F401
-
+        self.is_async = enable_async
         _environment_sanity_check(self)
 
     def add_extension(self, extension):
@@ -1119,13 +1113,20 @@ class Template:
 
         This will return the rendered template as a string.
         """
-        vars = dict(*args, **kwargs)
+        if self.environment.is_async:
+            import asyncio
+
+            loop = asyncio.get_event_loop()
+            return loop.run_until_complete(self.render_async(*args, **kwargs))
+
+        ctx = self.new_context(dict(*args, **kwargs))
+
         try:
-            return concat(self.root_render_func(self.new_context(vars)))
+            return concat(self.root_render_func(ctx))
         except Exception:
             self.environment.handle_exception()
 
-    def render_async(self, *args, **kwargs):
+    async def render_async(self, *args, **kwargs):
         """This works similar to :meth:`render` but returns a coroutine
         that when awaited returns the entire rendered template string.  This
         requires the async feature to be enabled.
@@ -1134,10 +1135,17 @@ class Template:
 
             await template.render_async(knights='that say nih; asynchronously')
         """
-        # see asyncsupport for the actual implementation
-        raise NotImplementedError(
-            "This feature is not available for this version of Python"
-        )
+        if not self.environment.is_async:
+            raise RuntimeError(
+                "The environment was not created with async mode enabled."
+            )
+
+        ctx = self.new_context(dict(*args, **kwargs))
+
+        try:
+            return concat([n async for n in self.root_render_func(ctx)])
+        except Exception:
+            return self.environment.handle_exception()
 
     def stream(self, *args, **kwargs):
         """Works exactly like :meth:`generate` but returns a
@@ -1153,20 +1161,41 @@ class Template:
 
         It accepts the same arguments as :meth:`render`.
         """
-        vars = dict(*args, **kwargs)
+        if self.environment.is_async:
+            import asyncio
+
+            loop = asyncio.get_event_loop()
+            async_gen = self.generate_async(*args, **kwargs)
+
+            try:
+                while True:
+                    yield loop.run_until_complete(async_gen.__anext__())
+            except StopAsyncIteration:
+                return
+
+        ctx = self.new_context(dict(*args, **kwargs))
+
         try:
-            yield from self.root_render_func(self.new_context(vars))
+            yield from self.root_render_func(ctx)
         except Exception:
             yield self.environment.handle_exception()
 
-    def generate_async(self, *args, **kwargs):
+    async def generate_async(self, *args, **kwargs):
         """An async version of :meth:`generate`.  Works very similarly but
         returns an async iterator instead.
         """
-        # see asyncsupport for the actual implementation
-        raise NotImplementedError(
-            "This feature is not available for this version of Python"
-        )
+        if not self.environment.is_async:
+            raise RuntimeError(
+                "The environment was not created with async mode enabled."
+            )
+
+        ctx = self.new_context(dict(*args, **kwargs))
+
+        try:
+            async for event in self.root_render_func(ctx):
+                yield event
+        except Exception:
+            yield self.environment.handle_exception()
 
     def new_context(self, vars=None, shared=False, locals=None):
         """Create a new :class:`Context` for this template.  The vars
@@ -1187,42 +1216,56 @@ class Template:
         a dict which is then used as context.  The arguments are the same
         as for the :meth:`new_context` method.
         """
-        return TemplateModule(self, self.new_context(vars, shared, locals))
+        ctx = self.new_context(vars, shared, locals)
+        return TemplateModule(self, ctx)
 
-    def make_module_async(self, vars=None, shared=False, locals=None):
+    async def make_module_async(self, vars=None, shared=False, locals=None):
         """As template module creation can invoke template code for
         asynchronous executions this method must be used instead of the
         normal :meth:`make_module` one.  Likewise the module attribute
         becomes unavailable in async mode.
         """
-        # see asyncsupport for the actual implementation
-        raise NotImplementedError(
-            "This feature is not available for this version of Python"
-        )
+        ctx = self.new_context(vars, shared, locals)
+        return TemplateModule(self, ctx, [x async for x in self.root_render_func(ctx)])
 
     @internalcode
     def _get_default_module(self, ctx=None):
         """If a context is passed in, this means that the template was
-        imported.  Imported templates have access to the current template's
-        globals by default, but they can only be accessed via the context
-        during runtime.
-
-        If there are new globals, we need to create a new
-        module because the cached module is already rendered and will not have
-        access to globals from the current context.  This new module is not
-        cached as :attr:`_module` because the template can be imported elsewhere,
-        and it should have access to only the current template's globals.
+        imported. Imported templates have access to the current
+        template's globals by default, but they can only be accessed via
+        the context during runtime.
+
+        If there are new globals, we need to create a new module because
+        the cached module is already rendered and will not have access
+        to globals from the current context. This new module is not
+        cached because the template can be imported elsewhere, and it
+        should have access to only the current template's globals.
         """
+        if self.environment.is_async:
+            raise RuntimeError("Module is not available in async mode.")
+
         if ctx is not None:
-            globals = {
-                key: ctx.parent[key] for key in ctx.globals_keys - self.globals.keys()
-            }
-            if globals:
-                return self.make_module(globals)
-        if self._module is not None:
-            return self._module
-        self._module = rv = self.make_module()
-        return rv
+            keys = ctx.globals_keys - self.globals.keys()
+
+            if keys:
+                return self.make_module({k: ctx.parent[k] for k in keys})
+
+        if self._module is None:
+            self._module = self.make_module()
+
+        return self._module
+
+    async def _get_default_module_async(self, ctx=None):
+        if ctx is not None:
+            keys = ctx.globals_keys - self.globals.keys()
+
+            if keys:
+                return await self.make_module_async({k: ctx.parent[k] for k in keys})
+
+        if self._module is None:
+            self._module = await self.make_module_async()
+
+        return self._module
 
     @property
     def module(self):
index 82f2ff2137be57700ae16df83e888eda09eb893c..8aa11c2bd30c719a8110083282b273ab3b408949 100644 (file)
@@ -13,6 +13,10 @@ from markupsafe import escape
 from markupsafe import Markup
 from markupsafe import soft_str
 
+from .async_utils import async_variant
+from .async_utils import auto_aiter
+from .async_utils import auto_await
+from .async_utils import auto_to_list
 from .exceptions import FilterArgumentError
 from .runtime import Undefined
 from .utils import htmlsafe_json_dumps
@@ -550,7 +554,7 @@ def do_default(
 
 
 @pass_eval_context
-def do_join(
+def sync_do_join(
     eval_ctx: "EvalContext",
     value: t.Iterable,
     d: str = "",
@@ -607,13 +611,23 @@ def do_join(
     return soft_str(d).join(map(soft_str, value))
 
 
+@async_variant(sync_do_join)
+async def do_join(
+    eval_ctx: "EvalContext",
+    value: t.Union[t.AsyncIterable, t.Iterable],
+    d: str = "",
+    attribute: t.Optional[t.Union[str, int]] = None,
+) -> str:
+    return sync_do_join(eval_ctx, await auto_to_list(value), d, attribute)
+
+
 def do_center(value: str, width: int = 80) -> str:
     """Centers the value in a field of a given width."""
     return soft_str(value).center(width)
 
 
 @pass_environment
-def do_first(
+def sync_do_first(
     environment: "Environment", seq: "t.Iterable[V]"
 ) -> "t.Union[V, Undefined]":
     """Return the first item of a sequence."""
@@ -623,6 +637,16 @@ def do_first(
         return environment.undefined("No first item, sequence was empty.")
 
 
+@async_variant(sync_do_first)
+async def do_first(
+    environment: "Environment", seq: "t.Union[t.AsyncIterable[V], t.Iterable[V]]"
+) -> "t.Union[V, Undefined]":
+    try:
+        return t.cast("V", await auto_aiter(seq).__anext__())
+    except StopAsyncIteration:
+        return environment.undefined("No first item, sequence was empty.")
+
+
 @pass_environment
 def do_last(
     environment: "Environment", seq: "t.Reversible[V]"
@@ -642,6 +666,9 @@ def do_last(
         return environment.undefined("No last item, sequence was empty.")
 
 
+# No async do_last, it may not be safe in async mode.
+
+
 @pass_context
 def do_random(context: "Context", seq: "t.Sequence[V]") -> "t.Union[V, Undefined]":
     """Return a random item from the sequence."""
@@ -1006,7 +1033,7 @@ def do_striptags(value: "t.Union[str, HasHTML]") -> str:
     return Markup(str(value)).striptags()
 
 
-def do_slice(
+def sync_do_slice(
     value: "t.Collection[V]", slices: int, fill_with: "t.Optional[V]" = None
 ) -> "t.Iterator[t.List[V]]":
     """Slice an iterator and return a list of lists containing
@@ -1049,6 +1076,15 @@ def do_slice(
         yield tmp
 
 
+@async_variant(sync_do_slice)
+async def do_slice(
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    slices: int,
+    fill_with: t.Optional[t.Any] = None,
+) -> "t.Iterator[t.List[V]]":
+    return sync_do_slice(await auto_to_list(value), slices, fill_with)
+
+
 def do_batch(
     value: "t.Iterable[V]", linecount: int, fill_with: "t.Optional[V]" = None
 ) -> "t.Iterator[t.List[V]]":
@@ -1140,7 +1176,7 @@ class _GroupTuple(t.NamedTuple):
 
 
 @pass_environment
-def do_groupby(
+def sync_do_groupby(
     environment: "Environment",
     value: "t.Iterable[V]",
     attribute: t.Union[str, int],
@@ -1198,8 +1234,22 @@ def do_groupby(
     ]
 
 
+@async_variant(sync_do_groupby)
+async def do_groupby(
+    environment: "Environment",
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    attribute: t.Union[str, int],
+    default: t.Optional[t.Any] = None,
+) -> "t.List[t.Tuple[t.Any, t.List[V]]]":
+    expr = make_attrgetter(environment, attribute, default=default)
+    return [
+        _GroupTuple(key, await auto_to_list(values))
+        for key, values in groupby(sorted(await auto_to_list(value), key=expr), expr)
+    ]
+
+
 @pass_environment
-def do_sum(
+def sync_do_sum(
     environment: "Environment",
     iterable: "t.Iterable[V]",
     attribute: t.Optional[t.Union[str, int]] = None,
@@ -1225,13 +1275,40 @@ def do_sum(
     return sum(iterable, start)
 
 
-def do_list(value: "t.Iterable[V]") -> "t.List[V]":
+@async_variant(sync_do_sum)
+async def do_sum(
+    environment: "Environment",
+    iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    attribute: t.Optional[t.Union[str, int]] = None,
+    start: "V" = 0,  # type: ignore
+) -> "V":
+    rv = start
+
+    if attribute is not None:
+        func = make_attrgetter(environment, attribute)
+    else:
+
+        def func(x):
+            return x
+
+    async for item in auto_aiter(iterable):
+        rv += func(item)
+
+    return rv
+
+
+def sync_do_list(value: "t.Iterable[V]") -> "t.List[V]":
     """Convert the value into a list.  If it was a string the returned list
     will be a list of characters.
     """
     return list(value)
 
 
+@async_variant(sync_do_list)
+async def do_list(value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]") -> "t.List[V]":
+    return await auto_to_list(value)
+
+
 def do_mark_safe(value: str) -> Markup:
     """Mark the value as safe which means that in an environment with automatic
     escaping enabled this variable will not be escaped.
@@ -1304,14 +1381,14 @@ def do_attr(
 
 
 @typing.overload
-def do_map(
+def sync_do_map(
     context: "Context", value: t.Iterable, name: str, *args: t.Any, **kwargs: t.Any
 ) -> t.Iterable:
     ...
 
 
 @typing.overload
-def do_map(
+def sync_do_map(
     context: "Context",
     value: t.Iterable,
     *,
@@ -1322,7 +1399,7 @@ def do_map(
 
 
 @pass_context
-def do_map(context, value, *args, **kwargs):
+def sync_do_map(context, value, *args, **kwargs):
     """Applies a filter on a sequence of objects or looks up an attribute.
     This is useful when dealing with lists of objects but you are really
     only interested in a certain value of it.
@@ -1369,8 +1446,39 @@ def do_map(context, value, *args, **kwargs):
             yield func(item)
 
 
+@typing.overload
+def do_map(
+    context: "Context",
+    value: t.Union[t.AsyncIterable, t.Iterable],
+    name: str,
+    *args: t.Any,
+    **kwargs: t.Any,
+) -> t.Iterable:
+    ...
+
+
+@typing.overload
+def do_map(
+    context: "Context",
+    value: t.Union[t.AsyncIterable, t.Iterable],
+    *,
+    attribute: str = ...,
+    default: t.Optional[t.Any] = None,
+) -> t.Iterable:
+    ...
+
+
+@async_variant(sync_do_map)
+async def do_map(context, value, *args, **kwargs):
+    if value:
+        func = prepare_map(context, args, kwargs)
+
+        async for item in auto_aiter(value):
+            yield await auto_await(func(item))
+
+
 @pass_context
-def do_select(
+def sync_do_select(
     context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any
 ) -> "t.Iterator[V]":
     """Filters a sequence of objects by applying a test to each object,
@@ -1400,8 +1508,18 @@ def do_select(
     return select_or_reject(context, value, args, kwargs, lambda x: x, False)
 
 
+@async_variant(sync_do_select)
+async def do_select(
+    context: "Context",
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    *args: t.Any,
+    **kwargs: t.Any,
+) -> "t.AsyncIterator[V]":
+    return async_select_or_reject(context, value, args, kwargs, lambda x: x, False)
+
+
 @pass_context
-def do_reject(
+def sync_do_reject(
     context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any
 ) -> "t.Iterator[V]":
     """Filters a sequence of objects by applying a test to each object,
@@ -1426,8 +1544,18 @@ def do_reject(
     return select_or_reject(context, value, args, kwargs, lambda x: not x, False)
 
 
+@async_variant(sync_do_reject)
+async def do_reject(
+    context: "Context",
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    *args: t.Any,
+    **kwargs: t.Any,
+) -> "t.AsyncIterator[V]":
+    return async_select_or_reject(context, value, args, kwargs, lambda x: not x, False)
+
+
 @pass_context
-def do_selectattr(
+def sync_do_selectattr(
     context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any
 ) -> "t.Iterator[V]":
     """Filters a sequence of objects by applying a test to the specified
@@ -1456,8 +1584,18 @@ def do_selectattr(
     return select_or_reject(context, value, args, kwargs, lambda x: x, True)
 
 
+@async_variant(sync_do_selectattr)
+async def do_selectattr(
+    context: "Context",
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    *args: t.Any,
+    **kwargs: t.Any,
+) -> "t.AsyncIterator[V]":
+    return async_select_or_reject(context, value, args, kwargs, lambda x: x, True)
+
+
 @pass_context
-def do_rejectattr(
+def sync_do_rejectattr(
     context: "Context", value: "t.Iterable[V]", *args: t.Any, **kwargs: t.Any
 ) -> "t.Iterator[V]":
     """Filters a sequence of objects by applying a test to the specified
@@ -1484,6 +1622,16 @@ def do_rejectattr(
     return select_or_reject(context, value, args, kwargs, lambda x: not x, True)
 
 
+@async_variant(sync_do_rejectattr)
+async def do_rejectattr(
+    context: "Context",
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    *args: t.Any,
+    **kwargs: t.Any,
+) -> "t.AsyncIterator[V]":
+    return async_select_or_reject(context, value, args, kwargs, lambda x: not x, True)
+
+
 @pass_eval_context
 def do_tojson(
     eval_ctx: "EvalContext", value: t.Any, indent: t.Optional[int] = None
@@ -1591,6 +1739,22 @@ def select_or_reject(
                 yield item
 
 
+async def async_select_or_reject(
+    context: "Context",
+    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
+    args: t.Tuple,
+    kwargs: t.Dict[str, t.Any],
+    modfunc: t.Callable[[t.Any], t.Any],
+    lookup_attr: bool,
+) -> "t.AsyncIterator[V]":
+    if value:
+        func = prepare_select_or_reject(context, args, kwargs, modfunc, lookup_attr)
+
+        async for item in auto_aiter(value):
+            if func(item):
+                yield item
+
+
 FILTERS = {
     "abs": abs,
     "attr": do_attr,
index 8867a3165fbdaa2a321c88f77f1c5827e9626f21..6cca518c352de87add903d195c3010b8fb82b47e 100644 (file)
@@ -86,10 +86,10 @@ class NativeTemplate(Template):
         with :func:`ast.literal_eval`, the parsed value is returned.
         Otherwise, the string is returned.
         """
-        vars = dict(*args, **kwargs)
+        ctx = self.new_context(dict(*args, **kwargs))
 
         try:
-            return native_concat(self.root_render_func(self.new_context(vars)))
+            return native_concat(self.root_render_func(ctx))
         except Exception:
             return self.environment.handle_exception()
 
@@ -99,8 +99,7 @@ class NativeTemplate(Template):
                 "The environment was not created with async mode enabled."
             )
 
-        vars = dict(*args, **kwargs)
-        ctx = self.new_context(vars)
+        ctx = self.new_context(dict(*args, **kwargs))
 
         try:
             return native_concat([n async for n in self.root_render_func(ctx)])
index 3d55819408a5c06029ec46971353b60dea3ef33f..0ce493047cbee6ab2dd4f0472bfcf4e43fb0b484 100644 (file)
@@ -9,6 +9,8 @@ from markupsafe import escape  # noqa: F401
 from markupsafe import Markup
 from markupsafe import soft_str
 
+from .async_utils import auto_aiter
+from .async_utils import auto_await  # noqa: F401
 from .exceptions import TemplateNotFound  # noqa: F401
 from .exceptions import TemplateRuntimeError  # noqa: F401
 from .exceptions import UndefinedError
@@ -42,6 +44,11 @@ exported = [
     "Undefined",
     "internalcode",
 ]
+async_exported = [
+    "AsyncLoopContext",
+    "auto_aiter",
+    "auto_await",
+]
 
 
 def identity(x):
@@ -368,11 +375,25 @@ class BlockReference:
             )
         return BlockReference(self.name, self._context, self._stack, self._depth + 1)
 
+    @internalcode
+    async def _async_call(self):
+        rv = concat([x async for x in self._stack[self._depth](self._context)])
+
+        if self._context.eval_ctx.autoescape:
+            return Markup(rv)
+
+        return rv
+
     @internalcode
     def __call__(self):
+        if self._context.environment.is_async:
+            return self._async_call()
+
         rv = concat(self._stack[self._depth](self._context))
+
         if self._context.eval_ctx.autoescape:
-            rv = Markup(rv)
+            return Markup(rv)
+
         return rv
 
 
@@ -567,6 +588,73 @@ class LoopContext:
         return f"<{self.__class__.__name__} {self.index}/{self.length}>"
 
 
+class AsyncLoopContext(LoopContext):
+    @staticmethod
+    def _to_iterator(iterable):
+        return auto_aiter(iterable)
+
+    @property
+    async def length(self):
+        if self._length is not None:
+            return self._length
+
+        try:
+            self._length = len(self._iterable)
+        except TypeError:
+            iterable = [x async for x in self._iterator]
+            self._iterator = self._to_iterator(iterable)
+            self._length = len(iterable) + self.index + (self._after is not missing)
+
+        return self._length
+
+    @property
+    async def revindex0(self):
+        return await self.length - self.index
+
+    @property
+    async def revindex(self):
+        return await self.length - self.index0
+
+    async def _peek_next(self):
+        if self._after is not missing:
+            return self._after
+
+        try:
+            self._after = await self._iterator.__anext__()
+        except StopAsyncIteration:
+            self._after = missing
+
+        return self._after
+
+    @property
+    async def last(self):
+        return await self._peek_next() is missing
+
+    @property
+    async def nextitem(self):
+        rv = await self._peek_next()
+
+        if rv is missing:
+            return self._undefined("there is no next item")
+
+        return rv
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        if self._after is not missing:
+            rv = self._after
+            self._after = missing
+        else:
+            rv = await self._iterator.__anext__()
+
+        self.index0 += 1
+        self._before = self._current
+        self._current = rv
+        return rv, self
+
+
 class Macro:
     """Wraps a macro function."""
 
@@ -672,11 +760,23 @@ class Macro:
 
         return self._invoke(arguments, autoescape)
 
+    async def _async_invoke(self, arguments, autoescape):
+        rv = await self._func(*arguments)
+
+        if autoescape:
+            return Markup(rv)
+
+        return rv
+
     def _invoke(self, arguments, autoescape):
-        """This method is being swapped out by the async implementation."""
+        if self._environment.is_async:
+            return self._async_invoke(arguments, autoescape)
+
         rv = self._func(*arguments)
+
         if autoescape:
             rv = Markup(rv)
+
         return rv
 
     def __repr__(self):
index 80769a7329b9c2c3128e728adfb333b7506a5359..c49dbb53fc4a4e79c2714076eedc310fb926e448 100644 (file)
@@ -20,13 +20,10 @@ if t.TYPE_CHECKING:
 # special singleton representing missing values for the runtime
 missing = type("MissingType", (), {"__repr__": lambda x: "missing"})()
 
-# internal code
 internal_code: t.MutableSet[CodeType] = set()
 
 concat = "".join
 
-_slash_escape = "\\/" not in json.dumps("/")
-
 
 def pass_context(f: "F") -> "F":
     """Pass the :class:`~jinja2.runtime.Context` as the first argument
@@ -832,14 +829,6 @@ class Namespace:
         return f"<Namespace {self.__attrs!r}>"
 
 
-# does this python version support async for in and async generators?
-try:
-    exec("async def _():\n async for _ in ():\n  yield _")
-    have_async_gen = True
-except SyntaxError:
-    have_async_gen = False
-
-
 class Markup(markupsafe.Markup):
     def __init__(self, *args, **kwargs):
         warnings.warn(
index ce30d8b2938260d1f2a1d53ecb0219fefcd54530..ddcacc23c5d65e3cae99b9dfbc6f5e95a30ec118 100644 (file)
@@ -2,15 +2,8 @@ import os
 
 import pytest
 
-from jinja2 import Environment
 from jinja2 import loaders
-from jinja2.utils import have_async_gen
-
-
-def pytest_ignore_collect(path):
-    if "async" in path.basename and not have_async_gen:
-        return True
-    return False
+from jinja2.environment import Environment
 
 
 @pytest.fixture
index cd243fd8286fe88c1629acbb1f9475b64af7e6ad..f8be8dfffab41aeea496043a428a702d05d92f62 100644 (file)
@@ -6,7 +6,7 @@ from jinja2 import ChainableUndefined
 from jinja2 import DictLoader
 from jinja2 import Environment
 from jinja2 import Template
-from jinja2.asyncsupport import auto_aiter
+from jinja2.async_utils import auto_aiter
 from jinja2.exceptions import TemplateNotFound
 from jinja2.exceptions import TemplatesNotFound
 from jinja2.exceptions import UndefinedError
similarity index 99%
rename from tests/test_asyncfilters.py
rename to tests/test_async_filters.py
index f5fcbf29fd04cfd0767e4ef7910aea1bf8730908..5d4f332e5de17b82ce3fc990b2ef268ab702ae37 100644 (file)
@@ -4,7 +4,7 @@ import pytest
 from markupsafe import Markup
 
 from jinja2 import Environment
-from jinja2.asyncsupport import auto_aiter
+from jinja2.async_utils import auto_aiter
 
 
 async def make_aiter(iter):