_PROCESS_KILL_TIMEOUT_SEC = 2.0
_READY_MESSAGE = b"ready"
_RECV_BUFFER_SIZE = 1024
+_BINARY_PROFILE_HEADER_SIZE = 64
+_BINARY_PROFILE_MAGICS = (b"HCAT", b"TACH")
# Format configuration
FORMAT_EXTENSIONS = {
print(f"Warning: Could not open browser: {e}", file=sys.stderr)
+def _validate_replay_input_file(filename):
+ """Validate that the replay input looks like a sampling binary profile."""
+ try:
+ with open(filename, "rb") as file:
+ header = file.read(_BINARY_PROFILE_HEADER_SIZE)
+ except OSError as exc:
+ sys.exit(f"Error: Could not read input file {filename}: {exc}")
+
+ if (
+ len(header) < _BINARY_PROFILE_HEADER_SIZE
+ or header[:4] not in _BINARY_PROFILE_MAGICS
+ ):
+ sys.exit(
+ "Error: Input file is not a binary sampling profile. "
+ "The replay command only accepts files created with --binary"
+ )
+
+
+def _replay_with_reader(args, reader):
+ """Replay samples from an open binary reader."""
+ info = reader.get_info()
+ interval = info['sample_interval_us']
+
+ print(f"Replaying {info['sample_count']} samples from {args.input_file}")
+ print(f" Sample interval: {interval} us")
+ print(
+ " Compression: "
+ f"{'zstd' if info.get('compression_type', 0) == 1 else 'none'}"
+ )
+
+ collector = _create_collector(
+ args.format, interval, skip_idle=False,
+ diff_baseline=args.diff_baseline
+ )
+
+ def progress_callback(current, total):
+ if total > 0:
+ pct = current / total
+ bar_width = 40
+ filled = int(bar_width * pct)
+ bar = '█' * filled + '░' * (bar_width - filled)
+ print(
+ f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})",
+ end="",
+ flush=True,
+ )
+
+ count = reader.replay_samples(collector, progress_callback)
+ print()
+
+ if args.format == "pstats":
+ if args.outfile:
+ collector.export(args.outfile)
+ else:
+ sort_choice = (
+ args.sort if args.sort is not None else "nsamples"
+ )
+ limit = args.limit if args.limit is not None else 15
+ sort_mode = _sort_to_mode(sort_choice)
+ collector.print_stats(
+ sort_mode, limit, not args.no_summary,
+ PROFILING_MODE_WALL
+ )
+ else:
+ filename = (
+ args.outfile
+ or _generate_output_filename(args.format, os.getpid())
+ )
+ collector.export(filename)
+
+ # Auto-open browser for HTML output if --browser flag is set
+ if (
+ args.format in (
+ 'flamegraph', 'diff_flamegraph', 'heatmap'
+ )
+ and getattr(args, 'browser', False)
+ ):
+ _open_in_browser(filename)
+
+ print(f"Replayed {count} samples")
+
+
def _handle_output(collector, args, pid, mode):
"""Handle output for the collector based on format and arguments.
if not os.path.exists(args.input_file):
sys.exit(f"Error: Input file not found: {args.input_file}")
- with BinaryReader(args.input_file) as reader:
- info = reader.get_info()
- interval = info['sample_interval_us']
+ _validate_replay_input_file(args.input_file)
- print(f"Replaying {info['sample_count']} samples from {args.input_file}")
- print(f" Sample interval: {interval} us")
- print(f" Compression: {'zstd' if info.get('compression_type', 0) == 1 else 'none'}")
-
- collector = _create_collector(
- args.format, interval, skip_idle=False,
- diff_baseline=args.diff_baseline
- )
-
- def progress_callback(current, total):
- if total > 0:
- pct = current / total
- bar_width = 40
- filled = int(bar_width * pct)
- bar = '█' * filled + '░' * (bar_width - filled)
- print(f"\r [{bar}] {pct*100:5.1f}% ({current:,}/{total:,})", end="", flush=True)
-
- count = reader.replay_samples(collector, progress_callback)
- print()
-
- if args.format == "pstats":
- if args.outfile:
- collector.export(args.outfile)
- else:
- sort_choice = args.sort if args.sort is not None else "nsamples"
- limit = args.limit if args.limit is not None else 15
- sort_mode = _sort_to_mode(sort_choice)
- collector.print_stats(sort_mode, limit, not args.no_summary, PROFILING_MODE_WALL)
- else:
- filename = args.outfile or _generate_output_filename(args.format, os.getpid())
- collector.export(filename)
-
- # Auto-open browser for HTML output if --browser flag is set
- if args.format in ('flamegraph', 'diff_flamegraph', 'heatmap') and getattr(args, 'browser', False):
- _open_in_browser(filename)
-
- print(f"Replayed {count} samples")
+ try:
+ with BinaryReader(args.input_file) as reader:
+ _replay_with_reader(args, reader)
+ except (OSError, ValueError) as exc:
+ sys.exit(f"Error: {exc}")
if __name__ == "__main__":
"""Tests for sampling profiler CLI argument parsing and functionality."""
import io
+import os
import subprocess
import sys
+import tempfile
import unittest
from unittest import mock
main()
self.assertIn(fake_pid, str(cm.exception))
+
+ def test_cli_replay_rejects_non_binary_profile(self):
+ with tempfile.TemporaryDirectory() as tempdir:
+ profile = os.path.join(tempdir, "output.prof")
+ with open(profile, "wb") as file:
+ file.write(b"not a binary sampling profile")
+
+ with mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]):
+ with self.assertRaises(SystemExit) as cm:
+ main()
+
+ error = str(cm.exception)
+ self.assertIn("not a binary sampling profile", error)
+ self.assertIn("--binary", error)
+
+ def test_cli_replay_reader_errors_exit_cleanly(self):
+ with tempfile.TemporaryDirectory() as tempdir:
+ profile = os.path.join(tempdir, "output.bin")
+ with open(profile, "wb") as file:
+ file.write(b"HCAT" + (b"\0" * 60))
+
+ with (
+ mock.patch("sys.argv", ["profiling.sampling.cli", "replay", profile]),
+ mock.patch(
+ "profiling.sampling.cli.BinaryReader",
+ side_effect=ValueError("Unsupported format version 2"),
+ ),
+ ):
+ with self.assertRaises(SystemExit) as cm:
+ main()
+
+ self.assertEqual(
+ str(cm.exception),
+ "Error: Unsupported format version 2",
+ )