]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Begin to fix type confusion in HTTPHeaders. 1699/head
authorBen Darnell <ben@bendarnell.com>
Sat, 23 Apr 2016 23:34:23 +0000 (19:34 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 23 Apr 2016 23:34:23 +0000 (19:34 -0400)
Values in HTTPHeaders were typed inconsistently: the class itself
assumed that values were of type str, while many callers assumed they
were bytes. Since headers are practically restricted to ascii, this
would work on python 2 but fail when certain combinations were
encountered on python 3 (notably the combination of GzipContentEncoding
with multiple preexisting Vary headers).

This commit adds a test and fixes the bug, and also adds enough type
annotations that mypy was able to report errors in the old incorrect
code.

Fixes #1670

tornado/escape.py
tornado/http1connection.py
tornado/httputil.py
tornado/test/web_test.py
tornado/util.py
tornado/web.py

index e6636f203831aca90b790fb3504826d5e47d7b0d..7a3b0e03495b927bc667b44f2df3c393fa5b987f 100644 (file)
@@ -37,6 +37,11 @@ else:
     import htmlentitydefs
     import urllib as urllib_parse
 
+try:
+    import typing  # noqa
+except ImportError:
+    pass
+
 
 _XHTML_ESCAPE_RE = re.compile('[&<>"\']')
 _XHTML_ESCAPE_DICT = {'&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;',
@@ -180,6 +185,7 @@ _UTF8_TYPES = (bytes, type(None))
 
 
 def utf8(value):
+    # type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
     """Converts a string argument to a byte string.
 
     If the argument is already a byte string or None, it is returned unchanged.
index b04cff1322b270df8aa8ab3da1766453aa97cdf6..8194f91436beca1853a7ac435ec936f8760736c2 100644 (file)
@@ -30,7 +30,7 @@ from tornado import httputil
 from tornado import iostream
 from tornado.log import gen_log, app_log
 from tornado import stack_context
-from tornado.util import GzipDecompressor
+from tornado.util import GzipDecompressor, PY3
 
 
 class _QuietException(Exception):
@@ -372,7 +372,14 @@ class HTTP1Connection(httputil.HTTPConnection):
             self._expected_content_remaining = int(headers['Content-Length'])
         else:
             self._expected_content_remaining = None
-        lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
+        # TODO: headers are supposed to be of type str, but we still have some
+        # cases that let bytes slip through. Remove these native_str calls when those
+        # are fixed.
+        header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
+        if PY3:
+            lines.extend(l.encode('latin1') for l in header_lines)
+        else:
+            lines.extend(header_lines)
         for line in lines:
             if b'\n' in line:
                 raise ValueError('Newline in header: ' + repr(line))
index d0901565a3c9a40a753bb5d53280eca7f83371fe..866681adf531ab0c8bccda12760639476cb75d37 100644 (file)
@@ -59,6 +59,12 @@ except ImportError:
     # on the class definition itself; must go through an assignment.
     SSLError = _SSLError  # type: ignore
 
+try:
+    import typing
+except ImportError:
+    pass
+
+
 # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
 # terminator and ignore any preceding CR.
 _CRLF_RE = re.compile(r'\r?\n')
@@ -124,8 +130,8 @@ class HTTPHeaders(collections.MutableMapping):
     Set-Cookie: C=D
     """
     def __init__(self, *args, **kwargs):
-        self._dict = {}
-        self._as_list = {}
+        self._dict = {}  # type: typing.Dict[str, str]
+        self._as_list = {}  # type: typing.Dict[str, typing.List[str]]
         self._last_key = None
         if (len(args) == 1 and len(kwargs) == 0 and
                 isinstance(args[0], HTTPHeaders)):
@@ -139,6 +145,7 @@ class HTTPHeaders(collections.MutableMapping):
     # new public methods
 
     def add(self, name, value):
+        # type: (str, str) -> None
         """Adds a new value for the given key."""
         norm_name = _normalized_headers[name]
         self._last_key = norm_name
@@ -155,6 +162,7 @@ class HTTPHeaders(collections.MutableMapping):
         return self._as_list.get(norm_name, [])
 
     def get_all(self):
+        # type: () -> typing.Iterable[typing.Tuple[str, str]]
         """Returns an iterable of all (name, value) pairs.
 
         If a header has multiple values, multiple pairs will be
@@ -203,6 +211,7 @@ class HTTPHeaders(collections.MutableMapping):
         self._as_list[norm_name] = [value]
 
     def __getitem__(self, name):
+        # type: (str) -> str
         return self._dict[_normalized_headers[name]]
 
     def __delitem__(self, name):
index fac23a21fd82ec9e65062312bd29c3f74b40a3a6..7e417854983c23440d0d18e1f76dcb742ebc6f47 100644 (file)
@@ -1506,8 +1506,8 @@ class ErrorHandlerXSRFTest(WebTestCase):
 class GzipTestCase(SimpleHandlerTestCase):
     class Handler(RequestHandler):
         def get(self):
-            if self.get_argument('vary', None):
-                self.set_header('Vary', self.get_argument('vary'))
+            for v in self.get_arguments('vary'):
+                self.add_header('Vary', v)
             # Must write at least MIN_LENGTH bytes to activate compression.
             self.write('hello world' + ('!' * GZipContentEncoding.MIN_LENGTH))
 
@@ -1516,8 +1516,7 @@ class GzipTestCase(SimpleHandlerTestCase):
             gzip=True,
             static_path=os.path.join(os.path.dirname(__file__), 'static'))
 
-    def test_gzip(self):
-        response = self.fetch('/')
+    def assert_compressed(self, response):
         # simple_httpclient renames the content-encoding header;
         # curl_httpclient doesn't.
         self.assertEqual(
@@ -1525,17 +1524,18 @@ class GzipTestCase(SimpleHandlerTestCase):
                 'Content-Encoding',
                 response.headers.get('X-Consumed-Content-Encoding')),
             'gzip')
+
+
+    def test_gzip(self):
+        response = self.fetch('/')
+        self.assert_compressed(response)
         self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
 
     def test_gzip_static(self):
         # The streaming responses in StaticFileHandler have subtle
         # interactions with the gzip output so test this case separately.
         response = self.fetch('/robots.txt')
-        self.assertEqual(
-            response.headers.get(
-                'Content-Encoding',
-                response.headers.get('X-Consumed-Content-Encoding')),
-            'gzip')
+        self.assert_compressed(response)
         self.assertEqual(response.headers['Vary'], 'Accept-Encoding')
 
     def test_gzip_not_requested(self):
@@ -1545,9 +1545,16 @@ class GzipTestCase(SimpleHandlerTestCase):
 
     def test_vary_already_present(self):
         response = self.fetch('/?vary=Accept-Language')
-        self.assertEqual(response.headers['Vary'],
-                         'Accept-Language, Accept-Encoding')
-
+        self.assert_compressed(response)
+        self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
+                         ['Accept-Language', 'Accept-Encoding'])
+
+    def test_vary_already_present_multiple(self):
+        # Regression test for https://github.com/tornadoweb/tornado/issues/1670
+        response = self.fetch('/?vary=Accept-Language&vary=Cookie')
+        self.assert_compressed(response)
+        self.assertEqual([s.strip() for s in response.headers['Vary'].split(',')],
+                         ['Accept-Language', 'Cookie', 'Accept-Encoding'])
 
 @wsgi_safe
 class PathArgsInPrepareTest(WebTestCase):
index 4283d4e86aa11388bccef436e57239d297c452ed..d49a84f42b6e2ef5b1cfc3ba059ecb43a359a6c2 100644 (file)
@@ -33,12 +33,13 @@ else:
 # Aliases for types that are spelled differently in different Python
 # versions. bytes_type is deprecated and no longer used in Tornado
 # itself but is left in case anyone outside Tornado is using it.
-unicode_type = type(u'')
 bytes_type = bytes
 if PY3:
+    unicode_type = str
     basestring_type = str
 else:
-    # The name basestring doesn't exist in py3 so silence flake8.
+    # The names unicode and basestring don't exist in py3 so silence flake8.
+    unicode_type = unicode  # noqa
     basestring_type = basestring  # noqa
 
 
index 8f2acfcc93f23c235489191a0a1b4f987b647109..c9ff2b2d0da18d3dbe5ae898667ca54ece09d1e7 100644 (file)
@@ -104,6 +104,11 @@ else:
 
 try:
     import typing  # noqa
+
+    # The following types are accepted by RequestHandler.set_header
+    # and related methods.
+    _HeaderTypes = typing.Union[bytes, unicode_type,
+                                numbers.Integral, datetime.datetime]
 except ImportError:
     pass
 
@@ -164,6 +169,7 @@ class RequestHandler(object):
         self._auto_finish = True
         self._transforms = None  # will be set in _execute
         self._prepared_future = None
+        self._headers = None  # type: httputil.HTTPHeaders
         self.path_args = None
         self.path_kwargs = None
         self.ui = ObjectDict((n, self._ui_method(m)) for n, m in
@@ -318,6 +324,7 @@ class RequestHandler(object):
         return self._status_code
 
     def set_header(self, name, value):
+        # type: (str, _HeaderTypes) -> None
         """Sets the given response header name and value.
 
         If a datetime is given, we automatically format it according to the
@@ -327,6 +334,7 @@ class RequestHandler(object):
         self._headers[name] = self._convert_header_value(value)
 
     def add_header(self, name, value):
+        # type: (str, _HeaderTypes) -> None
         """Adds the given response header and value.
 
         Unlike `set_header`, `add_header` may be called multiple times
@@ -343,13 +351,25 @@ class RequestHandler(object):
         if name in self._headers:
             del self._headers[name]
 
-    _INVALID_HEADER_CHAR_RE = re.compile(br"[\x00-\x1f]")
+    _INVALID_HEADER_CHAR_RE = re.compile(r"[\x00-\x1f]")
 
     def _convert_header_value(self, value):
-        if isinstance(value, bytes):
-            pass
-        elif isinstance(value, unicode_type):
-            value = value.encode('utf-8')
+        # type: (_HeaderTypes) -> str
+
+        # Convert the input value to a str. This type check is a bit
+        # subtle: The bytes case only executes on python 3, and the
+        # unicode case only executes on python 2, because the other
+        # cases are covered by the first match for str.
+        if isinstance(value, str):
+            retval = value
+        elif isinstance(value, bytes):  # py3
+            # Non-ascii characters in headers are not well supported,
+            # but if you pass bytes, use latin1 so they pass through as-is.
+            retval = value.decode('latin1')
+        elif isinstance(value, unicode_type):  # py2
+            # TODO: This is inconsistent with the use of latin1 above,
+            # but it's been that way for a long time. Should it change?
+            retval = escape.utf8(value)
         elif isinstance(value, numbers.Integral):
             # return immediately since we know the converted value will be safe
             return str(value)
@@ -359,9 +379,9 @@ class RequestHandler(object):
             raise TypeError("Unsupported header value %r" % value)
         # If \n is allowed into the header, it is possible to inject
         # additional headers or split the request.
-        if RequestHandler._INVALID_HEADER_CHAR_RE.search(value):
-            raise ValueError("Unsafe header value %r", value)
-        return value
+        if RequestHandler._INVALID_HEADER_CHAR_RE.search(retval):
+            raise ValueError("Unsafe header value %r", retval)
+        return retval
 
     _ARG_DEFAULT = object()
 
@@ -2696,6 +2716,7 @@ class OutputTransform(object):
         pass
 
     def transform_first_chunk(self, status_code, headers, chunk, finishing):
+        # type: (int, httputil.HTTPHeaders, bytes, bool) -> typing.Tuple[int, httputil.HTTPHeaders, bytes]
         return status_code, headers, chunk
 
     def transform_chunk(self, chunk, finishing):
@@ -2736,10 +2757,12 @@ class GZipContentEncoding(OutputTransform):
         return ctype.startswith('text/') or ctype in self.CONTENT_TYPES
 
     def transform_first_chunk(self, status_code, headers, chunk, finishing):
+        # type: (int, httputil.HTTPHeaders, bytes, bool) -> typing.Tuple[int, httputil.HTTPHeaders, bytes]
+        # TODO: can/should this type be inherited from the superclass?
         if 'Vary' in headers:
-            headers['Vary'] += b', Accept-Encoding'
+            headers['Vary'] += ', Accept-Encoding'
         else:
-            headers['Vary'] = b'Accept-Encoding'
+            headers['Vary'] = 'Accept-Encoding'
         if self._gzipping:
             ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
             self._gzipping = self._compressible_type(ctype) and \