]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
http1connection: Make content-length parsing more strict
authorBen Darnell <ben@bendarnell.com>
Wed, 9 Aug 2023 01:55:02 +0000 (21:55 -0400)
committerBen Darnell <ben@bendarnell.com>
Wed, 9 Aug 2023 01:55:02 +0000 (21:55 -0400)
Content-length and chunk size parsing now strictly matches the RFCs.
We previously used the python int() function which accepted leading
plus signs and internal underscores, which are not allowed by the
HTTP RFCs (it also accepts minus signs, but these are less problematic
in this context since they'd result in errors elsewhere)

It is important to fix this because when combined with certain proxies,
the lax parsing could result in a request smuggling vulnerability (if
both Tornado and the proxy accepted an invalid content-length but
interpreted it differently). This is known to occur with old versions
of haproxy, although the current version of haproxy is unaffected.

tornado/http1connection.py
tornado/test/httpserver_test.py

index 5ca91688878b1a2d9f811497dbc0920796d70826..ca50e8ff556d36a87adacf2252297e7176ef7c88 100644 (file)
@@ -442,7 +442,7 @@ class HTTP1Connection(httputil.HTTPConnection):
         ):
             self._expected_content_remaining = 0
         elif "Content-Length" in headers:
-            self._expected_content_remaining = int(headers["Content-Length"])
+            self._expected_content_remaining = parse_int(headers["Content-Length"])
         else:
             self._expected_content_remaining = None
         # TODO: headers are supposed to be of type str, but we still have some
@@ -618,7 +618,7 @@ class HTTP1Connection(httputil.HTTPConnection):
                 headers["Content-Length"] = pieces[0]
 
             try:
-                content_length = int(headers["Content-Length"])  # type: Optional[int]
+                content_length: Optional[int] = parse_int(headers["Content-Length"])
             except ValueError:
                 # Handles non-integer Content-Length value.
                 raise httputil.HTTPInputError(
@@ -668,7 +668,10 @@ class HTTP1Connection(httputil.HTTPConnection):
         total_size = 0
         while True:
             chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64)
-            chunk_len = int(chunk_len_str.strip(), 16)
+            try:
+                chunk_len = parse_hex_int(native_str(chunk_len_str[:-2]))
+            except ValueError:
+                raise httputil.HTTPInputError("invalid chunk size")
             if chunk_len == 0:
                 crlf = await self.stream.read_bytes(2)
                 if crlf != b"\r\n":
@@ -842,3 +845,21 @@ class HTTP1ServerConnection(object):
                 await asyncio.sleep(0)
         finally:
             delegate.on_close(self)
+
+
+DIGITS = re.compile(r"[0-9]+")
+HEXDIGITS = re.compile(r"[0-9a-fA-F]+")
+
+
+def parse_int(s: str) -> int:
+    """Parse a non-negative integer from a string."""
+    if DIGITS.fullmatch(s) is None:
+        raise ValueError("not an integer: %r" % s)
+    return int(s)
+
+
+def parse_hex_int(s: str) -> int:
+    """Parse a non-negative hexadecimal integer from a string."""
+    if HEXDIGITS.fullmatch(s) is None:
+        raise ValueError("not a hexadecimal integer: %r" % s)
+    return int(s, 16)
index cd0a0e100483427fc9cb3634ddf78d1ff442105e..db91d62daae03d4bb3419b8c34b443ccdc414fb4 100644 (file)
@@ -18,7 +18,7 @@ from tornado.httputil import (
 )
 from tornado.iostream import IOStream
 from tornado.locks import Event
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_options_to_context
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import (
@@ -41,6 +41,7 @@ import socket
 import ssl
 import sys
 import tempfile
+import textwrap
 import unittest
 import urllib.parse
 from io import BytesIO
@@ -118,7 +119,7 @@ class SSLTestMixin(object):
     def get_ssl_options(self):
         return dict(
             ssl_version=self.get_ssl_version(),
-            **AsyncHTTPSTestCase.default_ssl_options()
+            **AsyncHTTPSTestCase.default_ssl_options(),
         )
 
     def get_ssl_version(self):
@@ -558,23 +559,59 @@ bar
         )
         self.assertEqual(json_decode(response), {"foo": ["bar"]})
 
-    @gen_test
-    def test_invalid_content_length(self):
-        with ExpectLog(
-            gen_log, ".*Only integer Content-Length is allowed", level=logging.INFO
-        ):
-            self.stream.write(
-                b"""\
+    def test_chunked_request_body_invalid_size(self):
+        # Only hex digits are allowed in chunk sizes. Python's int() function
+        # also accepts underscores, so make sure we reject them here.
+        self.stream.write(
+            b"""\
 POST /echo HTTP/1.1
-Content-Length: foo
+Transfer-Encoding: chunked
 
-bar
+1_a
+1234567890abcdef1234567890
+0
 
 """.replace(
-                    b"\n", b"\r\n"
-                )
+                b"\n", b"\r\n"
             )
-            yield self.stream.read_until_close()
+        )
+        start_line, headers, response = self.io_loop.run_sync(
+            lambda: read_stream_body(self.stream)
+        )
+        self.assertEqual(400, start_line.code)
+
+    @gen_test
+    def test_invalid_content_length(self):
+        # HTTP only allows decimal digits in content-length. Make sure we don't
+        # accept anything else, with special attention to things accepted by the
+        # python int() function (leading plus signs and internal underscores).
+        test_cases = [
+            ("alphabetic", "foo"),
+            ("leading plus", "+10"),
+            ("internal underscore", "1_0"),
+        ]
+        for name, value in test_cases:
+            with self.subTest(name=name), closing(IOStream(socket.socket())) as stream:
+                with ExpectLog(
+                    gen_log,
+                    ".*Only integer Content-Length is allowed",
+                    level=logging.INFO,
+                ):
+                    yield stream.connect(("127.0.0.1", self.get_http_port()))
+                    stream.write(
+                        utf8(
+                            textwrap.dedent(
+                                f"""\
+                            POST /echo HTTP/1.1
+                            Content-Length: {value}
+                            Connection: close
+
+                            1234567890
+                            """
+                            ).replace("\n", "\r\n")
+                        )
+                    )
+                    yield stream.read_until_close()
 
 
 class XHeaderTest(HandlerBaseTestCase):
@@ -1123,6 +1160,46 @@ class StreamingChunkSizeTest(AsyncHTTPTestCase):
         )
 
 
+class InvalidOutputContentLengthTest(AsyncHTTPTestCase):
+    class MessageDelegate(HTTPMessageDelegate):
+        def __init__(self, connection):
+            self.connection = connection
+
+        def headers_received(self, start_line, headers):
+            content_lengths = {
+                "normal": "10",
+                "alphabetic": "foo",
+                "leading plus": "+10",
+                "underscore": "1_0",
+            }
+            self.connection.write_headers(
+                ResponseStartLine("HTTP/1.1", 200, "OK"),
+                HTTPHeaders({"Content-Length": content_lengths[headers["x-test"]]}),
+            )
+            self.connection.write(b"1234567890")
+            self.connection.finish()
+
+    def get_app(self):
+        class App(HTTPServerConnectionDelegate):
+            def start_request(self, server_conn, request_conn):
+                return InvalidOutputContentLengthTest.MessageDelegate(request_conn)
+
+        return App()
+
+    def test_invalid_output_content_length(self):
+        with self.subTest("normal"):
+            response = self.fetch("/", method="GET", headers={"x-test": "normal"})
+            response.rethrow()
+            self.assertEqual(response.body, b"1234567890")
+        for test in ["alphabetic", "leading plus", "underscore"]:
+            with self.subTest(test):
+                # This log matching could be tighter but I think I'm already
+                # over-testing here.
+                with ExpectLog(app_log, "Uncaught exception"):
+                    with self.assertRaises(HTTPError):
+                        self.fetch("/", method="GET", headers={"x-test": test})
+
+
 class MaxHeaderSizeTest(AsyncHTTPTestCase):
     def get_app(self):
         return Application([("/", HelloWorldRequestHandler)])