]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-116738: Make grp module thread-safe (#135434)
authorAlper <alperyoney@fb.com>
Mon, 14 Jul 2025 18:18:41 +0000 (11:18 -0700)
committerGitHub <noreply@github.com>
Mon, 14 Jul 2025 18:18:41 +0000 (11:18 -0700)
Make grp module methods getgrgid() and getgrnam() thread-safe when the GIL is disabled and getgrgid_r()/getgrnam_r() C APIs are not available.
---------

Co-authored-by: Kumar Aditya <kumaraditya@python.org>
Doc/library/test.rst
Lib/test/support/threading_helper.py
Lib/test/test_free_threading/test_grp.py [new file with mode: 0644]
Lib/test/test_free_threading/test_heapq.py
Misc/NEWS.d/next/Core_and_Builtins/2025-06-12-00-03-34.gh-issue-116738.iBBAdo.rst [new file with mode: 0644]
Modules/grpmodule.c
Tools/c-analyzer/cpython/ignored.tsv

index 0aae14c15a610484c1454cc3f3bcedf646ff93b5..9fdb21b8badbbf9193e21545ee371842360d4b5d 100644 (file)
@@ -1384,6 +1384,13 @@ The :mod:`test.support.threading_helper` module provides support for threading t
    .. versionadded:: 3.8
 
 
+.. function:: run_concurrently(worker_func, nthreads, args=(), kwargs={})
+
+    Run the worker function concurrently in multiple threads.
+    Re-raises an exception if any thread raises one, after all threads have
+    finished.
+
+
 :mod:`test.support.os_helper` --- Utilities for os tests
 ========================================================================
 
index afa25a76f63829555028aba818a7daa39c9b5d47..3e04c344a0d66fc41174e9c1a07a0518c6807ac6 100644 (file)
@@ -248,3 +248,27 @@ def requires_working_threading(*, module=False):
             raise unittest.SkipTest(msg)
     else:
         return unittest.skipUnless(can_start_thread, msg)
+
+
+def run_concurrently(worker_func, nthreads, args=(), kwargs={}):
+    """
+    Run the worker function concurrently in multiple threads.
+    """
+    barrier = threading.Barrier(nthreads)
+
+    def wrapper_func(*args, **kwargs):
+        # Wait for all threads to reach this point before proceeding.
+        barrier.wait()
+        worker_func(*args, **kwargs)
+
+    with catch_threading_exception() as cm:
+        workers = [
+            threading.Thread(target=wrapper_func, args=args, kwargs=kwargs)
+            for _ in range(nthreads)
+        ]
+        with start_threads(workers):
+            pass
+
+        # If a worker thread raises an exception, re-raise it.
+        if cm.exc_value is not None:
+            raise cm.exc_value
diff --git a/Lib/test/test_free_threading/test_grp.py b/Lib/test/test_free_threading/test_grp.py
new file mode 100644 (file)
index 0000000..1a47a97
--- /dev/null
@@ -0,0 +1,35 @@
+import unittest
+
+from test.support import import_helper, threading_helper
+from test.support.threading_helper import run_concurrently
+
+grp = import_helper.import_module("grp")
+
+from test import test_grp
+
+
+NTHREADS = 10
+
+
+@threading_helper.requires_working_threading()
+class TestGrp(unittest.TestCase):
+    def setUp(self):
+        self.test_grp = test_grp.GroupDatabaseTestCase()
+
+    def test_racing_test_values(self):
+        # test_grp.test_values() calls grp.getgrall() and checks the entries
+        run_concurrently(
+            worker_func=self.test_grp.test_values, nthreads=NTHREADS
+        )
+
+    def test_racing_test_values_extended(self):
+        # test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
+        # grp.getgrnam() and checks the entries
+        run_concurrently(
+            worker_func=self.test_grp.test_values_extended,
+            nthreads=NTHREADS,
+        )
+
+
+if __name__ == "__main__":
+    unittest.main()
index ee7adfb2b78d834c56665e72d8a83f2a5e862548..d771333ffcc9e0f892b953d25bd94197db814e86 100644 (file)
@@ -3,10 +3,11 @@ import unittest
 import heapq
 
 from enum import Enum
-from threading import Thread, Barrier, Lock
+from threading import Barrier, Lock
 from random import shuffle, randint
 
 from test.support import threading_helper
+from test.support.threading_helper import run_concurrently
 from test import test_heapq
 
 
@@ -28,8 +29,8 @@ class TestHeapq(unittest.TestCase):
         heap = list(range(OBJECT_COUNT))
         shuffle(heap)
 
-        self.run_concurrently(
-            worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
+        run_concurrently(
+            worker_func=heapq.heapify, nthreads=NTHREADS, args=(heap,)
         )
         self.test_heapq.check_invariant(heap)
 
@@ -40,8 +41,8 @@ class TestHeapq(unittest.TestCase):
             for item in reversed(range(OBJECT_COUNT)):
                 heapq.heappush(heap, item)
 
-        self.run_concurrently(
-            worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
+        run_concurrently(
+            worker_func=heappush_func, nthreads=NTHREADS, args=(heap,)
         )
         self.test_heapq.check_invariant(heap)
 
@@ -61,10 +62,10 @@ class TestHeapq(unittest.TestCase):
             # Each local list should be sorted
             self.assertTrue(self.is_sorted_ascending(local_list))
 
-        self.run_concurrently(
+        run_concurrently(
             worker_func=heappop_func,
-            args=(heap, per_thread_pop_count),
             nthreads=NTHREADS,
+            args=(heap, per_thread_pop_count),
         )
         self.assertEqual(len(heap), 0)
 
@@ -77,10 +78,10 @@ class TestHeapq(unittest.TestCase):
                 popped_item = heapq.heappushpop(heap, item)
                 self.assertTrue(popped_item <= item)
 
-        self.run_concurrently(
+        run_concurrently(
             worker_func=heappushpop_func,
-            args=(heap, pushpop_items),
             nthreads=NTHREADS,
+            args=(heap, pushpop_items),
         )
         self.assertEqual(len(heap), OBJECT_COUNT)
         self.test_heapq.check_invariant(heap)
@@ -93,10 +94,10 @@ class TestHeapq(unittest.TestCase):
             for item in replace_items:
                 heapq.heapreplace(heap, item)
 
-        self.run_concurrently(
+        run_concurrently(
             worker_func=heapreplace_func,
-            args=(heap, replace_items),
             nthreads=NTHREADS,
+            args=(heap, replace_items),
         )
         self.assertEqual(len(heap), OBJECT_COUNT)
         self.test_heapq.check_invariant(heap)
@@ -105,8 +106,8 @@ class TestHeapq(unittest.TestCase):
         max_heap = list(range(OBJECT_COUNT))
         shuffle(max_heap)
 
-        self.run_concurrently(
-            worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
+        run_concurrently(
+            worker_func=heapq.heapify_max, nthreads=NTHREADS, args=(max_heap,)
         )
         self.test_heapq.check_max_invariant(max_heap)
 
@@ -117,8 +118,8 @@ class TestHeapq(unittest.TestCase):
             for item in range(OBJECT_COUNT):
                 heapq.heappush_max(max_heap, item)
 
-        self.run_concurrently(
-            worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
+        run_concurrently(
+            worker_func=heappush_max_func, nthreads=NTHREADS, args=(max_heap,)
         )
         self.test_heapq.check_max_invariant(max_heap)
 
@@ -138,10 +139,10 @@ class TestHeapq(unittest.TestCase):
             # Each local list should be sorted
             self.assertTrue(self.is_sorted_descending(local_list))
 
-        self.run_concurrently(
+        run_concurrently(
             worker_func=heappop_max_func,
-            args=(max_heap, per_thread_pop_count),
             nthreads=NTHREADS,
+            args=(max_heap, per_thread_pop_count),
         )
         self.assertEqual(len(max_heap), 0)
 
@@ -154,10 +155,10 @@ class TestHeapq(unittest.TestCase):
                 popped_item = heapq.heappushpop_max(max_heap, item)
                 self.assertTrue(popped_item >= item)
 
-        self.run_concurrently(
+        run_concurrently(
             worker_func=heappushpop_max_func,
-            args=(max_heap, pushpop_items),
             nthreads=NTHREADS,
+            args=(max_heap, pushpop_items),
         )
         self.assertEqual(len(max_heap), OBJECT_COUNT)
         self.test_heapq.check_max_invariant(max_heap)
@@ -170,10 +171,10 @@ class TestHeapq(unittest.TestCase):
             for item in replace_items:
                 heapq.heapreplace_max(max_heap, item)
 
-        self.run_concurrently(
+        run_concurrently(
             worker_func=heapreplace_max_func,
-            args=(max_heap, replace_items),
             nthreads=NTHREADS,
+            args=(max_heap, replace_items),
         )
         self.assertEqual(len(max_heap), OBJECT_COUNT)
         self.test_heapq.check_max_invariant(max_heap)
@@ -203,7 +204,7 @@ class TestHeapq(unittest.TestCase):
                     except IndexError:
                         pass
 
-        self.run_concurrently(worker, (), n_threads * 2)
+        run_concurrently(worker, n_threads * 2)
 
     @staticmethod
     def is_sorted_ascending(lst):
@@ -241,27 +242,6 @@ class TestHeapq(unittest.TestCase):
         """
         return [randint(-a, b) for _ in range(size)]
 
-    def run_concurrently(self, worker_func, args, nthreads):
-        """
-        Run the worker function concurrently in multiple threads.
-        """
-        barrier = Barrier(nthreads)
-
-        def wrapper_func(*args):
-            # Wait for all threads to reach this point before proceeding.
-            barrier.wait()
-            worker_func(*args)
-
-        with threading_helper.catch_threading_exception() as cm:
-            workers = (
-                Thread(target=wrapper_func, args=args) for _ in range(nthreads)
-            )
-            with threading_helper.start_threads(workers):
-                pass
-
-            # Worker threads should not raise any exceptions
-            self.assertIsNone(cm.exc_value)
-
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-06-12-00-03-34.gh-issue-116738.iBBAdo.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-06-12-00-03-34.gh-issue-116738.iBBAdo.rst
new file mode 100644 (file)
index 0000000..2a1ed29
--- /dev/null
@@ -0,0 +1 @@
+Make functions in :mod:`grp` thread-safe on the :term:`free threaded <free threading>` build.
index 29da9936b65504f38e49f5242eb4ee90dcd9c9b1..652958618a2c4c8b8e38f36ca3572883c4dfe66c 100644 (file)
@@ -55,6 +55,11 @@ get_grp_state(PyObject *module)
 
 static struct PyModuleDef grpmodule;
 
+/* Mutex to protect calls to getgrgid(), getgrnam(), and getgrent().
+ * These functions return pointer to static data structure, which
+ * may be overwritten by any subsequent calls. */
+static PyMutex group_db_mutex = {0};
+
 #define DEFAULT_BUFFER_SIZE 1024
 
 static PyObject *
@@ -168,9 +173,15 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
 
     Py_END_ALLOW_THREADS
 #else
+    PyMutex_Lock(&group_db_mutex);
+    // The getgrgid() function need not be thread-safe.
+    // https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrgid.html
     p = getgrgid(gid);
 #endif
     if (p == NULL) {
+#ifndef HAVE_GETGRGID_R
+        PyMutex_Unlock(&group_db_mutex);
+#endif
         PyMem_RawFree(buf);
         if (nomem == 1) {
             return PyErr_NoMemory();
@@ -185,6 +196,8 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
     retval = mkgrent(module, p);
 #ifdef HAVE_GETGRGID_R
     PyMem_RawFree(buf);
+#else
+    PyMutex_Unlock(&group_db_mutex);
 #endif
     return retval;
 }
@@ -249,9 +262,15 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
 
     Py_END_ALLOW_THREADS
 #else
+    PyMutex_Lock(&group_db_mutex);
+    // The getgrnam() function need not be thread-safe.
+    // https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrnam.html
     p = getgrnam(name_chars);
 #endif
     if (p == NULL) {
+#ifndef HAVE_GETGRNAM_R
+        PyMutex_Unlock(&group_db_mutex);
+#endif
         if (nomem == 1) {
             PyErr_NoMemory();
         }
@@ -261,6 +280,9 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
         goto out;
     }
     retval = mkgrent(module, p);
+#ifndef HAVE_GETGRNAM_R
+    PyMutex_Unlock(&group_db_mutex);
+#endif
 out:
     PyMem_RawFree(buf);
     Py_DECREF(bytes);
@@ -285,8 +307,7 @@ grp_getgrall_impl(PyObject *module)
         return NULL;
     }
 
-    static PyMutex getgrall_mutex = {0};
-    PyMutex_Lock(&getgrall_mutex);
+    PyMutex_Lock(&group_db_mutex);
     setgrent();
 
     struct group *p;
@@ -306,7 +327,7 @@ grp_getgrall_impl(PyObject *module)
 
 done:
     endgrent();
-    PyMutex_Unlock(&getgrall_mutex);
+    PyMutex_Unlock(&group_db_mutex);
     return d;
 }
 
index 15b18f5286b399964728e01537987607b6fe6603..64a9f11a944176d82294c2ea150d8d30be71a5d8 100644 (file)
@@ -167,6 +167,7 @@ Python/sysmodule.c  -       _preinit_xoptions       -
 # XXX need race protection?
 Modules/faulthandler.c faulthandler_dump_traceback     reentrant       -
 Modules/faulthandler.c faulthandler_dump_c_stack       reentrant       -
+Modules/grpmodule.c    -       group_db_mutex  -
 Python/pylifecycle.c   _Py_FatalErrorFormat    reentrant       -
 Python/pylifecycle.c   fatal_error     reentrant       -