]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-118209: Add Windows structured exception handling to mmap module (GH-118213)
authorDobatymo <Dobatymo@users.noreply.github.com>
Fri, 10 May 2024 09:47:30 +0000 (17:47 +0800)
committerGitHub <noreply@github.com>
Fri, 10 May 2024 09:47:30 +0000 (10:47 +0100)
Doc/whatsnew/3.13.rst
Lib/test/test_mmap.py
Misc/NEWS.d/next/Windows/2024-04-24-05-16-32.gh-issue-118209.Ryyzlz.rst [new file with mode: 0644]
Modules/mmapmodule.c

index 9dab458b21009351cb773d7fbb1f43ca83c37dd6..37c857dd8197e50a60668015a2f37f797b4b318d 100644 (file)
@@ -796,6 +796,9 @@ mmap
 * :class:`mmap.mmap` now has a *trackfd* parameter on Unix; if it is ``False``,
   the file descriptor specified by *fileno* will not be duplicated.
   (Contributed by Zackery Spytz and Petr Viktorin in :gh:`78502`.)
+* :class:`mmap.mmap` is now protected from crashing on Windows when the mapped memory
+  is inaccessible due to file system errors or access violations.
+  (Contributed by Jannis Weigend in :gh:`118209`.)
 
 opcode
 ------
index ee86227e026b67804616de466d69010131945ab9..a1cf5384ada5b5f1530a26d335e0b07b11010894 100644 (file)
@@ -3,6 +3,7 @@ from test.support import (
 )
 from test.support.import_helper import import_module
 from test.support.os_helper import TESTFN, unlink
+from test.support.script_helper import assert_python_ok
 import unittest
 import errno
 import os
@@ -12,6 +13,7 @@ import random
 import socket
 import string
 import sys
+import textwrap
 import weakref
 
 # Skip test if we can't import mmap.
@@ -1058,6 +1060,81 @@ class MmapTests(unittest.TestCase):
                 with self.assertRaisesRegex(ValueError, "mmap closed or invalid"):
                     m.write_byte(X())
 
+    @unittest.skipUnless(os.name == 'nt', 'requires Windows')
+    @unittest.skipUnless(hasattr(mmap.mmap, '_protect'), 'test needs debug build')
+    def test_access_violations(self):
+        from test.support.os_helper import TESTFN
+
+        code = textwrap.dedent("""
+            import faulthandler
+            import mmap
+            import os
+            import sys
+            from contextlib import suppress
+
+            # Prevent logging access violations to stderr.
+            faulthandler.disable()
+
+            PAGESIZE = mmap.PAGESIZE
+            PAGE_NOACCESS = 0x01
+
+            with open(sys.argv[1], 'bw+') as f:
+                f.write(b'A'* PAGESIZE)
+                f.flush()
+
+                m = mmap.mmap(f.fileno(), PAGESIZE)
+                m._protect(PAGE_NOACCESS, 0, PAGESIZE)
+                with suppress(OSError):
+                    m.read(PAGESIZE)
+                    assert False, 'mmap.read() did not raise'
+                with suppress(OSError):
+                    m.read_byte()
+                    assert False, 'mmap.read_byte() did not raise'
+                with suppress(OSError):
+                    m.readline()
+                    assert False, 'mmap.readline() did not raise'
+                with suppress(OSError):
+                    m.write(b'A'* PAGESIZE)
+                    assert False, 'mmap.write() did not raise'
+                with suppress(OSError):
+                    m.write_byte(0)
+                    assert False, 'mmap.write_byte() did not raise'
+                with suppress(OSError):
+                    m[0]  # test mmap_subscript
+                    assert False, 'mmap.__getitem__() did not raise'
+                with suppress(OSError):
+                    m[0:10]  # test mmap_subscript
+                    assert False, 'mmap.__getitem__() did not raise'
+                with suppress(OSError):
+                    m[0:10:2]  # test mmap_subscript
+                    assert False, 'mmap.__getitem__() did not raise'
+                with suppress(OSError):
+                    m[0] = 1
+                    assert False, 'mmap.__setitem__() did not raise'
+                with suppress(OSError):
+                    m[0:10] = b'A'* 10
+                    assert False, 'mmap.__setitem__() did not raise'
+                with suppress(OSError):
+                    m[0:10:2] = b'A'* 5
+                    assert False, 'mmap.__setitem__() did not raise'
+                with suppress(OSError):
+                    m.move(0, 10, 1)
+                    assert False, 'mmap.move() did not raise'
+                with suppress(OSError):
+                    list(m)  # test mmap_item
+                    assert False, 'mmap.__getitem__() did not raise'
+                with suppress(OSError):
+                    m.find(b'A')
+                    assert False, 'mmap.find() did not raise'
+                with suppress(OSError):
+                    m.rfind(b'A')
+                    assert False, 'mmap.rfind() did not raise'
+        """)
+        rt, stdout, stderr = assert_python_ok("-c", code, TESTFN)
+        self.assertEqual(stdout.strip(), b'')
+        self.assertEqual(stderr.strip(), b'')
+
+
 class LargeMmapTests(unittest.TestCase):
 
     def setUp(self):
