]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
First pass at template autoescaping
authorBen Darnell <ben@bendarnell.com>
Mon, 30 May 2011 21:58:00 +0000 (14:58 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 30 May 2011 21:58:00 +0000 (14:58 -0700)
tornado/template.py
tornado/test/template_test.py

index 06c10ce30b3bb76203514440203f3eb897565d09..ee4cc9a621e2326d0c9f63f52e4c7affe81c0277 100644 (file)
@@ -90,6 +90,9 @@ import re
 from tornado import escape
 from tornado.util import bytes_type
 
+_DEFAULT_AUTOESCAPE = None
+_UNSET = object()
+
 class Template(object):
     """A compiled template.
 
@@ -97,13 +100,19 @@ class Template(object):
     the template from variables with generate().
     """
     def __init__(self, template_string, name="<string>", loader=None,
-                 compress_whitespace=None):
+                 compress_whitespace=None, autoescape=_UNSET):
         self.name = name
         if compress_whitespace is None:
             compress_whitespace = name.endswith(".html") or \
                 name.endswith(".js")
+        if autoescape is not _UNSET:
+            self.autoescape = autoescape
+        elif loader:
+            self.autoescape = loader.autoescape
+        else:
+            self.autoescape = _DEFAULT_AUTOESCAPE
         reader = _TemplateReader(name, escape.native_str(template_string))
-        self.file = _File(_parse(reader))
+        self.file = _File(_parse(reader, self))
         self.code = self._generate_python(loader, compress_whitespace)
         try:
             self.compiled = compile(self.code, self.name, "exec")
@@ -138,6 +147,7 @@ class Template(object):
     def _generate_python(self, loader, compress_whitespace):
         buffer = cStringIO.StringIO()
         try:
+            # named_blocks maps from names to _NamedBlock objects
             named_blocks = {}
             ancestors = self._get_ancestors(loader)
             ancestors.reverse()
@@ -164,8 +174,17 @@ class Template(object):
 
 
 class BaseLoader(object):
-    def __init__(self, root_directory):
+    def __init__(self, root_directory, autoescape=_DEFAULT_AUTOESCAPE):
+        """Creates a template loader.
+
+        root_directory may be the empty string if this loader does not
+        use the filesystem.
+
+        autoescape must be either None or a string naming a function
+        in the template namespace, such as "xhtml_escape".
+        """
         self.root = os.path.abspath(root_directory)
+        self.autoescape = autoescape
         self.templates = {}
 
     def reset(self):
