--- /dev/null
+import argparse
+import _remote_debugging
+import os
+import pstats
+import statistics
+import sys
+import sysconfig
+import time
+from collections import deque
+from _colorize import ANSIColors
+
+from .pstats_collector import PstatsCollector
+from .stack_collector import CollapsedStackCollector
+
+FREE_THREADED_BUILD = sysconfig.get_config_var("Py_GIL_DISABLED") is not None
+
+class SampleProfiler:
+ def __init__(self, pid, sample_interval_usec, all_threads):
+ self.pid = pid
+ self.sample_interval_usec = sample_interval_usec
+ self.all_threads = all_threads
+ if FREE_THREADED_BUILD:
+ self.unwinder = _remote_debugging.RemoteUnwinder(
+ self.pid, all_threads=self.all_threads
+ )
+ else:
+ only_active_threads = bool(self.all_threads)
+ self.unwinder = _remote_debugging.RemoteUnwinder(
+ self.pid, only_active_thread=only_active_threads
+ )
+ # Track sample intervals and total sample count
+ self.sample_intervals = deque(maxlen=100)
+ self.total_samples = 0
+ self.realtime_stats = False
+
+ def sample(self, collector, duration_sec=10):
+ sample_interval_sec = self.sample_interval_usec / 1_000_000
+ running_time = 0
+ num_samples = 0
+ errors = 0
+ start_time = next_time = time.perf_counter()
+ last_sample_time = start_time
+ realtime_update_interval = 1.0 # Update every second
+ last_realtime_update = start_time
+
+ while running_time < duration_sec:
+ current_time = time.perf_counter()
+ if next_time < current_time:
+ try:
+ stack_frames = self.unwinder.get_stack_trace()
+ collector.collect(stack_frames)
+ except ProcessLookupError:
+ break
+ except (RuntimeError, UnicodeDecodeError, MemoryError, OSError):
+ errors += 1
+ except Exception as e:
+ if not self._is_process_running():
+ break
+ raise e from None
+
+ # Track actual sampling intervals for real-time stats
+ if num_samples > 0:
+ actual_interval = current_time - last_sample_time
+ self.sample_intervals.append(
+ 1.0 / actual_interval
+ ) # Convert to Hz
+ self.total_samples += 1
+
+ # Print real-time statistics if enabled
+ if (
+ self.realtime_stats
+ and (current_time - last_realtime_update)
+ >= realtime_update_interval
+ ):
+ self._print_realtime_stats()
+ last_realtime_update = current_time
+
+ last_sample_time = current_time
+ num_samples += 1
+ next_time += sample_interval_sec
+
+ running_time = time.perf_counter() - start_time
+
+ # Clear real-time stats line if it was being displayed
+ if self.realtime_stats and len(self.sample_intervals) > 0:
+ print() # Add newline after real-time stats
+
+ print(f"Captured {num_samples} samples in {running_time:.2f} seconds")
+ print(f"Sample rate: {num_samples / running_time:.2f} samples/sec")
+ print(f"Error rate: {(errors / num_samples) * 100:.2f}%")
+
+ expected_samples = int(duration_sec / sample_interval_sec)
+ if num_samples < expected_samples:
+ print(
+ f"Warning: missed {expected_samples - num_samples} samples "
+ f"from the expected total of {expected_samples} "
+ f"({(expected_samples - num_samples) / expected_samples * 100:.2f}%)"
+ )
+
+ def _is_process_running(self):
+ if sys.platform == "linux" or sys.platform == "darwin":
+ try:
+ os.kill(self.pid, 0)
+ return True
+ except ProcessLookupError:
+ return False
+ elif sys.platform == "win32":
+ try:
+ _remote_debugging.RemoteUnwinder(self.pid)
+ except Exception:
+ return False
+ return True
+ else:
+ raise ValueError(f"Unsupported platform: {sys.platform}")
+
+ def _print_realtime_stats(self):
+ """Print real-time sampling statistics."""
+ if len(self.sample_intervals) < 2:
+ return
+
+ # Calculate statistics on the Hz values (deque automatically maintains rolling window)
+ hz_values = list(self.sample_intervals)
+ mean_hz = statistics.mean(hz_values)
+ min_hz = min(hz_values)
+ max_hz = max(hz_values)
+
+ # Calculate microseconds per sample for all metrics (1/Hz * 1,000,000)
+ mean_us_per_sample = (1.0 / mean_hz) * 1_000_000 if mean_hz > 0 else 0
+ min_us_per_sample = (
+ (1.0 / max_hz) * 1_000_000 if max_hz > 0 else 0
+ ) # Min time = Max Hz
+ max_us_per_sample = (
+ (1.0 / min_hz) * 1_000_000 if min_hz > 0 else 0
+ ) # Max time = Min Hz
+
+ # Clear line and print stats
+ print(
+ f"\r\033[K{ANSIColors.BOLD_BLUE}Real-time sampling stats:{ANSIColors.RESET} "
+ f"{ANSIColors.YELLOW}Mean: {mean_hz:.1f}Hz ({mean_us_per_sample:.2f}µs){ANSIColors.RESET} "
+ f"{ANSIColors.GREEN}Min: {min_hz:.1f}Hz ({max_us_per_sample:.2f}µs){ANSIColors.RESET} "
+ f"{ANSIColors.RED}Max: {max_hz:.1f}Hz ({min_us_per_sample:.2f}µs){ANSIColors.RESET} "
+ f"{ANSIColors.CYAN}Samples: {self.total_samples}{ANSIColors.RESET}",
+ end="",
+ flush=True,
+ )
+
+
+def _determine_best_unit(max_value):
+ """Determine the best unit (s, ms, μs) and scale factor for a maximum value."""
+ if max_value >= 1.0:
+ return "s", 1.0
+ elif max_value >= 0.001:
+ return "ms", 1000.0
+ else:
+ return "μs", 1000000.0
+
+
+def print_sampled_stats(
+ stats, sort=-1, limit=None, show_summary=True, sample_interval_usec=100
+):
+ # Get the stats data
+ stats_list = []
+ for func, (
+ direct_calls,
+ cumulative_calls,
+ total_time,
+ cumulative_time,
+ callers,
+ ) in stats.stats.items():
+ stats_list.append(
+ (
+ func,
+ direct_calls,
+ cumulative_calls,
+ total_time,
+ cumulative_time,
+ callers,
+ )
+ )
+
+ # Calculate total samples for percentage calculations (using direct_calls)
+ total_samples = sum(
+ direct_calls for _, direct_calls, _, _, _, _ in stats_list
+ )
+
+ # Sort based on the requested field
+ sort_field = sort
+ if sort_field == -1: # stdname
+ stats_list.sort(key=lambda x: str(x[0]))
+ elif sort_field == 0: # nsamples (direct samples)
+ stats_list.sort(key=lambda x: x[1], reverse=True) # direct_calls
+ elif sort_field == 1: # tottime
+ stats_list.sort(key=lambda x: x[3], reverse=True) # total_time
+ elif sort_field == 2: # cumtime
+ stats_list.sort(key=lambda x: x[4], reverse=True) # cumulative_time
+ elif sort_field == 3: # sample%
+ stats_list.sort(
+ key=lambda x: (x[1] / total_samples * 100)
+ if total_samples > 0
+ else 0,
+ reverse=True, # direct_calls percentage
+ )
+ elif sort_field == 4: # cumul%
+ stats_list.sort(
+ key=lambda x: (x[2] / total_samples * 100)
+ if total_samples > 0
+ else 0,
+ reverse=True, # cumulative_calls percentage
+ )
+ elif sort_field == 5: # nsamples (cumulative samples)
+ stats_list.sort(key=lambda x: x[2], reverse=True) # cumulative_calls
+
+ # Apply limit if specified
+ if limit is not None:
+ stats_list = stats_list[:limit]
+
+ # Determine the best unit for time columns based on maximum values
+ max_total_time = max(
+ (total_time for _, _, _, total_time, _, _ in stats_list), default=0
+ )
+ max_cumulative_time = max(
+ (cumulative_time for _, _, _, _, cumulative_time, _ in stats_list),
+ default=0,
+ )
+
+ total_time_unit, total_time_scale = _determine_best_unit(max_total_time)
+ cumulative_time_unit, cumulative_time_scale = _determine_best_unit(
+ max_cumulative_time
+ )
+
+ # Define column widths for consistent alignment
+ col_widths = {
+ "nsamples": 15, # "nsamples" column (inline/cumulative format)
+ "sample_pct": 8, # "sample%" column
+ "tottime": max(12, len(f"tottime ({total_time_unit})")),
+ "cum_pct": 8, # "cumul%" column
+ "cumtime": max(12, len(f"cumtime ({cumulative_time_unit})")),
+ }
+
+ # Print header with colors and proper alignment
+ print(f"{ANSIColors.BOLD_BLUE}Profile Stats:{ANSIColors.RESET}")
+
+ header_nsamples = f"{ANSIColors.BOLD_BLUE}{'nsamples':>{col_widths['nsamples']}}{ANSIColors.RESET}"
+ header_sample_pct = f"{ANSIColors.BOLD_BLUE}{'sample%':>{col_widths['sample_pct']}}{ANSIColors.RESET}"
+ header_tottime = f"{ANSIColors.BOLD_BLUE}{f'tottime ({total_time_unit})':>{col_widths['tottime']}}{ANSIColors.RESET}"
+ header_cum_pct = f"{ANSIColors.BOLD_BLUE}{'cumul%':>{col_widths['cum_pct']}}{ANSIColors.RESET}"
+ header_cumtime = f"{ANSIColors.BOLD_BLUE}{f'cumtime ({cumulative_time_unit})':>{col_widths['cumtime']}}{ANSIColors.RESET}"
+ header_filename = (
+ f"{ANSIColors.BOLD_BLUE}filename:lineno(function){ANSIColors.RESET}"
+ )
+
+ print(
+ f"{header_nsamples} {header_sample_pct} {header_tottime} {header_cum_pct} {header_cumtime} {header_filename}"
+ )
+
+ # Print each line with proper alignment
+ for (
+ func,
+ direct_calls,
+ cumulative_calls,
+ total_time,
+ cumulative_time,
+ callers,
+ ) in stats_list:
+ # Calculate percentages
+ sample_pct = (
+ (direct_calls / total_samples * 100) if total_samples > 0 else 0
+ )
+ cum_pct = (
+ (cumulative_calls / total_samples * 100)
+ if total_samples > 0
+ else 0
+ )
+
+ # Format values with proper alignment - always use A/B format
+ nsamples_str = f"{direct_calls}/{cumulative_calls}"
+ nsamples_str = f"{nsamples_str:>{col_widths['nsamples']}}"
+ sample_pct_str = f"{sample_pct:{col_widths['sample_pct']}.1f}"
+ tottime = f"{total_time * total_time_scale:{col_widths['tottime']}.3f}"
+ cum_pct_str = f"{cum_pct:{col_widths['cum_pct']}.1f}"
+ cumtime = f"{cumulative_time * cumulative_time_scale:{col_widths['cumtime']}.3f}"
+
+ # Format the function name with colors
+ func_name = (
+ f"{ANSIColors.GREEN}{func[0]}{ANSIColors.RESET}:"
+ f"{ANSIColors.YELLOW}{func[1]}{ANSIColors.RESET}("
+ f"{ANSIColors.CYAN}{func[2]}{ANSIColors.RESET})"
+ )
+
+ # Print the formatted line with consistent spacing
+ print(
+ f"{nsamples_str} {sample_pct_str} {tottime} {cum_pct_str} {cumtime} {func_name}"
+ )
+
+ # Print legend
+ print(f"\n{ANSIColors.BOLD_BLUE}Legend:{ANSIColors.RESET}")
+ print(
+ f" {ANSIColors.YELLOW}nsamples{ANSIColors.RESET}: Direct/Cumulative samples (direct executing / on call stack)"
+ )
+ print(
+ f" {ANSIColors.YELLOW}sample%{ANSIColors.RESET}: Percentage of total samples this function was directly executing"
+ )
+ print(
+ f" {ANSIColors.YELLOW}tottime{ANSIColors.RESET}: Estimated total time spent directly in this function"
+ )
+ print(
+ f" {ANSIColors.YELLOW}cumul%{ANSIColors.RESET}: Percentage of total samples when this function was on the call stack"
+ )
+ print(
+ f" {ANSIColors.YELLOW}cumtime{ANSIColors.RESET}: Estimated cumulative time (including time in called functions)"
+ )
+ print(
+ f" {ANSIColors.YELLOW}filename:lineno(function){ANSIColors.RESET}: Function location and name"
+ )
+
+ def _format_func_name(func):
+ """Format function name with colors."""
+ return (
+ f"{ANSIColors.GREEN}{func[0]}{ANSIColors.RESET}:"
+ f"{ANSIColors.YELLOW}{func[1]}{ANSIColors.RESET}("
+ f"{ANSIColors.CYAN}{func[2]}{ANSIColors.RESET})"
+ )
+
+ def _print_top_functions(stats_list, title, key_func, format_line, n=3):
+ """Print top N functions sorted by key_func with formatted output."""
+ print(f"\n{ANSIColors.BOLD_BLUE}{title}:{ANSIColors.RESET}")
+ sorted_stats = sorted(stats_list, key=key_func, reverse=True)
+ for stat in sorted_stats[:n]:
+ if line := format_line(stat):
+ print(f" {line}")
+
+ # Print summary of interesting functions if enabled
+ if show_summary and stats_list:
+ print(
+ f"\n{ANSIColors.BOLD_BLUE}Summary of Interesting Functions:{ANSIColors.RESET}"
+ )
+
+ # Aggregate stats by fully qualified function name (ignoring line numbers)
+ func_aggregated = {}
+ for (
+ func,
+ direct_calls,
+ cumulative_calls,
+ total_time,
+ cumulative_time,
+ callers,
+ ) in stats_list:
+ # Use filename:function_name as the key to get fully qualified name
+ qualified_name = f"{func[0]}:{func[2]}"
+ if qualified_name not in func_aggregated:
+ func_aggregated[qualified_name] = [
+ 0,
+ 0,
+ 0,
+ 0,
+ ] # direct_calls, cumulative_calls, total_time, cumulative_time
+ func_aggregated[qualified_name][0] += direct_calls
+ func_aggregated[qualified_name][1] += cumulative_calls
+ func_aggregated[qualified_name][2] += total_time
+ func_aggregated[qualified_name][3] += cumulative_time
+
+ # Convert aggregated data back to list format for processing
+ aggregated_stats = []
+ for qualified_name, (
+ prim_calls,
+ total_calls,
+ total_time,
+ cumulative_time,
+ ) in func_aggregated.items():
+ # Parse the qualified name back to filename and function name
+ if ":" in qualified_name:
+ filename, func_name = qualified_name.rsplit(":", 1)
+ else:
+ filename, func_name = "", qualified_name
+ # Create a dummy func tuple with filename and function name for display
+ dummy_func = (filename, "", func_name)
+ aggregated_stats.append(
+ (
+ dummy_func,
+ prim_calls,
+ total_calls,
+ total_time,
+ cumulative_time,
+ {},
+ )
+ )
+
+ # Determine best units for summary metrics
+ max_total_time = max(
+ (total_time for _, _, _, total_time, _, _ in aggregated_stats),
+ default=0,
+ )
+ max_cumulative_time = max(
+ (
+ cumulative_time
+ for _, _, _, _, cumulative_time, _ in aggregated_stats
+ ),
+ default=0,
+ )
+
+ total_unit, total_scale = _determine_best_unit(max_total_time)
+ cumulative_unit, cumulative_scale = _determine_best_unit(
+ max_cumulative_time
+ )
+
+ # Functions with highest direct/cumulative ratio (hot spots)
+ def format_hotspots(stat):
+ func, direct_calls, cumulative_calls, total_time, _, _ = stat
+ if direct_calls > 0 and cumulative_calls > 0:
+ ratio = direct_calls / cumulative_calls
+ direct_pct = (
+ (direct_calls / total_samples * 100)
+ if total_samples > 0
+ else 0
+ )
+ return (
+ f"{ratio:.3f} direct/cumulative ratio, "
+ f"{direct_pct:.1f}% direct samples: {_format_func_name(func)}"
+ )
+ return None
+
+ _print_top_functions(
+ aggregated_stats,
+ "Functions with Highest Direct/Cumulative Ratio (Hot Spots)",
+ key_func=lambda x: (x[1] / x[2]) if x[2] > 0 else 0,
+ format_line=format_hotspots,
+ )
+
+ # Functions with highest call frequency (cumulative/direct difference)
+ def format_call_frequency(stat):
+ func, direct_calls, cumulative_calls, total_time, _, _ = stat
+ if cumulative_calls > direct_calls:
+ call_frequency = cumulative_calls - direct_calls
+ cum_pct = (
+ (cumulative_calls / total_samples * 100)
+ if total_samples > 0
+ else 0
+ )
+ return (
+ f"{call_frequency:d} indirect calls, "
+ f"{cum_pct:.1f}% total stack presence: {_format_func_name(func)}"
+ )
+ return None
+
+ _print_top_functions(
+ aggregated_stats,
+ "Functions with Highest Call Frequency (Indirect Calls)",
+ key_func=lambda x: x[2] - x[1], # Sort by (cumulative - direct)
+ format_line=format_call_frequency,
+ )
+
+ # Functions with highest cumulative-to-direct multiplier (call magnification)
+ def format_call_magnification(stat):
+ func, direct_calls, cumulative_calls, total_time, _, _ = stat
+ if direct_calls > 0 and cumulative_calls > direct_calls:
+ multiplier = cumulative_calls / direct_calls
+ indirect_calls = cumulative_calls - direct_calls
+ return (
+ f"{multiplier:.1f}x call magnification, "
+ f"{indirect_calls:d} indirect calls from {direct_calls:d} direct: {_format_func_name(func)}"
+ )
+ return None
+
+ _print_top_functions(
+ aggregated_stats,
+ "Functions with Highest Call Magnification (Cumulative/Direct)",
+ key_func=lambda x: (x[2] / x[1])
+ if x[1] > 0
+ else 0, # Sort by cumulative/direct ratio
+ format_line=format_call_magnification,
+ )
+
+
+def sample(
+ pid,
+ *,
+ sort=2,
+ sample_interval_usec=100,
+ duration_sec=10,
+ filename=None,
+ all_threads=False,
+ limit=None,
+ show_summary=True,
+ output_format="pstats",
+ realtime_stats=False,
+):
+ profiler = SampleProfiler(
+ pid, sample_interval_usec, all_threads=all_threads
+ )
+ profiler.realtime_stats = realtime_stats
+
+ collector = None
+ match output_format:
+ case "pstats":
+ collector = PstatsCollector(sample_interval_usec)
+ case "collapsed":
+ collector = CollapsedStackCollector()
+ filename = filename or f"collapsed.{pid}.txt"
+ case _:
+ raise ValueError(f"Invalid output format: {output_format}")
+
+ profiler.sample(collector, duration_sec)
+
+ if output_format == "pstats" and not filename:
+ stats = pstats.SampledStats(collector).strip_dirs()
+ print_sampled_stats(
+ stats, sort, limit, show_summary, sample_interval_usec
+ )
+ else:
+ collector.export(filename)
+
+
+def _validate_collapsed_format_args(args, parser):
+ # Check for incompatible pstats options
+ invalid_opts = []
+
+ # Get list of pstats-specific options
+ pstats_options = {"sort": None, "limit": None, "no_summary": False}
+
+ # Find the default values from the argument definitions
+ for action in parser._actions:
+ if action.dest in pstats_options and hasattr(action, "default"):
+ pstats_options[action.dest] = action.default
+
+ # Check if any pstats-specific options were provided by comparing with defaults
+ for opt, default in pstats_options.items():
+ if getattr(args, opt) != default:
+ invalid_opts.append(opt.replace("no_", ""))
+
+ if invalid_opts:
+ parser.error(
+ f"The following options are only valid with --pstats format: {', '.join(invalid_opts)}"
+ )
+
+ # Set default output filename for collapsed format
+ if not args.outfile:
+ args.outfile = f"collapsed.{args.pid}.txt"
+
+
+def main():
+ # Create the main parser
+ parser = argparse.ArgumentParser(
+ description=(
+ "Sample a process's stack frames and generate profiling data.\n"
+ "Supports two output formats:\n"
+ " - pstats: Detailed profiling statistics with sorting options\n"
+ " - collapsed: Stack traces for generating flamegraphs\n"
+ "\n"
+ "Examples:\n"
+ " # Profile process 1234 for 10 seconds with default settings\n"
+ " python -m profile.sample 1234\n"
+ "\n"
+ " # Profile with custom interval and duration, save to file\n"
+ " python -m profile.sample -i 50 -d 30 -o profile.stats 1234\n"
+ "\n"
+ " # Generate collapsed stacks for flamegraph\n"
+ " python -m profile.sample --collapsed 1234\n"
+ "\n"
+ " # Profile all threads, sort by total time\n"
+ " python -m profile.sample -a --sort-tottime 1234\n"
+ "\n"
+ " # Profile for 1 minute with 1ms sampling interval\n"
+ " python -m profile.sample -i 1000 -d 60 1234\n"
+ "\n"
+ " # Show only top 20 functions sorted by direct samples\n"
+ " python -m profile.sample --sort-nsamples -l 20 1234\n"
+ "\n"
+ " # Profile all threads and save collapsed stacks\n"
+ " python -m profile.sample -a --collapsed -o stacks.txt 1234\n"
+ "\n"
+ " # Profile with real-time sampling statistics\n"
+ " python -m profile.sample --realtime-stats 1234\n"
+ "\n"
+ " # Sort by sample percentage to find most sampled functions\n"
+ " python -m profile.sample --sort-sample-pct 1234\n"
+ "\n"
+ " # Sort by cumulative samples to find functions most on call stack\n"
+ " python -m profile.sample --sort-nsamples-cumul 1234"
+ ),
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ # Required arguments
+ parser.add_argument("pid", type=int, help="Process ID to sample")
+
+ # Sampling options
+ sampling_group = parser.add_argument_group("Sampling configuration")
+ sampling_group.add_argument(
+ "-i",
+ "--interval",
+ type=int,
+ default=100,
+ help="Sampling interval in microseconds (default: 100)",
+ )
+ sampling_group.add_argument(
+ "-d",
+ "--duration",
+ type=int,
+ default=10,
+ help="Sampling duration in seconds (default: 10)",
+ )
+ sampling_group.add_argument(
+ "-a",
+ "--all-threads",
+ action="store_true",
+ help="Sample all threads in the process instead of just the main thread",
+ )
+ sampling_group.add_argument(
+ "--realtime-stats",
+ action="store_true",
+ default=False,
+ help="Print real-time sampling statistics (Hz, mean, min, max, stdev) during profiling",
+ )
+
+ # Output format selection
+ output_group = parser.add_argument_group("Output options")
+ output_format = output_group.add_mutually_exclusive_group()
+ output_format.add_argument(
+ "--pstats",
+ action="store_const",
+ const="pstats",
+ dest="format",
+ default="pstats",
+ help="Generate pstats output (default)",
+ )
+ output_format.add_argument(
+ "--collapsed",
+ action="store_const",
+ const="collapsed",
+ dest="format",
+ help="Generate collapsed stack traces for flamegraphs",
+ )
+
+ output_group.add_argument(
+ "-o",
+ "--outfile",
+ help="Save output to a file (if omitted, prints to stdout for pstats, "
+ "or saves to collapsed.<pid>.txt for collapsed format)",
+ )
+
+ # pstats-specific options
+ pstats_group = parser.add_argument_group("pstats format options")
+ sort_group = pstats_group.add_mutually_exclusive_group()
+ sort_group.add_argument(
+ "--sort-nsamples",
+ action="store_const",
+ const=0,
+ dest="sort",
+ help="Sort by number of direct samples (nsamples column)",
+ )
+ sort_group.add_argument(
+ "--sort-tottime",
+ action="store_const",
+ const=1,
+ dest="sort",
+ help="Sort by total time (tottime column)",
+ )
+ sort_group.add_argument(
+ "--sort-cumtime",
+ action="store_const",
+ const=2,
+ dest="sort",
+ help="Sort by cumulative time (cumtime column, default)",
+ )
+ sort_group.add_argument(
+ "--sort-sample-pct",
+ action="store_const",
+ const=3,
+ dest="sort",
+ help="Sort by sample percentage (sample%% column)",
+ )
+ sort_group.add_argument(
+ "--sort-cumul-pct",
+ action="store_const",
+ const=4,
+ dest="sort",
+ help="Sort by cumulative sample percentage (cumul%% column)",
+ )
+ sort_group.add_argument(
+ "--sort-nsamples-cumul",
+ action="store_const",
+ const=5,
+ dest="sort",
+ help="Sort by cumulative samples (nsamples column, cumulative part)",
+ )
+ sort_group.add_argument(
+ "--sort-name",
+ action="store_const",
+ const=-1,
+ dest="sort",
+ help="Sort by function name",
+ )
+
+ pstats_group.add_argument(
+ "-l",
+ "--limit",
+ type=int,
+ help="Limit the number of rows in the output",
+ default=15,
+ )
+ pstats_group.add_argument(
+ "--no-summary",
+ action="store_true",
+ help="Disable the summary section in the output",
+ )
+
+ args = parser.parse_args()
+
+ # Validate format-specific arguments
+ if args.format == "collapsed":
+ _validate_collapsed_format_args(args, parser)
+
+ sort_value = args.sort if args.sort is not None else 2
+
+ sample(
+ args.pid,
+ sample_interval_usec=args.interval,
+ duration_sec=args.duration,
+ filename=args.outfile,
+ all_threads=args.all_threads,
+ limit=args.limit,
+ sort=sort_value,
+ show_summary=not args.no_summary,
+ output_format=args.format,
+ realtime_stats=args.realtime_stats,
+ )
+
+
+if __name__ == "__main__":
+ main()
--- /dev/null
+"""Tests for the sampling profiler (profile.sample)."""
+
+import contextlib
+import io
+import marshal
+import os
+import socket
+import subprocess
+import sys
+import tempfile
+import unittest
+from unittest import mock
+
+from profile.pstats_collector import PstatsCollector
+from profile.stack_collector import (
+ CollapsedStackCollector,
+)
+
+from test.support.os_helper import unlink
+from test.support import force_not_colorized_test_class, SHORT_TIMEOUT
+from test.support.socket_helper import find_unused_port
+from test.support import requires_subprocess
+
+PROCESS_VM_READV_SUPPORTED = False
+
+try:
+ from _remote_debugging import PROCESS_VM_READV_SUPPORTED
+ import _remote_debugging
+except ImportError:
+ raise unittest.SkipTest(
+ "Test only runs when _remote_debugging is available"
+ )
+else:
+ import profile.sample
+ from profile.sample import SampleProfiler
+
+
+
+class MockFrameInfo:
+ """Mock FrameInfo for testing since the real one isn't accessible."""
+
+ def __init__(self, filename, lineno, funcname):
+ self.filename = filename
+ self.lineno = lineno
+ self.funcname = funcname
+
+ def __repr__(self):
+ return f"MockFrameInfo(filename='{self.filename}', lineno={self.lineno}, funcname='{self.funcname}')"
+
+
+skip_if_not_supported = unittest.skipIf(
+ (
+ sys.platform != "darwin"
+ and sys.platform != "linux"
+ and sys.platform != "win32"
+ ),
+ "Test only runs on Linux, Windows and MacOS",
+)
+
+
+@contextlib.contextmanager
+def test_subprocess(script):
+ # Find an unused port for socket communication
+ port = find_unused_port()
+
+ # Inject socket connection code at the beginning of the script
+ socket_code = f'''
+import socket
+_test_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+_test_sock.connect(('localhost', {port}))
+_test_sock.sendall(b"ready")
+'''
+
+ # Combine socket code with user script
+ full_script = socket_code + script
+
+ # Create server socket to wait for process to be ready
+ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server_socket.bind(("localhost", port))
+ server_socket.settimeout(SHORT_TIMEOUT)
+ server_socket.listen(1)
+
+ proc = subprocess.Popen(
+ [sys.executable, "-c", full_script],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ )
+
+ client_socket = None
+ try:
+ # Wait for process to connect and send ready signal
+ client_socket, _ = server_socket.accept()
+ server_socket.close()
+ response = client_socket.recv(1024)
+ if response != b"ready":
+ raise RuntimeError(f"Unexpected response from subprocess: {response}")
+
+ yield proc
+ finally:
+ if client_socket is not None:
+ client_socket.close()
+ if proc.poll() is None:
+ proc.kill()
+ proc.wait()
+
+
+def close_and_unlink(file):
+ file.close()
+ unlink(file.name)
+
+
+class TestSampleProfilerComponents(unittest.TestCase):
+ """Unit tests for individual profiler components."""
+
+ def test_mock_frame_info_with_empty_and_unicode_values(self):
+ """Test MockFrameInfo handles empty strings, unicode characters, and very long names correctly."""
+ # Test with empty strings
+ frame = MockFrameInfo("", 0, "")
+ self.assertEqual(frame.filename, "")
+ self.assertEqual(frame.lineno, 0)
+ self.assertEqual(frame.funcname, "")
+ self.assertIn("filename=''", repr(frame))
+
+ # Test with unicode characters
+ frame = MockFrameInfo("文件.py", 42, "函数名")
+ self.assertEqual(frame.filename, "文件.py")
+ self.assertEqual(frame.funcname, "函数名")
+
+ # Test with very long names
+ long_filename = "x" * 1000 + ".py"
+ long_funcname = "func_" + "x" * 1000
+ frame = MockFrameInfo(long_filename, 999999, long_funcname)
+ self.assertEqual(frame.filename, long_filename)
+ self.assertEqual(frame.lineno, 999999)
+ self.assertEqual(frame.funcname, long_funcname)
+
+ def test_pstats_collector_with_extreme_intervals_and_empty_data(self):
+ """Test PstatsCollector handles zero/large intervals, empty frames, None thread IDs, and duplicate frames."""
+ # Test with zero interval
+ collector = PstatsCollector(sample_interval_usec=0)
+ self.assertEqual(collector.sample_interval_usec, 0)
+
+ # Test with very large interval
+ collector = PstatsCollector(sample_interval_usec=1000000000)
+ self.assertEqual(collector.sample_interval_usec, 1000000000)
+
+ # Test collecting empty frames list
+ collector = PstatsCollector(sample_interval_usec=1000)
+ collector.collect([])
+ self.assertEqual(len(collector.result), 0)
+
+ # Test collecting frames with None thread id
+ test_frames = [(None, [MockFrameInfo("file.py", 10, "func")])]
+ collector.collect(test_frames)
+ # Should still process the frames
+ self.assertEqual(len(collector.result), 1)
+
+ # Test collecting duplicate frames in same sample
+ test_frames = [
+ (
+ 1,
+ [
+ MockFrameInfo("file.py", 10, "func1"),
+ MockFrameInfo("file.py", 10, "func1"), # Duplicate
+ ],
+ )
+ ]
+ collector = PstatsCollector(sample_interval_usec=1000)
+ collector.collect(test_frames)
+ # Should count both occurrences
+ self.assertEqual(
+ collector.result[("file.py", 10, "func1")]["cumulative_calls"], 2
+ )
+
+ def test_pstats_collector_single_frame_stacks(self):
+ """Test PstatsCollector with single-frame call stacks to trigger len(frames) <= 1 branch."""
+ collector = PstatsCollector(sample_interval_usec=1000)
+
+ # Test with exactly one frame (should trigger the <= 1 condition)
+ single_frame = [(1, [MockFrameInfo("single.py", 10, "single_func")])]
+ collector.collect(single_frame)
+
+ # Should record the single frame with inline call
+ self.assertEqual(len(collector.result), 1)
+ single_key = ("single.py", 10, "single_func")
+ self.assertIn(single_key, collector.result)
+ self.assertEqual(collector.result[single_key]["direct_calls"], 1)
+ self.assertEqual(collector.result[single_key]["cumulative_calls"], 1)
+
+ # Test with empty frames (should also trigger <= 1 condition)
+ empty_frames = [(1, [])]
+ collector.collect(empty_frames)
+
+ # Should not add any new entries
+ self.assertEqual(
+ len(collector.result), 1
+ ) # Still just the single frame
+
+ # Test mixed single and multi-frame stacks
+ mixed_frames = [
+ (
+ 1,
+ [MockFrameInfo("single2.py", 20, "single_func2")],
+ ), # Single frame
+ (
+ 2,
+ [ # Multi-frame stack
+ MockFrameInfo("multi.py", 30, "multi_func1"),
+ MockFrameInfo("multi.py", 40, "multi_func2"),
+ ],
+ ),
+ ]
+ collector.collect(mixed_frames)
+
+ # Should have recorded all functions
+ self.assertEqual(
+ len(collector.result), 4
+ ) # single + single2 + multi1 + multi2
+
+ # Verify single frame handling
+ single2_key = ("single2.py", 20, "single_func2")
+ self.assertIn(single2_key, collector.result)
+ self.assertEqual(collector.result[single2_key]["direct_calls"], 1)
+ self.assertEqual(collector.result[single2_key]["cumulative_calls"], 1)
+
+ # Verify multi-frame handling still works
+ multi1_key = ("multi.py", 30, "multi_func1")
+ multi2_key = ("multi.py", 40, "multi_func2")
+ self.assertIn(multi1_key, collector.result)
+ self.assertIn(multi2_key, collector.result)
+ self.assertEqual(collector.result[multi1_key]["direct_calls"], 1)
+ self.assertEqual(
+ collector.result[multi2_key]["cumulative_calls"], 1
+ ) # Called from multi1
+
+ def test_collapsed_stack_collector_with_empty_and_deep_stacks(self):
+ """Test CollapsedStackCollector handles empty frames, single-frame stacks, and very deep call stacks."""
+ collector = CollapsedStackCollector()
+
+ # Test with empty frames
+ collector.collect([])
+ self.assertEqual(len(collector.call_trees), 0)
+
+ # Test with single frame stack
+ test_frames = [(1, [("file.py", 10, "func")])]
+ collector.collect(test_frames)
+ self.assertEqual(len(collector.call_trees), 1)
+ self.assertEqual(collector.call_trees[0], [("file.py", 10, "func")])
+
+ # Test with very deep stack
+ deep_stack = [(f"file{i}.py", i, f"func{i}") for i in range(100)]
+ test_frames = [(1, deep_stack)]
+ collector = CollapsedStackCollector()
+ collector.collect(test_frames)
+ self.assertEqual(len(collector.call_trees[0]), 100)
+ # Check it's properly reversed
+ self.assertEqual(
+ collector.call_trees[0][0], ("file99.py", 99, "func99")
+ )
+ self.assertEqual(collector.call_trees[0][-1], ("file0.py", 0, "func0"))
+
+ def test_pstats_collector_basic(self):
+ """Test basic PstatsCollector functionality."""
+ collector = PstatsCollector(sample_interval_usec=1000)
+
+ # Test empty state
+ self.assertEqual(len(collector.result), 0)
+ self.assertEqual(len(collector.stats), 0)
+
+ # Test collecting sample data
+ test_frames = [
+ (
+ 1,
+ [
+ MockFrameInfo("file.py", 10, "func1"),
+ MockFrameInfo("file.py", 20, "func2"),
+ ],
+ )
+ ]
+ collector.collect(test_frames)
+
+ # Should have recorded calls for both functions
+ self.assertEqual(len(collector.result), 2)
+ self.assertIn(("file.py", 10, "func1"), collector.result)
+ self.assertIn(("file.py", 20, "func2"), collector.result)
+
+ # Top-level function should have direct call
+ self.assertEqual(
+ collector.result[("file.py", 10, "func1")]["direct_calls"], 1
+ )
+ self.assertEqual(
+ collector.result[("file.py", 10, "func1")]["cumulative_calls"], 1
+ )
+
+ # Calling function should have cumulative call but no direct calls
+ self.assertEqual(
+ collector.result[("file.py", 20, "func2")]["cumulative_calls"], 1
+ )
+ self.assertEqual(
+ collector.result[("file.py", 20, "func2")]["direct_calls"], 0
+ )
+
+ def test_pstats_collector_create_stats(self):
+ """Test PstatsCollector stats creation."""
+ collector = PstatsCollector(
+ sample_interval_usec=1000000
+ ) # 1 second intervals
+
+ test_frames = [
+ (
+ 1,
+ [
+ MockFrameInfo("file.py", 10, "func1"),
+ MockFrameInfo("file.py", 20, "func2"),
+ ],
+ )
+ ]
+ collector.collect(test_frames)
+ collector.collect(test_frames) # Collect twice
+
+ collector.create_stats()
+
+ # Check stats format: (direct_calls, cumulative_calls, tt, ct, callers)
+ func1_stats = collector.stats[("file.py", 10, "func1")]
+ self.assertEqual(func1_stats[0], 2) # direct_calls (top of stack)
+ self.assertEqual(func1_stats[1], 2) # cumulative_calls
+ self.assertEqual(
+ func1_stats[2], 2.0
+ ) # tt (total time - 2 samples * 1 sec)
+ self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time)
+
+ func2_stats = collector.stats[("file.py", 20, "func2")]
+ self.assertEqual(
+ func2_stats[0], 0
+ ) # direct_calls (never top of stack)
+ self.assertEqual(
+ func2_stats[1], 2
+ ) # cumulative_calls (appears in stack)
+ self.assertEqual(func2_stats[2], 0.0) # tt (no direct calls)
+ self.assertEqual(func2_stats[3], 2.0) # ct (cumulative time)
+
+ def test_collapsed_stack_collector_basic(self):
+ collector = CollapsedStackCollector()
+
+ # Test empty state
+ self.assertEqual(len(collector.call_trees), 0)
+ self.assertEqual(len(collector.function_samples), 0)
+
+ # Test collecting sample data
+ test_frames = [
+ (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")])
+ ]
+ collector.collect(test_frames)
+
+ # Should store call tree (reversed)
+ self.assertEqual(len(collector.call_trees), 1)
+ expected_tree = [("file.py", 20, "func2"), ("file.py", 10, "func1")]
+ self.assertEqual(collector.call_trees[0], expected_tree)
+
+ # Should count function samples
+ self.assertEqual(
+ collector.function_samples[("file.py", 10, "func1")], 1
+ )
+ self.assertEqual(
+ collector.function_samples[("file.py", 20, "func2")], 1
+ )
+
+ def test_collapsed_stack_collector_export(self):
+ collapsed_out = tempfile.NamedTemporaryFile(delete=False)
+ self.addCleanup(close_and_unlink, collapsed_out)
+
+ collector = CollapsedStackCollector()
+
+ test_frames1 = [
+ (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")])
+ ]
+ test_frames2 = [
+ (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")])
+ ] # Same stack
+ test_frames3 = [(1, [("other.py", 5, "other_func")])]
+
+ collector.collect(test_frames1)
+ collector.collect(test_frames2)
+ collector.collect(test_frames3)
+
+ collector.export(collapsed_out.name)
+ # Check file contents
+ with open(collapsed_out.name, "r") as f:
+ content = f.read()
+
+ lines = content.strip().split("\n")
+ self.assertEqual(len(lines), 2) # Two unique stacks
+
+ # Check collapsed format: file:func:line;file:func:line count
+ stack1_expected = "file.py:func2:20;file.py:func1:10 2"
+ stack2_expected = "other.py:other_func:5 1"
+
+ self.assertIn(stack1_expected, lines)
+ self.assertIn(stack2_expected, lines)
+
+ def test_pstats_collector_export(self):
+ collector = PstatsCollector(
+ sample_interval_usec=1000000
+ ) # 1 second intervals
+
+ test_frames1 = [
+ (
+ 1,
+ [
+ MockFrameInfo("file.py", 10, "func1"),
+ MockFrameInfo("file.py", 20, "func2"),
+ ],
+ )
+ ]
+ test_frames2 = [
+ (
+ 1,
+ [
+ MockFrameInfo("file.py", 10, "func1"),
+ MockFrameInfo("file.py", 20, "func2"),
+ ],
+ )
+ ] # Same stack
+ test_frames3 = [(1, [MockFrameInfo("other.py", 5, "other_func")])]
+
+ collector.collect(test_frames1)
+ collector.collect(test_frames2)
+ collector.collect(test_frames3)
+
+ pstats_out = tempfile.NamedTemporaryFile(
+ suffix=".pstats", delete=False
+ )
+ self.addCleanup(close_and_unlink, pstats_out)
+ collector.export(pstats_out.name)
+
+ # Check file can be loaded with marshal
+ with open(pstats_out.name, "rb") as f:
+ stats_data = marshal.load(f)
+
+ # Should be a dictionary with the sampled marker
+ self.assertIsInstance(stats_data, dict)
+ self.assertIn(("__sampled__",), stats_data)
+ self.assertTrue(stats_data[("__sampled__",)])
+
+ # Should have function data
+ function_entries = [
+ k for k in stats_data.keys() if k != ("__sampled__",)
+ ]
+ self.assertGreater(len(function_entries), 0)
+
+ # Check specific function stats format: (cc, nc, tt, ct, callers)
+ func1_key = ("file.py", 10, "func1")
+ func2_key = ("file.py", 20, "func2")
+ other_key = ("other.py", 5, "other_func")
+
+ self.assertIn(func1_key, stats_data)
+ self.assertIn(func2_key, stats_data)
+ self.assertIn(other_key, stats_data)
+
+ # Check func1 stats (should have 2 samples)
+ func1_stats = stats_data[func1_key]
+ self.assertEqual(func1_stats[0], 2) # total_calls
+ self.assertEqual(func1_stats[1], 2) # nc (non-recursive calls)
+ self.assertEqual(func1_stats[2], 2.0) # tt (total time)
+ self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time)
+
+
+class TestSampleProfiler(unittest.TestCase):
+ """Test the SampleProfiler class."""
+
+ def test_sample_profiler_initialization(self):
+ """Test SampleProfiler initialization with various parameters."""
+ from profile.sample import SampleProfiler
+
+ # Mock RemoteUnwinder to avoid permission issues
+ with mock.patch(
+ "_remote_debugging.RemoteUnwinder"
+ ) as mock_unwinder_class:
+ mock_unwinder_class.return_value = mock.MagicMock()
+
+ # Test basic initialization
+ profiler = SampleProfiler(
+ pid=12345, sample_interval_usec=1000, all_threads=False
+ )
+ self.assertEqual(profiler.pid, 12345)
+ self.assertEqual(profiler.sample_interval_usec, 1000)
+ self.assertEqual(profiler.all_threads, False)
+
+ # Test with all_threads=True
+ profiler = SampleProfiler(
+ pid=54321, sample_interval_usec=5000, all_threads=True
+ )
+ self.assertEqual(profiler.pid, 54321)
+ self.assertEqual(profiler.sample_interval_usec, 5000)
+ self.assertEqual(profiler.all_threads, True)
+
+ def test_sample_profiler_sample_method_timing(self):
+ """Test that the sample method respects duration and handles timing correctly."""
+ from profile.sample import SampleProfiler
+
+ # Mock the unwinder to avoid needing a real process
+ mock_unwinder = mock.MagicMock()
+ mock_unwinder.get_stack_trace.return_value = [
+ (
+ 1,
+ [
+ mock.MagicMock(
+ filename="test.py", lineno=10, funcname="test_func"
+ )
+ ],
+ )
+ ]
+
+ with mock.patch(
+ "_remote_debugging.RemoteUnwinder"
+ ) as mock_unwinder_class:
+ mock_unwinder_class.return_value = mock_unwinder
+
+ profiler = SampleProfiler(
+ pid=12345, sample_interval_usec=100000, all_threads=False
+ ) # 100ms interval
+
+ # Mock collector
+ mock_collector = mock.MagicMock()
+
+ # Mock time to control the sampling loop
+ start_time = 1000.0
+ times = [
+ start_time + i * 0.1 for i in range(12)
+ ] # 0, 0.1, 0.2, ..., 1.1 seconds
+
+ with mock.patch("time.perf_counter", side_effect=times):
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ profiler.sample(mock_collector, duration_sec=1)
+
+ result = output.getvalue()
+
+ # Should have captured approximately 10 samples (1 second / 0.1 second interval)
+ self.assertIn("Captured", result)
+ self.assertIn("samples", result)
+
+ # Verify collector was called multiple times
+ self.assertGreaterEqual(mock_collector.collect.call_count, 5)
+ self.assertLessEqual(mock_collector.collect.call_count, 11)
+
+ def test_sample_profiler_error_handling(self):
+ """Test that the sample method handles errors gracefully."""
+ from profile.sample import SampleProfiler
+
+ # Mock unwinder that raises errors
+ mock_unwinder = mock.MagicMock()
+ error_sequence = [
+ RuntimeError("Process died"),
+ [
+ (
+ 1,
+ [
+ mock.MagicMock(
+ filename="test.py", lineno=10, funcname="test_func"
+ )
+ ],
+ )
+ ],
+ UnicodeDecodeError("utf-8", b"", 0, 1, "invalid"),
+ [
+ (
+ 1,
+ [
+ mock.MagicMock(
+ filename="test.py",
+ lineno=20,
+ funcname="test_func2",
+ )
+ ],
+ )
+ ],
+ OSError("Permission denied"),
+ ]
+ mock_unwinder.get_stack_trace.side_effect = error_sequence
+
+ with mock.patch(
+ "_remote_debugging.RemoteUnwinder"
+ ) as mock_unwinder_class:
+ mock_unwinder_class.return_value = mock_unwinder
+
+ profiler = SampleProfiler(
+ pid=12345, sample_interval_usec=10000, all_threads=False
+ )
+
+ mock_collector = mock.MagicMock()
+
+ # Control timing to run exactly 5 samples
+ times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06]
+
+ with mock.patch("time.perf_counter", side_effect=times):
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ profiler.sample(mock_collector, duration_sec=0.05)
+
+ result = output.getvalue()
+
+ # Should report error rate
+ self.assertIn("Error rate:", result)
+ self.assertIn("%", result)
+
+ # Collector should have been called only for successful samples (should be > 0)
+ self.assertGreater(mock_collector.collect.call_count, 0)
+ self.assertLessEqual(mock_collector.collect.call_count, 3)
+
+ def test_sample_profiler_missed_samples_warning(self):
+ """Test that the profiler warns about missed samples when sampling is too slow."""
+ from profile.sample import SampleProfiler
+
+ mock_unwinder = mock.MagicMock()
+ mock_unwinder.get_stack_trace.return_value = [
+ (
+ 1,
+ [
+ mock.MagicMock(
+ filename="test.py", lineno=10, funcname="test_func"
+ )
+ ],
+ )
+ ]
+
+ with mock.patch(
+ "_remote_debugging.RemoteUnwinder"
+ ) as mock_unwinder_class:
+ mock_unwinder_class.return_value = mock_unwinder
+
+ # Use very short interval that we'll miss
+ profiler = SampleProfiler(
+ pid=12345, sample_interval_usec=1000, all_threads=False
+ ) # 1ms interval
+
+ mock_collector = mock.MagicMock()
+
+ # Simulate slow sampling where we miss many samples
+ times = [
+ 0.0,
+ 0.1,
+ 0.2,
+ 0.3,
+ 0.4,
+ 0.5,
+ 0.6,
+ 0.7,
+ ] # Extra time points to avoid StopIteration
+
+ with mock.patch("time.perf_counter", side_effect=times):
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ profiler.sample(mock_collector, duration_sec=0.5)
+
+ result = output.getvalue()
+
+ # Should warn about missed samples
+ self.assertIn("Warning: missed", result)
+ self.assertIn("samples from the expected total", result)
+
+
+@force_not_colorized_test_class
+class TestPrintSampledStats(unittest.TestCase):
+ """Test the print_sampled_stats function."""
+
+ def setUp(self):
+ """Set up test data."""
+ # Mock stats data
+ self.mock_stats = mock.MagicMock()
+ self.mock_stats.stats = {
+ ("file1.py", 10, "func1"): (
+ 100,
+ 100,
+ 0.5,
+ 0.5,
+ {},
+ ), # cc, nc, tt, ct, callers
+ ("file2.py", 20, "func2"): (50, 50, 0.25, 0.3, {}),
+ ("file3.py", 30, "func3"): (200, 200, 1.5, 2.0, {}),
+ ("file4.py", 40, "func4"): (
+ 10,
+ 10,
+ 0.001,
+ 0.001,
+ {},
+ ), # millisecond range
+ ("file5.py", 50, "func5"): (
+ 5,
+ 5,
+ 0.000001,
+ 0.000002,
+ {},
+ ), # microsecond range
+ }
+
+ def test_print_sampled_stats_basic(self):
+ """Test basic print_sampled_stats functionality."""
+ from profile.sample import print_sampled_stats
+
+ # Capture output
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(self.mock_stats, sample_interval_usec=100)
+
+ result = output.getvalue()
+
+ # Check header is present
+ self.assertIn("Profile Stats:", result)
+ self.assertIn("nsamples", result)
+ self.assertIn("tottime", result)
+ self.assertIn("cumtime", result)
+
+ # Check functions are present
+ self.assertIn("func1", result)
+ self.assertIn("func2", result)
+ self.assertIn("func3", result)
+
+ def test_print_sampled_stats_sorting(self):
+ """Test different sorting options."""
+ from profile.sample import print_sampled_stats
+
+ # Test sort by calls
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats, sort=0, sample_interval_usec=100
+ )
+
+ result = output.getvalue()
+ lines = result.strip().split("\n")
+
+ # Find the data lines (skip header)
+ data_lines = [l for l in lines if "file" in l and ".py" in l]
+ # func3 should be first (200 calls)
+ self.assertIn("func3", data_lines[0])
+
+ # Test sort by time
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats, sort=1, sample_interval_usec=100
+ )
+
+ result = output.getvalue()
+ lines = result.strip().split("\n")
+
+ data_lines = [l for l in lines if "file" in l and ".py" in l]
+ # func3 should be first (1.5s time)
+ self.assertIn("func3", data_lines[0])
+
+ def test_print_sampled_stats_limit(self):
+ """Test limiting output rows."""
+ from profile.sample import print_sampled_stats
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats, limit=2, sample_interval_usec=100
+ )
+
+ result = output.getvalue()
+
+ # Count function entries in the main stats section (not in summary)
+ lines = result.split("\n")
+ # Find where the main stats section ends (before summary)
+ main_section_lines = []
+ for line in lines:
+ if "Summary of Interesting Functions:" in line:
+ break
+ main_section_lines.append(line)
+
+ # Count function entries only in main section
+ func_count = sum(
+ 1
+ for line in main_section_lines
+ if "func" in line and ".py" in line
+ )
+ self.assertEqual(func_count, 2)
+
+ def test_print_sampled_stats_time_units(self):
+ """Test proper time unit selection."""
+ from profile.sample import print_sampled_stats
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(self.mock_stats, sample_interval_usec=100)
+
+ result = output.getvalue()
+
+ # Should use seconds for the header since max time is > 1s
+ self.assertIn("tottime (s)", result)
+ self.assertIn("cumtime (s)", result)
+
+ # Test with only microsecond-range times
+ micro_stats = mock.MagicMock()
+ micro_stats.stats = {
+ ("file1.py", 10, "func1"): (100, 100, 0.000005, 0.000010, {}),
+ }
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(micro_stats, sample_interval_usec=100)
+
+ result = output.getvalue()
+
+ # Should use microseconds
+ self.assertIn("tottime (μs)", result)
+ self.assertIn("cumtime (μs)", result)
+
+ def test_print_sampled_stats_summary(self):
+ """Test summary section generation."""
+ from profile.sample import print_sampled_stats
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats,
+ show_summary=True,
+ sample_interval_usec=100,
+ )
+
+ result = output.getvalue()
+
+ # Check summary sections are present
+ self.assertIn("Summary of Interesting Functions:", result)
+ self.assertIn(
+ "Functions with Highest Direct/Cumulative Ratio (Hot Spots):",
+ result,
+ )
+ self.assertIn(
+ "Functions with Highest Call Frequency (Indirect Calls):", result
+ )
+ self.assertIn(
+ "Functions with Highest Call Magnification (Cumulative/Direct):",
+ result,
+ )
+
+ def test_print_sampled_stats_no_summary(self):
+ """Test disabling summary output."""
+ from profile.sample import print_sampled_stats
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats,
+ show_summary=False,
+ sample_interval_usec=100,
+ )
+
+ result = output.getvalue()
+
+ # Summary should not be present
+ self.assertNotIn("Summary of Interesting Functions:", result)
+
+ def test_print_sampled_stats_empty_stats(self):
+ """Test with empty stats."""
+ from profile.sample import print_sampled_stats
+
+ empty_stats = mock.MagicMock()
+ empty_stats.stats = {}
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(empty_stats, sample_interval_usec=100)
+
+ result = output.getvalue()
+
+ # Should still print header
+ self.assertIn("Profile Stats:", result)
+
+ def test_print_sampled_stats_sample_percentage_sorting(self):
+ """Test sample percentage sorting options."""
+ from profile.sample import print_sampled_stats
+
+ # Add a function with high sample percentage (more direct calls than func3's 200)
+ self.mock_stats.stats[("expensive.py", 60, "expensive_func")] = (
+ 300, # direct calls (higher than func3's 200)
+ 300, # cumulative calls
+ 1.0, # total time
+ 1.0, # cumulative time
+ {},
+ )
+
+ # Test sort by sample percentage
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats, sort=3, sample_interval_usec=100
+ ) # sample percentage
+
+ result = output.getvalue()
+ lines = result.strip().split("\n")
+
+ data_lines = [l for l in lines if ".py" in l and "func" in l]
+ # expensive_func should be first (highest sample percentage)
+ self.assertIn("expensive_func", data_lines[0])
+
+ def test_print_sampled_stats_with_recursive_calls(self):
+ """Test print_sampled_stats with recursive calls where nc != cc."""
+ from profile.sample import print_sampled_stats
+
+ # Create stats with recursive calls (nc != cc)
+ recursive_stats = mock.MagicMock()
+ recursive_stats.stats = {
+ # (direct_calls, cumulative_calls, tt, ct, callers) - recursive function
+ ("recursive.py", 10, "factorial"): (
+ 5, # direct_calls
+ 10, # cumulative_calls (appears more times in stack due to recursion)
+ 0.5,
+ 0.6,
+ {},
+ ),
+ ("normal.py", 20, "normal_func"): (
+ 3, # direct_calls
+ 3, # cumulative_calls (same as direct for non-recursive)
+ 0.2,
+ 0.2,
+ {},
+ ),
+ }
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(recursive_stats, sample_interval_usec=100)
+
+ result = output.getvalue()
+
+ # Should display recursive calls as "5/10" format
+ self.assertIn("5/10", result) # nc/cc format for recursive calls
+ self.assertIn("3", result) # just nc for non-recursive calls
+ self.assertIn("factorial", result)
+ self.assertIn("normal_func", result)
+
+ def test_print_sampled_stats_with_zero_call_counts(self):
+ """Test print_sampled_stats with zero call counts to trigger division protection."""
+ from profile.sample import print_sampled_stats
+
+ # Create stats with zero call counts
+ zero_stats = mock.MagicMock()
+ zero_stats.stats = {
+ ("file.py", 10, "zero_calls"): (0, 0, 0.0, 0.0, {}), # Zero calls
+ ("file.py", 20, "normal_func"): (
+ 5,
+ 5,
+ 0.1,
+ 0.1,
+ {},
+ ), # Normal function
+ }
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(zero_stats, sample_interval_usec=100)
+
+ result = output.getvalue()
+
+ # Should handle zero call counts gracefully
+ self.assertIn("zero_calls", result)
+ self.assertIn("zero_calls", result)
+ self.assertIn("normal_func", result)
+
+ def test_print_sampled_stats_sort_by_name(self):
+ """Test sort by function name option."""
+ from profile.sample import print_sampled_stats
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ self.mock_stats, sort=-1, sample_interval_usec=100
+ ) # sort by name
+
+ result = output.getvalue()
+ lines = result.strip().split("\n")
+
+ # Find the data lines (skip header and summary)
+ # Data lines start with whitespace and numbers, and contain filename:lineno(function)
+ data_lines = []
+ for line in lines:
+ # Skip header lines and summary sections
+ if (
+ line.startswith(" ")
+ and "(" in line
+ and ")" in line
+ and not line.startswith(
+ " 1."
+ ) # Skip summary lines that start with times
+ and not line.startswith(
+ " 0."
+ ) # Skip summary lines that start with times
+ and not "per call" in line # Skip summary lines
+ and not "calls" in line # Skip summary lines
+ and not "total time" in line # Skip summary lines
+ and not "cumulative time" in line
+ ): # Skip summary lines
+ data_lines.append(line)
+
+ # Extract just the function names for comparison
+ func_names = []
+ import re
+
+ for line in data_lines:
+ # Function name is between the last ( and ), accounting for ANSI color codes
+ match = re.search(r"\(([^)]+)\)$", line)
+ if match:
+ func_name = match.group(1)
+ # Remove ANSI color codes
+ func_name = re.sub(r"\x1b\[[0-9;]*m", "", func_name)
+ func_names.append(func_name)
+
+ # Verify we extracted function names and they are sorted
+ self.assertGreater(
+ len(func_names), 0, "Should have extracted some function names"
+ )
+ self.assertEqual(
+ func_names,
+ sorted(func_names),
+ f"Function names {func_names} should be sorted alphabetically",
+ )
+
+ def test_print_sampled_stats_with_zero_time_functions(self):
+ """Test summary sections with functions that have zero time."""
+ from profile.sample import print_sampled_stats
+
+ # Create stats with zero-time functions
+ zero_time_stats = mock.MagicMock()
+ zero_time_stats.stats = {
+ ("file1.py", 10, "zero_time_func"): (
+ 5,
+ 5,
+ 0.0,
+ 0.0,
+ {},
+ ), # Zero time
+ ("file2.py", 20, "normal_func"): (
+ 3,
+ 3,
+ 0.1,
+ 0.1,
+ {},
+ ), # Normal time
+ }
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ zero_time_stats,
+ show_summary=True,
+ sample_interval_usec=100,
+ )
+
+ result = output.getvalue()
+
+ # Should handle zero-time functions gracefully in summary
+ self.assertIn("Summary of Interesting Functions:", result)
+ self.assertIn("zero_time_func", result)
+ self.assertIn("normal_func", result)
+
+ def test_print_sampled_stats_with_malformed_qualified_names(self):
+ """Test summary generation with function names that don't contain colons."""
+ from profile.sample import print_sampled_stats
+
+ # Create stats with function names that would create malformed qualified names
+ malformed_stats = mock.MagicMock()
+ malformed_stats.stats = {
+ # Function name without clear module separation
+ ("no_colon_func", 10, "func"): (3, 3, 0.1, 0.1, {}),
+ ("", 20, "empty_filename_func"): (2, 2, 0.05, 0.05, {}),
+ ("normal.py", 30, "normal_func"): (5, 5, 0.2, 0.2, {}),
+ }
+
+ with io.StringIO() as output:
+ with mock.patch("sys.stdout", output):
+ print_sampled_stats(
+ malformed_stats,
+ show_summary=True,
+ sample_interval_usec=100,
+ )
+
+ result = output.getvalue()
+
+ # Should handle malformed names gracefully in summary aggregation
+ self.assertIn("Summary of Interesting Functions:", result)
+ # All function names should appear somewhere in the output
+ self.assertIn("func", result)
+ self.assertIn("empty_filename_func", result)
+ self.assertIn("normal_func", result)
+
+ def test_print_sampled_stats_with_recursive_call_stats_creation(self):
+ """Test create_stats with recursive call data to trigger total_rec_calls branch."""
+ collector = PstatsCollector(sample_interval_usec=1000000) # 1 second
+
+ # Simulate recursive function data where total_rec_calls would be set
+ # We need to manually manipulate the collector result to test this branch
+ collector.result = {
+ ("recursive.py", 10, "factorial"): {
+ "total_rec_calls": 3, # Non-zero recursive calls
+ "direct_calls": 5,
+ "cumulative_calls": 10,
+ },
+ ("normal.py", 20, "normal_func"): {
+ "total_rec_calls": 0, # Zero recursive calls
+ "direct_calls": 2,
+ "cumulative_calls": 5,
+ },
+ }
+
+ collector.create_stats()
+
+ # Check that recursive calls are handled differently from non-recursive
+ factorial_stats = collector.stats[("recursive.py", 10, "factorial")]
+ normal_stats = collector.stats[("normal.py", 20, "normal_func")]
+
+ # factorial should use cumulative_calls (10) as nc
+ self.assertEqual(
+ factorial_stats[1], 10
+ ) # nc should be cumulative_calls
+ self.assertEqual(factorial_stats[0], 5) # cc should be direct_calls
+
+ # normal_func should use cumulative_calls as nc
+ self.assertEqual(normal_stats[1], 5) # nc should be cumulative_calls
+ self.assertEqual(normal_stats[0], 2) # cc should be direct_calls
+
+
+@skip_if_not_supported
+@unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+)
+class TestRecursiveFunctionProfiling(unittest.TestCase):
+ """Test profiling of recursive functions and complex call patterns."""
+
+ def test_recursive_function_call_counting(self):
+ """Test that recursive function calls are counted correctly."""
+ collector = PstatsCollector(sample_interval_usec=1000)
+
+ # Simulate a recursive call pattern: fibonacci(5) calling itself
+ recursive_frames = [
+ (
+ 1,
+ [ # First sample: deep in recursion
+ MockFrameInfo("fib.py", 10, "fibonacci"),
+ MockFrameInfo("fib.py", 10, "fibonacci"), # recursive call
+ MockFrameInfo(
+ "fib.py", 10, "fibonacci"
+ ), # deeper recursion
+ MockFrameInfo("fib.py", 10, "fibonacci"), # even deeper
+ MockFrameInfo("main.py", 5, "main"), # main caller
+ ],
+ ),
+ (
+ 1,
+ [ # Second sample: different recursion depth
+ MockFrameInfo("fib.py", 10, "fibonacci"),
+ MockFrameInfo("fib.py", 10, "fibonacci"), # recursive call
+ MockFrameInfo("main.py", 5, "main"), # main caller
+ ],
+ ),
+ (
+ 1,
+ [ # Third sample: back to deeper recursion
+ MockFrameInfo("fib.py", 10, "fibonacci"),
+ MockFrameInfo("fib.py", 10, "fibonacci"),
+ MockFrameInfo("fib.py", 10, "fibonacci"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ ]
+
+ for frames in recursive_frames:
+ collector.collect([frames])
+
+ collector.create_stats()
+
+ # Check that recursive calls are counted properly
+ fib_key = ("fib.py", 10, "fibonacci")
+ main_key = ("main.py", 5, "main")
+
+ self.assertIn(fib_key, collector.stats)
+ self.assertIn(main_key, collector.stats)
+
+ # Fibonacci should have many calls due to recursion
+ fib_stats = collector.stats[fib_key]
+ direct_calls, cumulative_calls, tt, ct, callers = fib_stats
+
+ # Should have recorded multiple calls (9 total appearances in samples)
+ self.assertEqual(cumulative_calls, 9)
+ self.assertGreater(tt, 0) # Should have some total time
+ self.assertGreater(ct, 0) # Should have some cumulative time
+
+ # Main should have fewer calls
+ main_stats = collector.stats[main_key]
+ main_direct_calls, main_cumulative_calls = main_stats[0], main_stats[1]
+ self.assertEqual(main_direct_calls, 0) # Never directly executing
+ self.assertEqual(main_cumulative_calls, 3) # Appears in all 3 samples
+
+ def test_nested_function_hierarchy(self):
+ """Test profiling of deeply nested function calls."""
+ collector = PstatsCollector(sample_interval_usec=1000)
+
+ # Simulate a deep call hierarchy
+ deep_call_frames = [
+ (
+ 1,
+ [
+ MockFrameInfo("level1.py", 10, "level1_func"),
+ MockFrameInfo("level2.py", 20, "level2_func"),
+ MockFrameInfo("level3.py", 30, "level3_func"),
+ MockFrameInfo("level4.py", 40, "level4_func"),
+ MockFrameInfo("level5.py", 50, "level5_func"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ (
+ 1,
+ [ # Same hierarchy sampled again
+ MockFrameInfo("level1.py", 10, "level1_func"),
+ MockFrameInfo("level2.py", 20, "level2_func"),
+ MockFrameInfo("level3.py", 30, "level3_func"),
+ MockFrameInfo("level4.py", 40, "level4_func"),
+ MockFrameInfo("level5.py", 50, "level5_func"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ ]
+
+ for frames in deep_call_frames:
+ collector.collect([frames])
+
+ collector.create_stats()
+
+ # All levels should be recorded
+ for level in range(1, 6):
+ key = (f"level{level}.py", level * 10, f"level{level}_func")
+ self.assertIn(key, collector.stats)
+
+ stats = collector.stats[key]
+ direct_calls, cumulative_calls, tt, ct, callers = stats
+
+ # Each level should appear in stack twice (2 samples)
+ self.assertEqual(cumulative_calls, 2)
+
+ # Only level1 (deepest) should have direct calls
+ if level == 1:
+ self.assertEqual(direct_calls, 2)
+ else:
+ self.assertEqual(direct_calls, 0)
+
+ # Deeper levels should have lower cumulative time than higher levels
+ # (since they don't include time from functions they call)
+ if level == 1: # Deepest level with most time
+ self.assertGreater(ct, 0)
+
+ def test_alternating_call_patterns(self):
+ """Test profiling with alternating call patterns."""
+ collector = PstatsCollector(sample_interval_usec=1000)
+
+ # Simulate alternating execution paths
+ pattern_frames = [
+ # Pattern A: path through func_a
+ (
+ 1,
+ [
+ MockFrameInfo("module.py", 10, "func_a"),
+ MockFrameInfo("module.py", 30, "shared_func"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ # Pattern B: path through func_b
+ (
+ 1,
+ [
+ MockFrameInfo("module.py", 20, "func_b"),
+ MockFrameInfo("module.py", 30, "shared_func"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ # Pattern A again
+ (
+ 1,
+ [
+ MockFrameInfo("module.py", 10, "func_a"),
+ MockFrameInfo("module.py", 30, "shared_func"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ # Pattern B again
+ (
+ 1,
+ [
+ MockFrameInfo("module.py", 20, "func_b"),
+ MockFrameInfo("module.py", 30, "shared_func"),
+ MockFrameInfo("main.py", 5, "main"),
+ ],
+ ),
+ ]
+
+ for frames in pattern_frames:
+ collector.collect([frames])
+
+ collector.create_stats()
+
+ # Check that both paths are recorded equally
+ func_a_key = ("module.py", 10, "func_a")
+ func_b_key = ("module.py", 20, "func_b")
+ shared_key = ("module.py", 30, "shared_func")
+ main_key = ("main.py", 5, "main")
+
+ # func_a and func_b should each be directly executing twice
+ self.assertEqual(collector.stats[func_a_key][0], 2) # direct_calls
+ self.assertEqual(collector.stats[func_a_key][1], 2) # cumulative_calls
+ self.assertEqual(collector.stats[func_b_key][0], 2) # direct_calls
+ self.assertEqual(collector.stats[func_b_key][1], 2) # cumulative_calls
+
+ # shared_func should appear in all samples (4 times) but never directly executing
+ self.assertEqual(collector.stats[shared_key][0], 0) # direct_calls
+ self.assertEqual(collector.stats[shared_key][1], 4) # cumulative_calls
+
+ # main should appear in all samples but never directly executing
+ self.assertEqual(collector.stats[main_key][0], 0) # direct_calls
+ self.assertEqual(collector.stats[main_key][1], 4) # cumulative_calls
+
+ def test_collapsed_stack_with_recursion(self):
+ """Test collapsed stack collector with recursive patterns."""
+ collector = CollapsedStackCollector()
+
+ # Recursive call pattern
+ recursive_frames = [
+ (
+ 1,
+ [
+ ("factorial.py", 10, "factorial"),
+ ("factorial.py", 10, "factorial"), # recursive
+ ("factorial.py", 10, "factorial"), # deeper
+ ("main.py", 5, "main"),
+ ],
+ ),
+ (
+ 1,
+ [
+ ("factorial.py", 10, "factorial"),
+ ("factorial.py", 10, "factorial"), # different depth
+ ("main.py", 5, "main"),
+ ],
+ ),
+ ]
+
+ for frames in recursive_frames:
+ collector.collect([frames])
+
+ # Should capture both call trees
+ self.assertEqual(len(collector.call_trees), 2)
+
+ # First tree should be longer (deeper recursion)
+ tree1 = collector.call_trees[0]
+ tree2 = collector.call_trees[1]
+
+ # Trees should be different lengths due to different recursion depths
+ self.assertNotEqual(len(tree1), len(tree2))
+
+ # Both should contain factorial calls
+ self.assertTrue(any("factorial" in str(frame) for frame in tree1))
+ self.assertTrue(any("factorial" in str(frame) for frame in tree2))
+
+ # Function samples should count all occurrences
+ factorial_key = ("factorial.py", 10, "factorial")
+ main_key = ("main.py", 5, "main")
+
+ # factorial appears 5 times total (3 + 2)
+ self.assertEqual(collector.function_samples[factorial_key], 5)
+ # main appears 2 times total
+ self.assertEqual(collector.function_samples[main_key], 2)
+
+
+@requires_subprocess()
+@skip_if_not_supported
+class TestSampleProfilerIntegration(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.test_script = '''
+import time
+import os
+
+def slow_fibonacci(n):
+ """Recursive fibonacci - should show up prominently in profiler."""
+ if n <= 1:
+ return n
+ return slow_fibonacci(n-1) + slow_fibonacci(n-2)
+
+def cpu_intensive_work():
+ """CPU intensive work that should show in profiler."""
+ result = 0
+ for i in range(10000):
+ result += i * i
+ if i % 100 == 0:
+ result = result % 1000000
+ return result
+
+def medium_computation():
+ """Medium complexity function."""
+ result = 0
+ for i in range(100):
+ result += i * i
+ return result
+
+def fast_loop():
+ """Fast simple loop."""
+ total = 0
+ for i in range(50):
+ total += i
+ return total
+
+def nested_calls():
+ """Test nested function calls."""
+ def level1():
+ def level2():
+ return medium_computation()
+ return level2()
+ return level1()
+
+def main_loop():
+ """Main test loop with different execution paths."""
+ iteration = 0
+
+ while True:
+ iteration += 1
+
+ # Different execution paths - focus on CPU intensive work
+ if iteration % 3 == 0:
+ # Very CPU intensive
+ result = cpu_intensive_work()
+ elif iteration % 5 == 0:
+ # Expensive recursive operation
+ result = slow_fibonacci(12)
+ else:
+ # Medium operation
+ result = nested_calls()
+
+ # No sleep - keep CPU busy
+
+if __name__ == "__main__":
+ main_loop()
+'''
+
+ def test_sampling_basic_functionality(self):
+ with (
+ test_subprocess(self.test_script) as proc,
+ io.StringIO() as captured_output,
+ mock.patch("sys.stdout", captured_output),
+ ):
+ try:
+ profile.sample.sample(
+ proc.pid,
+ duration_sec=2,
+ sample_interval_usec=1000, # 1ms
+ show_summary=False,
+ )
+ except PermissionError:
+ self.skipTest("Insufficient permissions for remote profiling")
+
+ output = captured_output.getvalue()
+
+ # Basic checks on output
+ self.assertIn("Captured", output)
+ self.assertIn("samples", output)
+ self.assertIn("Profile Stats", output)
+
+ # Should see some of our test functions
+ self.assertIn("slow_fibonacci", output)
+
+ def test_sampling_with_pstats_export(self):
+ pstats_out = tempfile.NamedTemporaryFile(
+ suffix=".pstats", delete=False
+ )
+ self.addCleanup(close_and_unlink, pstats_out)
+
+ with test_subprocess(self.test_script) as proc:
+ # Suppress profiler output when testing file export
+ with (
+ io.StringIO() as captured_output,
+ mock.patch("sys.stdout", captured_output),
+ ):
+ try:
+ profile.sample.sample(
+ proc.pid,
+ duration_sec=1,
+ filename=pstats_out.name,
+ sample_interval_usec=10000,
+ )
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions for remote profiling"
+ )
+
+ # Verify file was created and contains valid data
+ self.assertTrue(os.path.exists(pstats_out.name))
+ self.assertGreater(os.path.getsize(pstats_out.name), 0)
+
+ # Try to load the stats file
+ with open(pstats_out.name, "rb") as f:
+ stats_data = marshal.load(f)
+
+ # Should be a dictionary with the sampled marker
+ self.assertIsInstance(stats_data, dict)
+ self.assertIn(("__sampled__",), stats_data)
+ self.assertTrue(stats_data[("__sampled__",)])
+
+ # Should have some function data
+ function_entries = [
+ k for k in stats_data.keys() if k != ("__sampled__",)
+ ]
+ self.assertGreater(len(function_entries), 0)
+
+ def test_sampling_with_collapsed_export(self):
+ collapsed_file = tempfile.NamedTemporaryFile(
+ suffix=".txt", delete=False
+ )
+ self.addCleanup(close_and_unlink, collapsed_file)
+
+ with (
+ test_subprocess(self.test_script) as proc,
+ ):
+ # Suppress profiler output when testing file export
+ with (
+ io.StringIO() as captured_output,
+ mock.patch("sys.stdout", captured_output),
+ ):
+ try:
+ profile.sample.sample(
+ proc.pid,
+ duration_sec=1,
+ filename=collapsed_file.name,
+ output_format="collapsed",
+ sample_interval_usec=10000,
+ )
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions for remote profiling"
+ )
+
+ # Verify file was created and contains valid data
+ self.assertTrue(os.path.exists(collapsed_file.name))
+ self.assertGreater(os.path.getsize(collapsed_file.name), 0)
+
+ # Check file format
+ with open(collapsed_file.name, "r") as f:
+ content = f.read()
+
+ lines = content.strip().split("\n")
+ self.assertGreater(len(lines), 0)
+
+ # Each line should have format: stack_trace count
+ for line in lines:
+ parts = line.rsplit(" ", 1)
+ self.assertEqual(len(parts), 2)
+
+ stack_trace, count_str = parts
+ self.assertGreater(len(stack_trace), 0)
+ self.assertTrue(count_str.isdigit())
+ self.assertGreater(int(count_str), 0)
+
+ # Stack trace should contain semicolon-separated entries
+ if ";" in stack_trace:
+ stack_parts = stack_trace.split(";")
+ for part in stack_parts:
+ # Each part should be file:function:line
+ self.assertIn(":", part)
+
+ def test_sampling_all_threads(self):
+ with (
+ test_subprocess(self.test_script) as proc,
+ # Suppress profiler output
+ io.StringIO() as captured_output,
+ mock.patch("sys.stdout", captured_output),
+ ):
+ try:
+ profile.sample.sample(
+ proc.pid,
+ duration_sec=1,
+ all_threads=True,
+ sample_interval_usec=10000,
+ show_summary=False,
+ )
+ except PermissionError:
+ self.skipTest("Insufficient permissions for remote profiling")
+
+ # Just verify that sampling completed without error
+ # We're not testing output format here
+
+
+@skip_if_not_supported
+@unittest.skipIf(
+ sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
+ "Test only runs on Linux with process_vm_readv support",
+)
+class TestSampleProfilerErrorHandling(unittest.TestCase):
+ def test_invalid_pid(self):
+ with self.assertRaises((OSError, RuntimeError)):
+ profile.sample.sample(-1, duration_sec=1)
+
+ def test_process_dies_during_sampling(self):
+ with test_subprocess("import time; time.sleep(0.5); exit()") as proc:
+ with (
+ io.StringIO() as captured_output,
+ mock.patch("sys.stdout", captured_output),
+ ):
+ try:
+ profile.sample.sample(
+ proc.pid,
+ duration_sec=2, # Longer than process lifetime
+ sample_interval_usec=50000,
+ )
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions for remote profiling"
+ )
+
+ output = captured_output.getvalue()
+
+ self.assertIn("Error rate", output)
+
+ def test_invalid_output_format(self):
+ with self.assertRaises(ValueError):
+ profile.sample.sample(
+ os.getpid(),
+ duration_sec=1,
+ output_format="invalid_format",
+ )
+
+ def test_invalid_output_format_with_mocked_profiler(self):
+ """Test invalid output format with proper mocking to avoid permission issues."""
+ with mock.patch(
+ "profile.sample.SampleProfiler"
+ ) as mock_profiler_class:
+ mock_profiler = mock.MagicMock()
+ mock_profiler_class.return_value = mock_profiler
+
+ with self.assertRaises(ValueError) as cm:
+ profile.sample.sample(
+ 12345,
+ duration_sec=1,
+ output_format="unknown_format",
+ )
+
+ # Should raise ValueError with the invalid format name
+ self.assertIn(
+ "Invalid output format: unknown_format", str(cm.exception)
+ )
+
+ def test_is_process_running(self):
+ with test_subprocess("import time; time.sleep(1000)") as proc:
+ try:
+ profiler = SampleProfiler(pid=proc.pid, sample_interval_usec=1000, all_threads=False)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+ self.assertTrue(profiler._is_process_running())
+ self.assertIsNotNone(profiler.unwinder.get_stack_trace())
+ proc.kill()
+ proc.wait()
+ # ValueError on MacOS (yeah I know), ProcessLookupError on Linux and Windows
+ self.assertRaises((ValueError, ProcessLookupError), profiler.unwinder.get_stack_trace)
+
+ # Exit the context manager to ensure the process is terminated
+ self.assertFalse(profiler._is_process_running())
+ self.assertRaises((ValueError, ProcessLookupError), profiler.unwinder.get_stack_trace)
+
+ @unittest.skipUnless(sys.platform == "linux", "Only valid on Linux")
+ def test_esrch_signal_handling(self):
+ with test_subprocess("import time; time.sleep(1000)") as proc:
+ try:
+ unwinder = _remote_debugging.RemoteUnwinder(proc.pid)
+ except PermissionError:
+ self.skipTest(
+ "Insufficient permissions to read the stack trace"
+ )
+ initial_trace = unwinder.get_stack_trace()
+ self.assertIsNotNone(initial_trace)
+
+ proc.kill()
+
+ # Wait for the process to die and try to get another trace
+ proc.wait()
+
+ with self.assertRaises(ProcessLookupError):
+ unwinder.get_stack_trace()
+
+
+
+class TestSampleProfilerCLI(unittest.TestCase):
+ def test_cli_collapsed_format_validation(self):
+ """Test that CLI properly validates incompatible options with collapsed format."""
+ test_cases = [
+ # Test sort options are invalid with collapsed
+ (
+ ["profile.sample", "--collapsed", "--sort-nsamples", "12345"],
+ "sort",
+ ),
+ (
+ ["profile.sample", "--collapsed", "--sort-tottime", "12345"],
+ "sort",
+ ),
+ (
+ [
+ "profile.sample",
+ "--collapsed",
+ "--sort-cumtime",
+ "12345",
+ ],
+ "sort",
+ ),
+ (
+ [
+ "profile.sample",
+ "--collapsed",
+ "--sort-sample-pct",
+ "12345",
+ ],
+ "sort",
+ ),
+ (
+ [
+ "profile.sample",
+ "--collapsed",
+ "--sort-cumul-pct",
+ "12345",
+ ],
+ "sort",
+ ),
+ (
+ ["profile.sample", "--collapsed", "--sort-name", "12345"],
+ "sort",
+ ),
+ # Test limit option is invalid with collapsed
+ (["profile.sample", "--collapsed", "-l", "20", "12345"], "limit"),
+ (
+ ["profile.sample", "--collapsed", "--limit", "20", "12345"],
+ "limit",
+ ),
+ # Test no-summary option is invalid with collapsed
+ (
+ ["profile.sample", "--collapsed", "--no-summary", "12345"],
+ "summary",
+ ),
+ ]
+
+ for test_args, expected_error_keyword in test_cases:
+ with (
+ mock.patch("sys.argv", test_args),
+ mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
+ self.assertRaises(SystemExit) as cm,
+ ):
+ profile.sample.main()
+
+ self.assertEqual(cm.exception.code, 2) # argparse error code
+ error_msg = mock_stderr.getvalue()
+ self.assertIn("error:", error_msg)
+ self.assertIn("--pstats format", error_msg)
+
+ def test_cli_default_collapsed_filename(self):
+ """Test that collapsed format gets a default filename when not specified."""
+ test_args = ["profile.sample", "--collapsed", "12345"]
+
+ with (
+ mock.patch("sys.argv", test_args),
+ mock.patch("profile.sample.sample") as mock_sample,
+ ):
+ profile.sample.main()
+
+ # Check that filename was set to default collapsed format
+ mock_sample.assert_called_once()
+ call_args = mock_sample.call_args[1]
+ self.assertEqual(call_args["output_format"], "collapsed")
+ self.assertEqual(call_args["filename"], "collapsed.12345.txt")
+
+ def test_cli_custom_output_filenames(self):
+ """Test custom output filenames for both formats."""
+ test_cases = [
+ (
+ ["profile.sample", "--pstats", "-o", "custom.pstats", "12345"],
+ "custom.pstats",
+ "pstats",
+ ),
+ (
+ ["profile.sample", "--collapsed", "-o", "custom.txt", "12345"],
+ "custom.txt",
+ "collapsed",
+ ),
+ ]
+
+ for test_args, expected_filename, expected_format in test_cases:
+ with (
+ mock.patch("sys.argv", test_args),
+ mock.patch("profile.sample.sample") as mock_sample,
+ ):
+ profile.sample.main()
+
+ mock_sample.assert_called_once()
+ call_args = mock_sample.call_args[1]
+ self.assertEqual(call_args["filename"], expected_filename)
+ self.assertEqual(call_args["output_format"], expected_format)
+
+ def test_cli_missing_required_arguments(self):
+ """Test that CLI requires PID argument."""
+ with (
+ mock.patch("sys.argv", ["profile.sample"]),
+ mock.patch("sys.stderr", io.StringIO()),
+ ):
+ with self.assertRaises(SystemExit):
+ profile.sample.main()
+
+ def test_cli_mutually_exclusive_format_options(self):
+ """Test that pstats and collapsed options are mutually exclusive."""
+ with (
+ mock.patch(
+ "sys.argv",
+ ["profile.sample", "--pstats", "--collapsed", "12345"],
+ ),
+ mock.patch("sys.stderr", io.StringIO()),
+ ):
+ with self.assertRaises(SystemExit):
+ profile.sample.main()
+
+ def test_argument_parsing_basic(self):
+ test_args = ["profile.sample", "12345"]
+
+ with (
+ mock.patch("sys.argv", test_args),
+ mock.patch("profile.sample.sample") as mock_sample,
+ ):
+ profile.sample.main()
+
+ mock_sample.assert_called_once_with(
+ 12345,
+ sample_interval_usec=100,
+ duration_sec=10,
+ filename=None,
+ all_threads=False,
+ limit=15,
+ sort=2,
+ show_summary=True,
+ output_format="pstats",
+ realtime_stats=False,
+ )
+
+ def test_sort_options(self):
+ sort_options = [
+ ("--sort-nsamples", 0),
+ ("--sort-tottime", 1),
+ ("--sort-cumtime", 2),
+ ("--sort-sample-pct", 3),
+ ("--sort-cumul-pct", 4),
+ ("--sort-name", -1),
+ ]
+
+ for option, expected_sort_value in sort_options:
+ test_args = ["profile.sample", option, "12345"]
+
+ with (
+ mock.patch("sys.argv", test_args),
+ mock.patch("profile.sample.sample") as mock_sample,
+ ):
+ profile.sample.main()
+
+ mock_sample.assert_called_once()
+ call_args = mock_sample.call_args[1]
+ self.assertEqual(
+ call_args["sort"],
+ expected_sort_value,
+ )
+ mock_sample.reset_mock()
+
+
+if __name__ == "__main__":
+ unittest.main()