]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
template: Add type annotations
authorBen Darnell <ben@bendarnell.com>
Sat, 29 Sep 2018 03:40:13 +0000 (23:40 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 29 Sep 2018 03:40:13 +0000 (23:40 -0400)
setup.cfg
tornado/template.py

index 36b6401022e922b0e775ee8e7720b373e5544659..b34641c1d737558f60d8bdbe7653c9659ae61eb6 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -7,7 +7,7 @@ python_version = 3.5
 [mypy-tornado.*,tornado.platform.*]
 disallow_untyped_defs = True
 
-[mypy-tornado.auth,tornado.routing,tornado.template,tornado.web,tornado.websocket,tornado.wsgi]
+[mypy-tornado.auth,tornado.routing,tornado.web,tornado.websocket,tornado.wsgi]
 disallow_untyped_defs = False
 
 # It's generally too tedious to require type annotations in tests, but
index ac5dcf8885120b63730d7aeaf90a1d1f8051f4af..6e3c1552f3105e08447187c2ad38c6f4358f6ae4 100644 (file)
@@ -207,11 +207,22 @@ from tornado import escape
 from tornado.log import app_log
 from tornado.util import ObjectDict, exec_in, unicode_type
 
+from typing import Any, Union, Callable, List, Dict, Iterable, Optional, TextIO, ContextManager
+import typing
+if typing.TYPE_CHECKING:
+    from typing import Tuple  # noqa: F401
+
 _DEFAULT_AUTOESCAPE = "xhtml_escape"
-_UNSET = object()
 
 
-def filter_whitespace(mode, text):
+class _UnsetMarker:
+    pass
+
+
+_UNSET = _UnsetMarker()
+
+
+def filter_whitespace(mode: str, text: str) -> str:
     """Transform whitespace in ``text`` according to ``mode``.
 
     Available modes are:
@@ -245,9 +256,10 @@ class Template(object):
     # note that the constructor's signature is not extracted with
     # autodoc because _UNSET looks like garbage.  When changing
     # this signature update website/sphinx/template.rst too.
-    def __init__(self, template_string, name="<string>", loader=None,
-                 compress_whitespace=_UNSET, autoescape=_UNSET,
-                 whitespace=None):
+    def __init__(self, template_string: Union[str, bytes], name: str="<string>",
+                 loader: 'BaseLoader'=None, compress_whitespace: Union[bool, _UnsetMarker]=_UNSET,
+                 autoescape: Union[str, _UnsetMarker]=_UNSET,
+                 whitespace: str=None) -> None:
         """Construct a Template.
 
         :arg str template_string: the contents of the template file.
@@ -283,10 +295,11 @@ class Template(object):
                 else:
                     whitespace = "all"
         # Validate the whitespace setting.
+        assert whitespace is not None
         filter_whitespace(whitespace, '')
 
-        if autoescape is not _UNSET:
-            self.autoescape = autoescape
+        if not isinstance(autoescape, _UnsetMarker):
+            self.autoescape = autoescape  # type: Optional[str]
         elif loader:
             self.autoescape = loader.autoescape
         else:
@@ -312,7 +325,7 @@ class Template(object):
             app_log.error("%s code:\n%s", self.name, formatted_code)
             raise
 
-    def generate(self, **kwargs):
+    def generate(self, **kwargs: Any) -> bytes:
         """Generate this template with the given arguments."""
         namespace = {
             "escape": escape.xhtml_escape,
@@ -332,18 +345,18 @@ class Template(object):
         namespace.update(self.namespace)
         namespace.update(kwargs)
         exec_in(self.compiled, namespace)
-        execute = namespace["_tt_execute"]
+        execute = typing.cast(Callable[[], bytes], namespace["_tt_execute"])
         # Clear the traceback module's cache of source data now that
         # we've generated a new template (mainly for this module's
         # unittests, where different tests reuse the same name).
         linecache.clearcache()
         return execute()
 
-    def _generate_python(self, loader):
+    def _generate_python(self, loader: Optional['BaseLoader']) -> str:
         buffer = StringIO()
         try:
             # named_blocks maps from names to _NamedBlock objects
-            named_blocks = {}
+            named_blocks = {}  # type: Dict[str, _NamedBlock]
             ancestors = self._get_ancestors(loader)
             ancestors.reverse()
             for ancestor in ancestors:
@@ -355,7 +368,7 @@ class Template(object):
         finally:
             buffer.close()
 
-    def _get_ancestors(self, loader):
+    def _get_ancestors(self, loader: Optional['BaseLoader']) -> List['_File']:
         ancestors = [self.file]
         for chunk in self.file.body.chunks:
             if isinstance(chunk, _ExtendsBlock):
@@ -374,8 +387,9 @@ class BaseLoader(object):
     ``{% extends %}`` and ``{% include %}``. The loader caches all
     templates after they are loaded the first time.
     """
-    def __init__(self, autoescape=_DEFAULT_AUTOESCAPE, namespace=None,
-                 whitespace=None):
+    def __init__(self, autoescape: str=_DEFAULT_AUTOESCAPE,
+                 namespace: Dict[str, Any]=None,
+                 whitespace: str=None) -> None:
         """Construct a template loader.
 
         :arg str autoescape: The name of a function in the template
@@ -394,7 +408,7 @@ class BaseLoader(object):
         self.autoescape = autoescape
         self.namespace = namespace or {}
         self.whitespace = whitespace
-        self.templates = {}
+        self.templates = {}  # type: Dict[str, Template]
         # self.lock protects self.templates.  It's a reentrant lock
         # because templates may load other templates via `include` or
         # `extends`.  Note that thanks to the GIL this code would be safe
@@ -402,16 +416,16 @@ class BaseLoader(object):
         # threads tried to compile the same template simultaneously.
         self.lock = threading.RLock()
 
-    def reset(self):
+    def reset(self) -> None:
         """Resets the cache of compiled templates."""
         with self.lock:
             self.templates = {}
 
-    def resolve_path(self, name, parent_path=None):
+    def resolve_path(self, name: str, parent_path: str=None) -> str:
         """Converts a possibly-relative path to absolute (used internally)."""
         raise NotImplementedError()
 
-    def load(self, name, parent_path=None):
+    def load(self, name: str, parent_path: str=None) -> Template:
         """Loads a template."""
         name = self.resolve_path(name, parent_path=parent_path)
         with self.lock:
@@ -419,18 +433,18 @@ class BaseLoader(object):
                 self.templates[name] = self._create_template(name)
             return self.templates[name]
 
-    def _create_template(self, name):
+    def _create_template(self, name: str) -> Template:
         raise NotImplementedError()
 
 
 class Loader(BaseLoader):
     """A template loader that loads from a single root directory.
     """
-    def __init__(self, root_directory, **kwargs):
+    def __init__(self, root_directory: str, **kwargs: Any) -> None:
         super(Loader, self).__init__(**kwargs)
         self.root = os.path.abspath(root_directory)
 
-    def resolve_path(self, name, parent_path=None):
+    def resolve_path(self, name: str, parent_path: str=None) -> str:
         if parent_path and not parent_path.startswith("<") and \
             not parent_path.startswith("/") and \
                 not name.startswith("/"):
@@ -441,7 +455,7 @@ class Loader(BaseLoader):
                 name = relative_path[len(self.root) + 1:]
         return name
 
-    def _create_template(self, name):
+    def _create_template(self, name: str) -> Template:
         path = os.path.join(self.root, name)
         with open(path, "rb") as f:
             template = Template(f.read(), name=name, loader=self)
@@ -450,11 +464,11 @@ class Loader(BaseLoader):
 
 class DictLoader(BaseLoader):
     """A template loader that loads from a dictionary."""
-    def __init__(self, dict, **kwargs):
+    def __init__(self, dict: Dict[str, str], **kwargs: Any) -> None:
         super(DictLoader, self).__init__(**kwargs)
         self.dict = dict
 
-    def resolve_path(self, name, parent_path=None):
+    def resolve_path(self, name: str, parent_path: str=None) -> str:
         if parent_path and not parent_path.startswith("<") and \
             not parent_path.startswith("/") and \
                 not name.startswith("/"):
@@ -462,29 +476,30 @@ class DictLoader(BaseLoader):
             name = posixpath.normpath(posixpath.join(file_dir, name))
         return name
 
-    def _create_template(self, name):
+    def _create_template(self, name: str) -> Template:
         return Template(self.dict[name], name=name, loader=self)
 
 
 class _Node(object):
-    def each_child(self):
+    def each_child(self) -> Iterable['_Node']:
         return ()
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         raise NotImplementedError()
 
-    def find_named_blocks(self, loader, named_blocks):
+    def find_named_blocks(self, loader: Optional[BaseLoader],
+                          named_blocks: Dict[str, '_NamedBlock']) -> None:
         for child in self.each_child():
             child.find_named_blocks(loader, named_blocks)
 
 
 class _File(_Node):
-    def __init__(self, template, body):
+    def __init__(self, template: Template, body: '_ChunkList') -> None:
         self.template = template
         self.body = body
         self.line = 0
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         writer.write_line("def _tt_execute():", self.line)
         with writer.indent():
             writer.write_line("_tt_buffer = []", self.line)
@@ -492,73 +507,77 @@ class _File(_Node):
             self.body.generate(writer)
             writer.write_line("return _tt_utf8('').join(_tt_buffer)", self.line)
 
-    def each_child(self):
+    def each_child(self) -> Iterable['_Node']:
         return (self.body,)
 
 
 class _ChunkList(_Node):
-    def __init__(self, chunks):
+    def __init__(self, chunks: List[_Node]) -> None:
         self.chunks = chunks
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         for chunk in self.chunks:
             chunk.generate(writer)
 
-    def each_child(self):
+    def each_child(self) -> Iterable['_Node']:
         return self.chunks
 
 
 class _NamedBlock(_Node):
-    def __init__(self, name, body, template, line):
+    def __init__(self, name: str, body: _Node, template: Template, line: int) -> None:
         self.name = name
         self.body = body
         self.template = template
         self.line = line
 
-    def each_child(self):
+    def each_child(self) -> Iterable['_Node']:
         return (self.body,)
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         block = writer.named_blocks[self.name]
         with writer.include(block.template, self.line):
             block.body.generate(writer)
 
-    def find_named_blocks(self, loader, named_blocks):
+    def find_named_blocks(self, loader: Optional[BaseLoader],
+                          named_blocks: Dict[str, '_NamedBlock']) -> None:
         named_blocks[self.name] = self
         _Node.find_named_blocks(self, loader, named_blocks)
 
 
 class _ExtendsBlock(_Node):
-    def __init__(self, name):
+    def __init__(self, name: str) -> None:
         self.name = name
 
 
 class _IncludeBlock(_Node):
-    def __init__(self, name, reader, line):
+    def __init__(self, name: str, reader: '_TemplateReader', line: int) -> None:
         self.name = name
         self.template_name = reader.name
         self.line = line
 
-    def find_named_blocks(self, loader, named_blocks):
+    def find_named_blocks(self, loader: Optional[BaseLoader],
+                          named_blocks: Dict[str, _NamedBlock]) -> None:
+        assert loader is not None
         included = loader.load(self.name, self.template_name)
         included.file.find_named_blocks(loader, named_blocks)
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
+        assert writer.loader is not None
         included = writer.loader.load(self.name, self.template_name)
         with writer.include(included, self.line):
             included.file.body.generate(writer)
 
 
 class _ApplyBlock(_Node):
-    def __init__(self, method, line, body=None):
+    def __init__(self, method: str, line: int, body: _Node) -> None:
         self.method = method
         self.line = line
         self.body = body
 
-    def each_child(self):
+    def each_child(self) -> Iterable['_Node']:
         return (self.body,)
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         method_name = "_tt_apply%d" % writer.apply_counter
         writer.apply_counter += 1
         writer.write_line("def %s():" % method_name, self.line)
@@ -572,15 +591,15 @@ class _ApplyBlock(_Node):
 
 
 class _ControlBlock(_Node):
-    def __init__(self, statement, line, body=None):
+    def __init__(self, statement: str, line: int, body: _Node) -> None:
         self.statement = statement
         self.line = line
         self.body = body
 
-    def each_child(self):
+    def each_child(self) -> Iterable[_Node]:
         return (self.body,)
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         writer.write_line("%s:" % self.statement, self.line)
         with writer.indent():
             self.body.generate(writer)
@@ -589,32 +608,32 @@ class _ControlBlock(_Node):
 
 
 class _IntermediateControlBlock(_Node):
-    def __init__(self, statement, line):
+    def __init__(self, statement: str, line: int) -> None:
         self.statement = statement
         self.line = line
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         # In case the previous block was empty
         writer.write_line("pass", self.line)
         writer.write_line("%s:" % self.statement, self.line, writer.indent_size() - 1)
 
 
 class _Statement(_Node):
-    def __init__(self, statement, line):
+    def __init__(self, statement: str, line: int) -> None:
         self.statement = statement
         self.line = line
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         writer.write_line(self.statement, self.line)
 
 
 class _Expression(_Node):
-    def __init__(self, expression, line, raw=False):
+    def __init__(self, expression: str, line: int, raw: bool=False) -> None:
         self.expression = expression
         self.line = line
         self.raw = raw
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         writer.write_line("_tt_tmp = %s" % self.expression, self.line)
         writer.write_line("if isinstance(_tt_tmp, _tt_string_types):"
                           " _tt_tmp = _tt_utf8(_tt_tmp)", self.line)
@@ -628,18 +647,18 @@ class _Expression(_Node):
 
 
 class _Module(_Expression):
-    def __init__(self, expression, line):
+    def __init__(self, expression: str, line: int) -> None:
         super(_Module, self).__init__("_tt_modules." + expression, line,
                                       raw=True)
 
 
 class _Text(_Node):
-    def __init__(self, value, line, whitespace):
+    def __init__(self, value: str, line: int, whitespace: str) -> None:
         self.value = value
         self.line = line
         self.whitespace = whitespace
 
-    def generate(self, writer):
+    def generate(self, writer: '_CodeWriter') -> None:
         value = self.value
 
         # Compress whitespace if requested, with a crude heuristic to avoid
@@ -660,56 +679,57 @@ class ParseError(Exception):
     .. versionchanged:: 4.3
        Added ``filename`` and ``lineno`` attributes.
     """
-    def __init__(self, message, filename=None, lineno=0):
+    def __init__(self, message: str, filename: str=None, lineno: int=0) -> None:
         self.message = message
         # The names "filename" and "lineno" are chosen for consistency
         # with python SyntaxError.
         self.filename = filename
         self.lineno = lineno
 
-    def __str__(self):
+    def __str__(self) -> str:
         return '%s at %s:%d' % (self.message, self.filename, self.lineno)
 
 
 class _CodeWriter(object):
-    def __init__(self, file, named_blocks, loader, current_template):
+    def __init__(self, file: TextIO, named_blocks: Dict[str, _NamedBlock],
+                 loader: Optional[BaseLoader], current_template: Template) -> None:
         self.file = file
         self.named_blocks = named_blocks
         self.loader = loader
         self.current_template = current_template
         self.apply_counter = 0
-        self.include_stack = []
+        self.include_stack = []  # type: List[Tuple[Template, int]]
         self._indent = 0
 
-    def indent_size(self):
+    def indent_size(self) -> int:
         return self._indent
 
-    def indent(self):
+    def indent(self) -> ContextManager:
         class Indenter(object):
-            def __enter__(_):
+            def __enter__(_) -> '_CodeWriter':
                 self._indent += 1
                 return self
 
-            def __exit__(_, *args):
+            def __exit__(_, *args: Any) -> None:
                 assert self._indent > 0
                 self._indent -= 1
 
         return Indenter()
 
-    def include(self, template, line):
+    def include(self, template: Template, line: int) -> ContextManager:
         self.include_stack.append((self.current_template, line))
         self.current_template = template
 
         class IncludeTemplate(object):
-            def __enter__(_):
+            def __enter__(_) -> '_CodeWriter':
                 return self
 
-            def __exit__(_, *args):
+            def __exit__(_, *args: Any) -> None:
                 self.current_template = self.include_stack.pop()[0]
 
         return IncludeTemplate()
 
-    def write_line(self, line, line_number, indent=None):
+    def write_line(self, line: str, line_number: int, indent: int=None) -> None:
         if indent is None:
             indent = self._indent
         line_comment = '  # %s:%d' % (self.current_template.name, line_number)
@@ -721,14 +741,14 @@ class _CodeWriter(object):
 
 
 class _TemplateReader(object):
-    def __init__(self, name, text, whitespace):
+    def __init__(self, name: str, text: str, whitespace: str) -> None:
         self.name = name
         self.text = text
         self.whitespace = whitespace
         self.line = 1
         self.pos = 0
 
-    def find(self, needle, start=0, end=None):
+    def find(self, needle: str, start: int=0, end: int=None) -> int:
         assert start >= 0, start
         pos = self.pos
         start += pos
@@ -742,7 +762,7 @@ class _TemplateReader(object):
             index -= pos
         return index
 
-    def consume(self, count=None):
+    def consume(self, count: int=None) -> str:
         if count is None:
             count = len(self.text) - self.pos
         newpos = self.pos + count
@@ -751,14 +771,14 @@ class _TemplateReader(object):
         self.pos = newpos
         return s
 
-    def remaining(self):
+    def remaining(self) -> int:
         return len(self.text) - self.pos
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.remaining()
 
-    def __getitem__(self, key):
-        if type(key) is slice:
+    def __getitem__(self, key: Union[int, slice]) -> str:
+        if isinstance(key, slice):
             size = len(self)
             start, stop, step = key.indices(size)
             if start is None:
@@ -773,20 +793,21 @@ class _TemplateReader(object):
         else:
             return self.text[self.pos + key]
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.text[self.pos:]
 
-    def raise_parse_error(self, msg):
+    def raise_parse_error(self, msg: str) -> None:
         raise ParseError(msg, self.name, self.line)
 
 
-def _format_code(code):
+def _format_code(code: str) -> str:
     lines = code.splitlines()
     format = "%%%dd  %%s\n" % len(repr(len(lines) + 1))
     return "".join([format % (i + 1, line) for (i, line) in enumerate(lines)])
 
 
-def _parse(reader, template, in_block=None, in_loop=None):
+def _parse(reader: _TemplateReader, template: Template,
+           in_block: str=None, in_loop: str=None) -> _ChunkList:
     body = _ChunkList([])
     while True:
         # Find next template directive
@@ -902,7 +923,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
                 suffix = suffix.strip('"').strip("'")
                 if not suffix:
                     reader.raise_parse_error("extends missing file path")
-                block = _ExtendsBlock(suffix)
+                block = _ExtendsBlock(suffix)  # type: _Node
             elif operator in ("import", "from"):
                 if not suffix:
                     reader.raise_parse_error("import missing statement")
@@ -917,7 +938,7 @@ def _parse(reader, template, in_block=None, in_loop=None):
                     reader.raise_parse_error("set missing statement")
                 block = _Statement(suffix, line)
             elif operator == "autoescape":
-                fn = suffix.strip()
+                fn = suffix.strip()  # type: Optional[str]
                 if fn == "None":
                     fn = None
                 template.autoescape = fn