]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
speedups: validate mask length
authorBen Darnell <ben@bendarnell.com>
Wed, 27 May 2026 01:17:59 +0000 (21:17 -0400)
committerBen Darnell <ben@bendarnell.com>
Wed, 27 May 2026 16:09:29 +0000 (12:09 -0400)
The lack of this check permitted a read of up to 3 bytes past the end
of the string in some cases.

tornado/speedups.c
tornado/test/websocket_test.py
tornado/util.py

index 992c29c15b7042ce99f9ec2a6c2e440682f93388..2c791c646bf69e988c708a656fe53f80da3abe10 100644 (file)
@@ -2,63 +2,76 @@
 #include <Python.h>
 #include <stdint.h>
 
-static PyObject* websocket_mask(PyObject* self, PyObject* args) {
-    const char* mask;
+static PyObject *websocket_mask(PyObject *self, PyObject *args)
+{
+    const char *mask;
     Py_ssize_t mask_len;
     uint32_t uint32_mask;
     uint64_t uint64_mask;
-    const chardata;
+    const char *data;
     Py_ssize_t data_len;
     Py_ssize_t i;
-    PyObjectresult;
-    charbuf;
+    PyObject *result;
+    char *buf;
 
-    if (!PyArg_ParseTuple(args, "s#s#", &mask, &mask_len, &data, &data_len)) {
+    if (!PyArg_ParseTuple(args, "s#s#", &mask, &mask_len, &data, &data_len))
+    {
         return NULL;
     }
 
-    uint32_mask = ((uint32_t*)mask)[0];
+    if (mask_len != 4)
+    {
+        PyErr_SetString(PyExc_ValueError, "mask must be 4 bytes");
+        return NULL;
+    }
+
+    uint32_mask = ((uint32_t *)mask)[0];
 
     result = PyBytes_FromStringAndSize(NULL, data_len);
-    if (!result) {
+    if (!result)
+    {
         return NULL;
     }
     buf = PyBytes_AsString(result);
 
-    if (sizeof(size_t) >= 8) {
+    if (sizeof(size_t) >= 8)
+    {
         uint64_mask = uint32_mask;
         uint64_mask = (uint64_mask << 32) | uint32_mask;
 
-        while (data_len >= 8) {
-            ((uint64_t*)buf)[0] = ((uint64_t*)data)[0] ^ uint64_mask;
+        while (data_len >= 8)
+        {
+            ((uint64_t *)buf)[0] = ((uint64_t *)data)[0] ^ uint64_mask;
             data += 8;
             buf += 8;
             data_len -= 8;
         }
     }
 
-    while (data_len >= 4) {
-        ((uint32_t*)buf)[0] = ((uint32_t*)data)[0] ^ uint32_mask;
+    while (data_len >= 4)
+    {
+        ((uint32_t *)buf)[0] = ((uint32_t *)data)[0] ^ uint32_mask;
         data += 4;
         buf += 4;
         data_len -= 4;
     }
 
-    for (i = 0; i < data_len; i++) {
+    for (i = 0; i < data_len; i++)
+    {
         buf[i] = data[i] ^ mask[i];
     }
 
     return result;
 }
 
-static int speedups_exec(PyObject *module) {
+static int speedups_exec(PyObject *module)
+{
     return 0;
 }
 
 static PyMethodDef methods[] = {
-    {"websocket_mask",  websocket_mask, METH_VARARGS, ""},
-    {NULL, NULL, 0, NULL}
-};
+    {"websocket_mask", websocket_mask, METH_VARARGS, ""},
+    {NULL, NULL, 0, NULL}};
 
 static PyModuleDef_Slot slots[] = {
     {Py_mod_exec, speedups_exec},
@@ -68,19 +81,19 @@ static PyModuleDef_Slot slots[] = {
 #if (!defined(Py_LIMITED_API) && PY_VERSION_HEX >= 0x030d0000) || Py_LIMITED_API >= 0x030d0000
     {Py_mod_gil, Py_MOD_GIL_NOT_USED},
 #endif
-    {0, NULL}
-};
+    {0, NULL}};
 
 static struct PyModuleDef speedupsmodule = {
-   PyModuleDef_HEAD_INIT,
-   "speedups",
-   NULL,
-   0,
-   methods,
-   slots,
+    PyModuleDef_HEAD_INIT,
+    "speedups",
+    NULL,
+    0,
+    methods,
+    slots,
 };
 
 PyMODINIT_FUNC
-PyInit_speedups(void) {
+PyInit_speedups(void)
+{
     return PyModuleDef_Init(&speedupsmodule);
 }
index 94ccac57afba8ccb7a32063543f09a37a7be4a68..fc45688c9ee42a6567f74d9294c7f969efe3061f 100644 (file)
@@ -794,6 +794,13 @@ class MaskFunctionMixin(unittest.TestCase):
             b"\xff\xfa\xff\xff\xfb\xfe",
         )
 
+    def test_length_validation(self: typing.Any):
+        # Test all lengths of mask that are not 4 bytes.
+        for mask in (b"", b"a", b"ab", b"abc", b"abcde", b"abcdef"):
+            with self.subTest(mask=mask):
+                with self.assertRaises(ValueError):
+                    self.mask(mask, b"data asdf")
+
 
 class PythonMaskFunctionTest(MaskFunctionMixin):
     def mask(self, mask, data):
index 810732a67f2c9589d272f685df9b6fdfcbcf5924..37b595e0b856cf6cdfb01d14e1d8e3ac12223ce9 100644 (file)
@@ -408,6 +408,8 @@ def _websocket_mask_python(mask: bytes, data: bytes) -> bytes:
 
     This pure-python implementation may be replaced by an optimized version when available.
     """
+    if len(mask) != 4:
+        raise ValueError("mask must be 4 bytes")
     mask_arr = array.array("B", mask)
     unmasked_arr = array.array("B", data)
     for i in range(len(data)):