from abc import ABC, abstractmethod
-
-# Thread status flags
-try:
- from _remote_debugging import THREAD_STATUS_HAS_GIL, THREAD_STATUS_ON_CPU, THREAD_STATUS_UNKNOWN, THREAD_STATUS_GIL_REQUESTED
-except ImportError:
- # Fallback for tests or when module is not available
- THREAD_STATUS_HAS_GIL = (1 << 0)
- THREAD_STATUS_ON_CPU = (1 << 1)
- THREAD_STATUS_UNKNOWN = (1 << 2)
- THREAD_STATUS_GIL_REQUESTED = (1 << 3)
+from .constants import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+ THREAD_STATUS_UNKNOWN,
+ THREAD_STATUS_GIL_REQUESTED,
+)
class Collector(ABC):
@abstractmethod
def collect(self, stack_frames):
"""Collect profiling data from stack frames."""
+ def collect_failed_sample(self):
+ """Collect data about a failed sample attempt."""
+
@abstractmethod
def export(self, filename):
"""Export collected data to a file."""
--- /dev/null
+"""Constants for the sampling profiler."""
+
+# Profiling mode constants
+PROFILING_MODE_WALL = 0
+PROFILING_MODE_CPU = 1
+PROFILING_MODE_GIL = 2
+PROFILING_MODE_ALL = 3 # Combines GIL + CPU checks
+
+# Sort mode constants
+SORT_MODE_NSAMPLES = 0
+SORT_MODE_TOTTIME = 1
+SORT_MODE_CUMTIME = 2
+SORT_MODE_SAMPLE_PCT = 3
+SORT_MODE_CUMUL_PCT = 4
+SORT_MODE_NSAMPLES_CUMUL = 5
+
+# Thread status flags
+try:
+ from _remote_debugging import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+ THREAD_STATUS_UNKNOWN,
+ THREAD_STATUS_GIL_REQUESTED,
+ )
+except ImportError:
+ # Fallback for tests or when module is not available
+ THREAD_STATUS_HAS_GIL = (1 << 0)
+ THREAD_STATUS_ON_CPU = (1 << 1)
+ THREAD_STATUS_UNKNOWN = (1 << 2)
+ THREAD_STATUS_GIL_REQUESTED = (1 << 3)
--- /dev/null
+"""Live profiling collector that displays top-like statistics using curses.
+
+ ┌─────────────────────────────┐
+ │ Target Python Process │
+ │ (being profiled) │
+ └──────────────┬──────────────┘
+ │ Stack sampling at
+ │ configured interval
+ │ (e.g., 10000µs)
+ ▼
+ ┌─────────────────────────────┐
+ │ LiveStatsCollector │
+ │ ┌───────────────────────┐ │
+ │ │ collect() │ │ Aggregates samples
+ │ │ - Iterates frames │ │ into statistics
+ │ │ - Updates counters │ │
+ │ └───────────┬───────────┘ │
+ │ │ │
+ │ ▼ │
+ │ ┌───────────────────────┐ │
+ │ │ Data Storage │ │
+ │ │ - result dict │ │ Tracks per-function:
+ │ │ - direct_calls │ │ • Direct samples
+ │ │ - cumulative_calls │ │ • Cumulative samples
+ │ └───────────┬───────────┘ │ • Derived time stats
+ │ │ │
+ │ ▼ │
+ │ ┌───────────────────────┐ │
+ │ │ Display Update │ │
+ │ │ (10Hz by default) │ │ Rate-limited refresh
+ │ └───────────┬───────────┘ │
+ └──────────────┼──────────────┘
+ │
+ ▼
+ ┌─────────────────────────────┐
+ │ DisplayInterface │
+ │ (Abstract layer) │
+ └──────────────┬──────────────┘
+ ┌───────┴────────┐
+ │ │
+ ┌──────────▼────────┐ ┌───▼──────────┐
+ │ CursesDisplay │ │ MockDisplay │
+ │ - Real terminal │ │ - Testing │
+ │ - ncurses backend │ │ - No UI │
+ └─────────┬─────────┘ └──────────────┘
+ │
+ ▼
+ ┌─────────────────────────────────────┐
+ │ Widget-Based Rendering │
+ │ ┌─────────────────────────────────┐ │
+ │ │ HeaderWidget │ │
+ │ │ • PID, uptime, time, interval │ │
+ │ │ • Sample stats & progress bar │ │
+ │ │ • Efficiency bar │ │
+ │ │ • Thread status & GC stats │ │
+ │ │ • Function summary │ │
+ │ │ • Top 3 hottest functions │ │
+ │ ├─────────────────────────────────┤ │
+ │ │ TableWidget │ │
+ │ │ • Column headers (sortable) │ │ Interactive display
+ │ │ • Stats rows (scrolling) │ │ with keyboard controls:
+ │ │ - nsamples % time │ │ s: sort, p: pause
+ │ │ - function file:line │ │ r: reset, /: filter
+ │ ├─────────────────────────────────┤ │ q: quit, h: help
+ │ │ FooterWidget │ │
+ │ │ • Legend and status │ │
+ │ │ • Filter input prompt │ │
+ │ └─────────────────────────────────┘ │
+ └─────────────────────────────────────┘
+
+Architecture:
+
+The live collector is organized into four layers. The data collection layer
+(LiveStatsCollector) aggregates stack samples into per-function statistics without
+any knowledge of how they will be presented. The display abstraction layer
+(DisplayInterface) defines rendering operations without coupling to curses or any
+specific UI framework. The widget layer (Widget, HeaderWidget, TableWidget,
+FooterWidget, HelpWidget, ProgressBarWidget) encapsulates individual UI components
+with their own rendering logic, promoting modularity and reusability. The
+presentation layer (CursesDisplay/MockDisplay) implements the actual rendering for
+terminal output and testing.
+
+The system runs two independent update loops. The sampling loop is driven by the
+profiler at the configured interval (e.g., 10000µs) and continuously collects
+stack frames and updates statistics. The display loop runs at a fixed refresh rate
+(default 10Hz) and updates the terminal independently of sampling frequency. This
+separation allows high-frequency sampling without overwhelming the terminal with
+constant redraws.
+
+Statistics are computed incrementally as samples arrive. The collector maintains
+running counters (direct calls and cumulative calls) in a dictionary keyed by
+function location. Derived metrics like time estimates and percentages are computed
+on-demand during display updates rather than being stored, which minimizes memory
+overhead as the number of tracked functions grows.
+
+User input is processed asynchronously during display updates using non-blocking I/O.
+This allows interactive controls (sorting, filtering, pausing) without interrupting
+the data collection pipeline. The collector maintains mode flags (paused,
+filter_input_mode) that affect what gets displayed but not what gets collected.
+
+"""
+
+# Re-export all public classes and constants for backward compatibility
+from .collector import LiveStatsCollector
+from .display import DisplayInterface, CursesDisplay, MockDisplay
+from .widgets import (
+ Widget,
+ ProgressBarWidget,
+ HeaderWidget,
+ TableWidget,
+ FooterWidget,
+ HelpWidget,
+)
+from .constants import (
+ MICROSECONDS_PER_SECOND,
+ DISPLAY_UPDATE_HZ,
+ DISPLAY_UPDATE_INTERVAL,
+ MIN_TERMINAL_WIDTH,
+ MIN_TERMINAL_HEIGHT,
+ WIDTH_THRESHOLD_SAMPLE_PCT,
+ WIDTH_THRESHOLD_TOTTIME,
+ WIDTH_THRESHOLD_CUMUL_PCT,
+ WIDTH_THRESHOLD_CUMTIME,
+ HEADER_LINES,
+ FOOTER_LINES,
+ SAFETY_MARGIN,
+ TOP_FUNCTIONS_DISPLAY_COUNT,
+ COL_WIDTH_NSAMPLES,
+ COL_SPACING,
+ COL_WIDTH_SAMPLE_PCT,
+ COL_WIDTH_TIME,
+ MIN_FUNC_NAME_WIDTH,
+ MAX_FUNC_NAME_WIDTH,
+ MIN_AVAILABLE_SPACE,
+ MIN_BAR_WIDTH,
+ MAX_SAMPLE_RATE_BAR_WIDTH,
+ MAX_EFFICIENCY_BAR_WIDTH,
+ MIN_SAMPLE_RATE_FOR_SCALING,
+ FINISHED_BANNER_EXTRA_LINES,
+ COLOR_PAIR_HEADER_BG,
+ COLOR_PAIR_CYAN,
+ COLOR_PAIR_YELLOW,
+ COLOR_PAIR_GREEN,
+ COLOR_PAIR_MAGENTA,
+ COLOR_PAIR_RED,
+ COLOR_PAIR_SORTED_HEADER,
+ DEFAULT_SORT_BY,
+ DEFAULT_DISPLAY_LIMIT,
+)
+
+__all__ = [
+ # Main collector
+ "LiveStatsCollector",
+ # Display interfaces
+ "DisplayInterface",
+ "CursesDisplay",
+ "MockDisplay",
+ # Widgets
+ "Widget",
+ "ProgressBarWidget",
+ "HeaderWidget",
+ "TableWidget",
+ "FooterWidget",
+ "HelpWidget",
+ # Constants
+ "MICROSECONDS_PER_SECOND",
+ "DISPLAY_UPDATE_HZ",
+ "DISPLAY_UPDATE_INTERVAL",
+ "MIN_TERMINAL_WIDTH",
+ "MIN_TERMINAL_HEIGHT",
+ "WIDTH_THRESHOLD_SAMPLE_PCT",
+ "WIDTH_THRESHOLD_TOTTIME",
+ "WIDTH_THRESHOLD_CUMUL_PCT",
+ "WIDTH_THRESHOLD_CUMTIME",
+ "HEADER_LINES",
+ "FOOTER_LINES",
+ "SAFETY_MARGIN",
+ "TOP_FUNCTIONS_DISPLAY_COUNT",
+ "COL_WIDTH_NSAMPLES",
+ "COL_SPACING",
+ "COL_WIDTH_SAMPLE_PCT",
+ "COL_WIDTH_TIME",
+ "MIN_FUNC_NAME_WIDTH",
+ "MAX_FUNC_NAME_WIDTH",
+ "MIN_AVAILABLE_SPACE",
+ "MIN_BAR_WIDTH",
+ "MAX_SAMPLE_RATE_BAR_WIDTH",
+ "MAX_EFFICIENCY_BAR_WIDTH",
+ "MIN_SAMPLE_RATE_FOR_SCALING",
+ "FINISHED_BANNER_EXTRA_LINES",
+ "COLOR_PAIR_HEADER_BG",
+ "COLOR_PAIR_CYAN",
+ "COLOR_PAIR_YELLOW",
+ "COLOR_PAIR_GREEN",
+ "COLOR_PAIR_MAGENTA",
+ "COLOR_PAIR_RED",
+ "COLOR_PAIR_SORTED_HEADER",
+ "DEFAULT_SORT_BY",
+ "DEFAULT_DISPLAY_LIMIT",
+]
--- /dev/null
+"""LiveStatsCollector - Main collector class for live profiling."""
+
+import collections
+import contextlib
+import curses
+from dataclasses import dataclass, field
+import os
+import site
+import sys
+import sysconfig
+import time
+import _colorize
+
+from ..collector import Collector
+from ..constants import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+ THREAD_STATUS_UNKNOWN,
+ THREAD_STATUS_GIL_REQUESTED,
+ PROFILING_MODE_CPU,
+ PROFILING_MODE_GIL,
+ PROFILING_MODE_WALL,
+)
+from .constants import (
+ MICROSECONDS_PER_SECOND,
+ DISPLAY_UPDATE_INTERVAL,
+ MIN_TERMINAL_WIDTH,
+ MIN_TERMINAL_HEIGHT,
+ HEADER_LINES,
+ FOOTER_LINES,
+ SAFETY_MARGIN,
+ FINISHED_BANNER_EXTRA_LINES,
+ DEFAULT_SORT_BY,
+ DEFAULT_DISPLAY_LIMIT,
+ COLOR_PAIR_HEADER_BG,
+ COLOR_PAIR_CYAN,
+ COLOR_PAIR_YELLOW,
+ COLOR_PAIR_GREEN,
+ COLOR_PAIR_MAGENTA,
+ COLOR_PAIR_RED,
+ COLOR_PAIR_SORTED_HEADER,
+)
+from .display import CursesDisplay
+from .widgets import HeaderWidget, TableWidget, FooterWidget, HelpWidget
+from .trend_tracker import TrendTracker
+
+
+@dataclass
+class ThreadData:
+ """Encapsulates all profiling data for a single thread."""
+
+ thread_id: int
+
+ # Function call statistics: {location: {direct_calls: int, cumulative_calls: int}}
+ result: dict = field(default_factory=lambda: collections.defaultdict(
+ lambda: dict(direct_calls=0, cumulative_calls=0)
+ ))
+
+ # Thread status statistics
+ has_gil: int = 0
+ on_cpu: int = 0
+ gil_requested: int = 0
+ unknown: int = 0
+ total: int = 0 # Total status samples for this thread
+
+ # Sample counts
+ sample_count: int = 0
+ gc_frame_samples: int = 0
+
+ def increment_status_flag(self, status_flags):
+ """Update status counts based on status bit flags."""
+ if status_flags & THREAD_STATUS_HAS_GIL:
+ self.has_gil += 1
+ if status_flags & THREAD_STATUS_ON_CPU:
+ self.on_cpu += 1
+ if status_flags & THREAD_STATUS_GIL_REQUESTED:
+ self.gil_requested += 1
+ if status_flags & THREAD_STATUS_UNKNOWN:
+ self.unknown += 1
+ self.total += 1
+
+ def as_status_dict(self):
+ """Return status counts as a dict for compatibility."""
+ return {
+ "has_gil": self.has_gil,
+ "on_cpu": self.on_cpu,
+ "gil_requested": self.gil_requested,
+ "unknown": self.unknown,
+ "total": self.total,
+ }
+
+
+class LiveStatsCollector(Collector):
+ """Collector that displays live top-like statistics using ncurses."""
+
+ def __init__(
+ self,
+ sample_interval_usec,
+ *,
+ skip_idle=False,
+ sort_by=DEFAULT_SORT_BY,
+ limit=DEFAULT_DISPLAY_LIMIT,
+ pid=None,
+ display=None,
+ mode=None,
+ ):
+ """
+ Initialize the live stats collector.
+
+ Args:
+ sample_interval_usec: Sampling interval in microseconds
+ skip_idle: Whether to skip idle threads
+ sort_by: Sort key ('tottime', 'nsamples', 'cumtime', 'sample_pct', 'cumul_pct')
+ limit: Maximum number of functions to display
+ pid: Process ID being profiled
+ display: DisplayInterface implementation (None means curses will be used)
+ mode: Profiling mode ('cpu', 'gil', etc.) - affects what stats are shown
+ """
+ self.result = collections.defaultdict(
+ lambda: dict(total_rec_calls=0, direct_calls=0, cumulative_calls=0)
+ )
+ self.sample_interval_usec = sample_interval_usec
+ self.sample_interval_sec = (
+ sample_interval_usec / MICROSECONDS_PER_SECOND
+ )
+ self.skip_idle = skip_idle
+ self.sort_by = sort_by
+ self.limit = limit
+ self.total_samples = 0
+ self.start_time = None
+ self.stdscr = None
+ self.display = display # DisplayInterface implementation
+ self.running = True
+ self.pid = pid
+ self.mode = mode # Profiling mode
+ self._saved_stdout = None
+ self._saved_stderr = None
+ self._devnull = None
+ self._last_display_update = None
+ self._max_sample_rate = 0 # Track maximum sample rate seen
+ self._successful_samples = 0 # Track samples that captured frames
+ self._failed_samples = 0 # Track samples that failed to capture frames
+ self._display_update_interval = DISPLAY_UPDATE_INTERVAL # Instance variable for display refresh rate
+
+ # Thread status statistics (bit flags)
+ self._thread_status_counts = {
+ "has_gil": 0,
+ "on_cpu": 0,
+ "gil_requested": 0,
+ "unknown": 0,
+ "total": 0, # Total thread count across all samples
+ }
+ self._gc_frame_samples = 0 # Track samples with GC frames
+
+ # Interactive controls state
+ self.paused = False # Pause UI updates (profiling continues)
+ self.show_help = False # Show help screen
+ self.filter_pattern = None # Glob pattern to filter functions
+ self.filter_input_mode = False # Currently entering filter text
+ self.filter_input_buffer = "" # Buffer for filter input
+ self.finished = False # Program has finished, showing final state
+ self.finish_timestamp = None # When profiling finished (for time freezing)
+ self.finish_wall_time = None # Wall clock time when profiling finished
+
+ # Thread tracking state
+ self.thread_ids = [] # List of thread IDs seen
+ self.view_mode = "ALL" # "ALL" or "PER_THREAD"
+ self.current_thread_index = (
+ 0 # Index into thread_ids when in PER_THREAD mode
+ )
+ self.per_thread_data = {} # {thread_id: ThreadData}
+
+ # Calculate common path prefixes to strip
+ self._path_prefixes = self._get_common_path_prefixes()
+
+ # Widgets (initialized when display is available)
+ self._header_widget = None
+ self._table_widget = None
+ self._footer_widget = None
+ self._help_widget = None
+
+ # Color mode
+ self._can_colorize = _colorize.can_colorize()
+
+ # Trend tracking (initialized after colors are set up)
+ self._trend_tracker = None
+
+ @property
+ def elapsed_time(self):
+ """Get the elapsed time, frozen when finished."""
+ if self.finished and self.finish_timestamp is not None:
+ return self.finish_timestamp - self.start_time
+ return time.perf_counter() - self.start_time if self.start_time else 0
+
+ @property
+ def current_time_display(self):
+ """Get the current time for display, frozen when finished."""
+ if self.finished and self.finish_wall_time is not None:
+ return time.strftime("%H:%M:%S", time.localtime(self.finish_wall_time))
+ return time.strftime("%H:%M:%S")
+
+ def _get_or_create_thread_data(self, thread_id):
+ """Get or create ThreadData for a thread ID."""
+ if thread_id not in self.per_thread_data:
+ self.per_thread_data[thread_id] = ThreadData(thread_id=thread_id)
+ return self.per_thread_data[thread_id]
+
+ def _get_current_thread_data(self):
+ """Get ThreadData for currently selected thread in PER_THREAD mode."""
+ if self.view_mode == "PER_THREAD" and self.current_thread_index < len(self.thread_ids):
+ thread_id = self.thread_ids[self.current_thread_index]
+ return self.per_thread_data.get(thread_id)
+ return None
+
+ def _get_current_result_source(self):
+ """Get result dict for current view mode (aggregated or per-thread)."""
+ if self.view_mode == "ALL":
+ return self.result
+ thread_data = self._get_current_thread_data()
+ return thread_data.result if thread_data else {}
+
+ def _get_common_path_prefixes(self):
+ """Get common path prefixes to strip from file paths."""
+ prefixes = []
+
+ # Get the actual stdlib location from the os module
+ # This works for both installed Python and development builds
+ os_module_file = os.__file__
+ if os_module_file:
+ # os.__file__ points to os.py, get its directory
+ stdlib_dir = os.path.dirname(os.path.abspath(os_module_file))
+ prefixes.append(stdlib_dir)
+
+ # Get stdlib location from sysconfig (may be different or same)
+ stdlib_path = sysconfig.get_path("stdlib")
+ if stdlib_path:
+ prefixes.append(stdlib_path)
+
+ # Get platstdlib location (platform-specific stdlib)
+ platstdlib_path = sysconfig.get_path("platstdlib")
+ if platstdlib_path:
+ prefixes.append(platstdlib_path)
+
+ # Get site-packages locations
+ for site_path in site.getsitepackages():
+ prefixes.append(site_path)
+
+ # Also check user site-packages
+ user_site = site.getusersitepackages()
+ if user_site:
+ prefixes.append(user_site)
+
+ # Remove duplicates and sort by length (longest first) to match most specific paths first
+ prefixes = list(set(prefixes))
+ prefixes.sort(key=lambda x: len(x), reverse=True)
+
+ return prefixes
+
+ def _simplify_path(self, filepath):
+ """Simplify a file path by removing common prefixes."""
+ # Try to match against known prefixes
+ for prefix_path in self._path_prefixes:
+ if filepath.startswith(prefix_path):
+ # Remove the prefix completely
+ relative = filepath[len(prefix_path) :].lstrip(os.sep)
+ return relative
+
+ # If no match, return the original path
+ return filepath
+
+ def _process_frames(self, frames, thread_id=None):
+ """Process a single thread's frame stack.
+
+ Args:
+ frames: List of frame information
+ thread_id: Thread ID for per-thread tracking (optional)
+ """
+ if not frames:
+ return
+
+ # Get per-thread data if tracking per-thread
+ thread_data = self._get_or_create_thread_data(thread_id) if thread_id is not None else None
+
+ # Process each frame in the stack to track cumulative calls
+ for frame in frames:
+ location = (frame.filename, frame.lineno, frame.funcname)
+ self.result[location]["cumulative_calls"] += 1
+ if thread_data:
+ thread_data.result[location]["cumulative_calls"] += 1
+
+ # The top frame gets counted as an inline call (directly executing)
+ top_location = (frames[0].filename, frames[0].lineno, frames[0].funcname)
+ self.result[top_location]["direct_calls"] += 1
+ if thread_data:
+ thread_data.result[top_location]["direct_calls"] += 1
+
+ def collect_failed_sample(self):
+ self._failed_samples += 1
+ self.total_samples += 1
+
+ def collect(self, stack_frames):
+ """Collect and display profiling data."""
+ if self.start_time is None:
+ self.start_time = time.perf_counter()
+ self._last_display_update = self.start_time
+
+ # Thread status counts for this sample
+ temp_status_counts = {
+ "has_gil": 0,
+ "on_cpu": 0,
+ "gil_requested": 0,
+ "unknown": 0,
+ "total": 0,
+ }
+ has_gc_frame = False
+
+ # Always collect data, even when paused
+ # Track thread status flags and GC frames
+ for interpreter_info in stack_frames:
+ threads = getattr(interpreter_info, "threads", [])
+ for thread_info in threads:
+ temp_status_counts["total"] += 1
+
+ # Track thread status using bit flags
+ status_flags = getattr(thread_info, "status", 0)
+ thread_id = getattr(thread_info, "thread_id", None)
+
+ # Update aggregated counts
+ if status_flags & THREAD_STATUS_HAS_GIL:
+ temp_status_counts["has_gil"] += 1
+ if status_flags & THREAD_STATUS_ON_CPU:
+ temp_status_counts["on_cpu"] += 1
+ if status_flags & THREAD_STATUS_GIL_REQUESTED:
+ temp_status_counts["gil_requested"] += 1
+ if status_flags & THREAD_STATUS_UNKNOWN:
+ temp_status_counts["unknown"] += 1
+
+ # Update per-thread status counts
+ if thread_id is not None:
+ thread_data = self._get_or_create_thread_data(thread_id)
+ thread_data.increment_status_flag(status_flags)
+
+ # Process frames (respecting skip_idle)
+ if self.skip_idle:
+ has_gil = bool(status_flags & THREAD_STATUS_HAS_GIL)
+ on_cpu = bool(status_flags & THREAD_STATUS_ON_CPU)
+ if not (has_gil or on_cpu):
+ continue
+
+ frames = getattr(thread_info, "frame_info", None)
+ if frames:
+ self._process_frames(frames, thread_id=thread_id)
+
+ # Track thread IDs only for threads that actually have samples
+ if (
+ thread_id is not None
+ and thread_id not in self.thread_ids
+ ):
+ self.thread_ids.append(thread_id)
+
+ # Increment per-thread sample count and check for GC frames
+ thread_has_gc_frame = False
+ for frame in frames:
+ funcname = getattr(frame, "funcname", "")
+ if "<GC>" in funcname or "gc_collect" in funcname:
+ has_gc_frame = True
+ thread_has_gc_frame = True
+ break
+
+ if thread_id is not None:
+ thread_data = self._get_or_create_thread_data(thread_id)
+ thread_data.sample_count += 1
+ if thread_has_gc_frame:
+ thread_data.gc_frame_samples += 1
+
+ # Update cumulative thread status counts
+ for key, count in temp_status_counts.items():
+ self._thread_status_counts[key] += count
+
+ if has_gc_frame:
+ self._gc_frame_samples += 1
+
+ self._successful_samples += 1
+ self.total_samples += 1
+
+ # Handle input on every sample for instant responsiveness
+ if self.display is not None:
+ self._handle_input()
+
+ # Update display at configured rate if display is initialized and not paused
+ if self.display is not None and not self.paused:
+ current_time = time.perf_counter()
+ if (
+ self._last_display_update is None
+ or (current_time - self._last_display_update)
+ >= self._display_update_interval
+ ):
+ self._update_display()
+ self._last_display_update = current_time
+
+ def _prepare_display_data(self, height):
+ """Prepare data for display rendering."""
+ elapsed = self.elapsed_time
+ stats_list = self._build_stats_list()
+
+ # Calculate available space for stats
+ # Add extra lines for finished banner when in finished state
+ extra_header_lines = (
+ FINISHED_BANNER_EXTRA_LINES if self.finished else 0
+ )
+ max_stats_lines = max(
+ 0,
+ height
+ - HEADER_LINES
+ - extra_header_lines
+ - FOOTER_LINES
+ - SAFETY_MARGIN,
+ )
+ stats_list = stats_list[:max_stats_lines]
+
+ return elapsed, stats_list
+
+ def _initialize_widgets(self, colors):
+ """Initialize widgets with display and colors."""
+ if self._header_widget is None:
+ # Initialize trend tracker with colors
+ if self._trend_tracker is None:
+ self._trend_tracker = TrendTracker(colors, enabled=True)
+
+ self._header_widget = HeaderWidget(self.display, colors, self)
+ self._table_widget = TableWidget(self.display, colors, self)
+ self._footer_widget = FooterWidget(self.display, colors, self)
+ self._help_widget = HelpWidget(self.display, colors)
+
+ def _render_display_sections(
+ self, height, width, elapsed, stats_list, colors
+ ):
+ """Render all display sections to the screen."""
+ line = 0
+ try:
+ # Initialize widgets if not already done
+ self._initialize_widgets(colors)
+
+ # Render header
+ line = self._header_widget.render(
+ line, width, elapsed=elapsed, stats_list=stats_list
+ )
+
+ # Render table
+ line = self._table_widget.render(
+ line, width, height=height, stats_list=stats_list
+ )
+
+ except curses.error:
+ pass
+
+ def _update_display(self):
+ """Update the display with current stats."""
+ try:
+ # Clear screen and get dimensions
+ self.display.clear()
+ height, width = self.display.get_dimensions()
+
+ # Check terminal size
+ if width < MIN_TERMINAL_WIDTH or height < MIN_TERMINAL_HEIGHT:
+ self._show_terminal_too_small(height, width)
+ self.display.refresh()
+ return
+
+ # Setup colors and initialize widgets (needed for both help and normal display)
+ colors = self._setup_colors()
+ self._initialize_widgets(colors)
+
+ # Show help screen if requested
+ if self.show_help:
+ self._help_widget.render(0, width, height=height)
+ self.display.refresh()
+ return
+
+ # Prepare data
+ elapsed, stats_list = self._prepare_display_data(height)
+
+ # Render all sections
+ self._render_display_sections(
+ height, width, elapsed, stats_list, colors
+ )
+
+ # Footer
+ self._footer_widget.render(height - 2, width)
+
+ # Show filter input prompt if in filter input mode
+ if self.filter_input_mode:
+ self._footer_widget.render_filter_input_prompt(
+ height - 1, width
+ )
+
+ # Refresh display
+ self.display.redraw()
+ self.display.refresh()
+
+ except Exception:
+ pass
+
+ def _cycle_sort(self, reverse=False):
+ """Cycle through different sort modes in column order.
+
+ Args:
+ reverse: If True, cycle backwards (right to left), otherwise forward (left to right)
+ """
+ sort_modes = [
+ "nsamples",
+ "sample_pct",
+ "tottime",
+ "cumul_pct",
+ "cumtime",
+ ]
+ try:
+ current_idx = sort_modes.index(self.sort_by)
+ if reverse:
+ self.sort_by = sort_modes[(current_idx - 1) % len(sort_modes)]
+ else:
+ self.sort_by = sort_modes[(current_idx + 1) % len(sort_modes)]
+ except ValueError:
+ self.sort_by = "nsamples"
+
+ def _setup_colors(self):
+ """Set up color pairs and return color attributes."""
+
+ A_BOLD = self.display.get_attr("A_BOLD")
+ A_REVERSE = self.display.get_attr("A_REVERSE")
+ A_UNDERLINE = self.display.get_attr("A_UNDERLINE")
+ A_NORMAL = self.display.get_attr("A_NORMAL")
+
+ # Check both curses color support and _colorize.can_colorize()
+ if self.display.has_colors() and self._can_colorize:
+ with contextlib.suppress(Exception):
+ # Color constants (using curses values for compatibility)
+ COLOR_CYAN = 6
+ COLOR_GREEN = 2
+ COLOR_YELLOW = 3
+ COLOR_BLACK = 0
+ COLOR_MAGENTA = 5
+ COLOR_RED = 1
+
+ # Initialize all color pairs used throughout the UI
+ self.display.init_color_pair(
+ 1, COLOR_CYAN, -1
+ ) # Data colors for stats rows
+ self.display.init_color_pair(2, COLOR_GREEN, -1)
+ self.display.init_color_pair(3, COLOR_YELLOW, -1)
+ self.display.init_color_pair(
+ COLOR_PAIR_HEADER_BG, COLOR_BLACK, COLOR_GREEN
+ )
+ self.display.init_color_pair(
+ COLOR_PAIR_CYAN, COLOR_CYAN, COLOR_BLACK
+ )
+ self.display.init_color_pair(
+ COLOR_PAIR_YELLOW, COLOR_YELLOW, COLOR_BLACK
+ )
+ self.display.init_color_pair(
+ COLOR_PAIR_GREEN, COLOR_GREEN, COLOR_BLACK
+ )
+ self.display.init_color_pair(
+ COLOR_PAIR_MAGENTA, COLOR_MAGENTA, COLOR_BLACK
+ )
+ self.display.init_color_pair(
+ COLOR_PAIR_RED, COLOR_RED, COLOR_BLACK
+ )
+ self.display.init_color_pair(
+ COLOR_PAIR_SORTED_HEADER, COLOR_BLACK, COLOR_YELLOW
+ )
+
+ return {
+ "header": self.display.get_color_pair(COLOR_PAIR_HEADER_BG)
+ | A_BOLD,
+ "cyan": self.display.get_color_pair(COLOR_PAIR_CYAN)
+ | A_BOLD,
+ "yellow": self.display.get_color_pair(COLOR_PAIR_YELLOW)
+ | A_BOLD,
+ "green": self.display.get_color_pair(COLOR_PAIR_GREEN)
+ | A_BOLD,
+ "magenta": self.display.get_color_pair(COLOR_PAIR_MAGENTA)
+ | A_BOLD,
+ "red": self.display.get_color_pair(COLOR_PAIR_RED)
+ | A_BOLD,
+ "sorted_header": self.display.get_color_pair(
+ COLOR_PAIR_SORTED_HEADER
+ )
+ | A_BOLD,
+ "normal_header": A_REVERSE | A_BOLD,
+ "color_samples": self.display.get_color_pair(1),
+ "color_file": self.display.get_color_pair(2),
+ "color_func": self.display.get_color_pair(3),
+ # Trend colors (stock-like indicators)
+ "trend_up": self.display.get_color_pair(COLOR_PAIR_GREEN) | A_BOLD,
+ "trend_down": self.display.get_color_pair(COLOR_PAIR_RED) | A_BOLD,
+ "trend_stable": A_NORMAL,
+ }
+
+ # Fallback to non-color attributes
+ return {
+ "header": A_REVERSE | A_BOLD,
+ "cyan": A_BOLD,
+ "yellow": A_BOLD,
+ "green": A_BOLD,
+ "magenta": A_BOLD,
+ "red": A_BOLD,
+ "sorted_header": A_REVERSE | A_BOLD | A_UNDERLINE,
+ "normal_header": A_REVERSE | A_BOLD,
+ "color_samples": A_NORMAL,
+ "color_file": A_NORMAL,
+ "color_func": A_NORMAL,
+ # Trend colors (fallback to bold/normal for monochrome)
+ "trend_up": A_BOLD,
+ "trend_down": A_BOLD,
+ "trend_stable": A_NORMAL,
+ }
+
+ def _build_stats_list(self):
+ """Build and sort the statistics list."""
+ stats_list = []
+ result_source = self._get_current_result_source()
+
+ for func, call_counts in result_source.items():
+ # Apply filter if set (using substring matching)
+ if self.filter_pattern:
+ filename, lineno, funcname = func
+ # Simple substring match (case-insensitive)
+ pattern_lower = self.filter_pattern.lower()
+ filename_lower = filename.lower()
+ funcname_lower = funcname.lower()
+
+ # Match if pattern is substring of filename, funcname, or combined
+ matched = (
+ pattern_lower in filename_lower
+ or pattern_lower in funcname_lower
+ or pattern_lower in f"{filename_lower}:{funcname_lower}"
+ )
+ if not matched:
+ continue
+
+ direct_calls = call_counts.get("direct_calls", 0)
+ cumulative_calls = call_counts.get("cumulative_calls", 0)
+ total_time = direct_calls * self.sample_interval_sec
+ cumulative_time = cumulative_calls * self.sample_interval_sec
+
+ # Calculate sample percentages
+ sample_pct = (direct_calls / self.total_samples * 100) if self.total_samples > 0 else 0
+ cumul_pct = (cumulative_calls / self.total_samples * 100) if self.total_samples > 0 else 0
+
+ # Calculate trends for all columns using TrendTracker
+ trends = {}
+ if self._trend_tracker is not None:
+ trends = self._trend_tracker.update_metrics(
+ func,
+ {
+ 'nsamples': direct_calls,
+ 'tottime': total_time,
+ 'cumtime': cumulative_time,
+ 'sample_pct': sample_pct,
+ 'cumul_pct': cumul_pct,
+ }
+ )
+
+ stats_list.append(
+ {
+ "func": func,
+ "direct_calls": direct_calls,
+ "cumulative_calls": cumulative_calls,
+ "total_time": total_time,
+ "cumulative_time": cumulative_time,
+ "trends": trends, # Dictionary of trends for all columns
+ }
+ )
+
+ # Sort the stats
+ if self.sort_by == "nsamples":
+ stats_list.sort(key=lambda x: x["direct_calls"], reverse=True)
+ elif self.sort_by == "tottime":
+ stats_list.sort(key=lambda x: x["total_time"], reverse=True)
+ elif self.sort_by == "cumtime":
+ stats_list.sort(key=lambda x: x["cumulative_time"], reverse=True)
+ elif self.sort_by == "sample_pct":
+ stats_list.sort(
+ key=lambda x: (x["direct_calls"] / self.total_samples * 100)
+ if self.total_samples > 0
+ else 0,
+ reverse=True,
+ )
+ elif self.sort_by == "cumul_pct":
+ stats_list.sort(
+ key=lambda x: (
+ x["cumulative_calls"] / self.total_samples * 100
+ )
+ if self.total_samples > 0
+ else 0,
+ reverse=True,
+ )
+
+ return stats_list
+
+ def reset_stats(self):
+ """Reset all collected statistics."""
+ self.result.clear()
+ self.per_thread_data.clear()
+ self.thread_ids.clear()
+ self.view_mode = "ALL"
+ self.current_thread_index = 0
+ self.total_samples = 0
+ self._successful_samples = 0
+ self._failed_samples = 0
+ self._max_sample_rate = 0
+ self._thread_status_counts = {
+ "has_gil": 0,
+ "on_cpu": 0,
+ "gil_requested": 0,
+ "unknown": 0,
+ "total": 0,
+ }
+ self._gc_frame_samples = 0
+ # Clear trend tracking
+ if self._trend_tracker is not None:
+ self._trend_tracker.clear()
+ # Reset finished state and finish timestamp
+ self.finished = False
+ self.finish_timestamp = None
+ self.finish_wall_time = None
+ self.start_time = time.perf_counter()
+ self._last_display_update = self.start_time
+
+ def mark_finished(self):
+ """Mark the profiling session as finished."""
+ self.finished = True
+ # Capture the finish timestamp to freeze all timing displays
+ self.finish_timestamp = time.perf_counter()
+ self.finish_wall_time = time.time() # Wall clock time for display
+ # Force a final display update to show the finished message
+ if self.display is not None:
+ self._update_display()
+
+ def _handle_finished_input_update(self, had_input):
+ """Update display after input when program is finished."""
+ if self.finished and had_input and self.display is not None:
+ self._update_display()
+
+ def _show_terminal_too_small(self, height, width):
+ """Display a message when terminal is too small."""
+ A_BOLD = self.display.get_attr("A_BOLD")
+ msg1 = "Terminal too small!"
+ msg2 = f"Need: {MIN_TERMINAL_WIDTH}x{MIN_TERMINAL_HEIGHT}"
+ msg3 = f"Have: {width}x{height}"
+ msg4 = "Please resize"
+
+ # Center the messages
+ if height >= 4:
+ self.display.add_str(
+ height // 2 - 2,
+ max(0, (width - len(msg1)) // 2),
+ msg1[: width - 1],
+ A_BOLD,
+ )
+ self.display.add_str(
+ height // 2 - 1,
+ max(0, (width - len(msg2)) // 2),
+ msg2[: width - 1],
+ )
+ self.display.add_str(
+ height // 2,
+ max(0, (width - len(msg3)) // 2),
+ msg3[: width - 1],
+ )
+ self.display.add_str(
+ height // 2 + 1,
+ max(0, (width - len(msg4)) // 2),
+ msg4[: width - 1],
+ )
+ elif height >= 1:
+ self.display.add_str(0, 0, msg1[: width - 1], A_BOLD)
+
+ def _show_terminal_size_warning_and_wait(self, height, width):
+ """Show terminal size warning during initialization and wait for user acknowledgment."""
+ A_BOLD = self.display.get_attr("A_BOLD")
+ A_DIM = self.display.get_attr("A_DIM")
+
+ self.display.clear()
+ msg1 = "WARNING: Terminal too small!"
+ msg2 = f"Required: {MIN_TERMINAL_WIDTH}x{MIN_TERMINAL_HEIGHT}"
+ msg3 = f"Current: {width}x{height}"
+ msg4 = "Please resize your terminal for best experience"
+ msg5 = "Press any key to continue..."
+
+ # Center the messages
+ if height >= 5:
+ self.display.add_str(
+ height // 2 - 2,
+ max(0, (width - len(msg1)) // 2),
+ msg1[: width - 1],
+ A_BOLD,
+ )
+ self.display.add_str(
+ height // 2 - 1,
+ max(0, (width - len(msg2)) // 2),
+ msg2[: width - 1],
+ )
+ self.display.add_str(
+ height // 2,
+ max(0, (width - len(msg3)) // 2),
+ msg3[: width - 1],
+ )
+ self.display.add_str(
+ height // 2 + 1,
+ max(0, (width - len(msg4)) // 2),
+ msg4[: width - 1],
+ )
+ self.display.add_str(
+ height // 2 + 3,
+ max(0, (width - len(msg5)) // 2),
+ msg5[: width - 1],
+ A_DIM,
+ )
+ elif height >= 1:
+ self.display.add_str(0, 0, msg1[: width - 1], A_BOLD)
+
+ self.display.refresh()
+ # Wait for user acknowledgment (2 seconds timeout)
+ self.display.set_nodelay(False)
+ # Note: timeout is curses-specific, skipping for now
+ self.display.get_input()
+ self.display.set_nodelay(True)
+
+ def _handle_input(self):
+ """Handle keyboard input (non-blocking)."""
+ from . import constants
+
+ self.display.set_nodelay(True)
+ ch = self.display.get_input()
+
+ # Handle filter input mode FIRST - takes precedence over all commands
+ if self.filter_input_mode:
+ if ch == 27: # ESC key
+ self.filter_input_mode = False
+ self.filter_input_buffer = ""
+ elif ch == 10 or ch == 13: # Enter key
+ self.filter_pattern = (
+ self.filter_input_buffer
+ if self.filter_input_buffer
+ else None
+ )
+ self.filter_input_mode = False
+ self.filter_input_buffer = ""
+ elif ch == 127 or ch == 263: # Backspace
+ if self.filter_input_buffer:
+ self.filter_input_buffer = self.filter_input_buffer[:-1]
+ elif ch >= 32 and ch < 127: # Printable characters
+ self.filter_input_buffer += chr(ch)
+
+ # Update display if input was processed while finished
+ self._handle_finished_input_update(ch != -1)
+ return
+
+ # Handle help toggle keys
+ if ch == ord("h") or ch == ord("H") or ch == ord("?"):
+ self.show_help = not self.show_help
+
+ # If showing help, any other key closes it
+ elif self.show_help and ch != -1:
+ self.show_help = False
+
+ # Handle regular commands
+ if ch == ord("q") or ch == ord("Q"):
+ self.running = False
+
+ elif ch == ord("s"):
+ self._cycle_sort(reverse=False)
+
+ elif ch == ord("S"):
+ self._cycle_sort(reverse=True)
+
+ elif ch == ord("p") or ch == ord("P"):
+ self.paused = not self.paused
+
+ elif ch == ord("r") or ch == ord("R"):
+ # Don't allow reset when profiling is finished
+ if not self.finished:
+ self.reset_stats()
+
+ elif ch == ord("+") or ch == ord("="):
+ # Decrease update interval (faster refresh)
+ self._display_update_interval = max(
+ 0.05, self._display_update_interval - 0.05
+ ) # Min 20Hz
+
+ elif ch == ord("-") or ch == ord("_"):
+ # Increase update interval (slower refresh)
+ self._display_update_interval = min(
+ 1.0, self._display_update_interval + 0.05
+ ) # Max 1Hz
+
+ elif ch == ord("c") or ch == ord("C"):
+ if self.filter_pattern:
+ self.filter_pattern = None
+
+ elif ch == ord("/"):
+ self.filter_input_mode = True
+ self.filter_input_buffer = self.filter_pattern or ""
+
+ elif ch == ord("t") or ch == ord("T"):
+ # Toggle between ALL and PER_THREAD modes
+ if self.view_mode == "ALL":
+ if len(self.thread_ids) > 0:
+ self.view_mode = "PER_THREAD"
+ self.current_thread_index = 0
+ else:
+ self.view_mode = "ALL"
+
+ elif ch == ord("x") or ch == ord("X"):
+ # Toggle trend colors on/off
+ if self._trend_tracker is not None:
+ self._trend_tracker.toggle()
+
+ elif ch == curses.KEY_LEFT or ch == curses.KEY_UP:
+ # Navigate to previous thread in PER_THREAD mode, or switch from ALL to PER_THREAD
+ if len(self.thread_ids) > 0:
+ if self.view_mode == "ALL":
+ self.view_mode = "PER_THREAD"
+ self.current_thread_index = 0
+ else:
+ self.current_thread_index = (
+ self.current_thread_index - 1
+ ) % len(self.thread_ids)
+
+ elif ch == curses.KEY_RIGHT or ch == curses.KEY_DOWN:
+ # Navigate to next thread in PER_THREAD mode, or switch from ALL to PER_THREAD
+ if len(self.thread_ids) > 0:
+ if self.view_mode == "ALL":
+ self.view_mode = "PER_THREAD"
+ self.current_thread_index = 0
+ else:
+ self.current_thread_index = (
+ self.current_thread_index + 1
+ ) % len(self.thread_ids)
+
+ # Update display if input was processed while finished
+ self._handle_finished_input_update(ch != -1)
+
+ def init_curses(self, stdscr):
+ """Initialize curses display and suppress stdout/stderr."""
+ self.stdscr = stdscr
+ self.display = CursesDisplay(stdscr)
+
+ # Check terminal size upfront and warn if too small
+ height, width = self.display.get_dimensions()
+
+ if width < MIN_TERMINAL_WIDTH or height < MIN_TERMINAL_HEIGHT:
+ # Show warning and wait briefly for user to see it
+ self._show_terminal_size_warning_and_wait(height, width)
+
+ curses.curs_set(0) # Hide cursor
+ stdscr.nodelay(True) # Non-blocking input
+ stdscr.scrollok(False) # Disable scrolling
+ stdscr.idlok(False) # Disable hardware insert/delete
+ stdscr.leaveok(True) # Don't care about cursor position
+
+ if curses.has_colors():
+ curses.start_color()
+ curses.use_default_colors()
+
+ # Suppress stdout and stderr to prevent interfering with curses display
+ # Use contextlib.redirect_stdout/stderr for better resource management
+ self._saved_stdout = sys.stdout
+ self._saved_stderr = sys.stderr
+ # Open devnull and ensure it's cleaned up even if an exception occurs
+ try:
+ self._devnull = open(os.devnull, "w")
+ sys.stdout = self._devnull
+ sys.stderr = self._devnull
+ except Exception:
+ # If redirection fails, restore original streams
+ sys.stdout = self._saved_stdout
+ sys.stderr = self._saved_stderr
+ raise
+
+ # Initial clear
+ self.display.clear()
+ self.display.refresh()
+
+ def cleanup_curses(self):
+ """Clean up curses display and restore stdout/stderr."""
+ # Restore stdout and stderr in reverse order
+ # Use try-finally to ensure cleanup even if restoration fails
+ try:
+ if self._saved_stdout is not None:
+ sys.stdout = self._saved_stdout
+ self._saved_stdout = None
+ if self._saved_stderr is not None:
+ sys.stderr = self._saved_stderr
+ self._saved_stderr = None
+ finally:
+ # Always close devnull, even if stdout/stderr restoration fails
+ if self._devnull is not None:
+ with contextlib.suppress(Exception):
+ self._devnull.close()
+ self._devnull = None
+
+ if self.display is not None and self.stdscr is not None:
+ with contextlib.suppress(Exception):
+ curses.curs_set(1) # Show cursor
+ self.display.set_nodelay(False)
+
+ def export(self, filename):
+ """Export is not supported in live mode."""
+ raise NotImplementedError(
+ "Export to file is not supported in live mode. "
+ "Use the live TUI to view statistics in real-time."
+ )
--- /dev/null
+"""Constants for the live profiling collector."""
+
+# Time conversion constants
+MICROSECONDS_PER_SECOND = 1_000_000
+
+# Display update constants
+DISPLAY_UPDATE_HZ = 10
+DISPLAY_UPDATE_INTERVAL = 1.0 / DISPLAY_UPDATE_HZ # 0.1 seconds
+
+# Terminal size constraints
+MIN_TERMINAL_WIDTH = 60
+MIN_TERMINAL_HEIGHT = 12
+
+# Column width thresholds
+WIDTH_THRESHOLD_SAMPLE_PCT = 80
+WIDTH_THRESHOLD_TOTTIME = 100
+WIDTH_THRESHOLD_CUMUL_PCT = 120
+WIDTH_THRESHOLD_CUMTIME = 140
+
+# Display layout constants
+HEADER_LINES = 10 # Increased to include thread status line
+FOOTER_LINES = 2
+SAFETY_MARGIN = 1
+TOP_FUNCTIONS_DISPLAY_COUNT = 3
+
+# Column widths for data display
+COL_WIDTH_NSAMPLES = 13
+COL_SPACING = 2
+COL_WIDTH_SAMPLE_PCT = 5
+COL_WIDTH_TIME = 10
+
+# Function name display
+MIN_FUNC_NAME_WIDTH = 10
+MAX_FUNC_NAME_WIDTH = 40
+MIN_AVAILABLE_SPACE = 10
+
+# Progress bar display
+MIN_BAR_WIDTH = 10
+MAX_SAMPLE_RATE_BAR_WIDTH = 30
+MAX_EFFICIENCY_BAR_WIDTH = 60
+
+# Sample rate scaling
+MIN_SAMPLE_RATE_FOR_SCALING = 100
+
+# Finished banner display
+FINISHED_BANNER_EXTRA_LINES = 3 # Blank line + banner + blank line
+
+# Color pair IDs
+COLOR_PAIR_HEADER_BG = 4
+COLOR_PAIR_CYAN = 5
+COLOR_PAIR_YELLOW = 6
+COLOR_PAIR_GREEN = 7
+COLOR_PAIR_MAGENTA = 8
+COLOR_PAIR_RED = 9
+COLOR_PAIR_SORTED_HEADER = 10
+
+# Default display settings
+DEFAULT_SORT_BY = "nsamples" # Number of samples in leaf (self time)
+DEFAULT_DISPLAY_LIMIT = 20
--- /dev/null
+"""Display interface abstractions for the live profiling collector."""
+
+import contextlib
+import curses
+from abc import ABC, abstractmethod
+
+
+class DisplayInterface(ABC):
+ """Abstract interface for display operations to enable testing."""
+
+ @abstractmethod
+ def get_dimensions(self):
+ """Get terminal dimensions as (height, width)."""
+ pass
+
+ @abstractmethod
+ def clear(self):
+ """Clear the screen."""
+ pass
+
+ @abstractmethod
+ def refresh(self):
+ """Refresh the screen to show changes."""
+ pass
+
+ @abstractmethod
+ def redraw(self):
+ """Redraw the entire window."""
+ pass
+
+ @abstractmethod
+ def add_str(self, line, col, text, attr=0):
+ """Add a string at the specified position."""
+ pass
+
+ @abstractmethod
+ def get_input(self):
+ """Get a character from input (non-blocking). Returns -1 if no input."""
+ pass
+
+ @abstractmethod
+ def set_nodelay(self, flag):
+ """Set non-blocking mode for input."""
+ pass
+
+ @abstractmethod
+ def has_colors(self):
+ """Check if terminal supports colors."""
+ pass
+
+ @abstractmethod
+ def init_color_pair(self, pair_id, fg, bg):
+ """Initialize a color pair."""
+ pass
+
+ @abstractmethod
+ def get_color_pair(self, pair_id):
+ """Get a color pair attribute."""
+ pass
+
+ @abstractmethod
+ def get_attr(self, name):
+ """Get a display attribute by name (e.g., 'A_BOLD', 'A_REVERSE')."""
+ pass
+
+
+class CursesDisplay(DisplayInterface):
+ """Real curses display implementation."""
+
+ def __init__(self, stdscr):
+ self.stdscr = stdscr
+
+ def get_dimensions(self):
+ return self.stdscr.getmaxyx()
+
+ def clear(self):
+ self.stdscr.clear()
+
+ def refresh(self):
+ self.stdscr.refresh()
+
+ def redraw(self):
+ self.stdscr.redrawwin()
+
+ def add_str(self, line, col, text, attr=0):
+ try:
+ height, width = self.get_dimensions()
+ if 0 <= line < height and 0 <= col < width:
+ max_len = width - col - 1
+ if len(text) > max_len:
+ text = text[:max_len]
+ self.stdscr.addstr(line, col, text, attr)
+ except curses.error:
+ pass
+
+ def get_input(self):
+ try:
+ return self.stdscr.getch()
+ except (KeyError, curses.error):
+ return -1
+
+ def set_nodelay(self, flag):
+ self.stdscr.nodelay(flag)
+
+ def has_colors(self):
+ return curses.has_colors()
+
+ def init_color_pair(self, pair_id, fg, bg):
+ try:
+ curses.init_pair(pair_id, fg, bg)
+ except curses.error:
+ pass
+
+ def get_color_pair(self, pair_id):
+ return curses.color_pair(pair_id)
+
+ def get_attr(self, name):
+ return getattr(curses, name, 0)
+
+
+class MockDisplay(DisplayInterface):
+ """Mock display for testing."""
+
+ def __init__(self, height=40, width=160):
+ self.height = height
+ self.width = width
+ self.buffer = {}
+ self.cleared = False
+ self.refreshed = False
+ self.redrawn = False
+ self.input_queue = []
+ self.nodelay_flag = True
+ self.colors_supported = True
+ self.color_pairs = {}
+
+ def get_dimensions(self):
+ return (self.height, self.width)
+
+ def clear(self):
+ self.buffer.clear()
+ self.cleared = True
+
+ def refresh(self):
+ self.refreshed = True
+
+ def redraw(self):
+ self.redrawn = True
+
+ def add_str(self, line, col, text, attr=0):
+ if 0 <= line < self.height and 0 <= col < self.width:
+ max_len = self.width - col - 1
+ if len(text) > max_len:
+ text = text[:max_len]
+ self.buffer[(line, col)] = (text, attr)
+
+ def get_input(self):
+ if self.input_queue:
+ return self.input_queue.pop(0)
+ return -1
+
+ def set_nodelay(self, flag):
+ self.nodelay_flag = flag
+
+ def has_colors(self):
+ return self.colors_supported
+
+ def init_color_pair(self, pair_id, fg, bg):
+ self.color_pairs[pair_id] = (fg, bg)
+
+ def get_color_pair(self, pair_id):
+ return pair_id << 8
+
+ def get_attr(self, name):
+ attrs = {
+ "A_NORMAL": 0,
+ "A_BOLD": 1 << 16,
+ "A_REVERSE": 1 << 17,
+ "A_UNDERLINE": 1 << 18,
+ "A_DIM": 1 << 19,
+ }
+ return attrs.get(name, 0)
+
+ def simulate_input(self, char):
+ """Helper method for tests to simulate keyboard input."""
+ self.input_queue.append(char)
+
+ def get_text_at(self, line, col):
+ """Helper method for tests to inspect buffer content."""
+ if (line, col) in self.buffer:
+ return self.buffer[(line, col)][0]
+ return None
+
+ def get_all_lines(self):
+ """Get all display content as a list of lines (for testing)."""
+ if not self.buffer:
+ return []
+
+ max_line = max(pos[0] for pos in self.buffer.keys())
+ lines = []
+ for line_num in range(max_line + 1):
+ line_parts = []
+ for col in range(self.width):
+ if (line_num, col) in self.buffer:
+ text, _ = self.buffer[(line_num, col)]
+ line_parts.append((col, text))
+
+ # Reconstruct line from parts
+ if line_parts:
+ line_parts.sort(key=lambda x: x[0])
+ line = ""
+ last_col = 0
+ for col, text in line_parts:
+ if col > last_col:
+ line += " " * (col - last_col)
+ line += text
+ last_col = col + len(text)
+ lines.append(line.rstrip())
+ else:
+ lines.append("")
+
+ # Remove trailing empty lines
+ while lines and not lines[-1]:
+ lines.pop()
+
+ return lines
+
+ def find_text(self, pattern):
+ """Find text matching pattern in buffer (for testing). Returns (line, col) or None."""
+ for (line, col), (text, _) in self.buffer.items():
+ if pattern in text:
+ return (line, col)
+ return None
+
+ def contains_text(self, text):
+ """Check if display contains the given text anywhere (for testing)."""
+ return self.find_text(text) is not None
--- /dev/null
+"""TrendTracker - Encapsulated trend tracking for live profiling metrics.
+
+This module provides trend tracking functionality for profiling metrics,
+calculating direction indicators (up/down/stable) and managing associated
+visual attributes like colors.
+"""
+
+import curses
+from typing import Dict, Literal, Any
+
+TrendDirection = Literal["up", "down", "stable"]
+
+
+class TrendTracker:
+ """
+ Tracks metric trends over time and provides visual indicators.
+
+ This class encapsulates all logic for:
+ - Tracking previous values of metrics
+ - Calculating trend directions (up/down/stable)
+ - Determining visual attributes (colors) for trends
+ - Managing enable/disable state
+
+ Example:
+ tracker = TrendTracker(colors_dict)
+ tracker.update("func1", "nsamples", 10)
+ trend = tracker.get_trend("func1", "nsamples")
+ color = tracker.get_color("func1", "nsamples")
+ """
+
+ # Threshold for determining if a value has changed significantly
+ CHANGE_THRESHOLD = 0.001
+
+ def __init__(self, colors: Dict[str, int], enabled: bool = True):
+ """
+ Initialize the trend tracker.
+
+ Args:
+ colors: Dictionary containing color attributes including
+ 'trend_up', 'trend_down', 'trend_stable'
+ enabled: Whether trend tracking is initially enabled
+ """
+ self._previous_values: Dict[Any, Dict[str, float]] = {}
+ self._enabled = enabled
+ self._colors = colors
+
+ @property
+ def enabled(self) -> bool:
+ """Whether trend tracking is enabled."""
+ return self._enabled
+
+ def toggle(self) -> bool:
+ """
+ Toggle trend tracking on/off.
+
+ Returns:
+ New enabled state
+ """
+ self._enabled = not self._enabled
+ return self._enabled
+
+ def set_enabled(self, enabled: bool) -> None:
+ """Set trend tracking enabled state."""
+ self._enabled = enabled
+
+ def update(self, key: Any, metric: str, value: float) -> TrendDirection:
+ """
+ Update a metric value and calculate its trend.
+
+ Args:
+ key: Identifier for the entity (e.g., function)
+ metric: Name of the metric (e.g., 'nsamples', 'tottime')
+ value: Current value of the metric
+
+ Returns:
+ Trend direction: 'up', 'down', or 'stable'
+ """
+ # Initialize storage for this key if needed
+ if key not in self._previous_values:
+ self._previous_values[key] = {}
+
+ # Get previous value, defaulting to current if not tracked yet
+ prev_value = self._previous_values[key].get(metric, value)
+
+ # Calculate trend
+ if value > prev_value + self.CHANGE_THRESHOLD:
+ trend = "up"
+ elif value < prev_value - self.CHANGE_THRESHOLD:
+ trend = "down"
+ else:
+ trend = "stable"
+
+ # Update previous value for next iteration
+ self._previous_values[key][metric] = value
+
+ return trend
+
+ def get_trend(self, key: Any, metric: str) -> TrendDirection:
+ """
+ Get the current trend for a metric without updating.
+
+ Args:
+ key: Identifier for the entity
+ metric: Name of the metric
+
+ Returns:
+ Trend direction, or 'stable' if not tracked
+ """
+ # This would require storing trends separately, which we don't do
+ # For now, return stable if not found
+ return "stable"
+
+ def get_color(self, trend: TrendDirection) -> int:
+ """
+ Get the color attribute for a trend direction.
+
+ Args:
+ trend: The trend direction
+
+ Returns:
+ Curses color attribute (or A_NORMAL if disabled)
+ """
+ if not self._enabled:
+ return curses.A_NORMAL
+
+ if trend == "up":
+ return self._colors.get("trend_up", curses.A_BOLD)
+ elif trend == "down":
+ return self._colors.get("trend_down", curses.A_BOLD)
+ else: # stable
+ return self._colors.get("trend_stable", curses.A_NORMAL)
+
+ def update_metrics(self, key: Any, metrics: Dict[str, float]) -> Dict[str, TrendDirection]:
+ """
+ Update multiple metrics at once and get their trends.
+
+ Args:
+ key: Identifier for the entity
+ metrics: Dictionary of metric_name -> value
+
+ Returns:
+ Dictionary of metric_name -> trend_direction
+ """
+ trends = {}
+ for metric, value in metrics.items():
+ trends[metric] = self.update(key, metric, value)
+ return trends
+
+ def clear(self) -> None:
+ """Clear all tracked values (useful on stats reset)."""
+ self._previous_values.clear()
+
+ def __repr__(self) -> str:
+ """String representation for debugging."""
+ status = "enabled" if self._enabled else "disabled"
+ tracked = len(self._previous_values)
+ return f"TrendTracker({status}, tracking {tracked} entities)"
--- /dev/null
+"""Widget classes for the live profiling collector UI."""
+
+import curses
+import time
+from abc import ABC, abstractmethod
+
+from .constants import (
+ TOP_FUNCTIONS_DISPLAY_COUNT,
+ MIN_FUNC_NAME_WIDTH,
+ MAX_FUNC_NAME_WIDTH,
+ WIDTH_THRESHOLD_SAMPLE_PCT,
+ WIDTH_THRESHOLD_TOTTIME,
+ WIDTH_THRESHOLD_CUMUL_PCT,
+ WIDTH_THRESHOLD_CUMTIME,
+ MICROSECONDS_PER_SECOND,
+ DISPLAY_UPDATE_INTERVAL,
+ MIN_BAR_WIDTH,
+ MAX_SAMPLE_RATE_BAR_WIDTH,
+ MAX_EFFICIENCY_BAR_WIDTH,
+ MIN_SAMPLE_RATE_FOR_SCALING,
+ FOOTER_LINES,
+ FINISHED_BANNER_EXTRA_LINES,
+)
+from ..constants import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+ THREAD_STATUS_UNKNOWN,
+ THREAD_STATUS_GIL_REQUESTED,
+ PROFILING_MODE_CPU,
+ PROFILING_MODE_GIL,
+ PROFILING_MODE_WALL,
+)
+
+
+class Widget(ABC):
+ """Base class for UI widgets."""
+
+ def __init__(self, display, colors):
+ """
+ Initialize widget.
+
+ Args:
+ display: DisplayInterface implementation
+ colors: Dictionary of color attributes
+ """
+ self.display = display
+ self.colors = colors
+
+ @abstractmethod
+ def render(self, line, width, **kwargs):
+ """
+ Render the widget starting at the given line.
+
+ Args:
+ line: Starting line number
+ width: Available width
+ **kwargs: Additional rendering parameters
+
+ Returns:
+ Next available line number after rendering
+ """
+ pass
+
+ def add_str(self, line, col, text, attr=0):
+ """Add a string to the display at the specified position."""
+ self.display.add_str(line, col, text, attr)
+
+
+class ProgressBarWidget(Widget):
+ """Reusable progress bar widget."""
+
+ def render(self, line, width, **kwargs):
+ """Render is not used for progress bars - use render_bar instead."""
+ raise NotImplementedError("Use render_bar method instead")
+
+ def render_bar(
+ self, filled, total, max_width, fill_char="█", empty_char="░"
+ ):
+ """
+ Render a progress bar and return the bar string and its length.
+
+ Args:
+ filled: Current filled amount
+ total: Total amount (max value)
+ max_width: Maximum width for the bar
+ fill_char: Character to use for filled portion
+ empty_char: Character to use for empty portion
+
+ Returns:
+ Tuple of (bar_string, bar_length)
+ """
+ bar_width = min(max_width, max_width)
+ normalized = min(filled / max(total, 1), 1.0)
+ bar_fill = int(normalized * bar_width)
+
+ bar = "["
+ for i in range(bar_width):
+ if i < bar_fill:
+ bar += fill_char
+ else:
+ bar += empty_char
+ bar += "]"
+ return bar, len(bar)
+
+
+class HeaderWidget(Widget):
+ """Widget for rendering the header section (lines 0-8)."""
+
+ def __init__(self, display, colors, collector):
+ """
+ Initialize header widget.
+
+ Args:
+ display: DisplayInterface implementation
+ colors: Dictionary of color attributes
+ collector: Reference to LiveStatsCollector for accessing stats
+ """
+ super().__init__(display, colors)
+ self.collector = collector
+ self.progress_bar = ProgressBarWidget(display, colors)
+
+ def render(self, line, width, **kwargs):
+ """
+ Render the complete header section.
+
+ Args:
+ line: Starting line number
+ width: Available width
+ kwargs: Must contain 'elapsed' key
+
+ Returns:
+ Next available line number
+ """
+ elapsed = kwargs["elapsed"]
+
+ line = self.draw_header_info(line, width, elapsed)
+ line = self.draw_sample_stats(line, width, elapsed)
+ line = self.draw_efficiency_bar(line, width)
+ line = self.draw_thread_status(line, width)
+ line = self.draw_function_stats(
+ line, width, kwargs.get("stats_list", [])
+ )
+ line = self.draw_top_functions(
+ line, width, kwargs.get("stats_list", [])
+ )
+
+ # Show prominent finished banner if profiling is complete
+ if self.collector.finished:
+ line = self.draw_finished_banner(line, width)
+
+ # Separator
+ A_DIM = self.display.get_attr("A_DIM")
+ separator = "─" * (width - 1)
+ self.add_str(line, 0, separator[: width - 1], A_DIM)
+ line += 1
+
+ return line
+
+ def format_uptime(self, elapsed):
+ """Format elapsed time as uptime string."""
+ uptime_sec = int(elapsed)
+ hours = uptime_sec // 3600
+ minutes = (uptime_sec % 3600) // 60
+ seconds = uptime_sec % 60
+ if hours > 0:
+ return f"{hours}h{minutes:02d}m{seconds:02d}s"
+ else:
+ return f"{minutes}m{seconds:02d}s"
+
+ def draw_header_info(self, line, width, elapsed):
+ """Draw the header information line with PID, uptime, time, and interval."""
+ # Draw title
+ A_BOLD = self.display.get_attr("A_BOLD")
+ title = "Tachyon Profiler"
+ self.add_str(line, 0, title, A_BOLD | self.colors["cyan"])
+ line += 1
+
+ current_time = self.collector.current_time_display
+ uptime = self.format_uptime(elapsed)
+
+ # Calculate display refresh rate
+ refresh_hz = (
+ 1.0 / self.collector._display_update_interval if self.collector._display_update_interval > 0 else 0
+ )
+
+ # Get current view mode and thread display
+ if self.collector.view_mode == "ALL":
+ thread_name = "ALL"
+ thread_color = self.colors["green"]
+ else:
+ # PER_THREAD mode
+ if self.collector.current_thread_index < len(
+ self.collector.thread_ids
+ ):
+ thread_id = self.collector.thread_ids[
+ self.collector.current_thread_index
+ ]
+ num_threads = len(self.collector.thread_ids)
+ thread_name = f"{thread_id} ({self.collector.current_thread_index + 1}/{num_threads})"
+ thread_color = self.colors["magenta"]
+ else:
+ thread_name = "ALL"
+ thread_color = self.colors["green"]
+
+ header_parts = [
+ ("PID: ", curses.A_BOLD),
+ (f"{self.collector.pid}", self.colors["cyan"]),
+ (" │ ", curses.A_DIM),
+ ("Thread: ", curses.A_BOLD),
+ (thread_name, thread_color),
+ (" │ ", curses.A_DIM),
+ ("Uptime: ", curses.A_BOLD),
+ (uptime, self.colors["green"]),
+ (" │ ", curses.A_DIM),
+ ("Time: ", curses.A_BOLD),
+ (current_time, self.colors["yellow"]),
+ (" │ ", curses.A_DIM),
+ ("Interval: ", curses.A_BOLD),
+ (
+ f"{self.collector.sample_interval_usec}µs",
+ self.colors["magenta"],
+ ),
+ (" │ ", curses.A_DIM),
+ ("Display: ", curses.A_BOLD),
+ (f"{refresh_hz:.1f}Hz", self.colors["cyan"]),
+ ]
+
+ col = 0
+ for text, attr in header_parts:
+ if col < width - 1:
+ self.add_str(line, col, text, attr)
+ col += len(text)
+ return line + 1
+
+ def format_rate_with_units(self, rate_hz):
+ """Format a rate in Hz with appropriate units (Hz, KHz, MHz)."""
+ if rate_hz >= 1_000_000:
+ return f"{rate_hz / 1_000_000:.1f}MHz"
+ elif rate_hz >= 1_000:
+ return f"{rate_hz / 1_000:.1f}KHz"
+ else:
+ return f"{rate_hz:.1f}Hz"
+
+ def draw_sample_stats(self, line, width, elapsed):
+ """Draw sample statistics with visual progress bar."""
+ sample_rate = (
+ self.collector.total_samples / elapsed if elapsed > 0 else 0
+ )
+
+ # Update max sample rate
+ if sample_rate > self.collector._max_sample_rate:
+ self.collector._max_sample_rate = sample_rate
+
+ col = 0
+ self.add_str(line, col, "Samples: ", curses.A_BOLD)
+ col += 9
+ self.add_str(
+ line,
+ col,
+ f"{self.collector.total_samples:>8}",
+ self.colors["cyan"],
+ )
+ col += 8
+ self.add_str(
+ line, col, f" total ({sample_rate:>7.1f}/s) ", curses.A_NORMAL
+ )
+ col += 23
+
+ # Draw sample rate bar
+ target_rate = (
+ MICROSECONDS_PER_SECOND / self.collector.sample_interval_usec
+ )
+
+ # Show current/target ratio with percentage
+ if sample_rate > 0 and target_rate > 0:
+ percentage = min((sample_rate / target_rate) * 100, 100)
+ current_formatted = self.format_rate_with_units(sample_rate)
+ target_formatted = self.format_rate_with_units(target_rate)
+
+ if percentage >= 99.5: # Show 100% when very close
+ rate_label = f" {current_formatted}/{target_formatted} (100%)"
+ else:
+ rate_label = f" {current_formatted}/{target_formatted} ({percentage:>4.1f}%)"
+ else:
+ target_formatted = self.format_rate_with_units(target_rate)
+ rate_label = f" target: {target_formatted}"
+
+ available_width = width - col - len(rate_label) - 3
+
+ if available_width >= MIN_BAR_WIDTH:
+ bar_width = min(MAX_SAMPLE_RATE_BAR_WIDTH, available_width)
+ # Use target rate as the reference, with a minimum for scaling
+ reference_rate = max(target_rate, MIN_SAMPLE_RATE_FOR_SCALING)
+ normalized_rate = min(sample_rate / reference_rate, 1.0)
+ bar_fill = int(normalized_rate * bar_width)
+
+ bar = "["
+ for i in range(bar_width):
+ bar += "█" if i < bar_fill else "░"
+ bar += "]"
+ self.add_str(line, col, bar, self.colors["green"])
+ col += len(bar)
+
+ if col + len(rate_label) < width - 1:
+ self.add_str(line, col + 1, rate_label, curses.A_DIM)
+ return line + 1
+
+ def draw_efficiency_bar(self, line, width):
+ """Draw sample efficiency bar showing success/failure rates."""
+ success_pct = (
+ self.collector._successful_samples
+ / max(1, self.collector.total_samples)
+ ) * 100
+ failed_pct = (
+ self.collector._failed_samples
+ / max(1, self.collector.total_samples)
+ ) * 100
+
+ col = 0
+ self.add_str(line, col, "Efficiency:", curses.A_BOLD)
+ col += 11
+
+ label = f" {success_pct:>5.2f}% good, {failed_pct:>4.2f}% failed"
+ available_width = width - col - len(label) - 3
+
+ if available_width >= MIN_BAR_WIDTH:
+ bar_width = min(MAX_EFFICIENCY_BAR_WIDTH, available_width)
+ success_fill = int(
+ (
+ self.collector._successful_samples
+ / max(1, self.collector.total_samples)
+ )
+ * bar_width
+ )
+ failed_fill = bar_width - success_fill
+
+ self.add_str(line, col, "[", curses.A_NORMAL)
+ col += 1
+ if success_fill > 0:
+ self.add_str(
+ line, col, "█" * success_fill, self.colors["green"]
+ )
+ col += success_fill
+ if failed_fill > 0:
+ self.add_str(line, col, "█" * failed_fill, self.colors["red"])
+ col += failed_fill
+ self.add_str(line, col, "]", curses.A_NORMAL)
+ col += 1
+
+ self.add_str(line, col + 1, label, curses.A_NORMAL)
+ return line + 1
+
+ def _add_percentage_stat(
+ self, line, col, value, label, color, add_separator=False
+ ):
+ """Add a percentage stat to the display.
+
+ Args:
+ line: Line number
+ col: Starting column
+ value: Percentage value
+ label: Label text
+ color: Color attribute
+ add_separator: Whether to add separator before the stat
+
+ Returns:
+ Updated column position
+ """
+ if add_separator:
+ self.add_str(line, col, " │ ", curses.A_DIM)
+ col += 3
+
+ self.add_str(line, col, f"{value:>4.1f}", color)
+ col += 4
+ self.add_str(line, col, f"% {label}", curses.A_NORMAL)
+ col += len(label) + 2
+
+ return col
+
+ def draw_thread_status(self, line, width):
+ """Draw thread status statistics and GC information."""
+ # Get status counts for current view mode
+ thread_data = self.collector._get_current_thread_data()
+ status_counts = thread_data.as_status_dict() if thread_data else self.collector._thread_status_counts
+
+ # Calculate percentages
+ total_threads = max(1, status_counts["total"])
+ pct_on_gil = (status_counts["has_gil"] / total_threads) * 100
+ pct_off_gil = 100.0 - pct_on_gil
+ pct_gil_requested = (status_counts["gil_requested"] / total_threads) * 100
+
+ # Get GC percentage based on view mode
+ if thread_data:
+ total_samples = max(1, thread_data.sample_count)
+ pct_gc = (thread_data.gc_frame_samples / total_samples) * 100
+ else:
+ total_samples = max(1, self.collector.total_samples)
+ pct_gc = (self.collector._gc_frame_samples / total_samples) * 100
+
+ col = 0
+ self.add_str(line, col, "Threads: ", curses.A_BOLD)
+ col += 11
+
+ # Show GIL stats only if mode is not GIL (GIL mode filters to only GIL holders)
+ if self.collector.mode != PROFILING_MODE_GIL:
+ col = self._add_percentage_stat(
+ line, col, pct_on_gil, "on gil", self.colors["green"]
+ )
+ col = self._add_percentage_stat(
+ line,
+ col,
+ pct_off_gil,
+ "off gil",
+ self.colors["red"],
+ add_separator=True,
+ )
+
+ # Show "waiting for gil" only if mode is not GIL
+ if self.collector.mode != PROFILING_MODE_GIL and col < width - 30:
+ col = self._add_percentage_stat(
+ line,
+ col,
+ pct_gil_requested,
+ "waiting for gil",
+ self.colors["yellow"],
+ add_separator=True,
+ )
+
+ # Always show GC stats
+ if col < width - 15:
+ col = self._add_percentage_stat(
+ line,
+ col,
+ pct_gc,
+ "GC",
+ self.colors["magenta"],
+ add_separator=(col > 11),
+ )
+
+ return line + 1
+
+ def draw_function_stats(self, line, width, stats_list):
+ """Draw function statistics summary."""
+ result_set = self.collector._get_current_result_source()
+ total_funcs = len(result_set)
+ funcs_shown = len(stats_list)
+ executing_funcs = sum(
+ 1 for f in result_set.values() if f.get("direct_calls", 0) > 0
+ )
+ stack_only = total_funcs - executing_funcs
+
+ col = 0
+ self.add_str(line, col, "Functions: ", curses.A_BOLD)
+ col += 11
+ self.add_str(line, col, f"{total_funcs:>5}", self.colors["cyan"])
+ col += 5
+ self.add_str(line, col, " total", curses.A_NORMAL)
+ col += 6
+
+ if col < width - 25:
+ self.add_str(line, col, " │ ", curses.A_DIM)
+ col += 3
+ self.add_str(
+ line, col, f"{executing_funcs:>5}", self.colors["green"]
+ )
+ col += 5
+ self.add_str(line, col, " exec", curses.A_NORMAL)
+ col += 5
+
+ if col < width - 25:
+ self.add_str(line, col, " │ ", curses.A_DIM)
+ col += 3
+ self.add_str(line, col, f"{stack_only:>5}", self.colors["yellow"])
+ col += 5
+ self.add_str(line, col, " stack", curses.A_NORMAL)
+ col += 6
+
+ if col < width - 20:
+ self.add_str(line, col, " │ ", curses.A_DIM)
+ col += 3
+ self.add_str(
+ line, col, f"{funcs_shown:>5}", self.colors["magenta"]
+ )
+ col += 5
+ self.add_str(line, col, " shown", curses.A_NORMAL)
+ return line + 1
+
+ def draw_top_functions(self, line, width, stats_list):
+ """Draw top N hottest functions."""
+ col = 0
+ self.add_str(
+ line,
+ col,
+ f"Top {TOP_FUNCTIONS_DISPLAY_COUNT}: ",
+ curses.A_BOLD,
+ )
+ col += 11
+
+ top_by_samples = sorted(
+ stats_list, key=lambda x: x["direct_calls"], reverse=True
+ )
+ emojis = ["🥇", "🥈", "🥉"]
+ medal_colors = [
+ self.colors["red"],
+ self.colors["yellow"],
+ self.colors["green"],
+ ]
+
+ displayed = 0
+ for func_data in top_by_samples:
+ if displayed >= TOP_FUNCTIONS_DISPLAY_COUNT:
+ break
+ if col >= width - 20:
+ break
+ if func_data["direct_calls"] == 0:
+ continue
+
+ func_name = func_data["func"][2]
+ func_pct = (
+ func_data["direct_calls"]
+ / max(1, self.collector.total_samples)
+ ) * 100
+
+ # Medal emoji
+ if col + 3 < width - 15:
+ self.add_str(
+ line, col, emojis[displayed] + " ", medal_colors[displayed]
+ )
+ col += 3
+
+ # Function name (truncate to fit)
+ available_for_name = width - col - 15
+ max_name_len = min(25, max(5, available_for_name))
+ if len(func_name) > max_name_len:
+ func_name = func_name[: max_name_len - 3] + "..."
+
+ if col + len(func_name) < width - 10:
+ self.add_str(line, col, func_name, medal_colors[displayed])
+ col += len(func_name)
+
+ pct_str = (
+ f" ({func_pct:.1f}%)"
+ if func_pct >= 0.1
+ else f" ({func_data['direct_calls']})"
+ )
+ self.add_str(line, col, pct_str, curses.A_DIM)
+ col += len(pct_str)
+
+ displayed += 1
+
+ if displayed < 3 and col < width - 30:
+ self.add_str(line, col, " │ ", curses.A_DIM)
+ col += 3
+
+ if displayed == 0 and col < width - 25:
+ self.add_str(line, col, "(collecting samples...)", curses.A_DIM)
+
+ return line + 1
+
+ def draw_finished_banner(self, line, width):
+ """Draw a prominent banner when profiling is finished."""
+ A_REVERSE = self.display.get_attr("A_REVERSE")
+ A_BOLD = self.display.get_attr("A_BOLD")
+
+ # Add blank line for separation
+ line += 1
+
+ # Create the banner message
+ message = " ✓ PROFILING COMPLETE - Final Results Below - Press 'q' to Quit "
+
+ # Center the message and fill the width with reverse video
+ if len(message) < width - 1:
+ padding_total = width - len(message) - 1
+ padding_left = padding_total // 2
+ padding_right = padding_total - padding_left
+ full_message = " " * padding_left + message + " " * padding_right
+ else:
+ full_message = message[: width - 1]
+
+ # Draw the banner with reverse video and bold
+ self.add_str(
+ line, 0, full_message, A_REVERSE | A_BOLD | self.colors["green"]
+ )
+ line += 1
+
+ # Add blank line for separation
+ line += 1
+
+ return line
+
+
+class TableWidget(Widget):
+ """Widget for rendering column headers and data rows."""
+
+ def __init__(self, display, colors, collector):
+ """
+ Initialize table widget.
+
+ Args:
+ display: DisplayInterface implementation
+ colors: Dictionary of color attributes
+ collector: Reference to LiveStatsCollector for accessing stats
+ """
+ super().__init__(display, colors)
+ self.collector = collector
+
+ def render(self, line, width, **kwargs):
+ """
+ Render column headers and data rows.
+
+ Args:
+ line: Starting line number
+ width: Available width
+ kwargs: Must contain 'height' and 'stats_list' keys
+
+ Returns:
+ Next available line number
+ """
+ height = kwargs["height"]
+ stats_list = kwargs["stats_list"]
+
+ # Draw column headers
+ line, show_sample_pct, show_tottime, show_cumul_pct, show_cumtime = (
+ self.draw_column_headers(line, width)
+ )
+ column_flags = (
+ show_sample_pct,
+ show_tottime,
+ show_cumul_pct,
+ show_cumtime,
+ )
+
+ # Draw data rows
+ line = self.draw_stats_rows(
+ line, height, width, stats_list, column_flags
+ )
+
+ return line
+
+ def draw_column_headers(self, line, width):
+ """Draw column headers with sort indicators."""
+ col = 0
+
+ # Determine which columns to show based on width
+ show_sample_pct = width >= WIDTH_THRESHOLD_SAMPLE_PCT
+ show_tottime = width >= WIDTH_THRESHOLD_TOTTIME
+ show_cumul_pct = width >= WIDTH_THRESHOLD_CUMUL_PCT
+ show_cumtime = width >= WIDTH_THRESHOLD_CUMTIME
+
+ sorted_header = self.colors["sorted_header"]
+ normal_header = self.colors["normal_header"]
+
+ # Determine which column is sorted
+ sort_col = {
+ "nsamples": 0,
+ "sample_pct": 1,
+ "tottime": 2,
+ "cumul_pct": 3,
+ "cumtime": 4,
+ }.get(self.collector.sort_by, -1)
+
+ # Column 0: nsamples
+ attr = sorted_header if sort_col == 0 else normal_header
+ text = f"{'▼nsamples' if sort_col == 0 else 'nsamples':>13}"
+ self.add_str(line, col, text, attr)
+ col += 15
+
+ # Column 1: sample %
+ if show_sample_pct:
+ attr = sorted_header if sort_col == 1 else normal_header
+ text = f"{'▼%' if sort_col == 1 else '%':>5}"
+ self.add_str(line, col, text, attr)
+ col += 7
+
+ # Column 2: tottime
+ if show_tottime:
+ attr = sorted_header if sort_col == 2 else normal_header
+ text = f"{'▼tottime' if sort_col == 2 else 'tottime':>10}"
+ self.add_str(line, col, text, attr)
+ col += 12
+
+ # Column 3: cumul %
+ if show_cumul_pct:
+ attr = sorted_header if sort_col == 3 else normal_header
+ text = f"{'▼%' if sort_col == 3 else '%':>5}"
+ self.add_str(line, col, text, attr)
+ col += 7
+
+ # Column 4: cumtime
+ if show_cumtime:
+ attr = sorted_header if sort_col == 4 else normal_header
+ text = f"{'▼cumtime' if sort_col == 4 else 'cumtime':>10}"
+ self.add_str(line, col, text, attr)
+ col += 12
+
+ # Remaining headers
+ if col < width - 15:
+ remaining_space = width - col - 1
+ func_width = min(
+ MAX_FUNC_NAME_WIDTH,
+ max(MIN_FUNC_NAME_WIDTH, remaining_space // 2),
+ )
+ self.add_str(
+ line, col, f"{'function':<{func_width}}", normal_header
+ )
+ col += func_width + 2
+
+ if col < width - 10:
+ self.add_str(line, col, "file:line", normal_header)
+
+ return (
+ line + 1,
+ show_sample_pct,
+ show_tottime,
+ show_cumul_pct,
+ show_cumtime,
+ )
+
+ def draw_stats_rows(self, line, height, width, stats_list, column_flags):
+ """Draw the statistics data rows."""
+ show_sample_pct, show_tottime, show_cumul_pct, show_cumtime = (
+ column_flags
+ )
+
+ # Get color attributes from the colors dict (already initialized)
+ color_samples = self.colors.get("color_samples", curses.A_NORMAL)
+ color_file = self.colors.get("color_file", curses.A_NORMAL)
+ color_func = self.colors.get("color_func", curses.A_NORMAL)
+
+ # Get trend tracker for color decisions
+ trend_tracker = self.collector._trend_tracker
+
+ for stat in stats_list:
+ if line >= height - FOOTER_LINES:
+ break
+
+ func = stat["func"]
+ direct_calls = stat["direct_calls"]
+ cumulative_calls = stat["cumulative_calls"]
+ total_time = stat["total_time"]
+ cumulative_time = stat["cumulative_time"]
+ trends = stat.get("trends", {})
+
+ sample_pct = (
+ (direct_calls / self.collector.total_samples * 100)
+ if self.collector.total_samples > 0
+ else 0
+ )
+ cum_pct = (
+ (cumulative_calls / self.collector.total_samples * 100)
+ if self.collector.total_samples > 0
+ else 0
+ )
+
+ # Helper function to get trend color for a specific column
+ def get_trend_color(column_name):
+ trend = trends.get(column_name, "stable")
+ if trend_tracker is not None:
+ return trend_tracker.get_color(trend)
+ return curses.A_NORMAL
+
+ filename, lineno, funcname = func[0], func[1], func[2]
+ samples_str = f"{direct_calls}/{cumulative_calls}"
+ col = 0
+
+ # Samples column - apply trend color based on nsamples trend
+ nsamples_color = get_trend_color("nsamples")
+ self.add_str(line, col, f"{samples_str:>13}", nsamples_color)
+ col += 15
+
+ # Sample % column
+ if show_sample_pct:
+ sample_pct_color = get_trend_color("sample_pct")
+ self.add_str(line, col, f"{sample_pct:>5.1f}", sample_pct_color)
+ col += 7
+
+ # Total time column
+ if show_tottime:
+ tottime_color = get_trend_color("tottime")
+ self.add_str(line, col, f"{total_time:>10.3f}", tottime_color)
+ col += 12
+
+ # Cumul % column
+ if show_cumul_pct:
+ cumul_pct_color = get_trend_color("cumul_pct")
+ self.add_str(line, col, f"{cum_pct:>5.1f}", cumul_pct_color)
+ col += 7
+
+ # Cumul time column
+ if show_cumtime:
+ cumtime_color = get_trend_color("cumtime")
+ self.add_str(line, col, f"{cumulative_time:>10.3f}", cumtime_color)
+ col += 12
+
+ # Function name column
+ if col < width - 15:
+ remaining_space = width - col - 1
+ func_width = min(
+ MAX_FUNC_NAME_WIDTH,
+ max(MIN_FUNC_NAME_WIDTH, remaining_space // 2),
+ )
+
+ func_display = funcname
+ if len(funcname) > func_width:
+ func_display = funcname[: func_width - 3] + "..."
+ func_display = f"{func_display:<{func_width}}"
+ self.add_str(line, col, func_display, color_func)
+ col += func_width + 2
+
+ # File:line column
+ if col < width - 10:
+ simplified_path = self.collector._simplify_path(filename)
+ file_line = f"{simplified_path}:{lineno}"
+ remaining_width = width - col - 1
+ self.add_str(
+ line, col, file_line[:remaining_width], color_file
+ )
+
+ line += 1
+
+ return line
+
+
+class FooterWidget(Widget):
+ """Widget for rendering the footer section (legend and controls)."""
+
+ def __init__(self, display, colors, collector):
+ """
+ Initialize footer widget.
+
+ Args:
+ display: DisplayInterface implementation
+ colors: Dictionary of color attributes
+ collector: Reference to LiveStatsCollector for accessing state
+ """
+ super().__init__(display, colors)
+ self.collector = collector
+
+ def render(self, line, width, **kwargs):
+ """
+ Render the footer at the specified position.
+
+ Args:
+ line: Starting line number (should be height - 2)
+ width: Available width
+
+ Returns:
+ Next available line number
+ """
+ A_DIM = self.display.get_attr("A_DIM")
+ A_BOLD = self.display.get_attr("A_BOLD")
+
+ # Legend line
+ legend = "nsamples: direct/cumulative (direct=executing, cumulative=on stack)"
+ self.add_str(line, 0, legend[: width - 1], A_DIM)
+ line += 1
+
+ # Controls line with status
+ sort_names = {
+ "tottime": "Total Time",
+ "nsamples": "Direct Samples",
+ "cumtime": "Cumulative Time",
+ "sample_pct": "Sample %",
+ "cumul_pct": "Cumulative %",
+ }
+ sort_display = sort_names.get(
+ self.collector.sort_by, self.collector.sort_by
+ )
+
+ # Build status indicators
+ status = []
+ if self.collector.finished:
+ status.append("[PROFILING FINISHED - Press 'q' to quit]")
+ elif self.collector.paused:
+ status.append("[PAUSED]")
+ if self.collector.filter_pattern:
+ status.append(
+ f"[Filter: {self.collector.filter_pattern} (c to clear)]"
+ )
+ # Show trend colors status if disabled
+ if self.collector._trend_tracker is not None and not self.collector._trend_tracker.enabled:
+ status.append("[Trend colors: OFF]")
+ status_str = " ".join(status) + " " if status else ""
+
+ if self.collector.finished:
+ footer = f"{status_str}"
+ else:
+ footer = f"{status_str}Sort: {sort_display} | 't':mode 'x':trends ←→:thread 'h':help 'q':quit"
+ self.add_str(
+ line,
+ 0,
+ footer[: width - 1],
+ A_BOLD
+ if (self.collector.paused or self.collector.finished)
+ else A_DIM,
+ )
+
+ return line + 1
+
+ def render_filter_input_prompt(self, line, width):
+ """Draw the filter input prompt at the bottom of the screen."""
+ A_BOLD = self.display.get_attr("A_BOLD")
+ A_REVERSE = self.display.get_attr("A_REVERSE")
+
+ # Draw prompt on last line
+ prompt = f"Function filter: {self.collector.filter_input_buffer}_"
+ self.add_str(line, 0, prompt[: width - 1], A_REVERSE | A_BOLD)
+
+
+class HelpWidget(Widget):
+ """Widget for rendering the help screen overlay."""
+
+ def render(self, line, width, **kwargs):
+ """
+ Render the help screen.
+
+ Args:
+ line: Starting line number (ignored, help is centered)
+ width: Available width
+ kwargs: Must contain 'height' key
+
+ Returns:
+ Next available line number (not used for overlays)
+ """
+ height = kwargs["height"]
+ A_BOLD = self.display.get_attr("A_BOLD")
+ A_NORMAL = self.display.get_attr("A_NORMAL")
+
+ help_lines = [
+ ("Tachyon Profiler - Interactive Commands", A_BOLD),
+ ("", A_NORMAL),
+ ("Navigation & Display:", A_BOLD),
+ (" s - Cycle through sort modes (forward)", A_NORMAL),
+ (" S - Cycle through sort modes (backward)", A_NORMAL),
+ (" t - Toggle view mode (ALL / per-thread)", A_NORMAL),
+ (" x - Toggle trend colors (on/off)", A_NORMAL),
+ (" ← → ↑ ↓ - Navigate threads (in per-thread mode)", A_NORMAL),
+ (" + - Faster display refresh rate", A_NORMAL),
+ (" - - Slower display refresh rate", A_NORMAL),
+ ("", A_NORMAL),
+ ("Control:", A_BOLD),
+ (" p - Freeze display (snapshot)", A_NORMAL),
+ (" r - Reset all statistics", A_NORMAL),
+ ("", A_NORMAL),
+ ("Filtering:", A_BOLD),
+ (" / - Enter function filter (substring)", A_NORMAL),
+ (" c - Clear filter", A_NORMAL),
+ (" ESC - Cancel filter input", A_NORMAL),
+ ("", A_NORMAL),
+ ("Other:", A_BOLD),
+ (" h or ? - Show/hide this help", A_NORMAL),
+ (" q - Quit profiler", A_NORMAL),
+ ("", A_NORMAL),
+ ("Press any key to close this help screen", A_BOLD),
+ ]
+
+ start_line = (height - len(help_lines)) // 2
+ for i, (text, attr) in enumerate(help_lines):
+ if start_line + i < height - 1:
+ col = 2 # Left-align with small margin
+ self.add_str(start_line + i, col, text[: width - 3], attr)
+
+ return line # Not used for overlays
from .pstats_collector import PstatsCollector
from .stack_collector import CollapsedStackCollector, FlamegraphCollector
from .gecko_collector import GeckoCollector
+from .constants import (
+ PROFILING_MODE_WALL,
+ PROFILING_MODE_CPU,
+ PROFILING_MODE_GIL,
+ PROFILING_MODE_ALL,
+ SORT_MODE_NSAMPLES,
+ SORT_MODE_TOTTIME,
+ SORT_MODE_CUMTIME,
+ SORT_MODE_SAMPLE_PCT,
+ SORT_MODE_CUMUL_PCT,
+ SORT_MODE_NSAMPLES_CUMUL,
+)
+try:
+ from .live_collector import LiveStatsCollector
+except ImportError:
+ LiveStatsCollector = None
_FREE_THREADED_BUILD = sysconfig.get_config_var("Py_GIL_DISABLED") is not None
-# Profiling mode constants
-PROFILING_MODE_WALL = 0
-PROFILING_MODE_CPU = 1
-PROFILING_MODE_GIL = 2
-PROFILING_MODE_ALL = 3 # Combines GIL + CPU checks
-
def _parse_mode(mode_string):
"""Convert mode string to mode constant."""
- --pstats: Detailed profiling statistics with sorting options
- --collapsed: Stack traces for generating flamegraphs
- --flamegraph Interactive HTML flamegraph visualization (requires web browser)
+ - --live: Live top-like statistics display using ncurses
Examples:
# Profile process 1234 for 10 seconds with default settings
# Generate a HTML flamegraph
python -m profiling.sampling --flamegraph -p 1234
+ # Display live top-like statistics (press 'q' to quit, 's' to cycle sort)
+ python -m profiling.sampling --live -p 1234
+
# Profile all threads, sort by total time
python -m profiling.sampling -a --sort-tottime -p 1234
_RECV_BUFFER_SIZE = 1024
-def _run_with_sync(original_cmd):
+def _run_with_sync(original_cmd, suppress_output=False):
"""Run a command with socket-based synchronization and return the process."""
# Create a TCP socket for synchronization with better socket options
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sync_sock:
cmd = (sys.executable, "-m", "profiling.sampling._sync_coordinator", str(sync_port), cwd) + tuple(target_args)
# Start the process with coordinator
- process = subprocess.Popen(cmd)
+ # Suppress stdout/stderr if requested (for live mode)
+ popen_kwargs = {}
+ if suppress_output:
+ popen_kwargs['stdin'] = subprocess.DEVNULL
+ popen_kwargs['stdout'] = subprocess.DEVNULL
+ popen_kwargs['stderr'] = subprocess.DEVNULL
+
+ process = subprocess.Popen(cmd, **popen_kwargs)
try:
# Wait for ready signal with timeout
last_realtime_update = start_time
while running_time < duration_sec:
+ # Check if live collector wants to stop
+ if hasattr(collector, 'running') and not collector.running:
+ break
+
current_time = time.perf_counter()
if next_time < current_time:
try:
duration_sec = current_time - start_time
break
except (RuntimeError, UnicodeDecodeError, MemoryError, OSError):
+ collector.collect_failed_sample()
errors += 1
except Exception as e:
if not self._is_process_running():
sample_rate = num_samples / running_time
error_rate = (errors / num_samples) * 100 if num_samples > 0 else 0
- print(f"Captured {num_samples} samples in {running_time:.2f} seconds")
- print(f"Sample rate: {sample_rate:.2f} samples/sec")
- print(f"Error rate: {error_rate:.2f}%")
+ # Don't print stats for live mode (curses is handling display)
+ is_live_mode = LiveStatsCollector is not None and isinstance(collector, LiveStatsCollector)
+ if not is_live_mode:
+ print(f"Captured {num_samples} samples in {running_time:.2f} seconds")
+ print(f"Sample rate: {sample_rate:.2f} samples/sec")
+ print(f"Error rate: {error_rate:.2f}%")
# Pass stats to flamegraph collector if it's the right type
if hasattr(collector, 'set_stats'):
collector.set_stats(self.sample_interval_usec, running_time, sample_rate, error_rate)
expected_samples = int(duration_sec / sample_interval_sec)
- if num_samples < expected_samples:
+ if num_samples < expected_samples and not is_live_mode:
print(
f"Warning: missed {expected_samples - num_samples} samples "
f"from the expected total of {expected_samples} "
# Gecko format never skips idle threads to show full thread states
collector = GeckoCollector(skip_idle=False)
filename = filename or f"gecko.{pid}.json"
+ case "live":
+ # Map sort value to sort_by string
+ sort_by_map = {
+ SORT_MODE_NSAMPLES: "nsamples",
+ SORT_MODE_TOTTIME: "tottime",
+ SORT_MODE_CUMTIME: "cumtime",
+ SORT_MODE_SAMPLE_PCT: "sample_pct",
+ SORT_MODE_CUMUL_PCT: "cumul_pct",
+ SORT_MODE_NSAMPLES_CUMUL: "cumul_pct",
+ }
+ sort_by = sort_by_map.get(sort, "tottime")
+ collector = LiveStatsCollector(
+ sample_interval_usec,
+ skip_idle=skip_idle,
+ sort_by=sort_by,
+ limit=limit or 20,
+ pid=pid,
+ mode=mode,
+ )
+ # Live mode is interactive, don't save file by default
+ # User can specify -o if they want to save stats
case _:
raise ValueError(f"Invalid output format: {output_format}")
- profiler.sample(collector, duration_sec)
+ # For live mode, wrap sampling in curses
+ if output_format == "live":
+ import curses
+ def curses_wrapper_func(stdscr):
+ collector.init_curses(stdscr)
+ try:
+ profiler.sample(collector, duration_sec)
+ # Mark as finished and keep the TUI running until user presses 'q'
+ collector.mark_finished()
+ # Keep processing input until user quits
+ while collector.running:
+ collector._handle_input()
+ time.sleep(0.05) # Small sleep to avoid busy waiting
+ finally:
+ collector.cleanup_curses()
+
+ try:
+ curses.wrapper(curses_wrapper_func)
+ except KeyboardInterrupt:
+ pass
+ else:
+ profiler.sample(collector, duration_sec)
if output_format == "pstats" and not filename:
stats = pstats.SampledStats(collector).strip_dirs()
print_sampled_stats(
stats, sort, limit, show_summary, sample_interval_usec
)
- else:
+ elif output_format != "live":
+ # Live mode is interactive only, no export unless filename specified
collector.export(filename)
-def _validate_collapsed_format_args(args, parser):
- # Check for incompatible pstats options
- invalid_opts = []
-
- # Get list of pstats-specific options
- pstats_options = {"sort": None, "limit": None, "no_summary": False}
+def _validate_file_output_format_args(args, parser):
+ """Validate arguments when using file-based output formats.
- # Find the default values from the argument definitions
- for action in parser._actions:
- if action.dest in pstats_options and hasattr(action, "default"):
- pstats_options[action.dest] = action.default
+ File-based formats (--collapsed, --gecko, --flamegraph) generate raw stack
+ data or visualizations, not formatted statistics, so pstats display options
+ are not applicable.
+ """
+ invalid_opts = []
- # Check if any pstats-specific options were provided by comparing with defaults
- for opt, default in pstats_options.items():
- if getattr(args, opt) != default:
- invalid_opts.append(opt.replace("no_", ""))
+ # Check if any pstats-specific sort options were provided
+ if args.sort is not None:
+ # Get the sort option name that was used
+ sort_names = {
+ SORT_MODE_NSAMPLES: "--sort-nsamples",
+ SORT_MODE_TOTTIME: "--sort-tottime",
+ SORT_MODE_CUMTIME: "--sort-cumtime",
+ SORT_MODE_SAMPLE_PCT: "--sort-sample-pct",
+ SORT_MODE_CUMUL_PCT: "--sort-cumul-pct",
+ SORT_MODE_NSAMPLES_CUMUL: "--sort-nsamples-cumul",
+ -1: "--sort-name",
+ }
+ sort_opt = sort_names.get(args.sort, "sort")
+ invalid_opts.append(sort_opt)
+
+ # Check limit option (default is 15)
+ if args.limit != 15:
+ invalid_opts.append("-l/--limit")
+
+ # Check no_summary option
+ if args.no_summary:
+ invalid_opts.append("--no-summary")
if invalid_opts:
parser.error(
- f"The following options are only valid with --pstats format: {', '.join(invalid_opts)}"
+ f"--{args.format} format is incompatible with: {', '.join(invalid_opts)}. "
+ "These options are only valid with --pstats format."
)
+ # Validate that --mode is not used with --gecko
+ if args.format == "gecko" and args.mode != "wall":
+ parser.error("--mode option is incompatible with --gecko format. Gecko format automatically uses ALL mode (GIL + CPU analysis).")
+
# Set default output filename for collapsed format only if we have a PID
# For module/script execution, this will be set later with the subprocess PID
if not args.outfile and args.pid is not None:
args.outfile = f"collapsed.{args.pid}.txt"
+def _validate_live_format_args(args, parser):
+ """Validate arguments when using --live output format.
+
+ Live mode provides an interactive TUI that is incompatible with file output
+ and certain pstats display options.
+ """
+ invalid_opts = []
+
+ # Live mode is incompatible with file output
+ if args.outfile:
+ invalid_opts.append("-o/--outfile")
+
+ # pstats-specific display options are incompatible
+ if args.no_summary:
+ invalid_opts.append("--no-summary")
+
+ if invalid_opts:
+ parser.error(
+ f"--live mode is incompatible with: {', '.join(invalid_opts)}. "
+ "Live mode provides its own interactive display."
+ )
+
+
def wait_for_process_and_sample(pid, sort_value, args):
"""Sample the process immediately since it has already signaled readiness."""
# Set default filename with subprocess PID if not already set
dest="format",
help="Generate Gecko format for Firefox Profiler",
)
+ output_format.add_argument(
+ "--live",
+ action="store_const",
+ const="live",
+ dest="format",
+ help="Display live top-like live statistics in a terminal UI",
+ )
output_group.add_argument(
"-o",
sort_group.add_argument(
"--sort-nsamples",
action="store_const",
- const=0,
+ const=SORT_MODE_NSAMPLES,
dest="sort",
- help="Sort by number of direct samples (nsamples column)",
+ help="Sort by number of direct samples (nsamples column, default)",
)
sort_group.add_argument(
"--sort-tottime",
action="store_const",
- const=1,
+ const=SORT_MODE_TOTTIME,
dest="sort",
help="Sort by total time (tottime column)",
)
sort_group.add_argument(
"--sort-cumtime",
action="store_const",
- const=2,
+ const=SORT_MODE_CUMTIME,
dest="sort",
- help="Sort by cumulative time (cumtime column, default)",
+ help="Sort by cumulative time (cumtime column)",
)
sort_group.add_argument(
"--sort-sample-pct",
action="store_const",
- const=3,
+ const=SORT_MODE_SAMPLE_PCT,
dest="sort",
help="Sort by sample percentage (sample%% column)",
)
sort_group.add_argument(
"--sort-cumul-pct",
action="store_const",
- const=4,
+ const=SORT_MODE_CUMUL_PCT,
dest="sort",
help="Sort by cumulative sample percentage (cumul%% column)",
)
sort_group.add_argument(
"--sort-nsamples-cumul",
action="store_const",
- const=5,
+ const=SORT_MODE_NSAMPLES_CUMUL,
dest="sort",
help="Sort by cumulative samples (nsamples column, cumulative part)",
)
args = parser.parse_args()
- # Validate format-specific arguments
- if args.format in ("collapsed", "gecko"):
- _validate_collapsed_format_args(args, parser)
+ # Check if live mode is available early
+ if args.format == "live" and LiveStatsCollector is None:
+ print(
+ "Error: Live mode (--live) requires the curses module, which is not available.\n",
+ file=sys.stderr
+ )
+ sys.exit(1)
- # Validate that --mode is not used with --gecko
- if args.format == "gecko" and args.mode != "wall":
- parser.error("--mode option is incompatible with --gecko format. Gecko format automatically uses ALL mode (GIL + CPU analysis).")
+ # Validate format-specific arguments
+ if args.format in ("collapsed", "gecko", "flamegraph"):
+ _validate_file_output_format_args(args, parser)
+ elif args.format == "live":
+ _validate_live_format_args(args, parser)
- sort_value = args.sort if args.sort is not None else 2
+ sort_value = args.sort if args.sort is not None else SORT_MODE_NSAMPLES
if args.module is not None and not args.module:
parser.error("argument -m/--module: expected one argument")
cmd = (sys.executable, *args.args)
# Use synchronized process startup
- process = _run_with_sync(cmd)
+ # Suppress output if using live mode
+ suppress_output = (args.format == "live")
+ process = _run_with_sync(cmd, suppress_output=suppress_output)
# Process has already signaled readiness, start sampling immediately
try:
--- /dev/null
+"""Common test helpers and mocks for live collector tests."""
+
+from profiling.sampling.constants import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+)
+
+
+class MockFrameInfo:
+ """Mock FrameInfo for testing."""
+
+ def __init__(self, filename, lineno, funcname):
+ self.filename = filename
+ self.lineno = lineno
+ self.funcname = funcname
+
+ def __repr__(self):
+ return f"MockFrameInfo(filename='{self.filename}', lineno={self.lineno}, funcname='{self.funcname}')"
+
+
+class MockThreadInfo:
+ """Mock ThreadInfo for testing."""
+
+ def __init__(self, thread_id, frame_info, status=THREAD_STATUS_HAS_GIL | THREAD_STATUS_ON_CPU):
+ self.thread_id = thread_id
+ self.frame_info = frame_info
+ self.status = status
+
+ def __repr__(self):
+ return f"MockThreadInfo(thread_id={self.thread_id}, frame_info={self.frame_info}, status={self.status})"
+
+
+class MockInterpreterInfo:
+ """Mock InterpreterInfo for testing."""
+
+ def __init__(self, interpreter_id, threads):
+ self.interpreter_id = interpreter_id
+ self.threads = threads
+
+ def __repr__(self):
+ return f"MockInterpreterInfo(interpreter_id={self.interpreter_id}, threads={self.threads})"
self._verify_coordinator_command(mock_popen, ("-m", "mymodule"))
mock_sample.assert_called_once_with(
12345,
- sort=2, # default sort (sort_value from args.sort)
+ sort=0, # default sort (sort_value from args.sort)
sample_interval_usec=100,
duration_sec=10,
filename=None,
)
mock_sample.assert_called_once_with(
12345,
- sort=2,
+ sort=0,
sample_interval_usec=100,
duration_sec=10,
filename=None,
self._verify_coordinator_command(mock_popen, ("myscript.py",))
mock_sample.assert_called_once_with(
12345,
- sort=2,
+ sort=0,
sample_interval_usec=100,
duration_sec=10,
filename=None,
# Verify profiler options were passed correctly
mock_sample.assert_called_once_with(
12345,
- sort=2, # default sort
+ sort=0, # default sort
sample_interval_usec=2000,
duration_sec=60,
filename="output.txt",
"-v",
"--output=/tmp/out",
"positional",
- )
+ ),
+ suppress_output=False
)
def test_cli_collapsed_format_validation(self):
filename=None,
all_threads=False,
limit=15,
- sort=2,
+ sort=0,
show_summary=True,
output_format="pstats",
realtime_stats=False,
--- /dev/null
+"""Core functionality tests for LiveStatsCollector.
+
+Tests for path simplification, frame processing, collect method,
+statistics building, sorting, and formatting.
+"""
+
+import os
+import unittest
+from test.support import requires
+from test.support.import_helper import import_module
+
+# Only run these tests if curses is available
+requires("curses")
+curses = import_module("curses")
+
+from profiling.sampling.live_collector import LiveStatsCollector, MockDisplay
+from profiling.sampling.constants import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+)
+from ._live_collector_helpers import (
+ MockFrameInfo,
+ MockThreadInfo,
+ MockInterpreterInfo,
+)
+
+
+class TestLiveStatsCollectorPathSimplification(unittest.TestCase):
+ """Tests for path simplification functionality."""
+
+ def test_simplify_stdlib_path(self):
+ """Test simplification of standard library paths."""
+ collector = LiveStatsCollector(1000)
+ # Get actual os module path
+ os_file = os.__file__
+ if os_file:
+ stdlib_dir = os.path.dirname(os.path.abspath(os_file))
+ test_path = os.path.join(stdlib_dir, "json", "decoder.py")
+ simplified = collector._simplify_path(test_path)
+ # Should remove the stdlib prefix
+ self.assertNotIn(stdlib_dir, simplified)
+ self.assertIn("json", simplified)
+
+ def test_simplify_unknown_path(self):
+ """Test that unknown paths are returned unchanged."""
+ collector = LiveStatsCollector(1000)
+ test_path = "/some/unknown/path/file.py"
+ simplified = collector._simplify_path(test_path)
+ self.assertEqual(simplified, test_path)
+
+
+class TestLiveStatsCollectorFrameProcessing(unittest.TestCase):
+ """Tests for frame processing functionality."""
+
+ def test_process_single_frame(self):
+ """Test processing a single frame."""
+ collector = LiveStatsCollector(1000)
+ frames = [MockFrameInfo("test.py", 10, "test_func")]
+ collector._process_frames(frames)
+
+ location = ("test.py", 10, "test_func")
+ self.assertEqual(collector.result[location]["direct_calls"], 1)
+ self.assertEqual(collector.result[location]["cumulative_calls"], 1)
+
+ def test_process_multiple_frames(self):
+ """Test processing a stack of multiple frames."""
+ collector = LiveStatsCollector(1000)
+ frames = [
+ MockFrameInfo("test.py", 10, "inner_func"),
+ MockFrameInfo("test.py", 20, "middle_func"),
+ MockFrameInfo("test.py", 30, "outer_func"),
+ ]
+ collector._process_frames(frames)
+
+ # Top frame (inner_func) should have both direct and cumulative
+ inner_loc = ("test.py", 10, "inner_func")
+ self.assertEqual(collector.result[inner_loc]["direct_calls"], 1)
+ self.assertEqual(collector.result[inner_loc]["cumulative_calls"], 1)
+
+ # Other frames should only have cumulative
+ middle_loc = ("test.py", 20, "middle_func")
+ self.assertEqual(collector.result[middle_loc]["direct_calls"], 0)
+ self.assertEqual(collector.result[middle_loc]["cumulative_calls"], 1)
+
+ outer_loc = ("test.py", 30, "outer_func")
+ self.assertEqual(collector.result[outer_loc]["direct_calls"], 0)
+ self.assertEqual(collector.result[outer_loc]["cumulative_calls"], 1)
+
+ def test_process_empty_frames(self):
+ """Test processing empty frames list."""
+ collector = LiveStatsCollector(1000)
+ collector._process_frames([])
+ # Should not raise an error and result should remain empty
+ self.assertEqual(len(collector.result), 0)
+
+ def test_process_frames_accumulation(self):
+ """Test that multiple calls accumulate correctly."""
+ collector = LiveStatsCollector(1000)
+ frames = [MockFrameInfo("test.py", 10, "test_func")]
+
+ collector._process_frames(frames)
+ collector._process_frames(frames)
+ collector._process_frames(frames)
+
+ location = ("test.py", 10, "test_func")
+ self.assertEqual(collector.result[location]["direct_calls"], 3)
+ self.assertEqual(collector.result[location]["cumulative_calls"], 3)
+
+ def test_process_frames_with_thread_id(self):
+ """Test processing frames with per-thread tracking."""
+ collector = LiveStatsCollector(1000)
+ frames = [MockFrameInfo("test.py", 10, "test_func")]
+
+ # Process frames with thread_id
+ collector._process_frames(frames, thread_id=123)
+
+ # Check aggregated result
+ location = ("test.py", 10, "test_func")
+ self.assertEqual(collector.result[location]["direct_calls"], 1)
+ self.assertEqual(collector.result[location]["cumulative_calls"], 1)
+
+ # Check per-thread result
+ self.assertIn(123, collector.per_thread_data)
+ self.assertEqual(
+ collector.per_thread_data[123].result[location]["direct_calls"], 1
+ )
+ self.assertEqual(
+ collector.per_thread_data[123].result[location]["cumulative_calls"], 1
+ )
+
+ def test_process_frames_multiple_threads(self):
+ """Test processing frames from multiple threads."""
+ collector = LiveStatsCollector(1000)
+ frames1 = [MockFrameInfo("test.py", 10, "test_func")]
+ frames2 = [MockFrameInfo("test.py", 20, "other_func")]
+
+ # Process frames from different threads
+ collector._process_frames(frames1, thread_id=123)
+ collector._process_frames(frames2, thread_id=456)
+
+ # Check that both threads have their own data
+ self.assertIn(123, collector.per_thread_data)
+ self.assertIn(456, collector.per_thread_data)
+
+ loc1 = ("test.py", 10, "test_func")
+ loc2 = ("test.py", 20, "other_func")
+
+ # Thread 123 should only have func1
+ self.assertEqual(
+ collector.per_thread_data[123].result[loc1]["direct_calls"], 1
+ )
+ self.assertNotIn(loc2, collector.per_thread_data[123].result)
+
+ # Thread 456 should only have func2
+ self.assertEqual(
+ collector.per_thread_data[456].result[loc2]["direct_calls"], 1
+ )
+ self.assertNotIn(loc1, collector.per_thread_data[456].result)
+
+
+class TestLiveStatsCollectorCollect(unittest.TestCase):
+ """Tests for the collect method."""
+
+ def test_collect_initializes_start_time(self):
+ """Test that collect initializes start_time on first call."""
+ collector = LiveStatsCollector(1000)
+ self.assertIsNone(collector.start_time)
+
+ # Create mock stack frames
+ thread_info = MockThreadInfo(123, [])
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+
+ collector.collect(stack_frames)
+ self.assertIsNotNone(collector.start_time)
+
+ def test_collect_increments_sample_count(self):
+ """Test that collect increments total_samples."""
+ collector = LiveStatsCollector(1000)
+ thread_info = MockThreadInfo(123, [])
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+
+ self.assertEqual(collector.total_samples, 0)
+ collector.collect(stack_frames)
+ self.assertEqual(collector.total_samples, 1)
+ collector.collect(stack_frames)
+ self.assertEqual(collector.total_samples, 2)
+
+ def test_collect_with_frames(self):
+ """Test collect with actual frame data."""
+ collector = LiveStatsCollector(1000)
+ frames = [MockFrameInfo("test.py", 10, "test_func")]
+ thread_info = MockThreadInfo(123, frames)
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+
+ collector.collect(stack_frames)
+
+ location = ("test.py", 10, "test_func")
+ self.assertEqual(collector.result[location]["direct_calls"], 1)
+ self.assertEqual(collector._successful_samples, 1)
+ self.assertEqual(collector._failed_samples, 0)
+
+ def test_collect_with_empty_frames(self):
+ """Test collect with empty frames."""
+ collector = LiveStatsCollector(1000)
+ thread_info = MockThreadInfo(123, [])
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+
+ collector.collect(stack_frames)
+
+ # Empty frames still count as successful since collect() was called successfully
+ self.assertEqual(collector._successful_samples, 1)
+ self.assertEqual(collector._failed_samples, 0)
+
+ def test_collect_skip_idle_threads(self):
+ """Test that idle threads are skipped when skip_idle=True."""
+ collector = LiveStatsCollector(1000, skip_idle=True)
+
+ frames = [MockFrameInfo("test.py", 10, "test_func")]
+ running_thread = MockThreadInfo(
+ 123, frames, status=THREAD_STATUS_HAS_GIL | THREAD_STATUS_ON_CPU
+ )
+ idle_thread = MockThreadInfo(124, frames, status=0) # No flags = idle
+ interpreter_info = MockInterpreterInfo(
+ 0, [running_thread, idle_thread]
+ )
+ stack_frames = [interpreter_info]
+
+ collector.collect(stack_frames)
+
+ # Only one thread should be processed
+ location = ("test.py", 10, "test_func")
+ self.assertEqual(collector.result[location]["direct_calls"], 1)
+
+ def test_collect_multiple_threads(self):
+ """Test collect with multiple threads."""
+ collector = LiveStatsCollector(1000)
+
+ frames1 = [MockFrameInfo("test1.py", 10, "func1")]
+ frames2 = [MockFrameInfo("test2.py", 20, "func2")]
+ thread1 = MockThreadInfo(123, frames1)
+ thread2 = MockThreadInfo(124, frames2)
+ interpreter_info = MockInterpreterInfo(0, [thread1, thread2])
+ stack_frames = [interpreter_info]
+
+ collector.collect(stack_frames)
+
+ loc1 = ("test1.py", 10, "func1")
+ loc2 = ("test2.py", 20, "func2")
+ self.assertEqual(collector.result[loc1]["direct_calls"], 1)
+ self.assertEqual(collector.result[loc2]["direct_calls"], 1)
+
+ # Check thread IDs are tracked
+ self.assertIn(123, collector.thread_ids)
+ self.assertIn(124, collector.thread_ids)
+
+
+class TestLiveStatsCollectorStatisticsBuilding(unittest.TestCase):
+ """Tests for statistics building and sorting."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.collector = LiveStatsCollector(1000)
+ # Add some test data
+ self.collector.result[("file1.py", 10, "func1")] = {
+ "direct_calls": 100,
+ "cumulative_calls": 150,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("file2.py", 20, "func2")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 200,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("file3.py", 30, "func3")] = {
+ "direct_calls": 75,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+ self.collector.total_samples = 300
+
+ def test_build_stats_list(self):
+ """Test that stats list is built correctly."""
+ stats_list = self.collector._build_stats_list()
+ self.assertEqual(len(stats_list), 3)
+
+ # Check that all expected keys are present
+ for stat in stats_list:
+ self.assertIn("func", stat)
+ self.assertIn("direct_calls", stat)
+ self.assertIn("cumulative_calls", stat)
+ self.assertIn("total_time", stat)
+ self.assertIn("cumulative_time", stat)
+
+ def test_sort_by_nsamples(self):
+ """Test sorting by number of samples."""
+ self.collector.sort_by = "nsamples"
+ stats_list = self.collector._build_stats_list()
+
+ # Should be sorted by direct_calls descending
+ self.assertEqual(stats_list[0]["func"][2], "func1") # 100 samples
+ self.assertEqual(stats_list[1]["func"][2], "func3") # 75 samples
+ self.assertEqual(stats_list[2]["func"][2], "func2") # 50 samples
+
+ def test_sort_by_tottime(self):
+ """Test sorting by total time."""
+ self.collector.sort_by = "tottime"
+ stats_list = self.collector._build_stats_list()
+
+ # Should be sorted by total_time descending
+ # total_time = direct_calls * sample_interval_sec
+ self.assertEqual(stats_list[0]["func"][2], "func1")
+ self.assertEqual(stats_list[1]["func"][2], "func3")
+ self.assertEqual(stats_list[2]["func"][2], "func2")
+
+ def test_sort_by_cumtime(self):
+ """Test sorting by cumulative time."""
+ self.collector.sort_by = "cumtime"
+ stats_list = self.collector._build_stats_list()
+
+ # Should be sorted by cumulative_time descending
+ self.assertEqual(stats_list[0]["func"][2], "func2") # 200 cumulative
+ self.assertEqual(stats_list[1]["func"][2], "func1") # 150 cumulative
+ self.assertEqual(stats_list[2]["func"][2], "func3") # 75 cumulative
+
+ def test_sort_by_sample_pct(self):
+ """Test sorting by sample percentage."""
+ self.collector.sort_by = "sample_pct"
+ stats_list = self.collector._build_stats_list()
+
+ # Should be sorted by percentage of direct_calls
+ self.assertEqual(stats_list[0]["func"][2], "func1") # 33.3%
+ self.assertEqual(stats_list[1]["func"][2], "func3") # 25%
+ self.assertEqual(stats_list[2]["func"][2], "func2") # 16.7%
+
+ def test_sort_by_cumul_pct(self):
+ """Test sorting by cumulative percentage."""
+ self.collector.sort_by = "cumul_pct"
+ stats_list = self.collector._build_stats_list()
+
+ # Should be sorted by percentage of cumulative_calls
+ self.assertEqual(stats_list[0]["func"][2], "func2") # 66.7%
+ self.assertEqual(stats_list[1]["func"][2], "func1") # 50%
+ self.assertEqual(stats_list[2]["func"][2], "func3") # 25%
+
+
+class TestLiveStatsCollectorSortCycle(unittest.TestCase):
+ """Tests for sort mode cycling."""
+
+ def test_cycle_sort_from_nsamples(self):
+ """Test cycling from nsamples."""
+ collector = LiveStatsCollector(1000, sort_by="nsamples")
+ collector._cycle_sort()
+ self.assertEqual(collector.sort_by, "sample_pct")
+
+ def test_cycle_sort_from_sample_pct(self):
+ """Test cycling from sample_pct."""
+ collector = LiveStatsCollector(1000, sort_by="sample_pct")
+ collector._cycle_sort()
+ self.assertEqual(collector.sort_by, "tottime")
+
+ def test_cycle_sort_from_tottime(self):
+ """Test cycling from tottime."""
+ collector = LiveStatsCollector(1000, sort_by="tottime")
+ collector._cycle_sort()
+ self.assertEqual(collector.sort_by, "cumul_pct")
+
+ def test_cycle_sort_from_cumul_pct(self):
+ """Test cycling from cumul_pct."""
+ collector = LiveStatsCollector(1000, sort_by="cumul_pct")
+ collector._cycle_sort()
+ self.assertEqual(collector.sort_by, "cumtime")
+
+ def test_cycle_sort_from_cumtime(self):
+ """Test cycling from cumtime back to nsamples."""
+ collector = LiveStatsCollector(1000, sort_by="cumtime")
+ collector._cycle_sort()
+ self.assertEqual(collector.sort_by, "nsamples")
+
+ def test_cycle_sort_invalid_mode(self):
+ """Test cycling from invalid mode resets to nsamples."""
+ collector = LiveStatsCollector(1000)
+ collector.sort_by = "invalid_mode"
+ collector._cycle_sort()
+ self.assertEqual(collector.sort_by, "nsamples")
+
+ def test_cycle_sort_backward_from_nsamples(self):
+ """Test cycling backward from nsamples goes to cumtime."""
+ collector = LiveStatsCollector(1000, sort_by="nsamples")
+ collector._cycle_sort(reverse=True)
+ self.assertEqual(collector.sort_by, "cumtime")
+
+ def test_cycle_sort_backward_from_cumtime(self):
+ """Test cycling backward from cumtime goes to cumul_pct."""
+ collector = LiveStatsCollector(1000, sort_by="cumtime")
+ collector._cycle_sort(reverse=True)
+ self.assertEqual(collector.sort_by, "cumul_pct")
+
+ def test_cycle_sort_backward_from_sample_pct(self):
+ """Test cycling backward from sample_pct goes to nsamples."""
+ collector = LiveStatsCollector(1000, sort_by="sample_pct")
+ collector._cycle_sort(reverse=True)
+ self.assertEqual(collector.sort_by, "nsamples")
+
+ def test_input_lowercase_s_cycles_forward(self):
+ """Test that lowercase 's' cycles forward."""
+ display = MockDisplay()
+ collector = LiveStatsCollector(
+ 1000, sort_by="nsamples", display=display
+ )
+
+ display.simulate_input(ord("s"))
+ collector._handle_input()
+
+ self.assertEqual(collector.sort_by, "sample_pct")
+
+ def test_input_uppercase_s_cycles_backward(self):
+ """Test that uppercase 'S' cycles backward."""
+ display = MockDisplay()
+ collector = LiveStatsCollector(
+ 1000, sort_by="nsamples", display=display
+ )
+
+ display.simulate_input(ord("S"))
+ collector._handle_input()
+
+ self.assertEqual(collector.sort_by, "cumtime")
+
+
+class TestLiveStatsCollectorFormatting(unittest.TestCase):
+ """Tests for formatting methods."""
+
+ def test_format_uptime_seconds(self):
+ """Test uptime formatting for seconds only."""
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ self.assertEqual(collector._header_widget.format_uptime(45), "0m45s")
+
+ def test_format_uptime_minutes(self):
+ """Test uptime formatting for minutes."""
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ self.assertEqual(collector._header_widget.format_uptime(125), "2m05s")
+
+ def test_format_uptime_hours(self):
+ """Test uptime formatting for hours."""
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ self.assertEqual(
+ collector._header_widget.format_uptime(3661), "1h01m01s"
+ )
+
+ def test_format_uptime_large_values(self):
+ """Test uptime formatting for large time values."""
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ self.assertEqual(
+ collector._header_widget.format_uptime(86400), "24h00m00s"
+ )
+
+ def test_format_uptime_zero(self):
+ """Test uptime formatting for zero."""
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ self.assertEqual(collector._header_widget.format_uptime(0), "0m00s")
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+"""Interactive controls tests for LiveStatsCollector.
+
+Tests for interactive controls, filtering, filter input, and thread navigation.
+"""
+
+import time
+import unittest
+from test.support import requires
+from test.support.import_helper import import_module
+
+# Only run these tests if curses is available
+requires("curses")
+curses = import_module("curses")
+
+from profiling.sampling.live_collector import LiveStatsCollector, MockDisplay
+from profiling.sampling.constants import (
+ THREAD_STATUS_HAS_GIL,
+ THREAD_STATUS_ON_CPU,
+)
+from ._live_collector_helpers import (
+ MockFrameInfo,
+ MockThreadInfo,
+ MockInterpreterInfo,
+)
+
+
+class TestLiveCollectorInteractiveControls(unittest.TestCase):
+ """Tests for interactive control features."""
+
+ def setUp(self):
+ """Set up collector with mock display."""
+ self.display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.display
+ )
+ self.collector.start_time = time.perf_counter()
+ # Set a consistent display update interval for tests
+ self.collector._display_update_interval = 0.1
+
+ def tearDown(self):
+ """Clean up after test."""
+ pass
+
+ def test_pause_functionality(self):
+ """Test pause/resume functionality."""
+ self.assertFalse(self.collector.paused)
+
+ # Simulate 'p' key press
+ self.display.simulate_input(ord("p"))
+ self.collector._handle_input()
+
+ self.assertTrue(self.collector.paused)
+
+ # Press 'p' again to resume
+ self.display.simulate_input(ord("p"))
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.paused)
+
+ def test_pause_stops_ui_updates(self):
+ """Test that pausing stops UI updates but profiling continues."""
+ # Add some data
+ self.collector.total_samples = 10
+ self.collector.result[("test.py", 1, "func")] = {
+ "direct_calls": 5,
+ "cumulative_calls": 10,
+ "total_rec_calls": 0,
+ }
+
+ # Pause
+ self.collector.paused = True
+
+ # Simulate a collect call (profiling continues)
+ thread_info = MockThreadInfo(123, [])
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+
+ initial_samples = self.collector.total_samples
+ self.collector.collect(stack_frames)
+
+ # Samples should still increment
+ self.assertEqual(self.collector.total_samples, initial_samples + 1)
+
+ # But display should not have been updated (buffer stays clear)
+ self.display.cleared = False
+ self.collector.collect(stack_frames)
+ self.assertFalse(
+ self.display.cleared, "Display should not update when paused"
+ )
+
+ def test_reset_stats(self):
+ """Test reset statistics functionality."""
+ # Add some stats
+ self.collector.total_samples = 100
+ self.collector._successful_samples = 90
+ self.collector._failed_samples = 10
+ self.collector.result[("test.py", 1, "func")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+
+ # Reset
+ self.collector.reset_stats()
+
+ self.assertEqual(self.collector.total_samples, 0)
+ self.assertEqual(self.collector._successful_samples, 0)
+ self.assertEqual(self.collector._failed_samples, 0)
+ self.assertEqual(len(self.collector.result), 0)
+
+ def test_increase_refresh_rate(self):
+ """Test increasing refresh rate (faster updates)."""
+ initial_interval = self.collector._display_update_interval
+
+ # Simulate '+' key press (faster = smaller interval)
+ self.display.simulate_input(ord("+"))
+ self.collector._handle_input()
+
+ self.assertLess(self.collector._display_update_interval, initial_interval)
+
+ def test_decrease_refresh_rate(self):
+ """Test decreasing refresh rate (slower updates)."""
+ initial_interval = self.collector._display_update_interval
+
+ # Simulate '-' key press (slower = larger interval)
+ self.display.simulate_input(ord("-"))
+ self.collector._handle_input()
+
+ self.assertGreater(self.collector._display_update_interval, initial_interval)
+
+ def test_refresh_rate_minimum(self):
+ """Test that refresh rate has a minimum (max speed)."""
+ self.collector._display_update_interval = 0.05 # Set to minimum
+
+ # Try to go faster
+ self.display.simulate_input(ord("+"))
+ self.collector._handle_input()
+
+ # Should stay at minimum
+ self.assertEqual(self.collector._display_update_interval, 0.05)
+
+ def test_refresh_rate_maximum(self):
+ """Test that refresh rate has a maximum (min speed)."""
+ self.collector._display_update_interval = 1.0 # Set to maximum
+
+ # Try to go slower
+ self.display.simulate_input(ord("-"))
+ self.collector._handle_input()
+
+ # Should stay at maximum
+ self.assertEqual(self.collector._display_update_interval, 1.0)
+
+ def test_help_toggle(self):
+ """Test help screen toggle."""
+ self.assertFalse(self.collector.show_help)
+
+ # Show help
+ self.display.simulate_input(ord("h"))
+ self.collector._handle_input()
+
+ self.assertTrue(self.collector.show_help)
+
+ # Pressing any key closes help
+ self.display.simulate_input(ord("x"))
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.show_help)
+
+ def test_help_with_question_mark(self):
+ """Test help screen with '?' key."""
+ self.display.simulate_input(ord("?"))
+ self.collector._handle_input()
+
+ self.assertTrue(self.collector.show_help)
+
+ def test_filter_clear(self):
+ """Test clearing filter."""
+ self.collector.filter_pattern = "test"
+
+ # Clear filter
+ self.display.simulate_input(ord("c"))
+ self.collector._handle_input()
+
+ self.assertIsNone(self.collector.filter_pattern)
+
+ def test_filter_clear_when_none(self):
+ """Test clearing filter when no filter is set."""
+ self.assertIsNone(self.collector.filter_pattern)
+
+ # Should not crash
+ self.display.simulate_input(ord("c"))
+ self.collector._handle_input()
+
+ self.assertIsNone(self.collector.filter_pattern)
+
+ def test_paused_status_in_footer(self):
+ """Test that paused status appears in footer."""
+ self.collector.total_samples = 10
+ self.collector.paused = True
+
+ self.collector._update_display()
+
+ # Check that PAUSED appears in display
+ self.assertTrue(self.display.contains_text("PAUSED"))
+
+ def test_filter_status_in_footer(self):
+ """Test that filter status appears in footer."""
+ self.collector.total_samples = 10
+ self.collector.filter_pattern = "mytest"
+
+ self.collector._update_display()
+
+ # Check that filter info appears
+ self.assertTrue(self.display.contains_text("Filter"))
+
+ def test_help_screen_display(self):
+ """Test that help screen is displayed."""
+ self.collector.show_help = True
+
+ self.collector._update_display()
+
+ # Check for help content
+ self.assertTrue(self.display.contains_text("Interactive Commands"))
+
+ def test_pause_uppercase(self):
+ """Test pause with uppercase 'P' key."""
+ self.assertFalse(self.collector.paused)
+
+ self.display.simulate_input(ord("P"))
+ self.collector._handle_input()
+
+ self.assertTrue(self.collector.paused)
+
+ def test_help_uppercase(self):
+ """Test help with uppercase 'H' key."""
+ self.assertFalse(self.collector.show_help)
+
+ self.display.simulate_input(ord("H"))
+ self.collector._handle_input()
+
+ self.assertTrue(self.collector.show_help)
+
+ def test_reset_lowercase(self):
+ """Test reset with lowercase 'r' key."""
+ # Add some stats
+ self.collector.total_samples = 100
+ self.collector.result[("test.py", 1, "func")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+
+ self.display.simulate_input(ord("r"))
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.total_samples, 0)
+ self.assertEqual(len(self.collector.result), 0)
+
+ def test_reset_uppercase(self):
+ """Test reset with uppercase 'R' key."""
+ self.collector.total_samples = 100
+
+ self.display.simulate_input(ord("R"))
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.total_samples, 0)
+
+ def test_filter_clear_uppercase(self):
+ """Test clearing filter with uppercase 'C' key."""
+ self.collector.filter_pattern = "test"
+
+ self.display.simulate_input(ord("C"))
+ self.collector._handle_input()
+
+ self.assertIsNone(self.collector.filter_pattern)
+
+ def test_increase_refresh_rate_with_equals(self):
+ """Test increasing refresh rate with '=' key."""
+ initial_interval = self.collector._display_update_interval
+
+ # Simulate '=' key press (alternative to '+')
+ self.display.simulate_input(ord("="))
+ self.collector._handle_input()
+
+ self.assertLess(self.collector._display_update_interval, initial_interval)
+
+ def test_decrease_refresh_rate_with_underscore(self):
+ """Test decreasing refresh rate with '_' key."""
+ initial_interval = self.collector._display_update_interval
+
+ # Simulate '_' key press (alternative to '-')
+ self.display.simulate_input(ord("_"))
+ self.collector._handle_input()
+
+ self.assertGreater(self.collector._display_update_interval, initial_interval)
+
+ def test_finished_state_displays_banner(self):
+ """Test that finished state shows prominent banner."""
+ # Add some sample data
+ thread_info = MockThreadInfo(
+ 123,
+ [
+ MockFrameInfo("test.py", 10, "work"),
+ MockFrameInfo("test.py", 20, "main"),
+ ],
+ )
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+ self.collector.collect(stack_frames)
+
+ # Mark as finished
+ self.collector.mark_finished()
+
+ # Check that finished flag is set
+ self.assertTrue(self.collector.finished)
+
+ # Check that the banner message is displayed
+ self.assertTrue(self.display.contains_text("PROFILING COMPLETE"))
+ self.assertTrue(self.display.contains_text("Press 'q' to Quit"))
+
+ def test_finished_state_allows_ui_controls(self):
+ """Test that finished state allows UI controls but prioritizes quit."""
+ self.collector.finished = True
+ self.collector.running = True
+
+ # Try pressing 's' (sort) - should work and trigger display update
+ original_sort = self.collector.sort_by
+ self.display.simulate_input(ord("s"))
+ self.collector._handle_input()
+ self.assertTrue(self.collector.running) # Still running
+ self.assertNotEqual(self.collector.sort_by, original_sort) # Sort changed
+
+ # Try pressing 'p' (pause) - should work
+ self.display.simulate_input(ord("p"))
+ self.collector._handle_input()
+ self.assertTrue(self.collector.running) # Still running
+ self.assertTrue(self.collector.paused) # Now paused
+
+ # Try pressing 'r' (reset) - should be ignored when finished
+ self.collector.total_samples = 100
+ self.display.simulate_input(ord("r"))
+ self.collector._handle_input()
+ self.assertTrue(self.collector.running) # Still running
+ self.assertEqual(self.collector.total_samples, 100) # NOT reset when finished
+
+ # Press 'q' - should stop
+ self.display.simulate_input(ord("q"))
+ self.collector._handle_input()
+ self.assertFalse(self.collector.running) # Stopped
+
+ def test_finished_state_footer_message(self):
+ """Test that footer shows appropriate message when finished."""
+ # Add some sample data
+ thread_info = MockThreadInfo(
+ 123,
+ [
+ MockFrameInfo("test.py", 10, "work"),
+ MockFrameInfo("test.py", 20, "main"),
+ ],
+ )
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+ self.collector.collect(stack_frames)
+
+ # Mark as finished
+ self.collector.mark_finished()
+
+ # Check that footer contains finished message
+ self.assertTrue(self.display.contains_text("PROFILING FINISHED"))
+
+ def test_finished_state_freezes_time(self):
+ """Test that time displays are frozen when finished."""
+ import time as time_module
+
+ # Set up collector with known start time
+ self.collector.start_time = time_module.perf_counter() - 10.0 # 10 seconds ago
+
+ # Mark as finished - this should freeze the time
+ self.collector.mark_finished()
+
+ # Get the frozen elapsed time
+ frozen_elapsed = self.collector.elapsed_time
+ frozen_time_display = self.collector.current_time_display
+
+ # Wait a bit to ensure time would advance
+ time_module.sleep(0.1)
+
+ # Time should remain frozen
+ self.assertEqual(self.collector.elapsed_time, frozen_elapsed)
+ self.assertEqual(self.collector.current_time_display, frozen_time_display)
+
+ # Verify finish timestamp was set
+ self.assertIsNotNone(self.collector.finish_timestamp)
+
+ # Reset should clear the frozen state
+ self.collector.reset_stats()
+ self.assertFalse(self.collector.finished)
+ self.assertIsNone(self.collector.finish_timestamp)
+
+
+class TestLiveCollectorFiltering(unittest.TestCase):
+ """Tests for filtering functionality."""
+
+ def setUp(self):
+ """Set up collector with test data."""
+ self.display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.display
+ )
+ self.collector.start_time = time.perf_counter()
+ self.collector.total_samples = 100
+
+ # Add test data
+ self.collector.result[("app/models.py", 10, "save")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("app/views.py", 20, "render")] = {
+ "direct_calls": 30,
+ "cumulative_calls": 40,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("lib/utils.py", 30, "helper")] = {
+ "direct_calls": 20,
+ "cumulative_calls": 25,
+ "total_rec_calls": 0,
+ }
+
+ def test_filter_by_filename(self):
+ """Test filtering by filename pattern."""
+ self.collector.filter_pattern = "models"
+
+ stats_list = self.collector._build_stats_list()
+
+ # Only models.py should be included
+ self.assertEqual(len(stats_list), 1)
+ self.assertIn("models.py", stats_list[0]["func"][0])
+
+ def test_filter_by_function_name(self):
+ """Test filtering by function name."""
+ self.collector.filter_pattern = "render"
+
+ stats_list = self.collector._build_stats_list()
+
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "render")
+
+ def test_filter_case_insensitive(self):
+ """Test that filtering is case-insensitive."""
+ self.collector.filter_pattern = "MODELS"
+
+ stats_list = self.collector._build_stats_list()
+
+ # Should still match models.py
+ self.assertEqual(len(stats_list), 1)
+
+ def test_filter_substring_matching(self):
+ """Test substring filtering."""
+ self.collector.filter_pattern = "app/"
+
+ stats_list = self.collector._build_stats_list()
+
+ # Should match both app files
+ self.assertEqual(len(stats_list), 2)
+
+ def test_no_filter(self):
+ """Test with no filter applied."""
+ self.collector.filter_pattern = None
+
+ stats_list = self.collector._build_stats_list()
+
+ # All items should be included
+ self.assertEqual(len(stats_list), 3)
+
+ def test_filter_partial_function_name(self):
+ """Test filtering by partial function name."""
+ self.collector.filter_pattern = "save"
+
+ stats_list = self.collector._build_stats_list()
+
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "save")
+
+ def test_filter_combined_filename_funcname(self):
+ """Test filtering matches filename:funcname pattern."""
+ self.collector.filter_pattern = "views.py:render"
+
+ stats_list = self.collector._build_stats_list()
+
+ # Should match the combined pattern
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "render")
+
+ def test_filter_no_matches(self):
+ """Test filter that matches nothing."""
+ self.collector.filter_pattern = "nonexistent"
+
+ stats_list = self.collector._build_stats_list()
+
+ self.assertEqual(len(stats_list), 0)
+
+
+class TestLiveCollectorFilterInput(unittest.TestCase):
+ """Tests for filter input mode."""
+
+ def setUp(self):
+ """Set up collector with mock display."""
+ self.display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.display
+ )
+ self.collector.start_time = time.perf_counter()
+
+ def test_enter_filter_mode(self):
+ """Test entering filter input mode."""
+ self.assertFalse(self.collector.filter_input_mode)
+
+ # Press '/' to enter filter mode
+ self.display.simulate_input(ord("/"))
+ self.collector._handle_input()
+
+ self.assertTrue(self.collector.filter_input_mode)
+
+ def test_filter_input_typing(self):
+ """Test typing characters in filter input mode."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = ""
+
+ # Type 't', 'e', 's', 't'
+ for ch in "test":
+ self.display.simulate_input(ord(ch))
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.filter_input_buffer, "test")
+
+ def test_filter_input_backspace(self):
+ """Test backspace in filter input mode."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "test"
+
+ # Press backspace (127)
+ self.display.simulate_input(127)
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.filter_input_buffer, "tes")
+
+ def test_filter_input_backspace_alt(self):
+ """Test alternative backspace key (263) in filter input mode."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "test"
+
+ # Press backspace (263)
+ self.display.simulate_input(263)
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.filter_input_buffer, "tes")
+
+ def test_filter_input_backspace_empty(self):
+ """Test backspace on empty buffer."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = ""
+
+ # Press backspace - should not crash
+ self.display.simulate_input(127)
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.filter_input_buffer, "")
+
+ def test_filter_input_enter_applies_filter(self):
+ """Test pressing Enter applies the filter."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "myfilter"
+
+ # Press Enter (10)
+ self.display.simulate_input(10)
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.filter_input_mode)
+ self.assertEqual(self.collector.filter_pattern, "myfilter")
+ self.assertEqual(self.collector.filter_input_buffer, "")
+
+ def test_filter_input_enter_alt(self):
+ """Test alternative Enter key (13) applies filter."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "myfilter"
+
+ # Press Enter (13)
+ self.display.simulate_input(13)
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.filter_input_mode)
+ self.assertEqual(self.collector.filter_pattern, "myfilter")
+
+ def test_filter_input_enter_empty_clears_filter(self):
+ """Test pressing Enter with empty buffer clears filter."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = ""
+ self.collector.filter_pattern = "oldfilter"
+
+ # Press Enter
+ self.display.simulate_input(10)
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.filter_input_mode)
+ self.assertIsNone(self.collector.filter_pattern)
+
+ def test_filter_input_escape_cancels(self):
+ """Test pressing ESC cancels filter input."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "newfilter"
+ self.collector.filter_pattern = "oldfilter"
+
+ # Press ESC (27)
+ self.display.simulate_input(27)
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.filter_input_mode)
+ self.assertEqual(
+ self.collector.filter_pattern, "oldfilter"
+ ) # Unchanged
+ self.assertEqual(self.collector.filter_input_buffer, "")
+
+ def test_filter_input_start_with_existing_filter(self):
+ """Test entering filter mode with existing filter pre-fills buffer."""
+ self.collector.filter_pattern = "existing"
+
+ # Enter filter mode
+ self.display.simulate_input(ord("/"))
+ self.collector._handle_input()
+
+ # Buffer should be pre-filled with existing pattern
+ self.assertEqual(self.collector.filter_input_buffer, "existing")
+
+ def test_filter_input_start_without_filter(self):
+ """Test entering filter mode with no existing filter."""
+ self.collector.filter_pattern = None
+
+ # Enter filter mode
+ self.display.simulate_input(ord("/"))
+ self.collector._handle_input()
+
+ # Buffer should be empty
+ self.assertEqual(self.collector.filter_input_buffer, "")
+
+ def test_filter_input_mode_blocks_other_commands(self):
+ """Test that filter input mode blocks other commands."""
+ self.collector.filter_input_mode = True
+ initial_sort = self.collector.sort_by
+
+ # Try to press 's' (sort) - should be captured as input
+ self.display.simulate_input(ord("s"))
+ self.collector._handle_input()
+
+ # Sort should not change, 's' should be in buffer
+ self.assertEqual(self.collector.sort_by, initial_sort)
+ self.assertEqual(self.collector.filter_input_buffer, "s")
+
+ def test_filter_input_non_printable_ignored(self):
+ """Test that non-printable characters are ignored."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "test"
+
+ # Try to input a control character (< 32)
+ self.display.simulate_input(1) # Ctrl-A
+ self.collector._handle_input()
+
+ # Buffer should be unchanged
+ self.assertEqual(self.collector.filter_input_buffer, "test")
+
+ def test_filter_input_high_ascii_ignored(self):
+ """Test that high ASCII characters (>= 127, except backspace) are ignored."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "test"
+
+ # Try to input high ASCII (128)
+ self.display.simulate_input(128)
+ self.collector._handle_input()
+
+ # Buffer should be unchanged
+ self.assertEqual(self.collector.filter_input_buffer, "test")
+
+ def test_filter_prompt_displayed(self):
+ """Test that filter prompt is displayed when in input mode."""
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = "myfilter"
+ self.collector.total_samples = 10
+
+ self.collector._update_display()
+
+ # Should show the filter prompt
+ self.assertTrue(self.display.contains_text("Function filter"))
+ self.assertTrue(self.display.contains_text("myfilter"))
+
+
+if __name__ == "__main__":
+ unittest.main()
+
+
+class TestLiveCollectorThreadNavigation(unittest.TestCase):
+ """Tests for thread navigation functionality."""
+
+ def setUp(self):
+ """Set up collector with mock display and multiple threads."""
+ self.mock_display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.mock_display
+ )
+ self.collector.start_time = time.perf_counter()
+
+ # Simulate data from multiple threads
+ frames1 = [MockFrameInfo("file1.py", 10, "func1")]
+ frames2 = [MockFrameInfo("file2.py", 20, "func2")]
+ frames3 = [MockFrameInfo("file3.py", 30, "func3")]
+
+ thread1 = MockThreadInfo(111, frames1)
+ thread2 = MockThreadInfo(222, frames2)
+ thread3 = MockThreadInfo(333, frames3)
+
+ interpreter_info = MockInterpreterInfo(0, [thread1, thread2, thread3])
+ stack_frames = [interpreter_info]
+
+ # Collect data to populate thread IDs
+ self.collector.collect(stack_frames)
+
+ def test_initial_view_mode_is_all(self):
+ """Test that collector starts in ALL mode."""
+ self.assertEqual(self.collector.view_mode, "ALL")
+ self.assertEqual(self.collector.current_thread_index, 0)
+
+ def test_thread_ids_are_tracked(self):
+ """Test that thread IDs are tracked during collection."""
+ self.assertIn(111, self.collector.thread_ids)
+ self.assertIn(222, self.collector.thread_ids)
+ self.assertIn(333, self.collector.thread_ids)
+ self.assertEqual(len(self.collector.thread_ids), 3)
+
+ def test_toggle_to_per_thread_mode(self):
+ """Test toggling from ALL to PER_THREAD mode with 't' key."""
+ self.assertEqual(self.collector.view_mode, "ALL")
+
+ self.mock_display.simulate_input(ord("t"))
+ self.collector._handle_input()
+
+ self.assertEqual(self.collector.view_mode, "PER_THREAD")
+ self.assertEqual(self.collector.current_thread_index, 0)
+
+ def test_toggle_back_to_all_mode(self):
+ """Test toggling back from PER_THREAD to ALL mode."""
+ # Switch to PER_THREAD
+ self.mock_display.simulate_input(ord("t"))
+ self.collector._handle_input()
+ self.assertEqual(self.collector.view_mode, "PER_THREAD")
+
+ # Switch back to ALL
+ self.mock_display.simulate_input(ord("T"))
+ self.collector._handle_input()
+ self.assertEqual(self.collector.view_mode, "ALL")
+
+ def test_arrow_right_navigates_threads_in_per_thread_mode(self):
+ """Test that arrow keys navigate threads in PER_THREAD mode."""
+ # Switch to PER_THREAD mode
+ self.mock_display.simulate_input(ord("t"))
+ self.collector._handle_input()
+
+ # Navigate forward
+ self.assertEqual(self.collector.current_thread_index, 0)
+
+ self.mock_display.simulate_input(curses.KEY_RIGHT)
+ self.collector._handle_input()
+ self.assertEqual(self.collector.current_thread_index, 1)
+
+ self.mock_display.simulate_input(curses.KEY_RIGHT)
+ self.collector._handle_input()
+ self.assertEqual(self.collector.current_thread_index, 2)
+
+ def test_arrow_left_navigates_threads_backward(self):
+ """Test that left arrow navigates threads backward."""
+ # Switch to PER_THREAD mode
+ self.mock_display.simulate_input(ord("t"))
+ self.collector._handle_input()
+
+ # Navigate backward (should wrap around)
+ self.mock_display.simulate_input(curses.KEY_LEFT)
+ self.collector._handle_input()
+ self.assertEqual(
+ self.collector.current_thread_index, 2
+ ) # Wrapped to last
+
+ self.mock_display.simulate_input(curses.KEY_LEFT)
+ self.collector._handle_input()
+ self.assertEqual(self.collector.current_thread_index, 1)
+
+ def test_arrow_down_navigates_like_right(self):
+ """Test that down arrow works like right arrow."""
+ # Switch to PER_THREAD mode
+ self.mock_display.simulate_input(ord("t"))
+ self.collector._handle_input()
+
+ self.mock_display.simulate_input(curses.KEY_DOWN)
+ self.collector._handle_input()
+ self.assertEqual(self.collector.current_thread_index, 1)
+
+ def test_arrow_up_navigates_like_left(self):
+ """Test that up arrow works like left arrow."""
+ # Switch to PER_THREAD mode
+ self.mock_display.simulate_input(ord("t"))
+ self.collector._handle_input()
+
+ self.mock_display.simulate_input(curses.KEY_UP)
+ self.collector._handle_input()
+ self.assertEqual(self.collector.current_thread_index, 2) # Wrapped
+
+ def test_arrow_keys_switch_to_per_thread_mode(self):
+ """Test that arrow keys switch from ALL mode to PER_THREAD mode."""
+ self.assertEqual(self.collector.view_mode, "ALL")
+
+ self.mock_display.simulate_input(curses.KEY_RIGHT)
+ self.collector._handle_input()
+ self.assertEqual(self.collector.view_mode, "PER_THREAD")
+ self.assertEqual(self.collector.current_thread_index, 0)
+
+ def test_stats_list_in_all_mode(self):
+ """Test that stats list uses aggregated data in ALL mode."""
+ stats_list = self.collector._build_stats_list()
+
+ # Should have all 3 functions
+ self.assertEqual(len(stats_list), 3)
+ func_names = {stat["func"][2] for stat in stats_list}
+ self.assertEqual(func_names, {"func1", "func2", "func3"})
+
+ def test_stats_list_in_per_thread_mode(self):
+ """Test that stats list filters by thread in PER_THREAD mode."""
+ # Switch to PER_THREAD mode
+ self.collector.view_mode = "PER_THREAD"
+ self.collector.current_thread_index = 0 # First thread (111)
+
+ stats_list = self.collector._build_stats_list()
+
+ # Should only have func1 from thread 111
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "func1")
+
+ def test_stats_list_switches_with_thread_navigation(self):
+ """Test that stats list updates when navigating threads."""
+ self.collector.view_mode = "PER_THREAD"
+
+ # Thread 0 (111) -> func1
+ self.collector.current_thread_index = 0
+ stats_list = self.collector._build_stats_list()
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "func1")
+
+ # Thread 1 (222) -> func2
+ self.collector.current_thread_index = 1
+ stats_list = self.collector._build_stats_list()
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "func2")
+
+ # Thread 2 (333) -> func3
+ self.collector.current_thread_index = 2
+ stats_list = self.collector._build_stats_list()
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], "func3")
+
+ def test_reset_stats_clears_thread_data(self):
+ """Test that reset_stats clears thread tracking data."""
+ self.assertGreater(len(self.collector.thread_ids), 0)
+ self.assertGreater(len(self.collector.per_thread_data), 0)
+
+ self.collector.reset_stats()
+
+ self.assertEqual(len(self.collector.thread_ids), 0)
+ self.assertEqual(len(self.collector.per_thread_data), 0)
+ self.assertEqual(self.collector.view_mode, "ALL")
+ self.assertEqual(self.collector.current_thread_index, 0)
+
+ def test_toggle_with_no_threads_stays_in_all_mode(self):
+ """Test that toggle does nothing when no threads exist."""
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ self.assertEqual(len(collector.thread_ids), 0)
+
+ collector.display.simulate_input(ord("t"))
+ collector._handle_input()
+
+ # Should remain in ALL mode since no threads
+ self.assertEqual(collector.view_mode, "ALL")
+
+ def test_per_thread_data_isolation(self):
+ """Test that per-thread data is properly isolated."""
+ # Check that each thread has its own isolated data
+ self.assertIn(111, self.collector.per_thread_data)
+ self.assertIn(222, self.collector.per_thread_data)
+ self.assertIn(333, self.collector.per_thread_data)
+
+ # Thread 111 should only have func1
+ thread1_funcs = list(self.collector.per_thread_data[111].result.keys())
+ self.assertEqual(len(thread1_funcs), 1)
+ self.assertEqual(thread1_funcs[0][2], "func1")
+
+ # Thread 222 should only have func2
+ thread2_funcs = list(self.collector.per_thread_data[222].result.keys())
+ self.assertEqual(len(thread2_funcs), 1)
+ self.assertEqual(thread2_funcs[0][2], "func2")
+
+ def test_aggregated_data_sums_all_threads(self):
+ """Test that ALL mode shows aggregated data from all threads."""
+ # All three functions should be in the aggregated result
+ self.assertEqual(len(self.collector.result), 3)
+
+ # Each function should have 1 direct call
+ for func_location, counts in self.collector.result.items():
+ self.assertEqual(counts["direct_calls"], 1)
+
+ def test_per_thread_status_tracking(self):
+ """Test that per-thread status statistics are tracked."""
+ # Each thread should have status counts
+ self.assertIn(111, self.collector.per_thread_data)
+ self.assertIn(222, self.collector.per_thread_data)
+ self.assertIn(333, self.collector.per_thread_data)
+
+ # Each thread should have the expected attributes
+ for thread_id in [111, 222, 333]:
+ thread_data = self.collector.per_thread_data[thread_id]
+ self.assertIsNotNone(thread_data.has_gil)
+ self.assertIsNotNone(thread_data.on_cpu)
+ self.assertIsNotNone(thread_data.gil_requested)
+ self.assertIsNotNone(thread_data.unknown)
+ self.assertIsNotNone(thread_data.total)
+ # Each thread was sampled once
+ self.assertEqual(thread_data.total, 1)
+
+ def test_reset_stats_clears_thread_status(self):
+ """Test that reset_stats clears per-thread status data."""
+ self.assertGreater(len(self.collector.per_thread_data), 0)
+
+ self.collector.reset_stats()
+
+ self.assertEqual(len(self.collector.per_thread_data), 0)
+
+ def test_per_thread_sample_counts(self):
+ """Test that per-thread sample counts are tracked correctly."""
+ # Each thread should have exactly 1 sample (we collected once)
+ for thread_id in [111, 222, 333]:
+ self.assertIn(thread_id, self.collector.per_thread_data)
+ self.assertEqual(self.collector.per_thread_data[thread_id].sample_count, 1)
+
+ def test_per_thread_gc_samples(self):
+ """Test that per-thread GC samples are tracked correctly."""
+ # Initially no threads have GC frames
+ for thread_id in [111, 222, 333]:
+ self.assertIn(thread_id, self.collector.per_thread_data)
+ self.assertEqual(
+ self.collector.per_thread_data[thread_id].gc_frame_samples, 0
+ )
+
+ # Now collect a sample with a GC frame in thread 222
+ gc_frames = [MockFrameInfo("gc.py", 100, "gc_collect")]
+ thread_with_gc = MockThreadInfo(222, gc_frames)
+ interpreter_info = MockInterpreterInfo(0, [thread_with_gc])
+ stack_frames = [interpreter_info]
+
+ self.collector.collect(stack_frames)
+
+ # Thread 222 should now have 1 GC sample
+ self.assertEqual(self.collector.per_thread_data[222].gc_frame_samples, 1)
+ # Other threads should still have 0
+ self.assertEqual(self.collector.per_thread_data[111].gc_frame_samples, 0)
+ self.assertEqual(self.collector.per_thread_data[333].gc_frame_samples, 0)
+
+ def test_only_threads_with_frames_are_tracked(self):
+ """Test that only threads with actual frame data are added to thread_ids."""
+ # Create a new collector
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+
+ # Create threads: one with frames, one without
+ frames = [MockFrameInfo("test.py", 10, "test_func")]
+ thread_with_frames = MockThreadInfo(111, frames)
+ thread_without_frames = MockThreadInfo(222, None) # No frames
+ interpreter_info = MockInterpreterInfo(
+ 0, [thread_with_frames, thread_without_frames]
+ )
+ stack_frames = [interpreter_info]
+
+ collector.collect(stack_frames)
+
+ # Only thread 111 should be tracked (it has frames)
+ self.assertIn(111, collector.thread_ids)
+ self.assertNotIn(222, collector.thread_ids)
+
+ def test_per_thread_status_isolation(self):
+ """Test that per-thread status counts are isolated per thread."""
+ # Create threads with different status flags
+
+ frames1 = [MockFrameInfo("file1.py", 10, "func1")]
+ frames2 = [MockFrameInfo("file2.py", 20, "func2")]
+
+ # Thread 444: has GIL but not on CPU
+ thread1 = MockThreadInfo(444, frames1, status=THREAD_STATUS_HAS_GIL)
+ # Thread 555: on CPU but not has GIL
+ thread2 = MockThreadInfo(555, frames2, status=THREAD_STATUS_ON_CPU)
+
+ interpreter_info = MockInterpreterInfo(0, [thread1, thread2])
+ stack_frames = [interpreter_info]
+
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ collector.collect(stack_frames)
+
+ # Check thread 444 status
+ self.assertEqual(collector.per_thread_data[444].has_gil, 1)
+ self.assertEqual(collector.per_thread_data[444].on_cpu, 0)
+
+ # Check thread 555 status
+ self.assertEqual(collector.per_thread_data[555].has_gil, 0)
+ self.assertEqual(collector.per_thread_data[555].on_cpu, 1)
+
+ def test_display_uses_per_thread_stats_in_per_thread_mode(self):
+ """Test that display widget uses per-thread stats when in PER_THREAD mode."""
+
+ # Create collector with mock display
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ collector.start_time = time.perf_counter()
+
+ # Create 2 threads with different characteristics
+ # Thread 111: always has GIL (10 samples)
+ # Thread 222: never has GIL (10 samples)
+ for _ in range(10):
+ frames1 = [MockFrameInfo("file1.py", 10, "func1")]
+ frames2 = [MockFrameInfo("file2.py", 20, "func2")]
+ thread1 = MockThreadInfo(
+ 111, frames1, status=THREAD_STATUS_HAS_GIL
+ )
+ thread2 = MockThreadInfo(222, frames2, status=0) # No flags
+ interpreter_info = MockInterpreterInfo(0, [thread1, thread2])
+ collector.collect([interpreter_info])
+
+ # In ALL mode, should show mixed stats (50% on GIL, 50% off GIL)
+ self.assertEqual(collector.view_mode, "ALL")
+ total_has_gil = collector._thread_status_counts["has_gil"]
+ total_threads = collector._thread_status_counts["total"]
+ self.assertEqual(total_has_gil, 10) # Only thread 111 has GIL
+ self.assertEqual(total_threads, 20) # 10 samples * 2 threads
+
+ # Switch to PER_THREAD mode and select thread 111
+ collector.view_mode = "PER_THREAD"
+ collector.current_thread_index = 0 # Thread 111
+
+ # Thread 111 should show 100% on GIL
+ thread_111_data = collector.per_thread_data[111]
+ self.assertEqual(thread_111_data.has_gil, 10)
+ self.assertEqual(thread_111_data.total, 10)
+
+ # Switch to thread 222
+ collector.current_thread_index = 1 # Thread 222
+
+ # Thread 222 should show 0% on GIL
+ thread_222_data = collector.per_thread_data[222]
+ self.assertEqual(thread_222_data.has_gil, 0)
+ self.assertEqual(thread_222_data.total, 10)
+
+ def test_display_uses_per_thread_gc_stats_in_per_thread_mode(self):
+ """Test that GC percentage uses per-thread data in PER_THREAD mode."""
+ # Create collector with mock display
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ collector.start_time = time.perf_counter()
+
+ # Thread 111: 5 samples, 2 with GC
+ # Thread 222: 5 samples, 0 with GC
+ for i in range(5):
+ if i < 2:
+ # First 2 samples for thread 111 have GC
+ frames1 = [MockFrameInfo("gc.py", 100, "gc_collect")]
+ else:
+ frames1 = [MockFrameInfo("file1.py", 10, "func1")]
+
+ frames2 = [MockFrameInfo("file2.py", 20, "func2")] # No GC
+
+ thread1 = MockThreadInfo(111, frames1)
+ thread2 = MockThreadInfo(222, frames2)
+ interpreter_info = MockInterpreterInfo(0, [thread1, thread2])
+ collector.collect([interpreter_info])
+
+ # Check aggregated GC stats (ALL mode)
+ # 2 GC samples out of 10 total = 20%
+ self.assertEqual(collector._gc_frame_samples, 2)
+ self.assertEqual(collector.total_samples, 5) # 5 collect() calls
+
+ # Check per-thread GC stats
+ # Thread 111: 2 GC samples out of 5 = 40%
+ self.assertEqual(collector.per_thread_data[111].gc_frame_samples, 2)
+ self.assertEqual(collector.per_thread_data[111].sample_count, 5)
+
+ # Thread 222: 0 GC samples out of 5 = 0%
+ self.assertEqual(collector.per_thread_data[222].gc_frame_samples, 0)
+ self.assertEqual(collector.per_thread_data[222].sample_count, 5)
+
+ # Now verify the display would use the correct stats
+ collector.view_mode = "PER_THREAD"
+
+ # For thread 111
+ collector.current_thread_index = 0
+ thread_id = collector.thread_ids[0]
+ self.assertEqual(thread_id, 111)
+ thread_gc_pct = (
+ collector.per_thread_data[111].gc_frame_samples
+ / collector.per_thread_data[111].sample_count
+ ) * 100
+ self.assertEqual(thread_gc_pct, 40.0)
+
+ # For thread 222
+ collector.current_thread_index = 1
+ thread_id = collector.thread_ids[1]
+ self.assertEqual(thread_id, 222)
+ thread_gc_pct = (
+ collector.per_thread_data[222].gc_frame_samples
+ / collector.per_thread_data[222].sample_count
+ ) * 100
+ self.assertEqual(thread_gc_pct, 0.0)
+
+ def test_function_counts_are_per_thread_in_per_thread_mode(self):
+ """Test that function counts (total/exec/stack) are per-thread in PER_THREAD mode."""
+ # Create collector with mock display
+ collector = LiveStatsCollector(1000, display=MockDisplay())
+ collector.start_time = time.perf_counter()
+
+ # Thread 111: calls func1, func2, func3 (3 functions)
+ # Thread 222: calls func4, func5 (2 functions)
+ frames1 = [
+ MockFrameInfo("file1.py", 10, "func1"),
+ MockFrameInfo("file1.py", 20, "func2"),
+ MockFrameInfo("file1.py", 30, "func3"),
+ ]
+ frames2 = [
+ MockFrameInfo("file2.py", 40, "func4"),
+ MockFrameInfo("file2.py", 50, "func5"),
+ ]
+
+ thread1 = MockThreadInfo(111, frames1)
+ thread2 = MockThreadInfo(222, frames2)
+ interpreter_info = MockInterpreterInfo(0, [thread1, thread2])
+ collector.collect([interpreter_info])
+
+ # In ALL mode, should have 5 total functions
+ self.assertEqual(len(collector.result), 5)
+
+ # In PER_THREAD mode for thread 111, should have 3 functions
+ collector.view_mode = "PER_THREAD"
+ collector.current_thread_index = 0 # Thread 111
+ thread_111_result = collector.per_thread_data[111].result
+ self.assertEqual(len(thread_111_result), 3)
+
+ # Verify the functions are the right ones
+ thread_111_funcs = {loc[2] for loc in thread_111_result.keys()}
+ self.assertEqual(thread_111_funcs, {"func1", "func2", "func3"})
+
+ # In PER_THREAD mode for thread 222, should have 2 functions
+ collector.current_thread_index = 1 # Thread 222
+ thread_222_result = collector.per_thread_data[222].result
+ self.assertEqual(len(thread_222_result), 2)
+
+ # Verify the functions are the right ones
+ thread_222_funcs = {loc[2] for loc in thread_222_result.keys()}
+ self.assertEqual(thread_222_funcs, {"func4", "func5"})
+
+
+class TestLiveCollectorNewFeatures(unittest.TestCase):
+ """Tests for new features added to live collector."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.display = MockDisplay()
+ self.collector = LiveStatsCollector(1000, display=self.display)
+ self.collector.start_time = time.perf_counter()
+
+ def test_filter_input_takes_precedence_over_commands(self):
+ """Test that filter input mode blocks command keys like 'h' and 'p'."""
+ # Enter filter input mode
+ self.collector.filter_input_mode = True
+ self.collector.filter_input_buffer = ""
+
+ # Press 'h' - should add to filter buffer, not show help
+ self.display.simulate_input(ord("h"))
+ self.collector._handle_input()
+
+ self.assertFalse(self.collector.show_help) # Help not triggered
+ self.assertEqual(self.collector.filter_input_buffer, "h") # Added to filter
+ self.assertTrue(self.collector.filter_input_mode) # Still in filter mode
+
+ def test_reset_blocked_when_finished(self):
+ """Test that reset command is blocked when profiling is finished."""
+ # Set up some sample data and mark as finished
+ self.collector.total_samples = 100
+ self.collector.finished = True
+
+ # Press 'r' for reset
+ self.display.simulate_input(ord("r"))
+ self.collector._handle_input()
+
+ # Should NOT have been reset
+ self.assertEqual(self.collector.total_samples, 100)
+ self.assertTrue(self.collector.finished)
+
+ def test_time_display_fix_when_finished(self):
+ """Test that time display shows correct frozen time when finished."""
+ import time as time_module
+
+ # Mark as finished to freeze time
+ self.collector.mark_finished()
+
+ # Should have set both timestamps correctly
+ self.assertIsNotNone(self.collector.finish_timestamp)
+ self.assertIsNotNone(self.collector.finish_wall_time)
+
+ # Get the frozen time display
+ frozen_time = self.collector.current_time_display
+
+ # Wait a bit
+ time_module.sleep(0.1)
+
+ # Should still show the same frozen time (not jump to wrong time)
+ self.assertEqual(self.collector.current_time_display, frozen_time)
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+"""UI and display tests for LiveStatsCollector.
+
+Tests for MockDisplay, curses integration, display methods,
+edge cases, update display, and display helpers.
+"""
+
+import sys
+import time
+import unittest
+from unittest import mock
+from test.support import requires
+from test.support.import_helper import import_module
+
+# Only run these tests if curses is available
+requires("curses")
+curses = import_module("curses")
+
+from profiling.sampling.live_collector import LiveStatsCollector, MockDisplay
+from ._live_collector_helpers import (
+ MockThreadInfo,
+ MockInterpreterInfo,
+)
+
+
+class TestLiveStatsCollectorWithMockDisplay(unittest.TestCase):
+ """Tests for display functionality using MockDisplay."""
+
+ def setUp(self):
+ """Set up collector with mock display."""
+ self.mock_display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.mock_display
+ )
+ self.collector.start_time = time.perf_counter()
+
+ def test_update_display_with_mock(self):
+ """Test that update_display works with MockDisplay."""
+ self.collector.total_samples = 100
+ self.collector.result[("test.py", 10, "test_func")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+
+ self.collector._update_display()
+
+ # Verify display operations were called
+ self.assertTrue(self.mock_display.cleared)
+ self.assertTrue(self.mock_display.refreshed)
+ self.assertTrue(self.mock_display.redrawn)
+
+ # Verify some content was written
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ def test_handle_input_quit(self):
+ """Test that 'q' input stops the collector."""
+ self.mock_display.simulate_input(ord("q"))
+ self.collector._handle_input()
+ self.assertFalse(self.collector.running)
+
+ def test_handle_input_sort_cycle(self):
+ """Test that 's' input cycles sort mode."""
+ self.collector.sort_by = "tottime"
+ self.mock_display.simulate_input(ord("s"))
+ self.collector._handle_input()
+ self.assertEqual(self.collector.sort_by, "cumul_pct")
+
+ def test_draw_methods_with_mock_display(self):
+ """Test that draw methods write to mock display."""
+ self.collector.total_samples = 500
+ self.collector._successful_samples = 450
+ self.collector._failed_samples = 50
+
+ colors = self.collector._setup_colors()
+ self.collector._initialize_widgets(colors)
+
+ # Test individual widget methods
+ line = self.collector._header_widget.draw_header_info(0, 160, 100.5)
+ self.assertEqual(line, 2) # Title + header info line
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ # Clear buffer and test next method
+ self.mock_display.buffer.clear()
+ line = self.collector._header_widget.draw_sample_stats(0, 160, 10.0)
+ self.assertEqual(line, 1)
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ def test_terminal_too_small_message(self):
+ """Test terminal too small warning."""
+ small_display = MockDisplay(height=10, width=50)
+ self.collector.display = small_display
+
+ self.collector._show_terminal_too_small(10, 50)
+
+ # Should have written warning message
+ text = small_display.get_text_at(3, 15) # Approximate center
+ self.assertIsNotNone(text)
+
+ def test_full_display_rendering_with_data(self):
+ """Test complete display rendering with realistic data."""
+ # Add multiple functions with different call counts
+ self.collector.total_samples = 1000
+ self.collector._successful_samples = 950
+ self.collector._failed_samples = 50
+
+ self.collector.result[("app.py", 10, "main")] = {
+ "direct_calls": 100,
+ "cumulative_calls": 500,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("utils.py", 20, "helper")] = {
+ "direct_calls": 300,
+ "cumulative_calls": 400,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("db.py", 30, "query")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 100,
+ "total_rec_calls": 0,
+ }
+
+ self.collector._update_display()
+
+ # Verify the display has content
+ self.assertGreater(len(self.mock_display.buffer), 10)
+
+ # Verify PID is shown
+ found_pid = False
+ for (line, col), (text, attr) in self.mock_display.buffer.items():
+ if "12345" in text:
+ found_pid = True
+ break
+ self.assertTrue(found_pid, "PID should be displayed")
+
+ def test_efficiency_bar_visualization(self):
+ """Test that efficiency bar shows correct proportions."""
+ self.collector.total_samples = 100
+ self.collector._successful_samples = 75
+ self.collector._failed_samples = 25
+
+ colors = self.collector._setup_colors()
+ self.collector._initialize_widgets(colors)
+ self.collector._header_widget.draw_efficiency_bar(0, 160)
+
+ # Check that something was drawn to the display
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ def test_stats_display_with_different_sort_modes(self):
+ """Test that stats are displayed correctly with different sort modes."""
+ self.collector.total_samples = 100
+ self.collector.result[("a.py", 1, "func_a")] = {
+ "direct_calls": 10,
+ "cumulative_calls": 20,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("b.py", 2, "func_b")] = {
+ "direct_calls": 30,
+ "cumulative_calls": 40,
+ "total_rec_calls": 0,
+ }
+
+ # Test each sort mode
+ for sort_mode in [
+ "nsamples",
+ "tottime",
+ "cumtime",
+ "sample_pct",
+ "cumul_pct",
+ ]:
+ self.mock_display.buffer.clear()
+ self.collector.sort_by = sort_mode
+
+ stats_list = self.collector._build_stats_list()
+ self.assertEqual(len(stats_list), 2)
+
+ # Verify sorting worked (func_b should be first for most modes)
+ if sort_mode in ["nsamples", "tottime", "sample_pct"]:
+ self.assertEqual(stats_list[0]["func"][2], "func_b")
+
+ def test_narrow_terminal_column_hiding(self):
+ """Test that columns are hidden on narrow terminals."""
+ narrow_display = MockDisplay(height=40, width=70)
+ collector = LiveStatsCollector(1000, pid=12345, display=narrow_display)
+ collector.start_time = time.perf_counter()
+
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ line, show_sample_pct, show_tottime, show_cumul_pct, show_cumtime = (
+ collector._table_widget.draw_column_headers(0, 70)
+ )
+
+ # On narrow terminal, some columns should be hidden
+ self.assertFalse(
+ show_cumul_pct or show_cumtime,
+ "Some columns should be hidden on narrow terminal",
+ )
+
+ def test_very_narrow_terminal_minimal_columns(self):
+ """Test minimal display on very narrow terminal."""
+ very_narrow = MockDisplay(height=40, width=60)
+ collector = LiveStatsCollector(1000, pid=12345, display=very_narrow)
+ collector.start_time = time.perf_counter()
+
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+ line, show_sample_pct, show_tottime, show_cumul_pct, show_cumtime = (
+ collector._table_widget.draw_column_headers(0, 60)
+ )
+
+ # Very narrow should hide even more columns
+ self.assertFalse(
+ show_sample_pct,
+ "Sample % should be hidden on very narrow terminal",
+ )
+
+ def test_display_updates_only_at_interval(self):
+ """Test that display updates respect the update interval."""
+ # Create collector with display
+ collector = LiveStatsCollector(1000, display=self.mock_display)
+
+ # Simulate multiple rapid collections
+ thread_info = MockThreadInfo(123, [])
+ interpreter_info = MockInterpreterInfo(0, [thread_info])
+ stack_frames = [interpreter_info]
+
+ # First collect should update display
+ collector.collect(stack_frames)
+ first_cleared = self.mock_display.cleared
+
+ # Reset flags
+ self.mock_display.cleared = False
+ self.mock_display.refreshed = False
+
+ # Immediate second collect should NOT update display (too soon)
+ collector.collect(stack_frames)
+ self.assertFalse(
+ self.mock_display.cleared,
+ "Display should not update too frequently",
+ )
+
+ def test_top_functions_display(self):
+ """Test that top functions are highlighted correctly."""
+ self.collector.total_samples = 1000
+
+ # Create functions with different sample counts
+ for i in range(10):
+ self.collector.result[(f"file{i}.py", i * 10, f"func{i}")] = {
+ "direct_calls": (10 - i) * 10, # Decreasing counts
+ "cumulative_calls": (10 - i) * 20,
+ "total_rec_calls": 0,
+ }
+
+ colors = self.collector._setup_colors()
+ self.collector._initialize_widgets(colors)
+ stats_list = self.collector._build_stats_list()
+
+ self.collector._header_widget.draw_top_functions(0, 160, stats_list)
+
+ # Top functions section should have written something
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+
+class TestLiveStatsCollectorCursesIntegration(unittest.TestCase):
+ """Tests for curses-related functionality using mocks."""
+
+ def setUp(self):
+ """Set up mock curses screen."""
+ self.mock_stdscr = mock.MagicMock()
+ self.mock_stdscr.getmaxyx.return_value = (40, 160) # height, width
+ self.mock_stdscr.getch.return_value = -1 # No input
+ # Save original stdout/stderr
+ self._orig_stdout = sys.stdout
+ self._orig_stderr = sys.stderr
+
+ def tearDown(self):
+ """Restore stdout/stderr if changed."""
+ sys.stdout = self._orig_stdout
+ sys.stderr = self._orig_stderr
+
+ def test_init_curses(self):
+ """Test curses initialization."""
+ collector = LiveStatsCollector(1000)
+
+ with (
+ mock.patch("curses.curs_set"),
+ mock.patch("curses.has_colors", return_value=True),
+ mock.patch("curses.start_color"),
+ mock.patch("curses.use_default_colors"),
+ mock.patch("builtins.open", mock.mock_open()) as mock_open_func,
+ ):
+ collector.init_curses(self.mock_stdscr)
+
+ self.assertIsNotNone(collector.stdscr)
+ self.mock_stdscr.nodelay.assert_called_with(True)
+ self.mock_stdscr.scrollok.assert_called_with(False)
+
+ # Clean up properly
+ if collector._devnull:
+ collector._devnull.close()
+ collector._saved_stdout = None
+ collector._saved_stderr = None
+
+ def test_cleanup_curses(self):
+ """Test curses cleanup."""
+ mock_display = MockDisplay()
+ collector = LiveStatsCollector(1000, display=mock_display)
+ collector.stdscr = self.mock_stdscr
+
+ # Mock devnull file to avoid resource warnings
+ mock_devnull = mock.MagicMock()
+ mock_saved_stdout = mock.MagicMock()
+ mock_saved_stderr = mock.MagicMock()
+
+ collector._devnull = mock_devnull
+ collector._saved_stdout = mock_saved_stdout
+ collector._saved_stderr = mock_saved_stderr
+
+ with mock.patch("curses.curs_set"):
+ collector.cleanup_curses()
+
+ mock_devnull.close.assert_called_once()
+ # Verify stdout/stderr were set back to the saved values
+ self.assertEqual(sys.stdout, mock_saved_stdout)
+ self.assertEqual(sys.stderr, mock_saved_stderr)
+ # Verify the saved values were cleared
+ self.assertIsNone(collector._saved_stdout)
+ self.assertIsNone(collector._saved_stderr)
+ self.assertIsNone(collector._devnull)
+
+ def test_add_str_with_mock_display(self):
+ """Test safe_addstr with MockDisplay."""
+ mock_display = MockDisplay(height=40, width=160)
+ collector = LiveStatsCollector(1000, display=mock_display)
+ colors = collector._setup_colors()
+ collector._initialize_widgets(colors)
+
+ collector._header_widget.add_str(5, 10, "Test", 0)
+ # Verify it was added to the buffer
+ self.assertIn((5, 10), mock_display.buffer)
+
+ def test_setup_colors_with_color_support(self):
+ """Test color setup when colors are supported."""
+ mock_display = MockDisplay(height=40, width=160)
+ mock_display.colors_supported = True
+ collector = LiveStatsCollector(1000, display=mock_display)
+
+ colors = collector._setup_colors()
+
+ self.assertIn("header", colors)
+ self.assertIn("cyan", colors)
+ self.assertIn("yellow", colors)
+ self.assertIn("green", colors)
+ self.assertIn("magenta", colors)
+ self.assertIn("red", colors)
+
+ def test_setup_colors_without_color_support(self):
+ """Test color setup when colors are not supported."""
+ mock_display = MockDisplay(height=40, width=160)
+ mock_display.colors_supported = False
+ collector = LiveStatsCollector(1000, display=mock_display)
+
+ colors = collector._setup_colors()
+
+ # Should still have all keys but with fallback values
+ self.assertIn("header", colors)
+ self.assertIn("cyan", colors)
+
+ def test_handle_input_quit(self):
+ """Test handling 'q' key to quit."""
+ mock_display = MockDisplay()
+ mock_display.simulate_input(ord("q"))
+ collector = LiveStatsCollector(1000, display=mock_display)
+
+ self.assertTrue(collector.running)
+ collector._handle_input()
+ self.assertFalse(collector.running)
+
+ def test_handle_input_quit_uppercase(self):
+ """Test handling 'Q' key to quit."""
+ mock_display = MockDisplay()
+ mock_display.simulate_input(ord("Q"))
+ collector = LiveStatsCollector(1000, display=mock_display)
+
+ self.assertTrue(collector.running)
+ collector._handle_input()
+ self.assertFalse(collector.running)
+
+ def test_handle_input_cycle_sort(self):
+ """Test handling 's' key to cycle sort."""
+ mock_display = MockDisplay()
+ mock_display.simulate_input(ord("s"))
+ collector = LiveStatsCollector(
+ 1000, sort_by="nsamples", display=mock_display
+ )
+
+ collector._handle_input()
+ self.assertEqual(collector.sort_by, "sample_pct")
+
+ def test_handle_input_cycle_sort_uppercase(self):
+ """Test handling 'S' key to cycle sort backward."""
+ mock_display = MockDisplay()
+ mock_display.simulate_input(ord("S"))
+ collector = LiveStatsCollector(
+ 1000, sort_by="nsamples", display=mock_display
+ )
+
+ collector._handle_input()
+ self.assertEqual(collector.sort_by, "cumtime")
+
+ def test_handle_input_no_key(self):
+ """Test handling when no key is pressed."""
+ mock_display = MockDisplay()
+ collector = LiveStatsCollector(1000, display=mock_display)
+
+ collector._handle_input()
+ # Should not change state
+ self.assertTrue(collector.running)
+
+
+class TestLiveStatsCollectorDisplayMethods(unittest.TestCase):
+ """Tests for display-related methods."""
+
+ def setUp(self):
+ """Set up collector with mock display."""
+ self.mock_display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.mock_display
+ )
+ self.collector.start_time = time.perf_counter()
+
+ def test_show_terminal_too_small(self):
+ """Test terminal too small message display."""
+ self.collector._show_terminal_too_small(10, 50)
+ # Should have written some content to the display buffer
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ def test_draw_header_info(self):
+ """Test drawing header information."""
+ colors = {
+ "cyan": curses.A_BOLD,
+ "green": curses.A_BOLD,
+ "yellow": curses.A_BOLD,
+ "magenta": curses.A_BOLD,
+ }
+ self.collector._initialize_widgets(colors)
+
+ line = self.collector._header_widget.draw_header_info(0, 160, 100.5)
+ self.assertEqual(line, 2) # Title + header info line
+
+ def test_draw_sample_stats(self):
+ """Test drawing sample statistics."""
+ self.collector.total_samples = 1000
+ colors = {"cyan": curses.A_BOLD, "green": curses.A_BOLD}
+ self.collector._initialize_widgets(colors)
+
+ line = self.collector._header_widget.draw_sample_stats(0, 160, 10.0)
+ self.assertEqual(line, 1)
+ self.assertGreater(self.collector._max_sample_rate, 0)
+
+ def test_progress_bar_uses_target_rate(self):
+ """Test that progress bar uses target rate instead of max rate."""
+ # Set up collector with specific sampling interval
+ collector = LiveStatsCollector(
+ 10000, pid=12345, display=self.mock_display
+ ) # 10ms = 100Hz target
+ collector.start_time = time.perf_counter()
+ collector.total_samples = 500
+ collector._max_sample_rate = (
+ 150 # Higher than target to test we don't use this
+ )
+
+ colors = {"cyan": curses.A_BOLD, "green": curses.A_BOLD}
+ collector._initialize_widgets(colors)
+
+ # Clear the display buffer to capture only our progress bar content
+ self.mock_display.buffer.clear()
+
+ # Draw sample stats with a known elapsed time that gives us a specific sample rate
+ elapsed = 10.0 # 500 samples in 10 seconds = 50 samples/second
+ line = collector._header_widget.draw_sample_stats(0, 160, elapsed)
+
+ # Verify display was updated
+ self.assertEqual(line, 1)
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ # Verify the label shows current/target format with units instead of "max"
+ found_current_target_label = False
+ found_max_label = False
+ for (line_num, col), (text, attr) in self.mock_display.buffer.items():
+ # Should show "50.0Hz/100.0Hz (50.0%)" since we're at 50% of target (50/100)
+ if "50.0Hz/100.0Hz" in text and "50.0%" in text:
+ found_current_target_label = True
+ if "max:" in text:
+ found_max_label = True
+
+ self.assertTrue(
+ found_current_target_label,
+ "Should display current/target rate with percentage",
+ )
+ self.assertFalse(found_max_label, "Should not display max rate label")
+
+ def test_progress_bar_different_intervals(self):
+ """Test that progress bar adapts to different sampling intervals."""
+ test_cases = [
+ (
+ 1000,
+ "1.0KHz",
+ "100.0Hz",
+ ), # 1ms interval -> 1000Hz target (1.0KHz), 100Hz current
+ (
+ 5000,
+ "200.0Hz",
+ "100.0Hz",
+ ), # 5ms interval -> 200Hz target, 100Hz current
+ (
+ 20000,
+ "50.0Hz",
+ "100.0Hz",
+ ), # 20ms interval -> 50Hz target, 100Hz current
+ (
+ 100000,
+ "10.0Hz",
+ "100.0Hz",
+ ), # 100ms interval -> 10Hz target, 100Hz current
+ ]
+
+ for (
+ interval_usec,
+ expected_target_formatted,
+ expected_current_formatted,
+ ) in test_cases:
+ with self.subTest(interval=interval_usec):
+ collector = LiveStatsCollector(
+ interval_usec, display=MockDisplay()
+ )
+ collector.start_time = time.perf_counter()
+ collector.total_samples = 100
+
+ colors = {"cyan": curses.A_BOLD, "green": curses.A_BOLD}
+ collector._initialize_widgets(colors)
+
+ # Clear buffer
+ collector.display.buffer.clear()
+
+ # Draw with 1 second elapsed time (gives us current rate of 100Hz)
+ collector._header_widget.draw_sample_stats(0, 160, 1.0)
+
+ # Check that the current/target format appears in the display with proper units
+ found_current_target_format = False
+ for (line_num, col), (
+ text,
+ attr,
+ ) in collector.display.buffer.items():
+ # Looking for format like "100.0Hz/1.0KHz" or "100.0Hz/200.0Hz"
+ expected_format = f"{expected_current_formatted}/{expected_target_formatted}"
+ if expected_format in text and "%" in text:
+ found_current_target_format = True
+ break
+
+ self.assertTrue(
+ found_current_target_format,
+ f"Should display current/target rate format with units for {interval_usec}µs interval",
+ )
+
+ def test_draw_efficiency_bar(self):
+ """Test drawing efficiency bar."""
+ self.collector._successful_samples = 900
+ self.collector._failed_samples = 100
+ self.collector.total_samples = 1000
+ colors = {"green": curses.A_BOLD, "red": curses.A_BOLD}
+ self.collector._initialize_widgets(colors)
+
+ line = self.collector._header_widget.draw_efficiency_bar(0, 160)
+ self.assertEqual(line, 1)
+
+ def test_draw_function_stats(self):
+ """Test drawing function statistics."""
+ self.collector.result[("test.py", 10, "func1")] = {
+ "direct_calls": 100,
+ "cumulative_calls": 150,
+ "total_rec_calls": 0,
+ }
+ self.collector.result[("test.py", 20, "func2")] = {
+ "direct_calls": 0,
+ "cumulative_calls": 50,
+ "total_rec_calls": 0,
+ }
+
+ stats_list = self.collector._build_stats_list()
+ colors = {
+ "cyan": curses.A_BOLD,
+ "green": curses.A_BOLD,
+ "yellow": curses.A_BOLD,
+ "magenta": curses.A_BOLD,
+ }
+ self.collector._initialize_widgets(colors)
+
+ line = self.collector._header_widget.draw_function_stats(
+ 0, 160, stats_list
+ )
+ self.assertEqual(line, 1)
+
+ def test_draw_top_functions(self):
+ """Test drawing top functions."""
+ self.collector.total_samples = 300
+ self.collector.result[("test.py", 10, "hot_func")] = {
+ "direct_calls": 100,
+ "cumulative_calls": 150,
+ "total_rec_calls": 0,
+ }
+
+ stats_list = self.collector._build_stats_list()
+ colors = {
+ "red": curses.A_BOLD,
+ "yellow": curses.A_BOLD,
+ "green": curses.A_BOLD,
+ }
+ self.collector._initialize_widgets(colors)
+
+ line = self.collector._header_widget.draw_top_functions(
+ 0, 160, stats_list
+ )
+ self.assertEqual(line, 1)
+
+ def test_draw_column_headers(self):
+ """Test drawing column headers."""
+ colors = {
+ "sorted_header": curses.A_BOLD,
+ "normal_header": curses.A_NORMAL,
+ }
+ self.collector._initialize_widgets(colors)
+
+ (
+ line,
+ show_sample_pct,
+ show_tottime,
+ show_cumul_pct,
+ show_cumtime,
+ ) = self.collector._table_widget.draw_column_headers(0, 160)
+ self.assertEqual(line, 1)
+ self.assertTrue(show_sample_pct)
+ self.assertTrue(show_tottime)
+ self.assertTrue(show_cumul_pct)
+ self.assertTrue(show_cumtime)
+
+ def test_draw_column_headers_narrow_terminal(self):
+ """Test column headers adapt to narrow terminal."""
+ colors = {
+ "sorted_header": curses.A_BOLD,
+ "normal_header": curses.A_NORMAL,
+ }
+ self.collector._initialize_widgets(colors)
+
+ (
+ line,
+ show_sample_pct,
+ show_tottime,
+ show_cumul_pct,
+ show_cumtime,
+ ) = self.collector._table_widget.draw_column_headers(0, 70)
+ self.assertEqual(line, 1)
+ # Some columns should be hidden on narrow terminal
+ self.assertFalse(show_cumul_pct)
+
+ def test_draw_footer(self):
+ """Test drawing footer."""
+ colors = self.collector._setup_colors()
+ self.collector._initialize_widgets(colors)
+ self.collector._footer_widget.render(38, 160)
+ # Should have written some content to the display buffer
+ self.assertGreater(len(self.mock_display.buffer), 0)
+
+ def test_draw_progress_bar(self):
+ """Test progress bar drawing."""
+ colors = self.collector._setup_colors()
+ self.collector._initialize_widgets(colors)
+ bar, length = self.collector._header_widget.progress_bar.render_bar(
+ 50, 100, 30
+ )
+
+ self.assertIn("[", bar)
+ self.assertIn("]", bar)
+ self.assertGreater(length, 0)
+ # Should be roughly 50% filled
+ self.assertIn("█", bar)
+ self.assertIn("░", bar)
+
+
+class TestLiveStatsCollectorEdgeCases(unittest.TestCase):
+ """Tests for edge cases and error handling."""
+
+ def test_very_long_function_name(self):
+ """Test handling of very long function names."""
+ collector = LiveStatsCollector(1000)
+ long_name = "x" * 200
+ collector.result[("test.py", 10, long_name)] = {
+ "direct_calls": 10,
+ "cumulative_calls": 20,
+ "total_rec_calls": 0,
+ }
+
+ stats_list = collector._build_stats_list()
+ self.assertEqual(len(stats_list), 1)
+ self.assertEqual(stats_list[0]["func"][2], long_name)
+
+
+class TestLiveStatsCollectorUpdateDisplay(unittest.TestCase):
+ """Tests for the _update_display method."""
+
+ def setUp(self):
+ """Set up collector with mock display."""
+ self.mock_display = MockDisplay(height=40, width=160)
+ self.collector = LiveStatsCollector(
+ 1000, pid=12345, display=self.mock_display
+ )
+ self.collector.start_time = time.perf_counter()
+
+ def test_update_display_terminal_too_small(self):
+ """Test update_display when terminal is too small."""
+ small_display = MockDisplay(height=10, width=50)
+ self.collector.display = small_display
+
+ with mock.patch.object(
+ self.collector, "_show_terminal_too_small"
+ ) as mock_show:
+ self.collector._update_display()
+ mock_show.assert_called_once()
+
+ def test_update_display_normal(self):
+ """Test normal update_display operation."""
+ self.collector.total_samples = 100
+ self.collector._successful_samples = 90
+ self.collector._failed_samples = 10
+ self.collector.result[("test.py", 10, "func")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+
+ self.collector._update_display()
+
+ self.assertTrue(self.mock_display.cleared)
+ self.assertTrue(self.mock_display.refreshed)
+
+ def test_update_display_handles_exception(self):
+ """Test that update_display handles exceptions gracefully."""
+ # Make one of the methods raise an exception
+ with mock.patch.object(
+ self.collector,
+ "_prepare_display_data",
+ side_effect=Exception("Test error"),
+ ):
+ # Should not raise an exception (it catches and logs via trace_exception)
+ try:
+ self.collector._update_display()
+ except Exception:
+ self.fail(
+ "_update_display should handle exceptions gracefully"
+ )
+
+
+class TestLiveCollectorWithMockDisplayHelpers(unittest.TestCase):
+ """Tests using the new MockDisplay helper methods."""
+
+ def test_verify_pid_display_with_contains(self):
+ """Test verifying PID is displayed using contains_text helper."""
+ display = MockDisplay(height=40, width=160)
+ collector = LiveStatsCollector(1000, pid=99999, display=display)
+ collector.start_time = time.perf_counter()
+ collector.total_samples = 10
+
+ collector._update_display()
+
+ # Use the helper method
+ self.assertTrue(
+ display.contains_text("99999"), "PID should be visible in display"
+ )
+
+ def test_verify_function_names_displayed(self):
+ """Test verifying function names appear in display."""
+ display = MockDisplay(height=40, width=160)
+ collector = LiveStatsCollector(1000, pid=12345, display=display)
+ collector.start_time = time.perf_counter()
+
+ collector.total_samples = 100
+ collector.result[("mymodule.py", 42, "my_special_function")] = {
+ "direct_calls": 50,
+ "cumulative_calls": 75,
+ "total_rec_calls": 0,
+ }
+
+ collector._update_display()
+
+ # Verify function name appears
+ self.assertTrue(
+ display.contains_text("my_special_function"),
+ "Function name should be visible",
+ )
+
+ def test_get_all_lines_full_display(self):
+ """Test getting all lines from a full display render."""
+ display = MockDisplay(height=40, width=160)
+ collector = LiveStatsCollector(1000, pid=12345, display=display)
+ collector.start_time = time.perf_counter()
+ collector.total_samples = 100
+
+ collector._update_display()
+
+ lines = display.get_all_lines()
+
+ # Should have multiple lines of content
+ self.assertGreater(len(lines), 5)
+
+ # Should have header content
+ self.assertTrue(any("PID" in line for line in lines))
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+"""Simple unit tests for TrendTracker."""
+
+import unittest
+from test.support import requires
+from test.support.import_helper import import_module
+
+# Only run these tests if curses is available
+requires("curses")
+curses = import_module("curses")
+
+from profiling.sampling.live_collector.trend_tracker import TrendTracker
+
+
+class TestTrendTracker(unittest.TestCase):
+ """Tests for TrendTracker class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.colors = {
+ "trend_up": curses.A_BOLD,
+ "trend_down": curses.A_REVERSE,
+ "trend_stable": curses.A_NORMAL,
+ }
+
+ def test_basic_trend_detection(self):
+ """Test basic up/down/stable trend detection."""
+ tracker = TrendTracker(self.colors, enabled=True)
+
+ # First value is always stable
+ self.assertEqual(tracker.update("func1", "nsamples", 10), "stable")
+
+ # Increasing value
+ self.assertEqual(tracker.update("func1", "nsamples", 20), "up")
+
+ # Decreasing value
+ self.assertEqual(tracker.update("func1", "nsamples", 15), "down")
+
+ # Small change (within threshold) is stable
+ self.assertEqual(tracker.update("func1", "nsamples", 15.0001), "stable")
+
+ def test_multiple_metrics(self):
+ """Test tracking multiple metrics simultaneously."""
+ tracker = TrendTracker(self.colors, enabled=True)
+
+ trends = tracker.update_metrics("func1", {
+ "nsamples": 10,
+ "tottime": 5.0,
+ })
+
+ self.assertEqual(trends["nsamples"], "stable")
+ self.assertEqual(trends["tottime"], "stable")
+
+ # Update with changes
+ trends = tracker.update_metrics("func1", {
+ "nsamples": 15,
+ "tottime": 3.0,
+ })
+
+ self.assertEqual(trends["nsamples"], "up")
+ self.assertEqual(trends["tottime"], "down")
+
+ def test_toggle_enabled(self):
+ """Test enable/disable toggle."""
+ tracker = TrendTracker(self.colors, enabled=True)
+ self.assertTrue(tracker.enabled)
+
+ tracker.toggle()
+ self.assertFalse(tracker.enabled)
+
+ # When disabled, should return A_NORMAL
+ self.assertEqual(tracker.get_color("up"), curses.A_NORMAL)
+
+ def test_get_color(self):
+ """Test color selection for trends."""
+ tracker = TrendTracker(self.colors, enabled=True)
+
+ self.assertEqual(tracker.get_color("up"), curses.A_BOLD)
+ self.assertEqual(tracker.get_color("down"), curses.A_REVERSE)
+ self.assertEqual(tracker.get_color("stable"), curses.A_NORMAL)
+
+ def test_clear(self):
+ """Test clearing tracked values."""
+ tracker = TrendTracker(self.colors, enabled=True)
+
+ # Add some data
+ tracker.update("func1", "nsamples", 10)
+ tracker.update("func1", "nsamples", 20)
+
+ # Clear
+ tracker.clear()
+
+ # After clear, first update should be stable
+ self.assertEqual(tracker.update("func1", "nsamples", 30), "stable")
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+Add a new ``--live`` mode to the tachyon profiler in
+:mod:`profiling.sampling` module. This mode consist of a live TUI that
+displays real-time profiling statistics as the target application runs,
+similar to ``top``. Patch by Pablo Galindo
// Check CPU status
long pthread_id = GET_MEMBER(long, ts, unwinder->debug_offsets.thread_state.thread_id);
- int cpu_status = get_thread_status(unwinder, tid, pthread_id);
+
+ // Optimization: only check CPU status if needed by mode because it's expensive
+ int cpu_status = -1;
+ if (unwinder->mode == PROFILING_MODE_CPU || unwinder->mode == PROFILING_MODE_ALL) {
+ cpu_status = get_thread_status(unwinder, tid, pthread_id);
+ }
+
if (cpu_status == -1) {
status_flags |= THREAD_STATUS_UNKNOWN;
} else if (cpu_status == THREAD_STATE_RUNNING) {