@@ -198,6 +217,9 @@ class Loader(BaseLoader):
     {% extends %} and {% include %}. Loader caches all templates after
     they are loaded the first time.
     """
+    def __init__(self, root_directory, **kwargs):
+        super(Loader, self).__init__(root_directory, **kwargs)
+
     def _create_template(self, name):
         path = os.path.join(self.root, name)
         f = open(path, "r")
@@ -208,8 +230,8 @@ class Loader(BaseLoader):
 
 class DictLoader(BaseLoader):
     """A template loader that loads from a dictionary."""
-    def __init__(self, dict):
-        super(DictLoader, self).__init__("")
+    def __init__(self, dict, **kwargs):
+        super(DictLoader, self).__init__("", **kwargs)
         self.dict = dict
 
     def _create_template(self, name):
@@ -257,18 +279,23 @@ class _ChunkList(_Node):
 
 
 class _NamedBlock(_Node):
-    def __init__(self, name, body=None):
+    def __init__(self, name, body, template):
         self.name = name
         self.body = body
+        self.template = template
 
     def each_child(self):
         return (self.body,)
 
     def generate(self, writer):
-        writer.named_blocks[self.name].generate(writer)
+        block = writer.named_blocks[self.name]
+        old = writer.current_template
+        writer.current_template = block.template
+        block.body.generate(writer)
+        writer.current_template = old
 
     def find_named_blocks(self, loader, named_blocks):
-        named_blocks[self.name] = self.body
+        named_blocks[self.name] = self
         _Node.find_named_blocks(self, loader, named_blocks)
 
 
@@ -351,8 +378,14 @@ class _Expression(_Node):
     def generate(self, writer):
         writer.write_line("_tmp = %s" % self.expression)
         writer.write_line("if isinstance(_tmp, _string_types):"
-                          " _buffer.append(_utf8(_tmp))")
-        writer.write_line("else: _buffer.append(_utf8(str(_tmp)))")
+                          " _tmp = _utf8(_tmp)")
+        writer.write_line("else: _tmp = _utf8(str(_tmp))")
+        if writer.current_template.autoescape is not None:
+            # In python3 functions like xhtml_escape return unicode,
+            # so we have to convert to utf8 again.
+            writer.write_line("_tmp = _utf8(%s(_tmp))" %
+                              writer.current_template.autoescape)
+        writer.write_line("_buffer.append(_tmp)")
 
 
 class _Text(_Node):
@@ -470,7 +503,7 @@ def _format_code(code):
     return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
 
 
-def _parse(reader, in_block=None):
+def _parse(reader, template, in_block=None):
     body = _ChunkList([])
     while True:
         # Find next template directive
@@ -554,7 +587,7 @@ def _parse(reader, in_block=None):
             return body
 
         elif operator in ("extends", "include", "set", "import", "from",
-                          "comment"):
+                          "comment", "autoescape"):
             if operator == "comment":
                 continue
             if operator == "extends":
@@ -575,12 +608,17 @@ def _parse(reader, in_block=None):
                 if not suffix:
                     raise ParseError("set missing statement on line %d" % line)
                 block = _Statement(suffix)
+            elif operator == "autoescape":
+                fn = suffix.strip()
+                if fn == "None": fn = None
+                template.autoescape = fn
+                continue
             body.chunks.append(block)
             continue
 
         elif operator in ("apply", "block", "try", "if", "for", "while"):
             # parse inner body recursively
-            block_body = _parse(reader, operator)
+            block_body = _parse(reader, template, operator)
             if operator == "apply":
                 if not suffix:
                     raise ParseError("apply missing method name on line %d" % line)
@@ -588,7 +626,7 @@ def _parse(reader, in_block=None):
             elif operator == "block":
                 if not suffix:
                     raise ParseError("block missing name on line %d" % line)
-                block = _NamedBlock(suffix, block_body)
+                block = _NamedBlock(suffix, block_body, template)
             else:
                 block = _ControlBlock(contents, block_body)
             body.chunks.append(block)
index 893711acfe9974786110f33519d29abb39f236b8..a7cef48e954c934b5bd62bf0247af214297a3d52 100644 (file)
@@ -49,3 +49,92 @@ class TemplateTest(LogTrapTestCase):
                 })
         self.assertEqual(loader.load("a/1.html").generate(),
                          b("ok"))
+
+class AutoEscapeTest(LogTrapTestCase):
+    def setUp(self):
+        self.templates = {
+            "escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
+            "unescaped.html": "{% autoescape None %}{{ name }}",
+            "default.html": "{{ name }}",
+
+            "include.html": """\
+escaped: {% include 'escaped.html' %}
+unescaped: {% include 'unescaped.html' %}
+default: {% include 'default.html' %}
+""",
+
+            "escaped_block.html": """\
+{% autoescape xhtml_escape %}\
+{% block name %}base: {{ name }}{% end %}""",
+            "unescaped_block.html": """\
+{% autoescape None %}\
+{% block name %}base: {{ name }}{% end %}""",
+
+            # Extend a base template with different autoescape policy,
+            # with and without overriding the base's blocks
+            "escaped_extends_unescaped.html": """\
+{% autoescape xhtml_escape %}\
+{% extends "unescaped_block.html" %}""",
+            "escaped_overrides_unescaped.html": """\
+{% autoescape xhtml_escape %}\
+{% extends "unescaped_block.html" %}\
+{% block name %}extended: {{ name }}{% end %}""",
+            "unescaped_extends_escaped.html": """\
+{% autoescape None %}\
+{% extends "escaped_block.html" %}""",
+            "unescaped_overrides_escaped.html": """\
+{% autoescape None %}\
+{% extends "escaped_block.html" %}\
+{% block name %}extended: {{ name }}{% end %}""",
+            }
+    
+    def test_default_off(self):
+        loader = DictLoader(self.templates, autoescape=None)
+        name = "Bobby <table>s"
+        self.assertEqual(loader.load("escaped.html").generate(name=name),
+                         b("Bobby &lt;table&gt;s"))
+        self.assertEqual(loader.load("unescaped.html").generate(name=name),
+                         b("Bobby <table>s"))
+        self.assertEqual(loader.load("default.html").generate(name=name),
+                         b("Bobby <table>s"))
+
+        self.assertEqual(loader.load("include.html").generate(name=name),
+                         b("escaped: Bobby &lt;table&gt;s\n"
+                           "unescaped: Bobby <table>s\n"
+                           "default: Bobby <table>s\n"))
+        
+    def test_default_on(self):
+        loader = DictLoader(self.templates, autoescape="xhtml_escape")
+        name = "Bobby <table>s"
+        self.assertEqual(loader.load("escaped.html").generate(name=name),
+                         b("Bobby &lt;table&gt;s"))
+        self.assertEqual(loader.load("unescaped.html").generate(name=name),
+                         b("Bobby <table>s"))
+        self.assertEqual(loader.load("default.html").generate(name=name),
+                         b("Bobby &lt;table&gt;s"))
+        
+        self.assertEqual(loader.load("include.html").generate(name=name),
+                         b("escaped: Bobby &lt;table&gt;s\n"
+                           "unescaped: Bobby <table>s\n"
+                           "default: Bobby &lt;table&gt;s\n"))
+
+    def test_unextended_block(self):
+        loader = DictLoader(self.templates)
+        name = "<script>"
+        self.assertEqual(loader.load("escaped_block.html").generate(name=name),
+                         b("base: &lt;script&gt;"))
+        self.assertEqual(loader.load("unescaped_block.html").generate(name=name),
+                         b("base: <script>"))
+
+    def test_extended_block(self):
+        loader = DictLoader(self.templates)
+        def render(name): return loader.load(name).generate(name="<script>")
+        self.assertEqual(render("escaped_extends_unescaped.html"),
+                         b("base: <script>"))
+        self.assertEqual(render("escaped_overrides_unescaped.html"),
+                         b("extended: &lt;script&gt;"))
+
+        self.assertEqual(render("unescaped_extends_escaped.html"),
+                         b("base: &lt;script&gt;"))
+        self.assertEqual(render("unescaped_overrides_escaped.html"),
+                         b("extended: <script>"))