diff --git a/Misc/NEWS.d/next/Windows/2024-04-24-05-16-32.gh-issue-118209.Ryyzlz.rst b/Misc/NEWS.d/next/Windows/2024-04-24-05-16-32.gh-issue-118209.Ryyzlz.rst
new file mode 100644 (file)
index 0000000..da70b25
--- /dev/null
@@ -0,0 +1,2 @@
+Avoid crashing in :mod:`mmap` on Windows when the mapped memory is inaccessible
+due to file system errors or access violations.
index dfc16ff437034941dc417425132b3fc0f904a117..99a85e9e49ad4775107f3a72478467661b2970a4 100644 (file)
@@ -41,6 +41,7 @@
 
 #ifdef MS_WINDOWS
 #include <windows.h>
+#include <ntsecapi.h> // LsaNtStatusToWinError
 static int
 my_getpagesize(void)
 {
@@ -255,6 +256,208 @@ do {                                                                    \
 } while (0)
 #endif /* UNIX */
 
+#if defined(MS_WINDOWS) && !defined(DONT_USE_SEH)
+static DWORD
+filter_page_exception(EXCEPTION_POINTERS *ptrs, EXCEPTION_RECORD *record)
+{
+    *record = *ptrs->ExceptionRecord;
+    if (record->ExceptionCode == EXCEPTION_IN_PAGE_ERROR ||
+        record->ExceptionCode == EXCEPTION_ACCESS_VIOLATION)
+    {
+        return EXCEPTION_EXECUTE_HANDLER;
+    }
+    return EXCEPTION_CONTINUE_SEARCH;
+}
+
+static DWORD
+filter_page_exception_method(mmap_object *self, EXCEPTION_POINTERS *ptrs,
+                             EXCEPTION_RECORD *record)
+{
+    *record = *ptrs->ExceptionRecord;
+    if (record->ExceptionCode == EXCEPTION_IN_PAGE_ERROR ||
+        record->ExceptionCode == EXCEPTION_ACCESS_VIOLATION)
+    {
+
+        ULONG_PTR address = record->ExceptionInformation[1];
+        if (address >= (ULONG_PTR) self->data &&
+            address < (ULONG_PTR) self->data + (ULONG_PTR) self->size)
+        {
+            return EXCEPTION_EXECUTE_HANDLER;
+        }
+    }
+    return EXCEPTION_CONTINUE_SEARCH;
+}
+#endif
+
+#if defined(MS_WINDOWS) && !defined(DONT_USE_SEH)
+#define HANDLE_INVALID_MEM(sourcecode)                                     \
+do {                                                                       \
+    EXCEPTION_RECORD record;                                               \
+    __try {                                                                \
+        sourcecode                                                         \
+    }                                                                      \
+    __except (filter_page_exception(GetExceptionInformation(), &record)) { \
+        assert(record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR ||          \
+               record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION);        \
+        if (record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR) {             \
+            NTSTATUS status = (NTSTATUS) record.ExceptionInformation[2];   \
+            ULONG code = LsaNtStatusToWinError(status);                    \
+            PyErr_SetFromWindowsErr(code);                                 \
+        }                                                                  \
+        else if (record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {     \
+            PyErr_SetFromWindowsErr(ERROR_NOACCESS);                       \
+        }                                                                  \
+        return -1;                                                         \
+    }                                                                      \
+} while (0)
+#else
+#define HANDLE_INVALID_MEM(sourcecode)                                     \
+do {                                                                       \
+    sourcecode                                                             \
+} while (0)
+#endif
+
+#if defined(MS_WINDOWS) && !defined(DONT_USE_SEH)
+#define HANDLE_INVALID_MEM_METHOD(self, sourcecode)                           \
+do {                                                                          \
+    EXCEPTION_RECORD record;                                                  \
+    __try {                                                                   \
+        sourcecode                                                            \
+    }                                                                         \
+    __except (filter_page_exception_method(self, GetExceptionInformation(),   \
+                                           &record)) {                        \
+        assert(record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR ||             \
+               record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION);           \
+        if (record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR) {                \
+            NTSTATUS status = (NTSTATUS) record.ExceptionInformation[2];      \
+            ULONG code = LsaNtStatusToWinError(status);                       \
+            PyErr_SetFromWindowsErr(code);                                    \
+        }                                                                     \
+        else if (record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {        \
+            PyErr_SetFromWindowsErr(ERROR_NOACCESS);                          \
+        }                                                                     \
+        return -1;                                                            \
+    }                                                                         \
+} while (0)
+#else
+#define HANDLE_INVALID_MEM_METHOD(self, sourcecode)                           \
+do {                                                                          \
+    sourcecode                                                                \
+} while (0)
+#endif
+
+int
+safe_memcpy(void *dest, const void *src, size_t count)
+{
+    HANDLE_INVALID_MEM(
+        memcpy(dest, src, count);
+    );
+    return 0;
+}
+
+int
+safe_byte_copy(char *dest, const char *src)
+{
+    HANDLE_INVALID_MEM(
+        *dest = *src;
+    );
+    return 0;
+}
+
+int
+safe_memchr(char **out, const void *ptr, int ch, size_t count)
+{
+    HANDLE_INVALID_MEM(
+        *out = (char *) memchr(ptr, ch, count);
+    );
+    return 0;
+}
+
+int
+safe_memmove(void *dest, const void *src, size_t count)
+{
+    HANDLE_INVALID_MEM(
+        memmove(dest, src, count);
+    );
+    return 0;
+}
+
+int
+safe_copy_from_slice(char *dest, const char *src, Py_ssize_t start,
+                     Py_ssize_t step, Py_ssize_t slicelen)
+{
+    HANDLE_INVALID_MEM(
+        size_t cur;
+        Py_ssize_t i;
+        for (cur = start, i = 0; i < slicelen; cur += step, i++) {
+            dest[cur] = src[i];
+        }
+    );
+    return 0;
+}
+
+int
+safe_copy_to_slice(char *dest, const char *src, Py_ssize_t start,
+                   Py_ssize_t step, Py_ssize_t slicelen)
+{
+    HANDLE_INVALID_MEM(
+        size_t cur;
+        Py_ssize_t i;
+        for (cur = start, i = 0; i < slicelen; cur += step, i++) {
+            dest[i] = src[cur];
+        }
+    );
+    return 0;
+}
+
+
+int
+_safe_PyBytes_Find(Py_ssize_t *out, mmap_object *self, const char *haystack,
+                   Py_ssize_t len_haystack, const char *needle,
+                   Py_ssize_t len_needle, Py_ssize_t offset)
+{
+    HANDLE_INVALID_MEM_METHOD(self,
+        *out = _PyBytes_Find(haystack, len_haystack, needle, len_needle, offset);
+    );
+    return 0;
+}
+
+int
+_safe_PyBytes_ReverseFind(Py_ssize_t *out, mmap_object *self,
+                          const char *haystack, Py_ssize_t len_haystack,
+                          const char *needle, Py_ssize_t len_needle,
+                          Py_ssize_t offset)
+{
+    HANDLE_INVALID_MEM_METHOD(self,
+        *out = _PyBytes_ReverseFind(haystack, len_haystack, needle, len_needle,
+                                    offset);
+    );
+    return 0;
+}
+
+PyObject *
+_safe_PyBytes_FromStringAndSize(char *start, size_t num_bytes) {
+    if (num_bytes == 1) {
+        char dest;
+        if (safe_byte_copy(&dest, start) < 0) {
+            return NULL;
+        }
+        else {
+            return PyBytes_FromStringAndSize(&dest, 1);
+        }
+    }
+    else {
+        PyObject *result = PyBytes_FromStringAndSize(NULL, num_bytes);
+        if (result == NULL) {
+            return NULL;
+        }
+        if (safe_memcpy(PyBytes_AS_STRING(result), start, num_bytes) < 0) {
+            Py_CLEAR(result);
+        }
+        return result;
+    }
+}
+
 static PyObject *
 mmap_read_byte_method(mmap_object *self,
                       PyObject *Py_UNUSED(ignored))
@@ -264,7 +467,12 @@ mmap_read_byte_method(mmap_object *self,
         PyErr_SetString(PyExc_ValueError, "read byte out of range");
         return NULL;
     }
-    return PyLong_FromLong((unsigned char)self->data[self->pos++]);
+    char dest;
+    if (safe_byte_copy(&dest, self->data + self->pos) < 0) {
+        return NULL;
+    }
+    self->pos++;
+    return PyLong_FromLong((unsigned char) dest);
 }
 
 static PyObject *
@@ -273,7 +481,6 @@ mmap_read_line_method(mmap_object *self,
 {
     Py_ssize_t remaining;
     char *start, *eol;
-    PyObject *result;
 
     CHECK_VALID(NULL);
 
@@ -281,13 +488,20 @@ mmap_read_line_method(mmap_object *self,
     if (!remaining)
         return PyBytes_FromString("");
     start = self->data + self->pos;
-    eol = memchr(start, '\n', remaining);
+
+    if (safe_memchr(&eol, start, '\n', remaining) < 0) {
+        return NULL;
+    }
+
     if (!eol)
         eol = self->data + self->size;
     else
         ++eol; /* advance past newline */
-    result = PyBytes_FromStringAndSize(start, (eol - start));
-    self->pos += (eol - start);
+
+    PyObject *result = _safe_PyBytes_FromStringAndSize(start, eol - start);
+    if (result != NULL) {
+        self->pos += (eol - start);
+    }
     return result;
 }
 
@@ -296,7 +510,6 @@ mmap_read_method(mmap_object *self,
                  PyObject *args)
 {
     Py_ssize_t num_bytes = PY_SSIZE_T_MAX, remaining;
-    PyObject *result;
 
     CHECK_VALID(NULL);
     if (!PyArg_ParseTuple(args, "|O&:read", _Py_convert_optional_to_ssize_t, &num_bytes))
@@ -307,8 +520,12 @@ mmap_read_method(mmap_object *self,
     remaining = (self->pos < self->size) ? self->size - self->pos : 0;
     if (num_bytes < 0 || num_bytes > remaining)
         num_bytes = remaining;
-    result = PyBytes_FromStringAndSize(&self->data[self->pos], num_bytes);
-    self->pos += num_bytes;
+
+    PyObject *result = _safe_PyBytes_FromStringAndSize(self->data + self->pos,
+                                                       num_bytes);
+    if (result != NULL) {
+        self->pos += num_bytes;
+    }
     return result;
 }
 
@@ -341,25 +558,38 @@ mmap_gfind(mmap_object *self,
         else if (end > self->size)
             end = self->size;
 
-        Py_ssize_t res;
+        Py_ssize_t index;
+        PyObject *result;
         CHECK_VALID_OR_RELEASE(NULL, view);
         if (end < start) {
-            res = -1;
+            result = PyLong_FromSsize_t(-1);
         }
         else if (reverse) {
             assert(0 <= start && start <= end && end <= self->size);
-            res = _PyBytes_ReverseFind(
+            if (_safe_PyBytes_ReverseFind(&index, self,
                 self->data + start, end - start,
-                view.buf, view.len, start);
+                view.buf, view.len, start) < 0)
+            {
+                result = NULL;
+            }
+            else {
+                result = PyLong_FromSsize_t(index);
+            }
         }
         else {
             assert(0 <= start && start <= end && end <= self->size);
-            res = _PyBytes_Find(
+            if (_safe_PyBytes_Find(&index, self,
                 self->data + start, end - start,
-                view.buf, view.len, start);
+                view.buf, view.len, start) < 0)
+            {
+                result = NULL;
+            }
+            else {
+                result = PyLong_FromSsize_t(index);
+            }
         }
         PyBuffer_Release(&view);
-        return PyLong_FromSsize_t(res);
+        return result;
     }
 }
 
@@ -432,10 +662,16 @@ mmap_write_method(mmap_object *self,
     }
 
     CHECK_VALID_OR_RELEASE(NULL, data);
-    memcpy(&self->data[self->pos], data.buf, data.len);
-    self->pos += data.len;
+    PyObject *result;
+    if (safe_memcpy(self->data + self->pos, data.buf, data.len) < 0) {
+        result = NULL;
+    }
+    else {
+        self->pos += data.len;
+        result = PyLong_FromSsize_t(data.len);
+    }
     PyBuffer_Release(&data);
-    return PyLong_FromSsize_t(data.len);
+    return result;
 }
 
 static PyObject *
@@ -452,14 +688,16 @@ mmap_write_byte_method(mmap_object *self,
         return NULL;
 
     CHECK_VALID(NULL);
-    if (self->pos < self->size) {
-        self->data[self->pos++] = value;
-        Py_RETURN_NONE;
-    }
-    else {
+    if (self->pos >= self->size) {
         PyErr_SetString(PyExc_ValueError, "write byte out of range");
         return NULL;
     }
+
+    if (safe_byte_copy(self->data + self->pos, &value) < 0) {
+        return NULL;
+    }
+    self->pos++;
+    Py_RETURN_NONE;
 }
 
 static PyObject *
@@ -763,8 +1001,9 @@ mmap_move_method(mmap_object *self, PyObject *args)
             goto bounds;
 
         CHECK_VALID(NULL);
-        memmove(&self->data[dest], &self->data[src], cnt);
-
+        if (safe_memmove(self->data + dest, self->data + src, cnt) < 0) {
+            return NULL;
+        };
         Py_RETURN_NONE;
 
       bounds:
@@ -855,6 +1094,29 @@ mmap__sizeof__method(mmap_object *self, void *Py_UNUSED(ignored))
 }
 #endif
 
+#if defined(MS_WINDOWS) && defined(Py_DEBUG)
+static PyObject *
+mmap_protect_method(mmap_object *self, PyObject *args) {
+    DWORD flNewProtect, flOldProtect;
+    Py_ssize_t start, length;
+
+    CHECK_VALID(NULL);
+
+    if (!PyArg_ParseTuple(args, "Inn:protect", &flNewProtect, &start, &length)) {
+        return NULL;
+    }
+
+    if (!VirtualProtect((void *) (self->data + start), length, flNewProtect,
+                        &flOldProtect))
+    {
+        PyErr_SetFromWindowsErr(GetLastError());
+        return NULL;
+    }
+
+    Py_RETURN_NONE;
+}
+#endif
+
 #ifdef HAVE_MADVISE
 static PyObject *
 mmap_madvise_method(mmap_object *self, PyObject *args)
@@ -924,7 +1186,10 @@ static struct PyMethodDef mmap_object_methods[] = {
     {"__exit__",        (PyCFunction) mmap__exit__method,       METH_VARARGS},
 #ifdef MS_WINDOWS
     {"__sizeof__",      (PyCFunction) mmap__sizeof__method,     METH_NOARGS},
-#endif
+#ifdef Py_DEBUG
+    {"_protect",        (PyCFunction) mmap_protect_method,      METH_VARARGS},
+#endif // Py_DEBUG
+#endif // MS_WINDOWS
     {NULL,         NULL}       /* sentinel */
 };
 
@@ -968,7 +1233,12 @@ mmap_item(mmap_object *self, Py_ssize_t i)
         PyErr_SetString(PyExc_IndexError, "mmap index out of range");
         return NULL;
     }
-    return PyBytes_FromStringAndSize(self->data + i, 1);
+
+    char dest;
+    if (safe_byte_copy(&dest, self->data + i) < 0) {
+        return NULL;
+    }
+    return PyBytes_FromStringAndSize(&dest, 1);
 }
 
 static PyObject *
@@ -987,7 +1257,12 @@ mmap_subscript(mmap_object *self, PyObject *item)
             return NULL;
         }
         CHECK_VALID(NULL);
