raise RuntimeError("Failed to get all awaited_by after retries")
+def _get_stack_trace_with_retry(unwinder, timeout=SHORT_TIMEOUT, condition=None):
+ """Get stack trace from an existing unwinder with retry for transient errors.
+
+ This handles the case where we want to reuse an existing RemoteUnwinder
+ instance but still handle transient failures like "Failed to parse initial
+ frame in chain" that can occur when sampling at an inopportune moment.
+ If condition is provided, keeps retrying until condition(traces) is True.
+ """
+ last_error = None
+ for _ in busy_retry(timeout):
+ try:
+ traces = unwinder.get_stack_trace()
+ if condition is None or condition(traces):
+ return traces
+ # Condition not met yet, keep retrying
+ except TRANSIENT_ERRORS as e:
+ last_error = e
+ continue
+ if last_error:
+ raise RuntimeError(
+ f"Failed to get stack trace after retries: {last_error}"
+ )
+ raise RuntimeError("Condition never satisfied within timeout")
+
+
# ============================================================================
# Base test class with shared infrastructure
# ============================================================================
# Get stack trace with all threads
unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
- for _ in range(MAX_TRIES):
- all_traces = unwinder_all.get_stack_trace()
- found = self._find_frame_in_trace(
- all_traces,
- lambda f: f.funcname == "main_work"
- and f.location.lineno > 12,
- )
- if found:
- break
- time.sleep(RETRY_DELAY)
+ for _ in busy_retry(SHORT_TIMEOUT):
+ with contextlib.suppress(*TRANSIENT_ERRORS):
+ all_traces = unwinder_all.get_stack_trace()
+ found = self._find_frame_in_trace(
+ all_traces,
+ lambda f: f.funcname == "main_work"
+ and f.location.lineno > 12,
+ )
+ if found:
+ break
else:
self.fail(
"Main thread did not start its busy work on time"
unwinder_gil = RemoteUnwinder(
p.pid, only_active_thread=True
)
- gil_traces = unwinder_gil.get_stack_trace()
+ gil_traces = _get_stack_trace_with_retry(unwinder_gil)
# Count threads
total_threads = sum(
mode=mode,
skip_non_matching_threads=False,
)
- for _ in range(MAX_TRIES):
- traces = unwinder.get_stack_trace()
- statuses = self._get_thread_statuses(traces)
+ for _ in busy_retry(SHORT_TIMEOUT):
+ with contextlib.suppress(*TRANSIENT_ERRORS):
+ traces = unwinder.get_stack_trace()
+ statuses = self._get_thread_statuses(traces)
- if check_condition(
- statuses, sleeper_tid, busy_tid
- ):
- break
- time.sleep(0.5)
+ if check_condition(
+ statuses, sleeper_tid, busy_tid
+ ):
+ break
return statuses, sleeper_tid, busy_tid
finally:
mode=PROFILING_MODE_ALL,
skip_non_matching_threads=False,
)
- for _ in range(MAX_TRIES):
- traces = unwinder.get_stack_trace()
- statuses = self._get_thread_statuses(traces)
-
- # Check ALL mode provides both GIL and CPU info
- if (
- sleeper_tid in statuses
- and busy_tid in statuses
- and not (
- statuses[sleeper_tid]
- & THREAD_STATUS_ON_CPU
- )
- and not (
- statuses[sleeper_tid]
- & THREAD_STATUS_HAS_GIL
- )
- and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
- and (
- statuses[busy_tid] & THREAD_STATUS_HAS_GIL
- )
- ):
- break
- time.sleep(0.5)
+ for _ in busy_retry(SHORT_TIMEOUT):
+ with contextlib.suppress(*TRANSIENT_ERRORS):
+ traces = unwinder.get_stack_trace()
+ statuses = self._get_thread_statuses(traces)
+
+ # Check ALL mode provides both GIL and CPU info
+ if (
+ sleeper_tid in statuses
+ and busy_tid in statuses
+ and not (
+ statuses[sleeper_tid]
+ & THREAD_STATUS_ON_CPU
+ )
+ and not (
+ statuses[sleeper_tid]
+ & THREAD_STATUS_HAS_GIL
+ )
+ and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
+ and (
+ statuses[busy_tid] & THREAD_STATUS_HAS_GIL
+ )
+ ):
+ break
self.assertIsNotNone(
sleeper_tid, "Sleeper thread id not received"
mode=PROFILING_MODE_ALL,
skip_non_matching_threads=False,
)
- for _ in range(MAX_TRIES):
- traces = unwinder.get_stack_trace()
- statuses = self._get_thread_statuses(traces)
-
- if (
- exception_tid in statuses
- and normal_tid in statuses
- and (statuses[exception_tid] & THREAD_STATUS_HAS_EXCEPTION)
- and not (statuses[normal_tid] & THREAD_STATUS_HAS_EXCEPTION)
- ):
- break
- time.sleep(0.5)
+ for _ in busy_retry(SHORT_TIMEOUT):
+ with contextlib.suppress(*TRANSIENT_ERRORS):
+ traces = unwinder.get_stack_trace()
+ statuses = self._get_thread_statuses(traces)
+
+ if (
+ exception_tid in statuses
+ and normal_tid in statuses
+ and (statuses[exception_tid] & THREAD_STATUS_HAS_EXCEPTION)
+ and not (statuses[normal_tid] & THREAD_STATUS_HAS_EXCEPTION)
+ ):
+ break
self.assertIn(exception_tid, statuses)
self.assertIn(normal_tid, statuses)
mode=PROFILING_MODE_EXCEPTION,
skip_non_matching_threads=True,
)
- for _ in range(MAX_TRIES):
- traces = unwinder.get_stack_trace()
- statuses = self._get_thread_statuses(traces)
-
- if exception_tid in statuses:
- self.assertNotIn(
- normal_tid,
- statuses,
- "Normal thread should be filtered out in exception mode",
- )
- return
- time.sleep(0.5)
+ for _ in busy_retry(SHORT_TIMEOUT):
+ with contextlib.suppress(*TRANSIENT_ERRORS):
+ traces = unwinder.get_stack_trace()
+ statuses = self._get_thread_statuses(traces)
+
+ if exception_tid in statuses:
+ self.assertNotIn(
+ normal_tid,
+ statuses,
+ "Normal thread should be filtered out in exception mode",
+ )
+ return
self.fail("Never found exception thread in exception mode")
finally:
_cleanup_sockets(client_socket, server_socket)
- def _check_exception_status(self, p, thread_tid, expect_exception):
- """Helper to check if thread has expected exception status."""
+ def _check_thread_status(
+ self, p, thread_tid, condition, condition_name="condition"
+ ):
+ """Helper to check thread status with a custom condition.
+
+ This waits until we see 3 consecutive samples where the condition
+ returns True, which confirms the thread has reached and is stable
+ in the expected state. Samples that don't match are ignored (the
+ thread may not have reached the expected state yet).
+
+ Args:
+ p: Process object with pid attribute
+ thread_tid: Thread ID to check
+ condition: Callable(statuses, thread_tid) -> bool that returns
+ True when the thread is in the expected state
+ condition_name: Description of condition for error messages
+ """
unwinder = RemoteUnwinder(
p.pid,
all_threads=True,
skip_non_matching_threads=False,
)
- # Collect multiple samples for reliability
- results = []
- for _ in range(MAX_TRIES):
- try:
+ # Wait for 3 consecutive samples matching expected state
+ matching_samples = 0
+ for _ in busy_retry(SHORT_TIMEOUT):
+ with contextlib.suppress(*TRANSIENT_ERRORS):
traces = unwinder.get_stack_trace()
- except TRANSIENT_ERRORS:
- time.sleep(RETRY_DELAY)
- continue
- statuses = self._get_thread_statuses(traces)
-
- if thread_tid in statuses:
- has_exc = bool(statuses[thread_tid] & THREAD_STATUS_HAS_EXCEPTION)
- results.append(has_exc)
+ statuses = self._get_thread_statuses(traces)
- if len(results) >= 3:
- break
+ if thread_tid in statuses:
+ if condition(statuses, thread_tid):
+ matching_samples += 1
+ if matching_samples >= 3:
+ return # Success - confirmed stable in expected state
+ else:
+ # Thread not yet in expected state, reset counter
+ matching_samples = 0
- time.sleep(RETRY_DELAY)
+ self.fail(
+ f"Thread did not stabilize in expected state "
+ f"({condition_name}) within timeout"
+ )
- # Check majority of samples match expected
- if not results:
- self.fail("Never found target thread in stack traces")
+ def _check_exception_status(self, p, thread_tid, expect_exception):
+ """Helper to check if thread has expected exception status."""
+ def condition(statuses, tid):
+ has_exc = bool(statuses[tid] & THREAD_STATUS_HAS_EXCEPTION)
+ return has_exc == expect_exception
- majority = sum(results) > len(results) // 2
- if expect_exception:
- self.assertTrue(
- majority,
- f"Thread should have HAS_EXCEPTION flag, got {results}"
- )
- else:
- self.assertFalse(
- majority,
- f"Thread should NOT have HAS_EXCEPTION flag, got {results}"
- )
+ self._check_thread_status(
+ p, thread_tid, condition,
+ condition_name=f"expect_exception={expect_exception}"
+ )
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
_wait_for_signal(client_socket, b"ready")
# Take a sample
- unwinder.get_stack_trace()
+ _get_stack_trace_with_retry(unwinder)
stats = unwinder.get_stats()
client_socket.sendall(b"done")