import socket
import threading
import time
+from contextlib import contextmanager
from asyncio import staggered, taskgroups, base_events, tasks
from unittest.mock import ANY
from test.support import (
PROFILING_MODE_ALL = 3
# Thread status flags
-THREAD_STATUS_HAS_GIL = (1 << 0)
-THREAD_STATUS_ON_CPU = (1 << 1)
-THREAD_STATUS_UNKNOWN = (1 << 2)
+THREAD_STATUS_HAS_GIL = 1 << 0
+THREAD_STATUS_ON_CPU = 1 << 1
+THREAD_STATUS_UNKNOWN = 1 << 2
+
+# Maximum number of retry attempts for operations that may fail transiently
+MAX_TRIES = 10
try:
from concurrent import interpreters
)
+# ============================================================================
+# Module-level helper functions
+# ============================================================================
+
+
def _make_test_script(script_dir, script_basename, source):
to_return = make_script(script_dir, script_basename, source)
importlib.invalidate_caches()
return to_return
+def _create_server_socket(port, backlog=1):
+ """Create and configure a server socket for test communication."""
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server_socket.bind(("localhost", port))
+ server_socket.settimeout(SHORT_TIMEOUT)
+ server_socket.listen(backlog)
+ return server_socket
+
+
+def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT):
+ """
+ Wait for expected signal(s) from a socket with proper timeout and EOF handling.
+
+ Args:
+ sock: Connected socket to read from
+ expected_signals: Single bytes object or list of bytes objects to wait for
+ timeout: Socket timeout in seconds
+
+ Returns:
+ bytes: Complete accumulated response buffer
+
+ Raises:
+ RuntimeError: If connection closed before signal received or timeout
+ """
+ if isinstance(expected_signals, bytes):
+ expected_signals = [expected_signals]
+
+ sock.settimeout(timeout)
+ buffer = b""
+
+ while True:
+ # Check if all expected signals are in buffer
+ if all(sig in buffer for sig in expected_signals):
+ return buffer
+
+ try:
+ chunk = sock.recv(4096)
+ if not chunk:
+ # EOF - connection closed
+ raise RuntimeError(
+ f"Connection closed before receiving expected signals. "
+ f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
+ )
+ buffer += chunk
+ except socket.timeout:
+ raise RuntimeError(
+ f"Timeout waiting for signals. "
+ f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
+ )
+
+
+def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT):
+ """
+ Wait for N occurrences of a signal pattern.
+
+ Args:
+ sock: Connected socket to read from
+ signal_pattern: bytes pattern to count (e.g., b"ready")
+ count: Number of occurrences expected
+ timeout: Socket timeout in seconds
+
+ Returns:
+ bytes: Complete accumulated response buffer
+
+ Raises:
+ RuntimeError: If connection closed or timeout before receiving all signals
+ """
+ sock.settimeout(timeout)
+ buffer = b""
+ found_count = 0
+
+ while found_count < count:
+ try:
+ chunk = sock.recv(4096)
+ if not chunk:
+ raise RuntimeError(
+ f"Connection closed after {found_count}/{count} signals. "
+ f"Last 200 bytes: {buffer[-200:]!r}"
+ )
+ buffer += chunk
+ # Count occurrences in entire buffer
+ found_count = buffer.count(signal_pattern)
+ except socket.timeout:
+ raise RuntimeError(
+ f"Timeout waiting for {count} signals (found {found_count}). "
+ f"Last 200 bytes: {buffer[-200:]!r}"
+ )
+
+ return buffer
+
+
+@contextmanager
+def _managed_subprocess(args, timeout=SHORT_TIMEOUT):
+ """
+ Context manager for subprocess lifecycle management.
+
+ Ensures process is properly terminated and cleaned up even on exceptions.
+ Uses graceful termination first, then forceful kill if needed.
+ """
+ p = subprocess.Popen(args)
+ try:
+ yield p
+ finally:
+ try:
+ p.terminate()
+ try:
+ p.wait(timeout=timeout)
+ except subprocess.TimeoutExpired:
+ p.kill()
+ try:
+ p.wait(timeout=timeout)
+ except subprocess.TimeoutExpired:
+ pass # Process refuses to die, nothing more we can do
+ except OSError:
+ pass # Process already dead
+
+
+def _cleanup_sockets(*sockets):
+ """Safely close multiple sockets, ignoring errors."""
+ for sock in sockets:
+ if sock is not None:
+ try:
+ sock.close()
+ except OSError:
+ pass
+
+
+# ============================================================================
+# Decorators and skip conditions
+# ============================================================================
+
skip_if_not_supported = unittest.skipIf(
(
sys.platform != "darwin"
def requires_subinterpreters(meth):
"""Decorator to skip a test if subinterpreters are not supported."""
- return unittest.skipIf(interpreters is None,
- 'subinterpreters required')(meth)
+ return unittest.skipIf(interpreters is None, "subinterpreters required")(
+ meth
+ )
+
+
+# ============================================================================
+# Simple wrapper functions for RemoteUnwinder
+# ============================================================================
+
+# Errors that can occur transiently when reading process memory without synchronization
+RETRIABLE_ERRORS = (
+ "Task list appears corrupted",
+ "Invalid linked list structure reading remote memory",
+ "Unknown error reading memory",
+ "Unhandled frame owner",
+ "Failed to parse initial frame",
+ "Failed to process frame chain",
+ "Failed to unwind stack",
+)
+
+
+def _is_retriable_error(exc):
+ """Check if an exception is a transient error that should be retried."""
+ msg = str(exc)
+ return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS)
def get_stack_trace(pid):
- unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
- return unwinder.get_stack_trace()
+ for _ in busy_retry(SHORT_TIMEOUT):
+ try:
+ unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
+ return unwinder.get_stack_trace()
+ except RuntimeError as e:
+ if _is_retriable_error(e):
+ continue
+ raise
+ raise RuntimeError("Failed to get stack trace after retries")
def get_async_stack_trace(pid):
- unwinder = RemoteUnwinder(pid, debug=True)
- return unwinder.get_async_stack_trace()
+ for _ in busy_retry(SHORT_TIMEOUT):
+ try:
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_async_stack_trace()
+ except RuntimeError as e:
+ if _is_retriable_error(e):
+ continue
+ raise
+ raise RuntimeError("Failed to get async stack trace after retries")
def get_all_awaited_by(pid):
- unwinder = RemoteUnwinder(pid, debug=True)
- return unwinder.get_all_awaited_by()
+ for _ in busy_retry(SHORT_TIMEOUT):
+ try:
+ unwinder = RemoteUnwinder(pid, debug=True)
+ return unwinder.get_all_awaited_by()
+ except RuntimeError as e:
+ if _is_retriable_error(e):
+ continue
+ raise
+ raise RuntimeError("Failed to get all awaited_by after retries")
+
+
+# ============================================================================
+# Base test class with shared infrastructure
+# ============================================================================
-class TestGetStackTrace(unittest.TestCase):
+class RemoteInspectionTestBase(unittest.TestCase):
+ """Base class for remote inspection tests with common helpers."""
+
maxDiff = None
+ def _run_script_and_get_trace(
+ self,
+ script,
+ trace_func,
+ wait_for_signals=None,
+ port=None,
+ backlog=1,
+ ):
+ """
+ Common pattern: run a script, wait for signals, get trace.
+
+ Args:
+ script: Script content (will be formatted with port if {port} present)
+ trace_func: Function to call with pid to get trace (e.g., get_stack_trace)
+ wait_for_signals: Signal(s) to wait for before getting trace
+ port: Port to use (auto-selected if None)
+ backlog: Socket listen backlog
+
+ Returns:
+ tuple: (trace_result, script_name)
+ """
+ if port is None:
+ port = find_unused_port()
+
+ # Format script with port if needed
+ if "{port}" in script or "{{port}}" in script:
+ script = script.replace("{{port}}", "{port}").format(port=port)
+
+ with os_helper.temp_dir() as work_dir:
+ script_dir = os.path.join(work_dir, "script_pkg")
+ os.mkdir(script_dir)
+
+ server_socket = _create_server_socket(port, backlog)
+ script_name = _make_test_script(script_dir, "script", script)
+ client_socket = None
+
+ try:
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
+
+ if wait_for_signals:
+ _wait_for_signal(client_socket, wait_for_signals)
+
+ try:
+ trace = trace_func(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+ return trace, script_name
+ finally:
+ _cleanup_sockets(client_socket, server_socket)
+
+ def _find_frame_in_trace(self, stack_trace, predicate):
+ """
+ Find a frame matching predicate in stack trace.
+
+ Args:
+ stack_trace: List of InterpreterInfo objects
+ predicate: Function(frame) -> bool
+
+ Returns:
+ FrameInfo or None
+ """
+ for interpreter_info in stack_trace:
+ for thread_info in interpreter_info.threads:
+ for frame in thread_info.frame_info:
+ if predicate(frame):
+ return frame
+ return None
+
+ def _find_thread_by_id(self, stack_trace, thread_id):
+ """Find a thread by its native thread ID."""
+ for interpreter_info in stack_trace:
+ for thread_info in interpreter_info.threads:
+ if thread_info.thread_id == thread_id:
+ return thread_info
+ return None
+
+ def _find_thread_with_frame(self, stack_trace, frame_predicate):
+ """Find a thread containing a frame matching predicate."""
+ for interpreter_info in stack_trace:
+ for thread_info in interpreter_info.threads:
+ for frame in thread_info.frame_info:
+ if frame_predicate(frame):
+ return thread_info
+ return None
+
+ def _get_thread_statuses(self, stack_trace):
+ """Extract thread_id -> status mapping from stack trace."""
+ statuses = {}
+ for interpreter_info in stack_trace:
+ for thread_info in interpreter_info.threads:
+ statuses[thread_info.thread_id] = thread_info.status
+ return statuses
+
+ def _get_task_id_map(self, stack_trace):
+ """Create task_id -> task mapping from async stack trace."""
+ return {task.task_id: task for task in stack_trace[0].awaited_by}
+
+ def _get_awaited_by_relationships(self, stack_trace):
+ """Extract task name to awaited_by set mapping."""
+ id_to_task = self._get_task_id_map(stack_trace)
+ return {
+ task.task_name: set(
+ id_to_task[awaited.task_name].task_name
+ for awaited in task.awaited_by
+ )
+ for task in stack_trace[0].awaited_by
+ }
+
+ def _extract_coroutine_stacks(self, stack_trace):
+ """Extract and format coroutine stacks from tasks."""
+ return {
+ task.task_name: sorted(
+ tuple(tuple(frame) for frame in coro.call_stack)
+ for coro in task.coroutine_stack
+ )
+ for task in stack_trace[0].awaited_by
+ }
+
+
+# ============================================================================
+# Test classes
+# ============================================================================
+
+
+class TestGetStackTrace(RemoteInspectionTestBase):
@skip_if_not_supported
@unittest.skipIf(
sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
"Test only runs on Linux with process_vm_readv support",
)
def test_remote_stack_trace(self):
- # Spawn a process with some realistic Python code
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time, sys, socket, threading
- # Connect to the test process
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
foo()
def foo():
- sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number
+ sock.sendall(b"ready:thread\\n"); time.sleep(10_000)
t = threading.Thread(target=bar)
t.start()
- sock.sendall(b"ready:main\\n"); t.join() # same line number
+ sock.sendall(b"ready:main\\n"); t.join()
"""
)
- stack_trace = None
+
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
-
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = b""
- while (
- b"ready:main" not in response
- or b"ready:thread" not in response
- ):
- response += client_socket.recv(1024)
- stack_trace = get_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
+
+ _wait_for_signal(
+ client_socket, [b"ready:main", b"ready:thread"]
+ )
+
+ try:
+ stack_trace = get_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+
+ thread_expected_stack_trace = [
+ FrameInfo([script_name, 15, "foo"]),
+ FrameInfo([script_name, 12, "baz"]),
+ FrameInfo([script_name, 9, "bar"]),
+ FrameInfo([threading.__file__, ANY, "Thread.run"]),
+ FrameInfo(
+ [
+ threading.__file__,
+ ANY,
+ "Thread._bootstrap_inner",
+ ]
+ ),
+ FrameInfo(
+ [threading.__file__, ANY, "Thread._bootstrap"]
+ ),
+ ]
+
+ # Find expected thread stack
+ found_thread = self._find_thread_with_frame(
+ stack_trace,
+ lambda f: f.funcname == "foo" and f.lineno == 15,
+ )
+ self.assertIsNotNone(
+ found_thread, "Expected thread stack trace not found"
+ )
+ self.assertEqual(
+ found_thread.frame_info, thread_expected_stack_trace
+ )
+
+ # Check main thread
+ main_frame = FrameInfo([script_name, 19, "<module>"])
+ found_main = self._find_frame_in_trace(
+ stack_trace, lambda f: f == main_frame
+ )
+ self.assertIsNotNone(
+ found_main, "Main thread stack trace not found"
+ )
finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
-
- thread_expected_stack_trace = [
- FrameInfo([script_name, 15, "foo"]),
- FrameInfo([script_name, 12, "baz"]),
- FrameInfo([script_name, 9, "bar"]),
- FrameInfo([threading.__file__, ANY, "Thread.run"]),
- FrameInfo([threading.__file__, ANY, "Thread._bootstrap_inner"]),
- FrameInfo([threading.__file__, ANY, "Thread._bootstrap"]),
- ]
- # Is possible that there are more threads, so we check that the
- # expected stack traces are in the result (looking at you Windows!)
- found_expected_stack = False
- for interpreter_info in stack_trace:
- for thread_info in interpreter_info.threads:
- if thread_info.frame_info == thread_expected_stack_trace:
- found_expected_stack = True
- break
- if found_expected_stack:
- break
- self.assertTrue(found_expected_stack, "Expected thread stack trace not found")
-
- # Check that the main thread stack trace is in the result
- frame = FrameInfo([script_name, 19, "<module>"])
- main_thread_found = False
- for interpreter_info in stack_trace:
- for thread_info in interpreter_info.threads:
- if frame in thread_info.frame_info:
- main_thread_found = True
- break
- if main_thread_found:
- break
- self.assertTrue(main_thread_found, "Main thread stack trace not found in result")
+ _cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
"Test only runs on Linux with process_vm_readv support",
)
def test_async_remote_stack_trace(self):
- # Spawn a process with some realistic Python code
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time
import sys
import socket
- # Connect to the test process
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def c5():
- sock.sendall(b"ready"); time.sleep(10_000) # same line number
+ sock.sendall(b"ready"); time.sleep(10_000)
async def c4():
await asyncio.sleep(0)
asyncio.run(main(), loop_factory={{TASK_FACTORY}})
"""
)
- stack_trace = None
+
for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop":
with (
self.subTest(task_factory_variant=task_factory_variant),
):
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- server_socket = socket.socket(
- socket.AF_INET, socket.SOCK_STREAM
- )
- server_socket.setsockopt(
- socket.SOL_SOCKET, socket.SO_REUSEADDR, 1
- )
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
+
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(
script_dir,
"script",
script.format(TASK_FACTORY=task_factory_variant),
)
client_socket = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
- stack_trace = get_async_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
-
- # First check all the tasks are present
- tasks_names = [
- task.task_name for task in stack_trace[0].awaited_by
- ]
- for task_name in ["c2_root", "sub_main_1", "sub_main_2"]:
- self.assertIn(task_name, tasks_names)
-
- # Now ensure that the awaited_by_relationships are correct
- id_to_task = {
- task.task_id: task for task in stack_trace[0].awaited_by
- }
- task_name_to_awaited_by = {
- task.task_name: set(
- id_to_task[awaited.task_name].task_name
- for awaited in task.awaited_by
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- task_name_to_awaited_by,
- {
- "c2_root": {"Task-1", "sub_main_1", "sub_main_2"},
- "Task-1": set(),
- "sub_main_1": {"Task-1"},
- "sub_main_2": {"Task-1"},
- },
- )
+ with _managed_subprocess(
+ [sys.executable, script_name]
+ ) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
- # Now ensure that the coroutine stacks are correct
- coroutine_stacks = {
- task.task_name: sorted(
- tuple(tuple(frame) for frame in coro.call_stack)
- for coro in task.coroutine_stack
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- coroutine_stacks,
- {
- "Task-1": [
- (
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
- ),
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
- ),
- tuple([script_name, 26, "main"]),
- )
- ],
- "c2_root": [
- (
- tuple([script_name, 10, "c5"]),
- tuple([script_name, 14, "c4"]),
- tuple([script_name, 17, "c3"]),
- tuple([script_name, 20, "c2"]),
+ response = _wait_for_signal(client_socket, b"ready")
+ self.assertIn(b"ready", response)
+
+ try:
+ stack_trace = get_async_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
)
- ],
- "sub_main_1": [(tuple([script_name, 23, "c1"]),)],
- "sub_main_2": [(tuple([script_name, 23, "c1"]),)],
- },
- )
- # Now ensure the coroutine stacks for the awaited_by relationships are correct.
- awaited_by_coroutine_stacks = {
- task.task_name: sorted(
- (
- id_to_task[coro.task_name].task_name,
- tuple(tuple(frame) for frame in coro.call_stack),
+ # Check all tasks are present
+ tasks_names = [
+ task.task_name
+ for task in stack_trace[0].awaited_by
+ ]
+ for task_name in [
+ "c2_root",
+ "sub_main_1",
+ "sub_main_2",
+ ]:
+ self.assertIn(task_name, tasks_names)
+
+ # Check awaited_by relationships
+ relationships = self._get_awaited_by_relationships(
+ stack_trace
)
- for coro in task.awaited_by
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- awaited_by_coroutine_stacks,
- {
- "Task-1": [],
- "c2_root": [
- (
- "Task-1",
- (
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
- ),
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
- ),
- tuple([script_name, 26, "main"]),
- ),
- ),
- ("sub_main_1", (tuple([script_name, 23, "c1"]),)),
- ("sub_main_2", (tuple([script_name, 23, "c1"]),)),
- ],
- "sub_main_1": [
- (
- "Task-1",
+ self.assertEqual(
+ relationships,
+ {
+ "c2_root": {
+ "Task-1",
+ "sub_main_1",
+ "sub_main_2",
+ },
+ "Task-1": set(),
+ "sub_main_1": {"Task-1"},
+ "sub_main_2": {"Task-1"},
+ },
+ )
+
+ # Check coroutine stacks
+ coroutine_stacks = self._extract_coroutine_stacks(
+ stack_trace
+ )
+ self.assertEqual(
+ coroutine_stacks,
+ {
+ "Task-1": [
+ (
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ tuple([script_name, 26, "main"]),
+ )
+ ],
+ "c2_root": [
+ (
+ tuple([script_name, 10, "c5"]),
+ tuple([script_name, 14, "c4"]),
+ tuple([script_name, 17, "c3"]),
+ tuple([script_name, 20, "c2"]),
+ )
+ ],
+ "sub_main_1": [
+ (tuple([script_name, 23, "c1"]),)
+ ],
+ "sub_main_2": [
+ (tuple([script_name, 23, "c1"]),)
+ ],
+ },
+ )
+
+ # Check awaited_by coroutine stacks
+ id_to_task = self._get_task_id_map(stack_trace)
+ awaited_by_coroutine_stacks = {
+ task.task_name: sorted(
(
+ id_to_task[coro.task_name].task_name,
tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
- ),
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
+ tuple(frame)
+ for frame in coro.call_stack
),
- tuple([script_name, 26, "main"]),
- ),
+ )
+ for coro in task.awaited_by
)
- ],
- "sub_main_2": [
- (
- "Task-1",
- (
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
+ for task in stack_trace[0].awaited_by
+ }
+ self.assertEqual(
+ awaited_by_coroutine_stacks,
+ {
+ "Task-1": [],
+ "c2_root": [
+ (
+ "Task-1",
+ (
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ tuple([script_name, 26, "main"]),
+ ),
),
- tuple(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
+ (
+ "sub_main_1",
+ (tuple([script_name, 23, "c1"]),),
),
- tuple([script_name, 26, "main"]),
- ),
- )
- ],
- },
- )
+ (
+ "sub_main_2",
+ (tuple([script_name, 23, "c1"]),),
+ ),
+ ],
+ "sub_main_1": [
+ (
+ "Task-1",
+ (
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ tuple([script_name, 26, "main"]),
+ ),
+ )
+ ],
+ "sub_main_2": [
+ (
+ "Task-1",
+ (
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ tuple(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ tuple([script_name, 26, "main"]),
+ ),
+ )
+ ],
+ },
+ )
+ finally:
+ _cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
"Test only runs on Linux with process_vm_readv support",
)
def test_asyncgen_remote_stack_trace(self):
- # Spawn a process with some realistic Python code
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time
import sys
import socket
- # Connect to the test process
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
async def gen_nested_call():
- sock.sendall(b"ready"); time.sleep(10_000) # same line number
+ sock.sendall(b"ready"); time.sleep(10_000)
async def gen():
for num in range(2):
asyncio.run(main())
"""
)
- stack_trace = None
+
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
+
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
- stack_trace = get_async_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
- # For this simple asyncgen test, we only expect one task with the full coroutine stack
- self.assertEqual(len(stack_trace[0].awaited_by), 1)
- task = stack_trace[0].awaited_by[0]
- self.assertEqual(task.task_name, "Task-1")
+ response = _wait_for_signal(client_socket, b"ready")
+ self.assertIn(b"ready", response)
- # Check the coroutine stack - based on actual output, only shows main
- coroutine_stack = sorted(
- tuple(tuple(frame) for frame in coro.call_stack)
- for coro in task.coroutine_stack
- )
- self.assertEqual(
- coroutine_stack,
- [
- (
- tuple([script_name, 10, "gen_nested_call"]),
- tuple([script_name, 16, "gen"]),
- tuple([script_name, 19, "main"]),
+ try:
+ stack_trace = get_async_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+
+ # For this simple asyncgen test, we only expect one task
+ self.assertEqual(len(stack_trace[0].awaited_by), 1)
+ task = stack_trace[0].awaited_by[0]
+ self.assertEqual(task.task_name, "Task-1")
+
+ # Check the coroutine stack
+ coroutine_stack = sorted(
+ tuple(tuple(frame) for frame in coro.call_stack)
+ for coro in task.coroutine_stack
+ )
+ self.assertEqual(
+ coroutine_stack,
+ [
+ (
+ tuple([script_name, 10, "gen_nested_call"]),
+ tuple([script_name, 16, "gen"]),
+ tuple([script_name, 19, "main"]),
+ )
+ ],
)
- ],
- )
- # No awaited_by relationships expected for this simple case
- self.assertEqual(task.awaited_by, [])
+ # No awaited_by relationships expected
+ self.assertEqual(task.awaited_by, [])
+ finally:
+ _cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
"Test only runs on Linux with process_vm_readv support",
)
def test_async_gather_remote_stack_trace(self):
- # Spawn a process with some realistic Python code
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time
import sys
import socket
- # Connect to the test process
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready"); time.sleep(10_000) # same line number
+ sock.sendall(b"ready"); time.sleep(10_000)
async def c1():
await asyncio.sleep(0)
asyncio.run(main())
"""
)
- stack_trace = None
+
with os_helper.temp_dir() as work_dir:
- script_dir = os.path.join(work_dir, "script_pkg")
- os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
- script_name = _make_test_script(script_dir, "script", script)
- client_socket = None
- try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
- stack_trace = get_async_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
-
- # First check all the tasks are present
- tasks_names = [
- task.task_name for task in stack_trace[0].awaited_by
- ]
- for task_name in ["Task-1", "Task-2"]:
- self.assertIn(task_name, tasks_names)
-
- # Now ensure that the awaited_by_relationships are correct
- id_to_task = {
- task.task_id: task for task in stack_trace[0].awaited_by
- }
- task_name_to_awaited_by = {
- task.task_name: set(
- id_to_task[awaited.task_name].task_name
- for awaited in task.awaited_by
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- task_name_to_awaited_by,
- {
- "Task-1": set(),
- "Task-2": {"Task-1"},
- },
- )
+ script_dir = os.path.join(work_dir, "script_pkg")
+ os.mkdir(script_dir)
- # Now ensure that the coroutine stacks are correct
- coroutine_stacks = {
- task.task_name: sorted(
- tuple(tuple(frame) for frame in coro.call_stack)
- for coro in task.coroutine_stack
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- coroutine_stacks,
- {
- "Task-1": [(tuple([script_name, 21, "main"]),)],
- "Task-2": [
- (
- tuple([script_name, 11, "deep"]),
- tuple([script_name, 15, "c1"]),
+ server_socket = _create_server_socket(port)
+ script_name = _make_test_script(script_dir, "script", script)
+ client_socket = None
+
+ try:
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
+
+ response = _wait_for_signal(client_socket, b"ready")
+ self.assertIn(b"ready", response)
+
+ try:
+ stack_trace = get_async_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
)
- ],
- },
- )
- # Now ensure the coroutine stacks for the awaited_by relationships are correct.
- awaited_by_coroutine_stacks = {
- task.task_name: sorted(
- (
- id_to_task[coro.task_name].task_name,
- tuple(tuple(frame) for frame in coro.call_stack),
+ # Check all tasks are present
+ tasks_names = [
+ task.task_name for task in stack_trace[0].awaited_by
+ ]
+ for task_name in ["Task-1", "Task-2"]:
+ self.assertIn(task_name, tasks_names)
+
+ # Check awaited_by relationships
+ relationships = self._get_awaited_by_relationships(
+ stack_trace
)
- for coro in task.awaited_by
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- awaited_by_coroutine_stacks,
- {
- "Task-1": [],
- "Task-2": [
- ("Task-1", (tuple([script_name, 21, "main"]),))
- ],
- },
- )
+ self.assertEqual(
+ relationships,
+ {
+ "Task-1": set(),
+ "Task-2": {"Task-1"},
+ },
+ )
+
+ # Check coroutine stacks
+ coroutine_stacks = self._extract_coroutine_stacks(
+ stack_trace
+ )
+ self.assertEqual(
+ coroutine_stacks,
+ {
+ "Task-1": [(tuple([script_name, 21, "main"]),)],
+ "Task-2": [
+ (
+ tuple([script_name, 11, "deep"]),
+ tuple([script_name, 15, "c1"]),
+ )
+ ],
+ },
+ )
+
+ # Check awaited_by coroutine stacks
+ id_to_task = self._get_task_id_map(stack_trace)
+ awaited_by_coroutine_stacks = {
+ task.task_name: sorted(
+ (
+ id_to_task[coro.task_name].task_name,
+ tuple(
+ tuple(frame) for frame in coro.call_stack
+ ),
+ )
+ for coro in task.awaited_by
+ )
+ for task in stack_trace[0].awaited_by
+ }
+ self.assertEqual(
+ awaited_by_coroutine_stacks,
+ {
+ "Task-1": [],
+ "Task-2": [
+ ("Task-1", (tuple([script_name, 21, "main"]),))
+ ],
+ },
+ )
+ finally:
+ _cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
"Test only runs on Linux with process_vm_readv support",
)
def test_async_staggered_race_remote_stack_trace(self):
- # Spawn a process with some realistic Python code
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time
import sys
import socket
- # Connect to the test process
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
async def deep():
await asyncio.sleep(0)
- sock.sendall(b"ready"); time.sleep(10_000) # same line number
+ sock.sendall(b"ready"); time.sleep(10_000)
async def c1():
await asyncio.sleep(0)
asyncio.run(main())
"""
)
- stack_trace = None
+
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
+
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = client_socket.recv(1024)
- self.assertEqual(response, b"ready")
- stack_trace = get_async_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
-
- # First check all the tasks are present
- tasks_names = [
- task.task_name for task in stack_trace[0].awaited_by
- ]
- for task_name in ["Task-1", "Task-2"]:
- self.assertIn(task_name, tasks_names)
-
- # Now ensure that the awaited_by_relationships are correct
- id_to_task = {
- task.task_id: task for task in stack_trace[0].awaited_by
- }
- task_name_to_awaited_by = {
- task.task_name: set(
- id_to_task[awaited.task_name].task_name
- for awaited in task.awaited_by
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- task_name_to_awaited_by,
- {
- "Task-1": set(),
- "Task-2": {"Task-1"},
- },
- )
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
- # Now ensure that the coroutine stacks are correct
- coroutine_stacks = {
- task.task_name: sorted(
- tuple(tuple(frame) for frame in coro.call_stack)
- for coro in task.coroutine_stack
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- coroutine_stacks,
- {
- "Task-1": [
- (
- tuple([staggered.__file__, ANY, "staggered_race"]),
- tuple([script_name, 21, "main"]),
- )
- ],
- "Task-2": [
- (
- tuple([script_name, 11, "deep"]),
- tuple([script_name, 15, "c1"]),
- tuple(
- [
- staggered.__file__,
- ANY,
- "staggered_race.<locals>.run_one_coro",
- ]
- ),
+ response = _wait_for_signal(client_socket, b"ready")
+ self.assertIn(b"ready", response)
+
+ try:
+ stack_trace = get_async_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
)
- ],
- },
- )
- # Now ensure the coroutine stacks for the awaited_by relationships are correct.
- awaited_by_coroutine_stacks = {
- task.task_name: sorted(
- (
- id_to_task[coro.task_name].task_name,
- tuple(tuple(frame) for frame in coro.call_stack),
+ # Check all tasks are present
+ tasks_names = [
+ task.task_name for task in stack_trace[0].awaited_by
+ ]
+ for task_name in ["Task-1", "Task-2"]:
+ self.assertIn(task_name, tasks_names)
+
+ # Check awaited_by relationships
+ relationships = self._get_awaited_by_relationships(
+ stack_trace
)
- for coro in task.awaited_by
- )
- for task in stack_trace[0].awaited_by
- }
- self.assertEqual(
- awaited_by_coroutine_stacks,
- {
- "Task-1": [],
- "Task-2": [
- (
- "Task-1",
+ self.assertEqual(
+ relationships,
+ {
+ "Task-1": set(),
+ "Task-2": {"Task-1"},
+ },
+ )
+
+ # Check coroutine stacks
+ coroutine_stacks = self._extract_coroutine_stacks(
+ stack_trace
+ )
+ self.assertEqual(
+ coroutine_stacks,
+ {
+ "Task-1": [
+ (
+ tuple(
+ [
+ staggered.__file__,
+ ANY,
+ "staggered_race",
+ ]
+ ),
+ tuple([script_name, 21, "main"]),
+ )
+ ],
+ "Task-2": [
+ (
+ tuple([script_name, 11, "deep"]),
+ tuple([script_name, 15, "c1"]),
+ tuple(
+ [
+ staggered.__file__,
+ ANY,
+ "staggered_race.<locals>.run_one_coro",
+ ]
+ ),
+ )
+ ],
+ },
+ )
+
+ # Check awaited_by coroutine stacks
+ id_to_task = self._get_task_id_map(stack_trace)
+ awaited_by_coroutine_stacks = {
+ task.task_name: sorted(
(
+ id_to_task[coro.task_name].task_name,
tuple(
- [staggered.__file__, ANY, "staggered_race"]
+ tuple(frame) for frame in coro.call_stack
),
- tuple([script_name, 21, "main"]),
- ),
+ )
+ for coro in task.awaited_by
)
- ],
- },
- )
+ for task in stack_trace[0].awaited_by
+ }
+ self.assertEqual(
+ awaited_by_coroutine_stacks,
+ {
+ "Task-1": [],
+ "Task-2": [
+ (
+ "Task-1",
+ (
+ tuple(
+ [
+ staggered.__file__,
+ ANY,
+ "staggered_race",
+ ]
+ ),
+ tuple([script_name, 21, "main"]),
+ ),
+ )
+ ],
+ },
+ )
+ finally:
+ _cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
"Test only runs on Linux with process_vm_readv support",
)
def test_async_global_awaited_by(self):
+ # Reduced from 1000 to 100 to avoid file descriptor exhaustion
+ # when running tests in parallel (e.g., -j 20)
+ NUM_TASKS = 100
+
port = find_unused_port()
script = textwrap.dedent(
f"""\
PORT = socket_helper.find_unused_port()
connections = 0
- # Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
assert message == data.decode()
writer.close()
await writer.wait_closed()
- # Signal we are ready to sleep
sock.sendall(b"ready")
await asyncio.sleep(SHORT_TIMEOUT)
async def echo_client_spam(server):
async with asyncio.TaskGroup() as tg:
- while connections < 1000:
+ while connections < {NUM_TASKS}:
msg = list(ascii_lowercase + digits)
random.shuffle(msg)
tg.create_task(echo_client("".join(msg)))
await asyncio.sleep(0)
- # at least a 1000 tasks created. Each task will signal
- # when is ready to avoid the race caused by the fact that
- # tasks are waited on tg.__exit__ and we cannot signal when
- # that happens otherwise
- # at this point all client tasks completed without assertion errors
- # let's wrap up the test
server.close()
await server.wait_closed()
asyncio.run(main())
"""
)
- stack_trace = None
+
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
+
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- for _ in range(1000):
- expected_response = b"ready"
- response = client_socket.recv(len(expected_response))
- self.assertEqual(response, expected_response)
- for _ in busy_retry(SHORT_TIMEOUT):
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
+
+ # Wait for NUM_TASKS "ready" signals
+ try:
+ _wait_for_n_signals(client_socket, b"ready", NUM_TASKS)
+ except RuntimeError as e:
+ self.fail(str(e))
+
try:
all_awaited_by = get_all_awaited_by(p.pid)
- except RuntimeError as re:
- # This call reads a linked list in another process with
- # no synchronization. That occasionally leads to invalid
- # reads. Here we avoid making the test flaky.
- msg = str(re)
- if msg.startswith("Task list appears corrupted"):
- continue
- elif msg.startswith(
- "Invalid linked list structure reading remote memory"
- ):
- continue
- elif msg.startswith("Unknown error reading memory"):
- continue
- elif msg.startswith("Unhandled frame owner"):
- continue
- raise # Unrecognized exception, safest not to ignore it
- else:
- break
- # expected: a list of two elements: 1 thread, 1 interp
- self.assertEqual(len(all_awaited_by), 2)
- # expected: a tuple with the thread ID and the awaited_by list
- self.assertEqual(len(all_awaited_by[0]), 2)
- # expected: no tasks in the fallback per-interp task list
- self.assertEqual(all_awaited_by[1], (0, []))
- entries = all_awaited_by[0][1]
- # expected: at least 1000 pending tasks
- self.assertGreaterEqual(len(entries), 1000)
- # the first three tasks stem from the code structure
- main_stack = [
- FrameInfo([taskgroups.__file__, ANY, "TaskGroup._aexit"]),
- FrameInfo(
- [taskgroups.__file__, ANY, "TaskGroup.__aexit__"]
- ),
- FrameInfo([script_name, 60, "main"]),
- ]
- self.assertIn(
- TaskInfo(
- [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []]
- ),
- entries,
- )
- self.assertIn(
- TaskInfo(
- [
- ANY,
- "server task",
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+
+ # Expected: a list of two elements: 1 thread, 1 interp
+ self.assertEqual(len(all_awaited_by), 2)
+ # Expected: a tuple with the thread ID and the awaited_by list
+ self.assertEqual(len(all_awaited_by[0]), 2)
+ # Expected: no tasks in the fallback per-interp task list
+ self.assertEqual(all_awaited_by[1], (0, []))
+
+ entries = all_awaited_by[0][1]
+ # Expected: at least NUM_TASKS pending tasks
+ self.assertGreaterEqual(len(entries), NUM_TASKS)
+
+ # Check the main task structure
+ main_stack = [
+ FrameInfo(
+ [taskgroups.__file__, ANY, "TaskGroup._aexit"]
+ ),
+ FrameInfo(
+ [taskgroups.__file__, ANY, "TaskGroup.__aexit__"]
+ ),
+ FrameInfo([script_name, 52, "main"]),
+ ]
+ self.assertIn(
+ TaskInfo(
+ [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []]
+ ),
+ entries,
+ )
+ self.assertIn(
+ TaskInfo(
[
- CoroInfo(
- [
+ ANY,
+ "server task",
+ [
+ CoroInfo(
[
- FrameInfo(
- [
- base_events.__file__,
- ANY,
- "Server.serve_forever",
- ]
- )
- ],
- ANY,
- ]
- )
- ],
- [
- CoroInfo(
- [
+ [
+ FrameInfo(
+ [
+ base_events.__file__,
+ ANY,
+ "Server.serve_forever",
+ ]
+ )
+ ],
+ ANY,
+ ]
+ )
+ ],
+ [
+ CoroInfo(
[
- FrameInfo(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
- ),
- FrameInfo(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
- ),
- FrameInfo(
- [script_name, ANY, "main"]
- ),
- ],
- ANY,
- ]
- )
- ],
- ]
- ),
- entries,
- )
- self.assertIn(
- TaskInfo(
- [
- ANY,
- "Task-4",
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [script_name, ANY, "main"]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ ]
+ ),
+ entries,
+ )
+ self.assertIn(
+ TaskInfo(
[
- CoroInfo(
- [
+ ANY,
+ "Task-4",
+ [
+ CoroInfo(
[
- FrameInfo(
- [tasks.__file__, ANY, "sleep"]
- ),
- FrameInfo(
- [
- script_name,
- 38,
- "echo_client",
- ]
- ),
- ],
- ANY,
- ]
- )
- ],
- [
- CoroInfo(
- [
+ [
+ FrameInfo(
+ [
+ tasks.__file__,
+ ANY,
+ "sleep",
+ ]
+ ),
+ FrameInfo(
+ [
+ script_name,
+ 36,
+ "echo_client",
+ ]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ [
+ CoroInfo(
[
- FrameInfo(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
- ),
- FrameInfo(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
- ),
- FrameInfo(
- [
- script_name,
- 41,
- "echo_client_spam",
- ]
- ),
- ],
- ANY,
- ]
- )
- ],
- ]
- ),
- entries,
- )
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [
+ script_name,
+ 39,
+ "echo_client_spam",
+ ]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ],
+ ]
+ ),
+ entries,
+ )
- expected_awaited_by = [
- CoroInfo(
- [
+ expected_awaited_by = [
+ CoroInfo(
[
- FrameInfo(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup._aexit",
- ]
- ),
- FrameInfo(
- [
- taskgroups.__file__,
- ANY,
- "TaskGroup.__aexit__",
- ]
- ),
- FrameInfo(
- [script_name, 41, "echo_client_spam"]
- ),
- ],
- ANY,
- ]
- )
- ]
- tasks_with_awaited = [
- task
- for task in entries
- if task.awaited_by == expected_awaited_by
- ]
- self.assertGreaterEqual(len(tasks_with_awaited), 1000)
-
- # the final task will have some random number, but it should for
- # sure be one of the echo client spam horde (In windows this is not true
- # for some reason)
- if sys.platform != "win32":
- self.assertEqual(
- tasks_with_awaited[-1].awaited_by,
- entries[-1].awaited_by,
- )
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
+ [
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup._aexit",
+ ]
+ ),
+ FrameInfo(
+ [
+ taskgroups.__file__,
+ ANY,
+ "TaskGroup.__aexit__",
+ ]
+ ),
+ FrameInfo(
+ [script_name, 39, "echo_client_spam"]
+ ),
+ ],
+ ANY,
+ ]
+ )
+ ]
+ tasks_with_awaited = [
+ task
+ for task in entries
+ if task.awaited_by == expected_awaited_by
+ ]
+ self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS)
+
+ # Final task should be from echo client spam (not on Windows)
+ if sys.platform != "win32":
+ self.assertEqual(
+ tasks_with_awaited[-1].awaited_by,
+ entries[-1].awaited_by,
+ )
finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
+ _cleanup_sockets(client_socket, server_socket)
@skip_if_not_supported
@unittest.skipIf(
)
def test_self_trace(self):
stack_trace = get_stack_trace(os.getpid())
- # Is possible that there are more threads, so we check that the
- # expected stack traces are in the result (looking at you Windows!)
- this_tread_stack = None
- # New format: [InterpreterInfo(interpreter_id, [ThreadInfo(...)])]
+
+ this_thread_stack = None
for interpreter_info in stack_trace:
for thread_info in interpreter_info.threads:
if thread_info.thread_id == threading.get_native_id():
- this_tread_stack = thread_info.frame_info
+ this_thread_stack = thread_info.frame_info
break
- if this_tread_stack:
+ if this_thread_stack:
break
- self.assertIsNotNone(this_tread_stack)
+
+ self.assertIsNotNone(this_thread_stack)
self.assertEqual(
- this_tread_stack[:2],
+ this_thread_stack[:2],
[
FrameInfo(
[
__file__,
- get_stack_trace.__code__.co_firstlineno + 2,
+ get_stack_trace.__code__.co_firstlineno + 4,
"get_stack_trace",
]
),
)
@requires_subinterpreters
def test_subinterpreter_stack_trace(self):
- # Test that subinterpreters are correctly handled
port = find_unused_port()
- # Calculate subinterpreter code separately and pickle it to avoid f-string issues
import pickle
- subinterp_code = textwrap.dedent(f'''
+
+ subinterp_code = textwrap.dedent(f"""
import socket
import time
nested_func()
sub_worker()
- ''').strip()
+ """).strip()
- # Pickle the subinterpreter code
pickled_code = pickle.dumps(subinterp_code)
script = textwrap.dedent(
import socket
import threading
- # Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def main_worker():
- # Function running in main interpreter
sock.sendall(b"ready:main\\n")
time.sleep(10_000)
def run_subinterp():
- # Create and run subinterpreter
subinterp = interpreters.create()
-
import pickle
pickled_code = {pickled_code!r}
subinterp_code = pickle.loads(pickled_code)
subinterp.exec(subinterp_code)
- # Start subinterpreter in thread
sub_thread = threading.Thread(target=run_subinterp)
sub_thread.start()
- # Start main thread work
main_thread = threading.Thread(target=main_worker)
main_thread.start()
- # Keep main thread alive
main_thread.join()
sub_thread.join()
"""
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
-
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_sockets = []
- try:
- p = subprocess.Popen([sys.executable, script_name])
- # Accept connections from both main and subinterpreter
- responses = set()
- while len(responses) < 2: # Wait for both "ready:main" and "ready:sub"
- try:
- client_socket, _ = server_socket.accept()
- client_sockets.append(client_socket)
+ try:
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ # Accept connections from both main and subinterpreter
+ responses = set()
+ while len(responses) < 2:
+ try:
+ client_socket, _ = server_socket.accept()
+ client_sockets.append(client_socket)
+ response = client_socket.recv(1024)
+ if b"ready:main" in response:
+ responses.add("main")
+ if b"ready:sub" in response:
+ responses.add("sub")
+ except socket.timeout:
+ break
- # Read the response from this connection
- response = client_socket.recv(1024)
- if b"ready:main" in response:
- responses.add("main")
- if b"ready:sub" in response:
- responses.add("sub")
- except socket.timeout:
- break
+ server_socket.close()
+ server_socket = None
- server_socket.close()
- stack_trace = get_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- for client_socket in client_sockets:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
+ try:
+ stack_trace = get_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
- # Verify we have multiple interpreters
- self.assertGreaterEqual(len(stack_trace), 1, "Should have at least one interpreter")
+ # Verify we have at least one interpreter
+ self.assertGreaterEqual(len(stack_trace), 1)
- # Look for main interpreter (ID 0) and subinterpreter (ID > 0)
- main_interp = None
- sub_interp = None
+ # Look for main interpreter (ID 0) and subinterpreter (ID > 0)
+ main_interp = None
+ sub_interp = None
+ for interpreter_info in stack_trace:
+ if interpreter_info.interpreter_id == 0:
+ main_interp = interpreter_info
+ elif interpreter_info.interpreter_id > 0:
+ sub_interp = interpreter_info
- for interpreter_info in stack_trace:
- if interpreter_info.interpreter_id == 0:
- main_interp = interpreter_info
- elif interpreter_info.interpreter_id > 0:
- sub_interp = interpreter_info
+ self.assertIsNotNone(
+ main_interp, "Main interpreter should be present"
+ )
- self.assertIsNotNone(main_interp, "Main interpreter should be present")
+ # Check main interpreter has expected stack trace
+ main_found = self._find_frame_in_trace(
+ [main_interp], lambda f: f.funcname == "main_worker"
+ )
+ self.assertIsNotNone(
+ main_found,
+ "Main interpreter should have main_worker in stack",
+ )
- # Check main interpreter has expected stack trace
- main_found = False
- for thread_info in main_interp.threads:
- for frame in thread_info.frame_info:
- if frame.funcname == "main_worker":
- main_found = True
- break
- if main_found:
- break
- self.assertTrue(main_found, "Main interpreter should have main_worker in stack")
-
- # If subinterpreter is present, check its stack trace
- if sub_interp:
- sub_found = False
- for thread_info in sub_interp.threads:
- for frame in thread_info.frame_info:
- if frame.funcname in ("sub_worker", "nested_func"):
- sub_found = True
- break
- if sub_found:
- break
- self.assertTrue(sub_found, "Subinterpreter should have sub_worker or nested_func in stack")
+ # If subinterpreter is present, check its stack trace
+ if sub_interp:
+ sub_found = self._find_frame_in_trace(
+ [sub_interp],
+ lambda f: f.funcname
+ in ("sub_worker", "nested_func"),
+ )
+ self.assertIsNotNone(
+ sub_found,
+ "Subinterpreter should have sub_worker or nested_func in stack",
+ )
+ finally:
+ _cleanup_sockets(*client_sockets, server_socket)
@skip_if_not_supported
@unittest.skipIf(
)
@requires_subinterpreters
def test_multiple_subinterpreters_with_threads(self):
- # Test multiple subinterpreters, each with multiple threads
port = find_unused_port()
- # Calculate subinterpreter codes separately and pickle them
import pickle
- # Code for first subinterpreter with 2 threads
- subinterp1_code = textwrap.dedent(f'''
+ subinterp1_code = textwrap.dedent(f"""
import socket
import time
import threading
t2.start()
t1.join()
t2.join()
- ''').strip()
+ """).strip()
- # Code for second subinterpreter with 2 threads
- subinterp2_code = textwrap.dedent(f'''
+ subinterp2_code = textwrap.dedent(f"""
import socket
import time
import threading
t2.start()
t1.join()
t2.join()
- ''').strip()
+ """).strip()
- # Pickle the subinterpreter codes
pickled_code1 = pickle.dumps(subinterp1_code)
pickled_code2 = pickle.dumps(subinterp2_code)
import socket
import threading
- # Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def main_worker():
- # Function running in main interpreter
sock.sendall(b"ready:main\\n")
time.sleep(10_000)
def run_subinterp1():
- # Create and run first subinterpreter
subinterp = interpreters.create()
-
import pickle
pickled_code = {pickled_code1!r}
subinterp_code = pickle.loads(pickled_code)
subinterp.exec(subinterp_code)
def run_subinterp2():
- # Create and run second subinterpreter
subinterp = interpreters.create()
-
import pickle
pickled_code = {pickled_code2!r}
subinterp_code = pickle.loads(pickled_code)
subinterp.exec(subinterp_code)
- # Start subinterpreters in threads
sub1_thread = threading.Thread(target=run_subinterp1)
sub2_thread = threading.Thread(target=run_subinterp2)
sub1_thread.start()
sub2_thread.start()
- # Start main thread work
main_thread = threading.Thread(target=main_worker)
main_thread.start()
- # Keep main thread alive
main_thread.join()
sub1_thread.join()
sub2_thread.join()
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(5) # Allow multiple connections
-
+ server_socket = _create_server_socket(port, backlog=5)
script_name = _make_test_script(script_dir, "script", script)
client_sockets = []
+
try:
- p = subprocess.Popen([sys.executable, script_name])
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ # Accept connections from main and all subinterpreter threads
+ expected_responses = {
+ "ready:main",
+ "ready:sub1-t1",
+ "ready:sub1-t2",
+ "ready:sub2-t1",
+ "ready:sub2-t2",
+ }
+ responses = set()
+
+ while len(responses) < 5:
+ try:
+ client_socket, _ = server_socket.accept()
+ client_sockets.append(client_socket)
+ response = client_socket.recv(1024)
+ response_str = response.decode().strip()
+ if response_str in expected_responses:
+ responses.add(response_str)
+ except socket.timeout:
+ break
- # Accept connections from main and all subinterpreter threads
- expected_responses = {"ready:main", "ready:sub1-t1", "ready:sub1-t2", "ready:sub2-t1", "ready:sub2-t2"}
- responses = set()
+ server_socket.close()
+ server_socket = None
- while len(responses) < 5: # Wait for all 5 ready signals
try:
- client_socket, _ = server_socket.accept()
- client_sockets.append(client_socket)
-
- # Read the response from this connection
- response = client_socket.recv(1024)
- response_str = response.decode().strip()
- if response_str in expected_responses:
- responses.add(response_str)
- except socket.timeout:
- break
-
- server_socket.close()
- stack_trace = get_stack_trace(p.pid)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- for client_socket in client_sockets:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
-
- # Verify we have multiple interpreters
- self.assertGreaterEqual(len(stack_trace), 2, "Should have at least two interpreters")
-
- # Count interpreters by ID
- interpreter_ids = {interp.interpreter_id for interp in stack_trace}
- self.assertIn(0, interpreter_ids, "Main interpreter should be present")
- self.assertGreaterEqual(len(interpreter_ids), 3, "Should have main + at least 2 subinterpreters")
+ stack_trace = get_stack_trace(p.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
- # Count total threads across all interpreters
- total_threads = sum(len(interp.threads) for interp in stack_trace)
- self.assertGreaterEqual(total_threads, 5, "Should have at least 5 threads total")
+ # Verify we have multiple interpreters
+ self.assertGreaterEqual(len(stack_trace), 2)
+
+ # Count interpreters by ID
+ interpreter_ids = {
+ interp.interpreter_id for interp in stack_trace
+ }
+ self.assertIn(
+ 0,
+ interpreter_ids,
+ "Main interpreter should be present",
+ )
+ self.assertGreaterEqual(len(interpreter_ids), 3)
- # Look for expected function names in stack traces
- all_funcnames = set()
- for interpreter_info in stack_trace:
- for thread_info in interpreter_info.threads:
- for frame in thread_info.frame_info:
- all_funcnames.add(frame.funcname)
+ # Count total threads
+ total_threads = sum(
+ len(interp.threads) for interp in stack_trace
+ )
+ self.assertGreaterEqual(total_threads, 5)
- # Should find functions from different interpreters and threads
- expected_funcs = {"main_worker", "worker1", "worker2", "nested_func"}
- found_funcs = expected_funcs.intersection(all_funcnames)
- self.assertGreater(len(found_funcs), 0, f"Should find some expected functions, got: {all_funcnames}")
+ # Look for expected function names
+ all_funcnames = set()
+ for interpreter_info in stack_trace:
+ for thread_info in interpreter_info.threads:
+ for frame in thread_info.frame_info:
+ all_funcnames.add(frame.funcname)
+
+ expected_funcs = {
+ "main_worker",
+ "worker1",
+ "worker2",
+ "nested_func",
+ }
+ found_funcs = expected_funcs.intersection(all_funcnames)
+ self.assertGreater(len(found_funcs), 0)
+ finally:
+ _cleanup_sockets(*client_sockets, server_socket)
@skip_if_not_supported
@unittest.skipIf(
)
@requires_gil_enabled("Free threaded builds don't have an 'active thread'")
def test_only_active_thread(self):
- # Test that only_active_thread parameter works correctly
port = find_unused_port()
script = textwrap.dedent(
f"""\
import time, sys, socket, threading
- # Connect to the test process
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
def worker_thread(name, barrier, ready_event):
- barrier.wait() # Synchronize thread start
- ready_event.wait() # Wait for main thread signal
- # Sleep to keep thread alive
+ barrier.wait()
+ ready_event.wait()
time.sleep(10_000)
def main_work():
- # Do busy work to hold the GIL
sock.sendall(b"working\\n")
count = 0
while count < 100000000:
count += 1
if count % 10000000 == 0:
- pass # Keep main thread busy
+ pass
sock.sendall(b"done\\n")
- # Create synchronization primitives
num_threads = 3
- barrier = threading.Barrier(num_threads + 1) # +1 for main thread
+ barrier = threading.Barrier(num_threads + 1)
ready_event = threading.Event()
- # Start worker threads
threads = []
for i in range(num_threads):
t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event))
t.start()
threads.append(t)
- # Wait for all threads to be ready
barrier.wait()
-
- # Signal ready to parent process
sock.sendall(b"ready\\n")
-
- # Signal threads to start waiting
ready_event.set()
-
- # Now do busy work to hold the GIL
main_work()
"""
)
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- # Create a socket server to communicate with the target process
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
-
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
- try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
-
- # Wait for ready signal
- response = b""
- while b"ready" not in response:
- response += client_socket.recv(1024)
-
- # Wait for the main thread to start its busy work
- while b"working" not in response:
- response += client_socket.recv(1024)
-
- # Get stack trace with all threads
- unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
- for _ in range(10):
- # Wait for the main thread to start its busy work
- all_traces = unwinder_all.get_stack_trace()
- found = False
- # New format: [InterpreterInfo(interpreter_id, [ThreadInfo(...)])]
- for interpreter_info in all_traces:
- for thread_info in interpreter_info.threads:
- if not thread_info.frame_info:
- continue
- current_frame = thread_info.frame_info[0]
- if (
- current_frame.funcname == "main_work"
- and current_frame.lineno > 15
- ):
- found = True
- break
- if found:
- break
- if found:
- break
- # Give a bit of time to take the next sample
- time.sleep(0.1)
- else:
- self.fail(
- "Main thread did not start its busy work on time"
- )
+ try:
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
- # Get stack trace with only GIL holder
- unwinder_gil = RemoteUnwinder(p.pid, only_active_thread=True)
- gil_traces = unwinder_gil.get_stack_trace()
+ # Wait for ready and working signals
+ _wait_for_signal(client_socket, [b"ready", b"working"])
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
- finally:
- if client_socket is not None:
- client_socket.close()
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
+ try:
+ # Get stack trace with all threads
+ unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
+ for _ in range(MAX_TRIES):
+ all_traces = unwinder_all.get_stack_trace()
+ found = self._find_frame_in_trace(
+ all_traces,
+ lambda f: f.funcname == "main_work"
+ and f.lineno > 12,
+ )
+ if found:
+ break
+ time.sleep(0.1)
+ else:
+ self.fail(
+ "Main thread did not start its busy work on time"
+ )
- # Count total threads across all interpreters in all_traces
- total_threads = sum(len(interpreter_info.threads) for interpreter_info in all_traces)
- self.assertGreater(
- total_threads, 1, "Should have multiple threads"
- )
+ # Get stack trace with only GIL holder
+ unwinder_gil = RemoteUnwinder(
+ p.pid, only_active_thread=True
+ )
+ gil_traces = unwinder_gil.get_stack_trace()
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
- # Count total threads across all interpreters in gil_traces
- total_gil_threads = sum(len(interpreter_info.threads) for interpreter_info in gil_traces)
- self.assertEqual(
- total_gil_threads, 1, "Should have exactly one GIL holder"
- )
+ # Count threads
+ total_threads = sum(
+ len(interp.threads) for interp in all_traces
+ )
+ self.assertGreater(total_threads, 1)
- # Get the GIL holder thread ID
- gil_thread_id = None
- for interpreter_info in gil_traces:
- if interpreter_info.threads:
- gil_thread_id = interpreter_info.threads[0].thread_id
- break
+ total_gil_threads = sum(
+ len(interp.threads) for interp in gil_traces
+ )
+ self.assertEqual(total_gil_threads, 1)
+
+ # Get the GIL holder thread ID
+ gil_thread_id = None
+ for interpreter_info in gil_traces:
+ if interpreter_info.threads:
+ gil_thread_id = interpreter_info.threads[
+ 0
+ ].thread_id
+ break
- # Get all thread IDs from all_traces
- all_thread_ids = []
- for interpreter_info in all_traces:
- for thread_info in interpreter_info.threads:
- all_thread_ids.append(thread_info.thread_id)
+ # Get all thread IDs
+ all_thread_ids = []
+ for interpreter_info in all_traces:
+ for thread_info in interpreter_info.threads:
+ all_thread_ids.append(thread_info.thread_id)
- self.assertIn(
- gil_thread_id,
- all_thread_ids,
- "GIL holder should be among all threads",
- )
+ self.assertIn(gil_thread_id, all_thread_ids)
+ finally:
+ _cleanup_sockets(client_socket, server_socket)
class TestUnsupportedPlatformHandling(unittest.TestCase):
sys.platform in ("linux", "darwin", "win32"),
"Test only runs on unsupported platforms (not Linux, macOS, or Windows)",
)
- @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception")
+ @unittest.skipIf(
+ sys.platform == "android", "Android raises Linux-specific exception"
+ )
def test_unsupported_platform_error(self):
with self.assertRaises(RuntimeError) as cm:
RemoteUnwinder(os.getpid())
self.assertIn(
"Reading the PyRuntime section is not supported on this platform",
- str(cm.exception)
+ str(cm.exception),
)
-class TestDetectionOfThreadStatus(unittest.TestCase):
- @unittest.skipIf(
- sys.platform not in ("linux", "darwin", "win32"),
- "Test only runs on unsupported platforms (not Linux, macOS, or Windows)",
- )
- @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception")
- def test_thread_status_detection(self):
+
+class TestDetectionOfThreadStatus(RemoteInspectionTestBase):
+ def _run_thread_status_test(self, mode, check_condition):
+ """
+ Common pattern for thread status detection tests.
+
+ Args:
+ mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.)
+ check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool
+ """
port = find_unused_port()
script = textwrap.dedent(
f"""\
sock.close()
"""
)
+
with os_helper.temp_dir() as work_dir:
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
- script_name = _make_test_script(script_dir, "thread_status_script", script)
+ server_socket = _create_server_socket(port)
+ script_name = _make_test_script(
+ script_dir, "thread_status_script", script
+ )
client_socket = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = b""
- sleeper_tid = None
- busy_tid = None
- while True:
- chunk = client_socket.recv(1024)
- response += chunk
- if b"ready:main" in response and b"ready:sleeper" in response and b"ready:busy" in response:
- # Parse TIDs from the response
- for line in response.split(b"\n"):
- if line.startswith(b"ready:sleeper:"):
- try:
- sleeper_tid = int(line.split(b":")[-1])
- except Exception:
- pass
- elif line.startswith(b"ready:busy:"):
- try:
- busy_tid = int(line.split(b":")[-1])
- except Exception:
- pass
- break
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
- attempts = 10
- statuses = {}
- try:
- unwinder = RemoteUnwinder(p.pid, all_threads=True, mode=PROFILING_MODE_CPU,
- skip_non_matching_threads=False)
- for _ in range(attempts):
- traces = unwinder.get_stack_trace()
- # Find threads and their statuses
- statuses = {}
- for interpreter_info in traces:
- for thread_info in interpreter_info.threads:
- statuses[thread_info.thread_id] = thread_info.status
-
- # Check if sleeper thread is off CPU and busy thread is on CPU
- # In the new flags system:
- # - sleeper should NOT have ON_CPU flag (off CPU)
- # - busy should have ON_CPU flag
- if (sleeper_tid in statuses and
- busy_tid in statuses and
- not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) and
- (statuses[busy_tid] & THREAD_STATUS_ON_CPU)):
- break
- time.sleep(0.5) # Give a bit of time to let threads settle
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
+ # Wait for all ready signals and parse TIDs
+ response = _wait_for_signal(
+ client_socket,
+ [b"ready:main", b"ready:sleeper", b"ready:busy"],
+ )
+
+ sleeper_tid = None
+ busy_tid = None
+ for line in response.split(b"\n"):
+ if line.startswith(b"ready:sleeper:"):
+ try:
+ sleeper_tid = int(line.split(b":")[-1])
+ except (ValueError, IndexError):
+ pass
+ elif line.startswith(b"ready:busy:"):
+ try:
+ busy_tid = int(line.split(b":")[-1])
+ except (ValueError, IndexError):
+ pass
+
+ self.assertIsNotNone(
+ sleeper_tid, "Sleeper thread id not received"
)
+ self.assertIsNotNone(
+ busy_tid, "Busy thread id not received"
+ )
+
+ # Sample until we see expected thread states
+ statuses = {}
+ try:
+ unwinder = RemoteUnwinder(
+ p.pid,
+ all_threads=True,
+ mode=mode,
+ skip_non_matching_threads=False,
+ )
+ for _ in range(MAX_TRIES):
+ traces = unwinder.get_stack_trace()
+ statuses = self._get_thread_statuses(traces)
- self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received")
- self.assertIsNotNone(busy_tid, "Busy thread id not received")
- self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads")
- self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads")
- self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper thread should be off CPU")
- self.assertTrue(statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy thread should be on CPU")
+ if check_condition(
+ statuses, sleeper_tid, busy_tid
+ ):
+ break
+ time.sleep(0.5)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+ return statuses, sleeper_tid, busy_tid
finally:
- if client_socket is not None:
- client_socket.close()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
+ _cleanup_sockets(client_socket, server_socket)
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
- "Test only runs on unsupported platforms (not Linux, macOS, or Windows)",
+ "Test only runs on supported platforms (Linux, macOS, or Windows)",
)
- @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception")
- def test_thread_status_gil_detection(self):
- port = find_unused_port()
- script = textwrap.dedent(
- f"""\
- import time, sys, socket, threading
- import os
-
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.connect(('localhost', {port}))
-
- def sleeper():
- tid = threading.get_native_id()
- sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode())
- time.sleep(10000)
-
- def busy():
- tid = threading.get_native_id()
- sock.sendall(f'ready:busy:{{tid}}\\n'.encode())
- x = 0
- while True:
- x = x + 1
- time.sleep(0.5)
+ @unittest.skipIf(
+ sys.platform == "android", "Android raises Linux-specific exception"
+ )
+ def test_thread_status_detection(self):
+ def check_cpu_status(statuses, sleeper_tid, busy_tid):
+ return (
+ sleeper_tid in statuses
+ and busy_tid in statuses
+ and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU)
+ and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
+ )
- t1 = threading.Thread(target=sleeper)
- t2 = threading.Thread(target=busy)
- t1.start()
- t2.start()
- sock.sendall(b'ready:main\\n')
- t1.join()
- t2.join()
- sock.close()
- """
+ statuses, sleeper_tid, busy_tid = self._run_thread_status_test(
+ PROFILING_MODE_CPU, check_cpu_status
)
- with os_helper.temp_dir() as work_dir:
- script_dir = os.path.join(work_dir, "script_pkg")
- os.mkdir(script_dir)
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
- script_name = _make_test_script(script_dir, "thread_status_script", script)
- client_socket = None
- try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
- response = b""
- sleeper_tid = None
- busy_tid = None
- while True:
- chunk = client_socket.recv(1024)
- response += chunk
- if b"ready:main" in response and b"ready:sleeper" in response and b"ready:busy" in response:
- # Parse TIDs from the response
- for line in response.split(b"\n"):
- if line.startswith(b"ready:sleeper:"):
- try:
- sleeper_tid = int(line.split(b":")[-1])
- except Exception:
- pass
- elif line.startswith(b"ready:busy:"):
- try:
- busy_tid = int(line.split(b":")[-1])
- except Exception:
- pass
- break
+ self.assertIn(sleeper_tid, statuses)
+ self.assertIn(busy_tid, statuses)
+ self.assertFalse(
+ statuses[sleeper_tid] & THREAD_STATUS_ON_CPU,
+ "Sleeper thread should be off CPU",
+ )
+ self.assertTrue(
+ statuses[busy_tid] & THREAD_STATUS_ON_CPU,
+ "Busy thread should be on CPU",
+ )
- attempts = 10
- statuses = {}
- try:
- unwinder = RemoteUnwinder(p.pid, all_threads=True, mode=PROFILING_MODE_GIL,
- skip_non_matching_threads=False)
- for _ in range(attempts):
- traces = unwinder.get_stack_trace()
- # Find threads and their statuses
- statuses = {}
- for interpreter_info in traces:
- for thread_info in interpreter_info.threads:
- statuses[thread_info.thread_id] = thread_info.status
-
- # Check if sleeper thread doesn't have GIL and busy thread has GIL
- # In the new flags system:
- # - sleeper should NOT have HAS_GIL flag (waiting for GIL)
- # - busy should have HAS_GIL flag
- if (sleeper_tid in statuses and
- busy_tid in statuses and
- not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) and
- (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)):
- break
- time.sleep(0.5) # Give a bit of time to let threads settle
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
+ @unittest.skipIf(
+ sys.platform not in ("linux", "darwin", "win32"),
+ "Test only runs on supported platforms (Linux, macOS, or Windows)",
+ )
+ @unittest.skipIf(
+ sys.platform == "android", "Android raises Linux-specific exception"
+ )
+ def test_thread_status_gil_detection(self):
+ def check_gil_status(statuses, sleeper_tid, busy_tid):
+ return (
+ sleeper_tid in statuses
+ and busy_tid in statuses
+ and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL)
+ and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)
+ )
- self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received")
- self.assertIsNotNone(busy_tid, "Busy thread id not received")
- self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads")
- self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads")
- self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper thread should not have GIL")
- self.assertTrue(statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy thread should have GIL")
+ statuses, sleeper_tid, busy_tid = self._run_thread_status_test(
+ PROFILING_MODE_GIL, check_gil_status
+ )
- finally:
- if client_socket is not None:
- client_socket.close()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
+ self.assertIn(sleeper_tid, statuses)
+ self.assertIn(busy_tid, statuses)
+ self.assertFalse(
+ statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL,
+ "Sleeper thread should not have GIL",
+ )
+ self.assertTrue(
+ statuses[busy_tid] & THREAD_STATUS_HAS_GIL,
+ "Busy thread should have GIL",
+ )
@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
"Test only runs on supported platforms (Linux, macOS, or Windows)",
)
- @unittest.skipIf(sys.platform == "android", "Android raises Linux-specific exception")
+ @unittest.skipIf(
+ sys.platform == "android", "Android raises Linux-specific exception"
+ )
def test_thread_status_all_mode_detection(self):
port = find_unused_port()
script = textwrap.dedent(
with os_helper.temp_dir() as tmp_dir:
script_file = make_script(tmp_dir, "script", script)
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.listen(2)
- server_socket.settimeout(SHORT_TIMEOUT)
-
- p = subprocess.Popen(
- [sys.executable, script_file],
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
-
+ server_socket = _create_server_socket(port, backlog=2)
client_sockets = []
- try:
- sleeper_tid = None
- busy_tid = None
- # Receive thread IDs from the child process
- for _ in range(2):
- client_socket, _ = server_socket.accept()
- client_sockets.append(client_socket)
- line = client_socket.recv(1024)
- if line:
- if line.startswith(b"sleeper:"):
- try:
- sleeper_tid = int(line.split(b":")[-1])
- except Exception:
- pass
- elif line.startswith(b"busy:"):
- try:
- busy_tid = int(line.split(b":")[-1])
- except Exception:
- pass
+ try:
+ with _managed_subprocess(
+ [sys.executable, script_file],
+ ) as p:
+ sleeper_tid = None
+ busy_tid = None
+
+ # Receive thread IDs from the child process
+ for _ in range(2):
+ client_socket, _ = server_socket.accept()
+ client_sockets.append(client_socket)
+ line = client_socket.recv(1024)
+ if line:
+ if line.startswith(b"sleeper:"):
+ try:
+ sleeper_tid = int(line.split(b":")[-1])
+ except (ValueError, IndexError):
+ pass
+ elif line.startswith(b"busy:"):
+ try:
+ busy_tid = int(line.split(b":")[-1])
+ except (ValueError, IndexError):
+ pass
- server_socket.close()
+ server_socket.close()
+ server_socket = None
- attempts = 10
- statuses = {}
- try:
- unwinder = RemoteUnwinder(p.pid, all_threads=True, mode=PROFILING_MODE_ALL,
- skip_non_matching_threads=False)
- for _ in range(attempts):
- traces = unwinder.get_stack_trace()
- # Find threads and their statuses
- statuses = {}
- for interpreter_info in traces:
- for thread_info in interpreter_info.threads:
- statuses[thread_info.thread_id] = thread_info.status
-
- # Check ALL mode provides both GIL and CPU info
- # - sleeper should NOT have ON_CPU and NOT have HAS_GIL
- # - busy should have ON_CPU and have HAS_GIL
- if (sleeper_tid in statuses and
- busy_tid in statuses and
- not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) and
- not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) and
- (statuses[busy_tid] & THREAD_STATUS_ON_CPU) and
- (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)):
- break
- time.sleep(0.5)
- except PermissionError:
- self.skipTest(
- "Insufficient permissions to read the stack trace"
- )
+ statuses = {}
+ try:
+ unwinder = RemoteUnwinder(
+ p.pid,
+ all_threads=True,
+ mode=PROFILING_MODE_ALL,
+ skip_non_matching_threads=False,
+ )
+ for _ in range(MAX_TRIES):
+ traces = unwinder.get_stack_trace()
+ statuses = self._get_thread_statuses(traces)
- self.assertIsNotNone(sleeper_tid, "Sleeper thread id not received")
- self.assertIsNotNone(busy_tid, "Busy thread id not received")
- self.assertIn(sleeper_tid, statuses, "Sleeper tid not found in sampled threads")
- self.assertIn(busy_tid, statuses, "Busy tid not found in sampled threads")
+ # Check ALL mode provides both GIL and CPU info
+ if (
+ sleeper_tid in statuses
+ and busy_tid in statuses
+ and not (
+ statuses[sleeper_tid]
+ & THREAD_STATUS_ON_CPU
+ )
+ and not (
+ statuses[sleeper_tid]
+ & THREAD_STATUS_HAS_GIL
+ )
+ and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
+ and (
+ statuses[busy_tid] & THREAD_STATUS_HAS_GIL
+ )
+ ):
+ break
+ time.sleep(0.5)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
- # Sleeper thread: off CPU, no GIL
- self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, "Sleeper should be off CPU")
- self.assertFalse(statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, "Sleeper should not have GIL")
+ self.assertIsNotNone(
+ sleeper_tid, "Sleeper thread id not received"
+ )
+ self.assertIsNotNone(
+ busy_tid, "Busy thread id not received"
+ )
+ self.assertIn(sleeper_tid, statuses)
+ self.assertIn(busy_tid, statuses)
- # Busy thread: on CPU, has GIL
- self.assertTrue(statuses[busy_tid] & THREAD_STATUS_ON_CPU, "Busy should be on CPU")
- self.assertTrue(statuses[busy_tid] & THREAD_STATUS_HAS_GIL, "Busy should have GIL")
+ # Sleeper: off CPU, no GIL
+ self.assertFalse(
+ statuses[sleeper_tid] & THREAD_STATUS_ON_CPU,
+ "Sleeper should be off CPU",
+ )
+ self.assertFalse(
+ statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL,
+ "Sleeper should not have GIL",
+ )
+ # Busy: on CPU, has GIL
+ self.assertTrue(
+ statuses[busy_tid] & THREAD_STATUS_ON_CPU,
+ "Busy should be on CPU",
+ )
+ self.assertTrue(
+ statuses[busy_tid] & THREAD_STATUS_HAS_GIL,
+ "Busy should have GIL",
+ )
finally:
- for client_socket in client_sockets:
- client_socket.close()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
- p.stdout.close()
- p.stderr.close()
+ _cleanup_sockets(*client_sockets, server_socket)
-class TestFrameCaching(unittest.TestCase):
+class TestFrameCaching(RemoteInspectionTestBase):
"""Test that frame caching produces correct results.
Uses socket-based synchronization for deterministic testing.
All tests verify cache reuse via object identity checks (assertIs).
"""
- maxDiff = None
- MAX_TRIES = 10
-
- @contextlib.contextmanager
+ @contextmanager
def _target_process(self, script_body):
"""Context manager for running a target process with socket sync."""
port = find_unused_port()
script_dir = os.path.join(work_dir, "script_pkg")
os.mkdir(script_dir)
- server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server_socket.bind(("localhost", port))
- server_socket.settimeout(SHORT_TIMEOUT)
- server_socket.listen(1)
-
+ server_socket = _create_server_socket(port)
script_name = _make_test_script(script_dir, "script", script)
client_socket = None
- p = None
+
try:
- p = subprocess.Popen([sys.executable, script_name])
- client_socket, _ = server_socket.accept()
- server_socket.close()
+ with _managed_subprocess([sys.executable, script_name]) as p:
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ server_socket = None
- def make_unwinder(cache_frames=True):
- return RemoteUnwinder(p.pid, all_threads=True, cache_frames=cache_frames)
+ def make_unwinder(cache_frames=True):
+ return RemoteUnwinder(
+ p.pid, all_threads=True, cache_frames=cache_frames
+ )
- yield p, client_socket, make_unwinder
+ yield p, client_socket, make_unwinder
except PermissionError:
- self.skipTest("Insufficient permissions to read the stack trace")
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
finally:
- if client_socket:
- client_socket.close()
- if p:
- p.kill()
- p.terminate()
- p.wait(timeout=SHORT_TIMEOUT)
-
- def _wait_for_signal(self, client_socket, signal):
- """Block until signal received from target."""
- response = b""
- while signal not in response:
- chunk = client_socket.recv(64)
- if not chunk:
- break
- response += chunk
- return response
-
- def _get_frames(self, unwinder, required_funcs):
- """Sample and return frame_info list for thread containing required_funcs."""
- traces = unwinder.get_stack_trace()
- for interp in traces:
- for thread in interp.threads:
- funcs = [f.funcname for f in thread.frame_info]
- if required_funcs.issubset(set(funcs)):
- return thread.frame_info
+ _cleanup_sockets(client_socket, server_socket)
+
+ def _get_frames_with_retry(self, unwinder, required_funcs):
+ """Get frames containing required_funcs, with retry for transient errors."""
+ for _ in range(MAX_TRIES):
+ try:
+ traces = unwinder.get_stack_trace()
+ for interp in traces:
+ for thread in interp.threads:
+ funcs = {f.funcname for f in thread.frame_info}
+ if required_funcs.issubset(funcs):
+ return thread.frame_info
+ except RuntimeError as e:
+ if _is_retriable_error(e):
+ pass
+ else:
+ raise
+ time.sleep(0.1)
return None
- def _sample_frames(self, client_socket, unwinder, wait_signal, send_ack, required_funcs, expected_frames=1):
- """Wait for signal, sample frames, send ack. Returns frame_info list."""
- self._wait_for_signal(client_socket, wait_signal)
- # Give at least MAX_TRIES tries for the process to arrive to a steady state
- for _ in range(self.MAX_TRIES):
- frames = self._get_frames(unwinder, required_funcs)
+ def _sample_frames(
+ self,
+ client_socket,
+ unwinder,
+ wait_signal,
+ send_ack,
+ required_funcs,
+ expected_frames=1,
+ ):
+ """Wait for signal, sample frames with retry until required funcs present, send ack."""
+ _wait_for_signal(client_socket, wait_signal)
+ frames = None
+ for _ in range(MAX_TRIES):
+ frames = self._get_frames_with_retry(unwinder, required_funcs)
if frames and len(frames) >= expected_frames:
break
time.sleep(0.1)
level1()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
expected = {"level1", "level2", "level3"}
- frames1 = self._sample_frames(client_socket, unwinder, b"sync1", b"ack", expected)
- frames2 = self._sample_frames(client_socket, unwinder, b"sync2", b"ack", expected)
- frames3 = self._sample_frames(client_socket, unwinder, b"sync3", b"done", expected)
+ frames1 = self._sample_frames(
+ client_socket, unwinder, b"sync1", b"ack", expected
+ )
+ frames2 = self._sample_frames(
+ client_socket, unwinder, b"sync2", b"ack", expected
+ )
+ frames3 = self._sample_frames(
+ client_socket, unwinder, b"sync3", b"done", expected
+ )
self.assertIsNotNone(frames1)
self.assertIsNotNone(frames2)
# Parent frames (index 1+) must be identical objects (cache reuse)
for i in range(1, len(frames1)):
f1, f2, f3 = frames1[i], frames2[i], frames3[i]
- self.assertIs(f1, f2, f"Frame {i}: samples 1-2 must be same object")
- self.assertIs(f2, f3, f"Frame {i}: samples 2-3 must be same object")
+ self.assertIs(
+ f1, f2, f"Frame {i}: samples 1-2 must be same object"
+ )
+ self.assertIs(
+ f2, f3, f"Frame {i}: samples 2-3 must be same object"
+ )
@skip_if_not_supported
@unittest.skipIf(
outer()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
- frames_a = self._sample_frames(client_socket, unwinder, b"line_a", b"ack", {"inner"})
- frames_b = self._sample_frames(client_socket, unwinder, b"line_b", b"ack", {"inner"})
- frames_c = self._sample_frames(client_socket, unwinder, b"line_c", b"ack", {"inner"})
- frames_d = self._sample_frames(client_socket, unwinder, b"line_d", b"done", {"inner"})
+ frames_a = self._sample_frames(
+ client_socket, unwinder, b"line_a", b"ack", {"inner"}
+ )
+ frames_b = self._sample_frames(
+ client_socket, unwinder, b"line_b", b"ack", {"inner"}
+ )
+ frames_c = self._sample_frames(
+ client_socket, unwinder, b"line_c", b"ack", {"inner"}
+ )
+ frames_d = self._sample_frames(
+ client_socket, unwinder, b"line_d", b"done", {"inner"}
+ )
self.assertIsNotNone(frames_a)
self.assertIsNotNone(frames_b)
self.assertEqual(inner_d.funcname, "inner")
# Line numbers must be different and increasing (execution moves forward)
- self.assertLess(inner_a.lineno, inner_b.lineno,
- "Line B should be after line A")
- self.assertLess(inner_b.lineno, inner_c.lineno,
- "Line C should be after line B")
- self.assertLess(inner_c.lineno, inner_d.lineno,
- "Line D should be after line C")
+ self.assertLess(
+ inner_a.lineno, inner_b.lineno, "Line B should be after line A"
+ )
+ self.assertLess(
+ inner_b.lineno, inner_c.lineno, "Line C should be after line B"
+ )
+ self.assertLess(
+ inner_c.lineno, inner_d.lineno, "Line D should be after line C"
+ )
@skip_if_not_supported
@unittest.skipIf(
outer()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
frames_deep = self._sample_frames(
- client_socket, unwinder, b"at_inner", b"ack", {"inner", "outer"})
+ client_socket,
+ unwinder,
+ b"at_inner",
+ b"ack",
+ {"inner", "outer"},
+ )
frames_shallow = self._sample_frames(
- client_socket, unwinder, b"at_outer", b"done", {"outer"})
+ client_socket, unwinder, b"at_outer", b"done", {"outer"}
+ )
self.assertIsNotNone(frames_deep)
self.assertIsNotNone(frames_shallow)
top()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
frames_before = self._sample_frames(
- client_socket, unwinder, b"at_middle", b"ack", {"middle", "top"})
+ client_socket,
+ unwinder,
+ b"at_middle",
+ b"ack",
+ {"middle", "top"},
+ )
frames_after = self._sample_frames(
- client_socket, unwinder, b"at_deeper", b"done", {"deeper", "middle", "top"})
+ client_socket,
+ unwinder,
+ b"at_deeper",
+ b"done",
+ {"deeper", "middle", "top"},
+ )
self.assertIsNotNone(frames_before)
self.assertIsNotNone(frames_after)
func_a()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
# Sample at C: stack is A→B→C
frames_c = self._sample_frames(
- client_socket, unwinder, b"at_c", b"ack", {"func_a", "func_b", "func_c"})
+ client_socket,
+ unwinder,
+ b"at_c",
+ b"ack",
+ {"func_a", "func_b", "func_c"},
+ )
# Sample at D: stack is A→B→D (C returned, D called)
frames_d = self._sample_frames(
- client_socket, unwinder, b"at_d", b"done", {"func_a", "func_b", "func_d"})
+ client_socket,
+ unwinder,
+ b"at_d",
+ b"done",
+ {"func_a", "func_b", "func_d"},
+ )
self.assertIsNotNone(frames_c)
self.assertIsNotNone(frames_d)
self.assertIsNotNone(frame_b_in_d)
# The bottom frames (A, B) should be the SAME objects (cache reuse)
- self.assertIs(frame_a_in_c, frame_a_in_d, "func_a frame should be reused from cache")
- self.assertIs(frame_b_in_c, frame_b_in_d, "func_b frame should be reused from cache")
+ self.assertIs(
+ frame_a_in_c,
+ frame_a_in_d,
+ "func_a frame should be reused from cache",
+ )
+ self.assertIs(
+ frame_b_in_c,
+ frame_b_in_d,
+ "func_b frame should be reused from cache",
+ )
@skip_if_not_supported
@unittest.skipIf(
recurse(5)
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
frames1 = self._sample_frames(
- client_socket, unwinder, b"sync1", b"ack", {"recurse"})
+ client_socket, unwinder, b"sync1", b"ack", {"recurse"}
+ )
frames2 = self._sample_frames(
- client_socket, unwinder, b"sync2", b"done", {"recurse"})
+ client_socket, unwinder, b"sync2", b"done", {"recurse"}
+ )
self.assertIsNotNone(frames1)
self.assertIsNotNone(frames2)
# Parent frames (index 1+) should be identical objects (cache reuse)
for i in range(1, len(frames1)):
- self.assertIs(frames1[i], frames2[i],
- f"Frame {i}: recursive frames must be same object")
+ self.assertIs(
+ frames1[i],
+ frames2[i],
+ f"Frame {i}: recursive frames must be same object",
+ )
@skip_if_not_supported
@unittest.skipIf(
level1()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
- self._wait_for_signal(client_socket, b"ready")
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
+ _wait_for_signal(client_socket, b"ready")
# Sample with cache
unwinder_cache = make_unwinder(cache_frames=True)
- frames_cached = self._get_frames(unwinder_cache, {"level1", "level2", "level3"})
+ frames_cached = self._get_frames_with_retry(
+ unwinder_cache, {"level1", "level2", "level3"}
+ )
# Sample without cache
unwinder_no_cache = make_unwinder(cache_frames=False)
- frames_no_cache = self._get_frames(unwinder_no_cache, {"level1", "level2", "level3"})
+ frames_no_cache = self._get_frames_with_retry(
+ unwinder_no_cache, {"level1", "level2", "level3"}
+ )
client_socket.sendall(b"done")
t2.join()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder = make_unwinder(cache_frames=True)
buffer = b""
# Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated)
self.assertIn("blech1", t1_blech)
self.assertIn("foo1", t1_blech)
- self.assertNotIn("bar1", t1_blech, "Cache not invalidated: bar1 still present")
- self.assertNotIn("baz1", t1_blech, "Cache not invalidated: baz1 still present")
+ self.assertNotIn(
+ "bar1", t1_blech, "Cache not invalidated: bar1 still present"
+ )
+ self.assertNotIn(
+ "baz1", t1_blech, "Cache not invalidated: baz1 still present"
+ )
# No cross-contamination
self.assertNotIn("blech2", t1_blech)
# Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated)
self.assertIn("blech2", t2_blech)
self.assertIn("foo2", t2_blech)
- self.assertNotIn("bar2", t2_blech, "Cache not invalidated: bar2 still present")
- self.assertNotIn("baz2", t2_blech, "Cache not invalidated: baz2 still present")
+ self.assertNotIn(
+ "bar2", t2_blech, "Cache not invalidated: bar2 still present"
+ )
+ self.assertNotIn(
+ "baz2", t2_blech, "Cache not invalidated: baz2 still present"
+ )
# No cross-contamination
self.assertNotIn("blech1", t2_blech)
level1()
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
expected = {"level1", "level2", "level3", "level4"}
# First unwinder samples - this sets last_profiled_frame in target
unwinder1 = make_unwinder(cache_frames=True)
- frames1 = self._sample_frames(client_socket, unwinder1, b"sync1", b"ack", expected)
+ frames1 = self._sample_frames(
+ client_socket, unwinder1, b"sync1", b"ack", expected
+ )
# Create NEW unwinder (empty cache) and sample
# The target still has last_profiled_frame set from unwinder1
unwinder2 = make_unwinder(cache_frames=True)
- frames2 = self._sample_frames(client_socket, unwinder2, b"sync2", b"done", expected)
+ frames2 = self._sample_frames(
+ client_socket, unwinder2, b"sync2", b"done", expected
+ )
self.assertIsNotNone(frames1)
self.assertIsNotNone(frames2)
self.assertIn(level, funcs2, f"{level} missing from second sample")
# Should have same stack depth
- self.assertEqual(len(frames1), len(frames2),
- "New unwinder should return complete stack despite stale last_profiled_frame")
+ self.assertEqual(
+ len(frames1),
+ len(frames2),
+ "New unwinder should return complete stack despite stale last_profiled_frame",
+ )
@skip_if_not_supported
@unittest.skipIf(
recurse({depth})
"""
- with self._target_process(script_body) as (p, client_socket, make_unwinder):
+ with self._target_process(script_body) as (
+ p,
+ client_socket,
+ make_unwinder,
+ ):
unwinder_cache = make_unwinder(cache_frames=True)
unwinder_no_cache = make_unwinder(cache_frames=False)
frames_cached = self._sample_frames(
- client_socket, unwinder_cache, b"ready", b"ack", {"recurse"}, expected_frames=1102
+ client_socket,
+ unwinder_cache,
+ b"ready",
+ b"ack",
+ {"recurse"},
+ expected_frames=1102,
)
# Sample again with no cache for comparison
frames_no_cache = self._sample_frames(
- client_socket, unwinder_no_cache, b"ready2", b"done", {"recurse"}, expected_frames=1102
+ client_socket,
+ unwinder_no_cache,
+ b"ready2",
+ b"done",
+ {"recurse"},
+ expected_frames=1102,
)
self.assertIsNotNone(frames_cached)
cached_count = [f.funcname for f in frames_cached].count("recurse")
no_cache_count = [f.funcname for f in frames_no_cache].count("recurse")
- self.assertGreater(cached_count, 1000, "Should have >1000 recurse frames")
- self.assertGreater(no_cache_count, 1000, "Should have >1000 recurse frames")
+ self.assertGreater(
+ cached_count, 1000, "Should have >1000 recurse frames"
+ )
+ self.assertGreater(
+ no_cache_count, 1000, "Should have >1000 recurse frames"
+ )
# Both modes should produce same frame count
- self.assertEqual(len(frames_cached), len(frames_no_cache),
- "Cache exhaustion should not affect stack completeness")
+ self.assertEqual(
+ len(frames_cached),
+ len(frames_no_cache),
+ "Cache exhaustion should not affect stack completeness",
+ )
@skip_if_not_supported
@unittest.skipIf(
with self._target_process(script_body) as (p, client_socket, _):
unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True)
- self._wait_for_signal(client_socket, b"ready")
+ _wait_for_signal(client_socket, b"ready")
# Take a sample
unwinder.get_stack_trace()
# Verify expected keys exist
expected_keys = [
- 'total_samples', 'frame_cache_hits', 'frame_cache_misses',
- 'frame_cache_partial_hits', 'frames_read_from_cache',
- 'frames_read_from_memory', 'frame_cache_hit_rate'
+ "total_samples",
+ "frame_cache_hits",
+ "frame_cache_misses",
+ "frame_cache_partial_hits",
+ "frames_read_from_cache",
+ "frames_read_from_memory",
+ "frame_cache_hit_rate",
]
for key in expected_keys:
self.assertIn(key, stats)
- self.assertEqual(stats['total_samples'], 1)
+ self.assertEqual(stats["total_samples"], 1)
@skip_if_not_supported
@unittest.skipIf(
"""
with self._target_process(script_body) as (p, client_socket, _):
- unwinder = RemoteUnwinder(p.pid, all_threads=True) # stats=False by default
- self._wait_for_signal(client_socket, b"ready")
+ unwinder = RemoteUnwinder(
+ p.pid, all_threads=True
+ ) # stats=False by default
+ _wait_for_signal(client_socket, b"ready")
with self.assertRaises(RuntimeError):
unwinder.get_stats()