]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
First pass on implementing async default module
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 14:17:10 +0000 (15:17 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 28 Dec 2016 14:17:10 +0000 (15:17 +0100)
jinja2/asyncsupport.py
jinja2/environment.py

index 5c2f300dbdf2b8e684b3a45c7e77d375adee6f94..e42a7131085b5ab91583327f98138c7ed9c9bfd5 100644 (file)
@@ -3,6 +3,7 @@ import asyncio
 import inspect
 
 from jinja2.utils import concat, internalcode, concat, Markup
+from jinja2.environment import TemplateModule
 
 
 async def concat_async(async_gen):
@@ -55,10 +56,40 @@ def wrap_block_reference_call(original_call):
     return __call__
 
 
+@internalcode
+async def get_default_module_async(self):
+    if self._module is not None:
+        return self._module
+    self._module = rv = await self.make_module_async()
+    return rv
+
+
+def wrap_default_module(original_default_module):
+    @internalcode
+    def _get_default_module(self):
+        if self.environment._async:
+            raise RuntimeError('Template module attribute is unavailable '
+                               'in async mode')
+        return original_default_module(self)
+    return _get_default_module
+
+
+async def make_module_async(self, vars=None, shared=False, locals=None):
+    context = self.new_context(vars, shared, locals)
+    body_stream = []
+    async for item in template.root_render_func(context):
+        body_stream.append(item)
+    return TemplateModule(self, context, body_stream)
+
+
 def patch_template():
     from jinja2 import Template
     Template.render_async = render_async
     Template.render = wrap_render_func(Template.render)
+    Template._get_default_module = wrap_default_module(
+        Template._get_default_module)
+    Template._get_default_module_async = get_default_module_async
+    Template.make_module_async = make_module_async
 
 
 def patch_runtime():
index 6c9fc193eaf6650973f1d95ee3ac33386b343606..c7dc642286b85a4871cbe55d2aab1ed7d6de136e 100644 (file)
@@ -1035,6 +1035,7 @@ class Template(object):
         """
         return TemplateModule(self, self.new_context(vars, shared, locals))
 
+    @internalcode
     def _get_default_module(self):
         if self._module is not None:
             return self._module
@@ -1092,8 +1093,13 @@ class TemplateModule(object):
     converting it into an unicode- or bytestrings renders the contents.
     """
 
-    def __init__(self, template, context):
-        self._body_stream = list(template.root_render_func(context))
+    def __init__(self, template, context, body_stream=None):
+        if body_stream is None:
+            if context.environment._async:
+                raise RuntimeError('Async mode requires a body stream '
+                                   'to be passed in.')
+            body_stream = list(template.root_render_func(context))
+        self._body_stream = body_stream
         self.__dict__.update(context.get_exported())
         self.__name__ = template.name