-        return PyLong_FromLong(Py_CHARMASK(self->data[i]));
+
+        char dest;
+        if (safe_byte_copy(&dest, self->data + i) < 0) {
+            return NULL;
+        }
+        return PyLong_FromLong(Py_CHARMASK(dest));
     }
     else if (PySlice_Check(item)) {
         Py_ssize_t start, stop, step, slicelen;
@@ -1001,23 +1276,22 @@ mmap_subscript(mmap_object *self, PyObject *item)
         if (slicelen <= 0)
             return PyBytes_FromStringAndSize("", 0);
         else if (step == 1)
-            return PyBytes_FromStringAndSize(self->data + start,
-                                              slicelen);
+            return _safe_PyBytes_FromStringAndSize(self->data + start, slicelen);
         else {
             char *result_buf = (char *)PyMem_Malloc(slicelen);
-            size_t cur;
-            Py_ssize_t i;
             PyObject *result;
 
             if (result_buf == NULL)
                 return PyErr_NoMemory();
 
-            for (cur = start, i = 0; i < slicelen;
-                 cur += step, i++) {
-                result_buf[i] = self->data[cur];
+            if (safe_copy_to_slice(result_buf, self->data, start, step,
+                                   slicelen) < 0)
+            {
+                result = NULL;
+            }
+            else {
+                result = PyBytes_FromStringAndSize(result_buf, slicelen);
             }
-            result = PyBytes_FromStringAndSize(result_buf,
-                                                slicelen);
             PyMem_Free(result_buf);
             return result;
         }
@@ -1052,7 +1326,10 @@ mmap_ass_item(mmap_object *self, Py_ssize_t i, PyObject *v)
     if (!is_writable(self))
         return -1;
     buf = PyBytes_AsString(v);
-    self->data[i] = buf[0];
+
+    if (safe_byte_copy(self->data + i, buf) < 0) {
+        return -1;
+    }
     return 0;
 }
 
@@ -1097,7 +1374,11 @@ mmap_ass_subscript(mmap_object *self, PyObject *item, PyObject *value)
             return -1;
         }
         CHECK_VALID(-1);
-        self->data[i] = (char) v;
+
+        char v_char = (char) v;
+        if (safe_byte_copy(self->data + i, &v_char) < 0) {
+            return -1;
+        }
         return 0;
     }
     else if (PySlice_Check(item)) {
@@ -1123,24 +1404,23 @@ mmap_ass_subscript(mmap_object *self, PyObject *item, PyObject *value)
         }
 
         CHECK_VALID_OR_RELEASE(-1, vbuf);
+        int result = 0;
         if (slicelen == 0) {
         }
         else if (step == 1) {
-            memcpy(self->data + start, vbuf.buf, slicelen);
+            if (safe_memcpy(self->data + start, vbuf.buf, slicelen) < 0) {
+                result = -1;
+            }
         }
         else {
-            size_t cur;
-            Py_ssize_t i;
-
-            for (cur = start, i = 0;
-                 i < slicelen;
-                 cur += step, i++)
+            if (safe_copy_from_slice(self->data, (char *)vbuf.buf, start, step,
+                                     slicelen) < 0)
             {
-                self->data[cur] = ((char *)vbuf.buf)[i];
+                result = -1;
             }
         }
         PyBuffer_Release(&vbuf);
-        return 0;
+        return result;
     }
     else {
         PyErr_SetString(PyExc_TypeError,