#include <Python.h>
#include <stdint.h>
-static PyObject* websocket_mask(PyObject* self, PyObject* args) {
- const char* mask;
+static PyObject *websocket_mask(PyObject *self, PyObject *args)
+{
+ const char *mask;
Py_ssize_t mask_len;
uint32_t uint32_mask;
uint64_t uint64_mask;
- const char* data;
+ const char *data;
Py_ssize_t data_len;
Py_ssize_t i;
- PyObject* result;
- char* buf;
+ PyObject *result;
+ char *buf;
- if (!PyArg_ParseTuple(args, "s#s#", &mask, &mask_len, &data, &data_len)) {
+ if (!PyArg_ParseTuple(args, "s#s#", &mask, &mask_len, &data, &data_len))
+ {
return NULL;
}
- uint32_mask = ((uint32_t*)mask)[0];
+ if (mask_len != 4)
+ {
+ PyErr_SetString(PyExc_ValueError, "mask must be 4 bytes");
+ return NULL;
+ }
+
+ uint32_mask = ((uint32_t *)mask)[0];
result = PyBytes_FromStringAndSize(NULL, data_len);
- if (!result) {
+ if (!result)
+ {
return NULL;
}
buf = PyBytes_AsString(result);
- if (sizeof(size_t) >= 8) {
+ if (sizeof(size_t) >= 8)
+ {
uint64_mask = uint32_mask;
uint64_mask = (uint64_mask << 32) | uint32_mask;
- while (data_len >= 8) {
- ((uint64_t*)buf)[0] = ((uint64_t*)data)[0] ^ uint64_mask;
+ while (data_len >= 8)
+ {
+ ((uint64_t *)buf)[0] = ((uint64_t *)data)[0] ^ uint64_mask;
data += 8;
buf += 8;
data_len -= 8;
}
}
- while (data_len >= 4) {
- ((uint32_t*)buf)[0] = ((uint32_t*)data)[0] ^ uint32_mask;
+ while (data_len >= 4)
+ {
+ ((uint32_t *)buf)[0] = ((uint32_t *)data)[0] ^ uint32_mask;
data += 4;
buf += 4;
data_len -= 4;
}
- for (i = 0; i < data_len; i++) {
+ for (i = 0; i < data_len; i++)
+ {
buf[i] = data[i] ^ mask[i];
}
return result;
}
-static int speedups_exec(PyObject *module) {
+static int speedups_exec(PyObject *module)
+{
return 0;
}
static PyMethodDef methods[] = {
- {"websocket_mask", websocket_mask, METH_VARARGS, ""},
- {NULL, NULL, 0, NULL}
-};
+ {"websocket_mask", websocket_mask, METH_VARARGS, ""},
+ {NULL, NULL, 0, NULL}};
static PyModuleDef_Slot slots[] = {
{Py_mod_exec, speedups_exec},
#if (!defined(Py_LIMITED_API) && PY_VERSION_HEX >= 0x030d0000) || Py_LIMITED_API >= 0x030d0000
{Py_mod_gil, Py_MOD_GIL_NOT_USED},
#endif
- {0, NULL}
-};
+ {0, NULL}};
static struct PyModuleDef speedupsmodule = {
- PyModuleDef_HEAD_INIT,
- "speedups",
- NULL,
- 0,
- methods,
- slots,
+ PyModuleDef_HEAD_INIT,
+ "speedups",
+ NULL,
+ 0,
+ methods,
+ slots,
};
PyMODINIT_FUNC
-PyInit_speedups(void) {
+PyInit_speedups(void)
+{
return PyModuleDef_Init(&speedupsmodule);
}
b"\xff\xfa\xff\xff\xfb\xfe",
)
+ def test_length_validation(self: typing.Any):
+ # Test all lengths of mask that are not 4 bytes.
+ for mask in (b"", b"a", b"ab", b"abc", b"abcde", b"abcdef"):
+ with self.subTest(mask=mask):
+ with self.assertRaises(ValueError):
+ self.mask(mask, b"data asdf")
+
class PythonMaskFunctionTest(MaskFunctionMixin):
def mask(self, mask, data):