]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-109974: Fix more threading lock_tests race conditions (#110089)
authorVictor Stinner <vstinner@python.org>
Fri, 29 Sep 2023 12:21:18 +0000 (14:21 +0200)
committerGitHub <noreply@github.com>
Fri, 29 Sep 2023 12:21:18 +0000 (12:21 +0000)
* Add context manager on Bunch class.
* Bunch now catchs exceptions on executed functions and re-raise them
  at __exit__() as an ExceptionGroup.
* Rewrite BarrierProxy.test_default_timeout(). Use a single thread.
  Only check that barrier.wait() blocks for at least default timeout
  seconds.
* test_with(): inline _with() function.

Lib/test/lock_tests.py
Lib/test/test_importlib/test_locks.py

index cbaae3afd6dde353752b5a2babcbaae123187b44..024c6debcd4a5472444b050ebc166334381d6fd7 100644 (file)
@@ -39,40 +39,54 @@ class Bunch(object):
         self.nthread = nthread
         self.started = []
         self.finished = []
+        self.exceptions = []
         self._can_exit = not wait_before_exit
-        self.wait_thread = threading_helper.wait_threads_exit()
-        self.wait_thread.__enter__()
+        self._wait_thread = None
 
-        def task():
-            tid = threading.get_ident()
-            self.started.append(tid)
-            try:
-                func()
-            finally:
-                self.finished.append(tid)
-                for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-                    if self._can_exit:
-                        break
+    def task(self):
+        tid = threading.get_ident()
+        self.started.append(tid)
+        try:
+            self.func()
+        except BaseException as exc:
+            self.exceptions.append(exc)
+        finally:
+            self.finished.append(tid)
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if self._can_exit:
+                    break
+
+    def __enter__(self):
+        self._wait_thread = threading_helper.wait_threads_exit(support.SHORT_TIMEOUT)
+        self._wait_thread.__enter__()
 
         try:
-            for i in range(nthread):
-                start_new_thread(task, ())
+            for _ in range(self.nthread):
+                start_new_thread(self.task, ())
         except:
             self._can_exit = True
             raise
 
-    def wait_for_started(self):
         for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
             if len(self.started) >= self.nthread:
                 break
 
-    def wait_for_finished(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
         for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
             if len(self.finished) >= self.nthread:
                 break
 
-        # Wait for threads exit
-        self.wait_thread.__exit__(None, None, None)
+        # Wait until threads completely exit according to _thread._count()
+        self._wait_thread.__exit__(None, None, None)
+
+        # Break reference cycle
+        exceptions = self.exceptions
+        self.exceptions = None
+        if exceptions:
+            raise ExceptionGroup(f"{self.func} threads raised exceptions",
+                                 exceptions)
 
     def do_finish(self):
         self._can_exit = True
@@ -143,7 +157,8 @@ class BaseLockTests(BaseTestCase):
         result = []
         def f():
             result.append(lock.acquire(False))
-        Bunch(f, 1).wait_for_finished()
+        with Bunch(f, 1):
+            pass
         self.assertFalse(result[0])
         lock.release()
 
@@ -154,33 +169,45 @@ class BaseLockTests(BaseTestCase):
             lock.acquire()
             lock.release()
 
-        # Threads block on lock.acquire()
         N = 5
-        b = Bunch(f, N)
-        b.wait_for_started()
-        wait_threads_blocked(N)
-        self.assertEqual(len(b.finished), 0)
+        with Bunch(f, N) as bunch:
+            # Threads block on lock.acquire()
+            wait_threads_blocked(N)
+            self.assertEqual(len(bunch.finished), 0)
 
-        # Threads unblocked
-        lock.release()
-        b.wait_for_finished()
-        self.assertEqual(len(b.finished), N)
+            # Threads unblocked
+            lock.release()
+
+        self.assertEqual(len(bunch.finished), N)
 
     def test_with(self):
         lock = self.locktype()
         def f():
             lock.acquire()
             lock.release()
-        def _with(err=None):
+
+        def with_lock(err=None):
             with lock:
                 if err is not None:
                     raise err
-        _with()
-        # Check the lock is unacquired
-        Bunch(f, 1).wait_for_finished()
-        self.assertRaises(TypeError, _with, TypeError)
-        # Check the lock is unacquired
-        Bunch(f, 1).wait_for_finished()
+
+        # Acquire the lock, do nothing, with releases the lock
+        with lock:
+            pass
+
+        # Check that the lock is unacquired
+        with Bunch(f, 1):
+            pass
+
+        # Acquire the lock, raise an exception, with releases the lock
+        with self.assertRaises(TypeError):
+            with lock:
+                raise TypeError
+
+        # Check that the lock is unacquired even if after an exception
+        # was raised in the previous "with lock:" block
+        with Bunch(f, 1):
+            pass
 
     def test_thread_leak(self):
         # The lock shouldn't leak a Thread instance when used from a foreign
@@ -192,7 +219,8 @@ class BaseLockTests(BaseTestCase):
 
         # We run many threads in the hope that existing threads ids won't
         # be recycled.
-        Bunch(f, 15).wait_for_finished()
+        with Bunch(f, 15):
+            pass
 
     def test_timeout(self):
         lock = self.locktype()
@@ -216,7 +244,8 @@ class BaseLockTests(BaseTestCase):
             results.append(lock.acquire(timeout=0.5))
             t2 = time.monotonic()
             results.append(t2 - t1)
-        Bunch(f, 1).wait_for_finished()
+        with Bunch(f, 1):
+            pass
         self.assertFalse(results[0])
         self.assertTimeout(results[1], 0.5)
 
@@ -264,8 +293,8 @@ class LockTests(BaseLockTests):
         lock.acquire()
         def f():
             lock.release()
-        b = Bunch(f, 1)
-        b.wait_for_finished()
+        with Bunch(f, 1):
+            pass
         lock.acquire()
         lock.release()
 
@@ -376,12 +405,12 @@ class RLockTests(BaseLockTests):
         lock = self.locktype()
         def f():
             lock.acquire()
-        b = Bunch(f, 1, True)
-        try:
-            self.assertRaises(RuntimeError, lock.release)
-        finally:
-            b.do_finish()
-        b.wait_for_finished()
+
+        with Bunch(f, 1, True) as bunch:
+            try:
+                self.assertRaises(RuntimeError, lock.release)
+            finally:
+                bunch.do_finish()
 
     def test__is_owned(self):
         lock = self.locktype()
@@ -393,7 +422,8 @@ class RLockTests(BaseLockTests):
         result = []
         def f():
             result.append(lock._is_owned())
-        Bunch(f, 1).wait_for_finished()
+        with Bunch(f, 1):
+            pass
         self.assertFalse(result[0])
         lock.release()
         self.assertTrue(lock._is_owned())
@@ -427,15 +457,14 @@ class EventTests(BaseTestCase):
             results1.append(evt.wait())
             results2.append(evt.wait())
 
-        # Threads blocked on first evt.wait()
-        b = Bunch(f, N)
-        b.wait_for_started()
-        wait_threads_blocked(N)
-        self.assertEqual(len(results1), 0)
+        with Bunch(f, N):
+            # Threads blocked on first evt.wait()
+            wait_threads_blocked(N)
+            self.assertEqual(len(results1), 0)
+
+            # Threads unblocked
+            evt.set()
 
-        # Threads unblocked
-        evt.set()
-        b.wait_for_finished()
         self.assertEqual(results1, [True] * N)
         self.assertEqual(results2, [True] * N)
 
@@ -458,16 +487,22 @@ class EventTests(BaseTestCase):
             r = evt.wait(0.5)
             t2 = time.monotonic()
             results2.append((r, t2 - t1))
-        Bunch(f, N).wait_for_finished()
+
+        with Bunch(f, N):
+            pass
+
         self.assertEqual(results1, [False] * N)
         for r, dt in results2:
             self.assertFalse(r)
             self.assertTimeout(dt, 0.5)
+
         # The event is set
         results1 = []
         results2 = []
         evt.set()
-        Bunch(f, N).wait_for_finished()
+        with Bunch(f, N):
+            pass
+
         self.assertEqual(results1, [True] * N)
         for r, dt in results2:
             self.assertTrue(r)
@@ -480,16 +515,15 @@ class EventTests(BaseTestCase):
         def f():
             results.append(event.wait(support.LONG_TIMEOUT))
 
-        # Threads blocked on event.wait()
         N = 5
-        b = Bunch(f, N)
-        b.wait_for_started()
-        wait_threads_blocked(N)
-
-        # Threads unblocked
-        event.set()
-        event.clear()
-        b.wait_for_finished()
+        with Bunch(f, N):
+            # Threads blocked on event.wait()
+            wait_threads_blocked(N)
+
+            # Threads unblocked
+            event.set()
+            event.clear()
+
         self.assertEqual(results, [True] * N)
 
     @requires_fork
@@ -573,73 +607,71 @@ class ConditionTests(BaseTestCase):
             results2.append((result, phase_num))
 
         N = 5
-        b = Bunch(f, N)
-        b.wait_for_started()
-        # first wait, to ensure all workers settle into cond.wait() before
-        # we continue. See issues #8799 and #30727.
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if len(ready) >= N:
-                break
+        with Bunch(f, N):
+            # first wait, to ensure all workers settle into cond.wait() before
+            # we continue. See issues #8799 and #30727.
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if len(ready) >= N:
+                    break
 
-        ready.clear()
-        self.assertEqual(results1, [])
+            ready.clear()
+            self.assertEqual(results1, [])
 
-        # Notify 3 threads at first
-        count1 = 3
-        cond.acquire()
-        cond.notify(count1)
-        wait_threads_blocked(count1)
+            # Notify 3 threads at first
+            count1 = 3
+            cond.acquire()
+            cond.notify(count1)
+            wait_threads_blocked(count1)
 
-        # Phase 1
-        phase_num = 1
-        cond.release()
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if len(results1) >= count1:
-                break
+            # Phase 1
+            phase_num = 1
+            cond.release()
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if len(results1) >= count1:
+                    break
 
-        self.assertEqual(results1, [(True, 1)] * count1)
-        self.assertEqual(results2, [])
+            self.assertEqual(results1, [(True, 1)] * count1)
+            self.assertEqual(results2, [])
 
-        # Wait until awaken workers are blocked on cond.wait()
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if len(ready) >= count1 :
-                break
+            # Wait until awaken workers are blocked on cond.wait()
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if len(ready) >= count1 :
+                    break
 
-        # Notify 5 threads: they might be in their first or second wait
-        cond.acquire()
-        cond.notify(5)
-        wait_threads_blocked(N)
+            # Notify 5 threads: they might be in their first or second wait
+            cond.acquire()
+            cond.notify(5)
+            wait_threads_blocked(N)
 
-        # Phase 2
-        phase_num = 2
-        cond.release()
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if len(results1) + len(results2) >= (N + count1):
-                break
+            # Phase 2
+            phase_num = 2
+            cond.release()
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if len(results1) + len(results2) >= (N + count1):
+                    break
 
-        count2 = N - count1
-        self.assertEqual(results1, [(True, 1)] * count1 + [(True, 2)] * count2)
-        self.assertEqual(results2, [(True, 2)] * count1)
+            count2 = N - count1
+            self.assertEqual(results1, [(True, 1)] * count1 + [(True, 2)] * count2)
+            self.assertEqual(results2, [(True, 2)] * count1)
 
-        # Make sure all workers settle into cond.wait()
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if len(ready) >= N:
-                break
+            # Make sure all workers settle into cond.wait()
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if len(ready) >= N:
+                    break
 
-        # Notify all threads: they are all in their second wait
-        cond.acquire()
-        cond.notify_all()
-        wait_threads_blocked(N)
+            # Notify all threads: they are all in their second wait
+            cond.acquire()
+            cond.notify_all()
+            wait_threads_blocked(N)
 
-        # Phase 3
-        phase_num = 3
-        cond.release()
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if len(results2) >= N:
-                break
-        self.assertEqual(results1, [(True, 1)] * count1 + [(True, 2)] * count2)
-        self.assertEqual(results2, [(True, 2)] * count1 + [(True, 3)] * count2)
-        b.wait_for_finished()
+            # Phase 3
+            phase_num = 3
+            cond.release()
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if len(results2) >= N:
+                    break
+            self.assertEqual(results1, [(True, 1)] * count1 + [(True, 2)] * count2)
+            self.assertEqual(results2, [(True, 2)] * count1 + [(True, 3)] * count2)
 
     def test_notify(self):
         cond = self.condtype()
@@ -660,7 +692,8 @@ class ConditionTests(BaseTestCase):
             results.append((t2 - t1, result))
 
         N = 5
-        Bunch(f, N).wait_for_finished()
+        with Bunch(f, N):
+            pass
         self.assertEqual(len(results), N)
 
         for dt, result in results:
@@ -680,14 +713,13 @@ class ConditionTests(BaseTestCase):
                 result = cond.wait_for(lambda: state == 4)
                 self.assertTrue(result)
                 self.assertEqual(state, 4)
-        b = Bunch(f, 1)
-        b.wait_for_started()
-        for i in range(4):
-            time.sleep(0.010)
-            with cond:
-                state += 1
-                cond.notify()
-        b.wait_for_finished()
+
+        with Bunch(f, 1):
+            for i in range(4):
+                time.sleep(0.010)
+                with cond:
+                    state += 1
+                    cond.notify()
 
     def test_waitfor_timeout(self):
         cond = self.condtype()
@@ -702,16 +734,14 @@ class ConditionTests(BaseTestCase):
                 self.assertTimeout(dt, 0.1)
                 success.append(None)
 
-        b = Bunch(f, 1)
-        b.wait_for_started()
-        # Only increment 3 times, so state == 4 is never reached.
-        for i in range(3):
-            time.sleep(0.010)
-            with cond:
-                state += 1
-                cond.notify()
+        with Bunch(f, 1):
+            # Only increment 3 times, so state == 4 is never reached.
+            for i in range(3):
+                time.sleep(0.010)
+                with cond:
+                    state += 1
+                    cond.notify()
 
-        b.wait_for_finished()
         self.assertEqual(len(success), 1)
 
 
@@ -761,38 +791,37 @@ class BaseSemaphoreTests(BaseTestCase):
                 if len(results1) + len(results2) >= count:
                     break
 
-        # Phase 0
         N = 10
-        b = Bunch(func, N)
-        b.wait_for_started()
-        count1 = sem_value - 1
-        wait_count(count1)
-        self.assertEqual(results1 + results2, [0] * count1)
-
-        # Phase 1
-        phase_num = 1
-        for i in range(sem_value):
-            sem.release()
-        count2 = sem_value
-        wait_count(count1 + count2)
-        self.assertEqual(sorted(results1 + results2),
-                         [0] * count1 + [1] * count2)
-
-        # Phase 2
-        phase_num = 2
-        count3 = (sem_value - 1)
-        for i in range(count3):
+        with Bunch(func, N):
+            # Phase 0
+            count1 = sem_value - 1
+            wait_count(count1)
+            self.assertEqual(results1 + results2, [0] * count1)
+
+            # Phase 1
+            phase_num = 1
+            for i in range(sem_value):
+                sem.release()
+            count2 = sem_value
+            wait_count(count1 + count2)
+            self.assertEqual(sorted(results1 + results2),
+                             [0] * count1 + [1] * count2)
+
+            # Phase 2
+            phase_num = 2
+            count3 = (sem_value - 1)
+            for i in range(count3):
+                sem.release()
+            wait_count(count1 + count2 + count3)
+            self.assertEqual(sorted(results1 + results2),
+                             [0] * count1 + [1] * count2 + [2] * count3)
+            # The semaphore is still locked
+            self.assertFalse(sem.acquire(False))
+
+            # Final release, to let the last thread finish
+            count4 = 1
             sem.release()
-        wait_count(count1 + count2 + count3)
-        self.assertEqual(sorted(results1 + results2),
-                         [0] * count1 + [1] * count2 + [2] * count3)
-        # The semaphore is still locked
-        self.assertFalse(sem.acquire(False))
 
-        # Final release, to let the last thread finish
-        count4 = 1
-        sem.release()
-        b.wait_for_finished()
         self.assertEqual(sem_results,
                          [True] * (count1 + count2 + count3 + count4))
 
@@ -816,34 +845,32 @@ class BaseSemaphoreTests(BaseTestCase):
                 if len(results1) + len(results2) >= count:
                     break
 
-        # Phase 0
-        b = Bunch(func, 10)
-        b.wait_for_started()
-        count1 = sem_value - 1
-        wait_count(count1)
-        self.assertEqual(results1 + results2, [0] * count1)
-
-        # Phase 1
-        phase_num = 1
-        count2 = sem_value
-        sem.release(count2)
-        wait_count(count1 + count2)
-        self.assertEqual(sorted(results1 + results2),
-                         [0] * count1 + [1] * count2)
-
-        # Phase 2
-        phase_num = 2
-        count3 = sem_value - 1
-        sem.release(count3)
-        wait_count(count1 + count2 + count3)
-        self.assertEqual(sorted(results1 + results2),
-                         [0] * count1 + [1] * count2 + [2] * count3)
-        # The semaphore is still locked
-        self.assertFalse(sem.acquire(False))
-
-        # Final release, to let the last thread finish
-        sem.release()
-        b.wait_for_finished()
+        with Bunch(func, 10):
+            # Phase 0
+            count1 = sem_value - 1
+            wait_count(count1)
+            self.assertEqual(results1 + results2, [0] * count1)
+
+            # Phase 1
+            phase_num = 1
+            count2 = sem_value
+            sem.release(count2)
+            wait_count(count1 + count2)
+            self.assertEqual(sorted(results1 + results2),
+                             [0] * count1 + [1] * count2)
+
+            # Phase 2
+            phase_num = 2
+            count3 = sem_value - 1
+            sem.release(count3)
+            wait_count(count1 + count2 + count3)
+            self.assertEqual(sorted(results1 + results2),
+                             [0] * count1 + [1] * count2 + [2] * count3)
+            # The semaphore is still locked
+            self.assertFalse(sem.acquire(False))
+
+            # Final release, to let the last thread finish
+            sem.release()
 
     def test_try_acquire(self):
         sem = self.semtype(2)
@@ -860,7 +887,8 @@ class BaseSemaphoreTests(BaseTestCase):
         def f():
             results.append(sem.acquire(False))
             results.append(sem.acquire(False))
-        Bunch(f, 5).wait_for_finished()
+        with Bunch(f, 5):
+            pass
         # There can be a thread switch between acquiring the semaphore and
         # appending the result, therefore results will not necessarily be
         # ordered.
@@ -887,15 +915,13 @@ class BaseSemaphoreTests(BaseTestCase):
             sem.acquire()
             sem.release()
 
-        # Thread blocked on sem.acquire()
-        b = Bunch(f, 1)
-        b.wait_for_started()
-        wait_threads_blocked(1)
-        self.assertFalse(b.finished)
+        with Bunch(f, 1) as bunch:
+            # Thread blocked on sem.acquire()
+            wait_threads_blocked(1)
+            self.assertFalse(bunch.finished)
 
-        # Thread unblocked
-        sem.release()
-        b.wait_for_finished()
+            # Thread unblocked
+            sem.release()
 
     def test_with(self):
         sem = self.semtype(2)
@@ -971,9 +997,8 @@ class BarrierTests(BaseTestCase):
         self.barrier.abort()
 
     def run_threads(self, f):
-        b = Bunch(f, self.N-1)
-        f()
-        b.wait_for_finished()
+        with Bunch(f, self.N):
+            pass
 
     def multipass(self, results, n):
         m = self.barrier.parties
@@ -1126,27 +1151,27 @@ class BarrierTests(BaseTestCase):
             i = self.barrier.wait()
             if i == self.N // 2:
                 # One thread is late!
-                time.sleep(1.0)
+                time.sleep(self.defaultTimeout / 2)
             # Default timeout is 2.0, so this is shorter.
             self.assertRaises(threading.BrokenBarrierError,
-                              self.barrier.wait, 0.5)
+                              self.barrier.wait, self.defaultTimeout / 4)
         self.run_threads(f)
 
     def test_default_timeout(self):
         """
         Test the barrier's default timeout
         """
-        # gh-109401: Barrier timeout should be long enough
-        # to create 4 threads on a slow CI.
-        timeout = 1.0
-        barrier = self.barriertype(self.N, timeout=timeout)
+        timeout = 0.100
+        barrier = self.barriertype(2, timeout=timeout)
         def f():
-            i = barrier.wait()
-            if i == self.N // 2:
-                # One thread is later than the default timeout.
-                time.sleep(timeout * 2)
-            self.assertRaises(threading.BrokenBarrierError, barrier.wait)
-        self.run_threads(f)
+            self.assertRaises(threading.BrokenBarrierError,
+                              barrier.wait)
+
+        start_time = time.monotonic()
+        with Bunch(f, 1):
+            pass
+        dt = time.monotonic() - start_time
+        self.assertGreaterEqual(dt, timeout)
 
     def test_single_thread(self):
         b = self.barriertype(1)
@@ -1160,19 +1185,18 @@ class BarrierTests(BaseTestCase):
         def f():
             barrier.wait(timeout)
 
-        # Threads blocked on barrier.wait()
         N = 2
-        bunch = Bunch(f, N)
-        bunch.wait_for_started()
-        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
-            if barrier.n_waiting >= N:
-                break
-        self.assertRegex(repr(barrier),
-                         r"<\w+\.Barrier at .*: waiters=2/3>")
+        with Bunch(f, N):
+            # Threads blocked on barrier.wait()
+            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
+                if barrier.n_waiting >= N:
+                    break
+            self.assertRegex(repr(barrier),
+                             r"<\w+\.Barrier at .*: waiters=2/3>")
+
+            # Threads unblocked
+            barrier.wait(timeout)
 
-        # Threads unblocked
-        barrier.wait(timeout)
-        bunch.wait_for_finished()
         self.assertRegex(repr(barrier),
                          r"<\w+\.Barrier at .*: waiters=0/3>")
 
index 7091c36aaaf7613f86a55862b7309ddaad5c6d33..befac5d62b0abf1ef607b880f5ab1373c8339acc 100644 (file)
@@ -93,7 +93,8 @@ class DeadlockAvoidanceTests:
                 b.release()
             if ra:
                 a.release()
-        lock_tests.Bunch(f, NTHREADS).wait_for_finished()
+        with lock_tests.Bunch(f, NTHREADS):
+            pass
         self.assertEqual(len(results), NTHREADS)
         return results