]> git.ipfire.org Git - thirdparty/jinja.git/commitdiff
extract common code for import/from nodes
authorDavid Lord <davidism@gmail.com>
Fri, 9 Apr 2021 23:49:55 +0000 (16:49 -0700)
committerDavid Lord <davidism@gmail.com>
Sat, 10 Apr 2021 15:58:43 +0000 (08:58 -0700)
src/jinja2/compiler.py

index 38d4fd39070703cda77d83fc59880f7c3b7e260f..7a15d8074d9544eb4610fc5747c6745b2dfd5a2e 100644 (file)
@@ -977,15 +977,11 @@ class CodeGenerator(NodeVisitor):
         if node.ignore_missing:
             self.outdent()
 
-    def visit_Import(self, node, frame):
-        """Visit regular imports."""
-        self.writeline(f"{frame.symbols.ref(node.target)} = ", node)
-        if frame.toplevel:
-            self.write(f"context.vars[{node.target!r}] = ")
-
+    def _import_common(self, node, frame):
         self.write(f"{self.choose_async('await ')}environment.get_template(")
         self.visit(node.template, frame)
         self.write(f", {self.name!r}).")
+
         if node.with_context:
             f_name = f"make_module{self.choose_async('_async')}"
             self.write(
@@ -995,26 +991,23 @@ class CodeGenerator(NodeVisitor):
             self.write("_get_default_module_async()")
         else:
             self.write("_get_default_module(context)")
+
+    def visit_Import(self, node, frame):
+        """Visit regular imports."""
+        self.writeline(f"{frame.symbols.ref(node.target)} = ", node)
+        if frame.toplevel:
+            self.write(f"context.vars[{node.target!r}] = ")
+
+        self._import_common(node, frame)
+
         if frame.toplevel and not node.target.startswith("_"):
             self.writeline(f"context.exported_vars.discard({node.target!r})")
 
     def visit_FromImport(self, node, frame):
         """Visit named imports."""
         self.newline(node)
-        prefix = self.choose_async("await ")
-        self.write(f"included_template = {prefix}environment.get_template(")
-        self.visit(node.template, frame)
-        self.write(f", {self.name!r}).")
-        if node.with_context:
-            f_name = f"make_module{self.choose_async('_async')}"
-            self.write(
-                f"{f_name}(context.get_all(), True, {self.dump_local_context(frame)})"
-            )
-        elif self.environment.is_async:
-            self.write("_get_default_module_async()")
-        else:
-            self.write("_get_default_module(context)")
-
+        self.write("included_template = ")
+        self._import_common(node, frame)
         var_names = []
         discarded_names = []
         for name in node.names: