):
self._expected_content_remaining = 0
elif "Content-Length" in headers:
- self._expected_content_remaining = int(headers["Content-Length"])
+ self._expected_content_remaining = parse_int(headers["Content-Length"])
else:
self._expected_content_remaining = None
# TODO: headers are supposed to be of type str, but we still have some
headers["Content-Length"] = pieces[0]
try:
- content_length = int(headers["Content-Length"]) # type: Optional[int]
+ content_length: Optional[int] = parse_int(headers["Content-Length"])
except ValueError:
# Handles non-integer Content-Length value.
raise httputil.HTTPInputError(
total_size = 0
while True:
chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64)
- chunk_len = int(chunk_len_str.strip(), 16)
+ try:
+ chunk_len = parse_hex_int(native_str(chunk_len_str[:-2]))
+ except ValueError:
+ raise httputil.HTTPInputError("invalid chunk size")
if chunk_len == 0:
crlf = await self.stream.read_bytes(2)
if crlf != b"\r\n":
await asyncio.sleep(0)
finally:
delegate.on_close(self)
+
+
+DIGITS = re.compile(r"[0-9]+")
+HEXDIGITS = re.compile(r"[0-9a-fA-F]+")
+
+
+def parse_int(s: str) -> int:
+ """Parse a non-negative integer from a string."""
+ if DIGITS.fullmatch(s) is None:
+ raise ValueError("not an integer: %r" % s)
+ return int(s)
+
+
+def parse_hex_int(s: str) -> int:
+ """Parse a non-negative hexadecimal integer from a string."""
+ if HEXDIGITS.fullmatch(s) is None:
+ raise ValueError("not a hexadecimal integer: %r" % s)
+ return int(s, 16)
)
from tornado.iostream import IOStream
from tornado.locks import Event
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import (
import ssl
import sys
import tempfile
+import textwrap
import unittest
import urllib.parse
from io import BytesIO
def get_ssl_options(self):
return dict(
ssl_version=self.get_ssl_version(),
- **AsyncHTTPSTestCase.default_ssl_options()
+ **AsyncHTTPSTestCase.default_ssl_options(),
)
def get_ssl_version(self):
)
self.assertEqual(json_decode(response), {"foo": ["bar"]})
- @gen_test
- def test_invalid_content_length(self):
- with ExpectLog(
- gen_log, ".*Only integer Content-Length is allowed", level=logging.INFO
- ):
- self.stream.write(
- b"""\
+ def test_chunked_request_body_invalid_size(self):
+ # Only hex digits are allowed in chunk sizes. Python's int() function
+ # also accepts underscores, so make sure we reject them here.
+ self.stream.write(
+ b"""\
POST /echo HTTP/1.1
-Content-Length: foo
+Transfer-Encoding: chunked
-bar
+1_a
+1234567890abcdef1234567890
+0
""".replace(
- b"\n", b"\r\n"
- )
+ b"\n", b"\r\n"
)
- yield self.stream.read_until_close()
+ )
+ start_line, headers, response = self.io_loop.run_sync(
+ lambda: read_stream_body(self.stream)
+ )
+ self.assertEqual(400, start_line.code)
+
+ @gen_test
+ def test_invalid_content_length(self):
+ # HTTP only allows decimal digits in content-length. Make sure we don't
+ # accept anything else, with special attention to things accepted by the
+ # python int() function (leading plus signs and internal underscores).
+ test_cases = [
+ ("alphabetic", "foo"),
+ ("leading plus", "+10"),
+ ("internal underscore", "1_0"),
+ ]
+ for name, value in test_cases:
+ with self.subTest(name=name), closing(IOStream(socket.socket())) as stream:
+ with ExpectLog(
+ gen_log,
+ ".*Only integer Content-Length is allowed",
+ level=logging.INFO,
+ ):
+ yield stream.connect(("127.0.0.1", self.get_http_port()))
+ stream.write(
+ utf8(
+ textwrap.dedent(
+ f"""\
+ POST /echo HTTP/1.1
+ Content-Length: {value}
+ Connection: close
+
+ 1234567890
+ """
+ ).replace("\n", "\r\n")
+ )
+ )
+ yield stream.read_until_close()
class XHeaderTest(HandlerBaseTestCase):
)
+class InvalidOutputContentLengthTest(AsyncHTTPTestCase):
+ class MessageDelegate(HTTPMessageDelegate):
+ def __init__(self, connection):
+ self.connection = connection
+
+ def headers_received(self, start_line, headers):
+ content_lengths = {
+ "normal": "10",
+ "alphabetic": "foo",
+ "leading plus": "+10",
+ "underscore": "1_0",
+ }
+ self.connection.write_headers(
+ ResponseStartLine("HTTP/1.1", 200, "OK"),
+ HTTPHeaders({"Content-Length": content_lengths[headers["x-test"]]}),
+ )
+ self.connection.write(b"1234567890")
+ self.connection.finish()
+
+ def get_app(self):
+ class App(HTTPServerConnectionDelegate):
+ def start_request(self, server_conn, request_conn):
+ return InvalidOutputContentLengthTest.MessageDelegate(request_conn)
+
+ return App()
+
+ def test_invalid_output_content_length(self):
+ with self.subTest("normal"):
+ response = self.fetch("/", method="GET", headers={"x-test": "normal"})
+ response.rethrow()
+ self.assertEqual(response.body, b"1234567890")
+ for test in ["alphabetic", "leading plus", "underscore"]:
+ with self.subTest(test):
+ # This log matching could be tighter but I think I'm already
+ # over-testing here.
+ with ExpectLog(app_log, "Uncaught exception"):
+ with self.assertRaises(HTTPError):
+ self.fetch("/", method="GET", headers={"x-test": test})
+
+
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
return Application([("/", HelloWorldRequestHandler)])