Skip to content

Commit 3b40ee8

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Add --kernel and --launch-id to Reproducer (#209)
Summary: This PR adds support for reproducing kernel launches by kernel name and launch ID, eliminating the need to manually find line numbers in trace files. ## Changes - **`reproducer/cli.py`**: - Added `--kernel` argument (str, default=None) - Added `--launch-id` argument (int, default=0, 0-based) - Updated `--line` help text to indicate mutual exclusivity - **`reproducer/orchestrator.py`**: - Extended `reproduce()` function signature: - `line_index: int` (required, maintains backward compatibility) - `out_dir: str` (required, no default value) - `template: str` (required, no default value) - `kernel_name: Optional[str] = None` (new, placed after required params) - `launch_id: int = 0` (new) - Implemented kernel lookup logic: if `kernel_name` is provided, uses `find_launch_index_by_kernel()` to find the actual `line_index` - Updated docstring to document support for `.ndjson`, `.ndjson.gz`, and `.bin.ndjson` formats - **`cli.py`**: - Added mutual exclusivity check: error if both `--kernel` and `--line` (non-zero) are provided - Updated `reproduce()` call to pass new parameters using unified calling pattern - **`tests/test_tritonparse.py`**: - Added helper methods to `TestTritonparseCPU` class: - `_get_test_ndjson_file()`: Get test file path - `setup_temp_reproduce_dir()`: Create temporary directory - `cleanup_temp_reproduce_dir()`: Cleanup temporary directory - Added 5 unit tests: - `test_reproduce_mutual_exclusivity()`: Test parameter mutual exclusivity - `test_reproduce_kernel_default_launch_id()`: Test default launch_id - `test_reproduce_kernel_launch_id()`: End-to-end integration test - `test_reproduce_kernel_not_found()`: Test error handling - `test_reproduce_launch_id_out_of_range()`: Test boundary conditions - Refactored tests to use helper methods, following `TestTritonparseCUDA` pattern - Added imports at module level: `Path` and `tritonparse.reproducer.orchestrator` ## Usage ```bash # Existing: use line number (0-based) tritonparseoss reproduce trace.ndjson --line 4 # NEW: use kernel name + launch id (0-based) tritonparseoss reproduce trace.ndjson --kernel matmul_kernel --launch-id 2 # Also works with .ndjson.gz and .bin.ndjson files tritonparseoss reproduce trace.ndjson.gz --kernel matmul_kernel --launch-id 0 ``` ## Testing Tests use real data from `tests/example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz` when possible, mock data for edge cases. ## Notes - All indices are 0-based for consistency with Python conventions - Kernel name matching is case-sensitive (exact match only) - Backward compatible: existing `--line` usage continues to work - Error messages include helpful hints (valid range, similar kernel suggestions) Pull Request resolved: #209 Reviewed By: wychi Differential Revision: D88171118 Pulled By: FindHao fbshipit-source-id: 3f82ddd3ee3d5298acabede98ca689ae233d2f6f
1 parent f2b3a7b commit 3b40ee8

File tree

4 files changed

+183
-5
lines changed

4 files changed

+183
-5
lines changed

tests/test_tritonparse.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
import unittest
1515
from collections import defaultdict
1616
from dataclasses import dataclass
17+
from pathlib import Path
1718
from typing import Any, Union
1819

1920
import torch
2021
import torch._inductor.config as inductor_config
2122
import triton # @manual=//triton:triton
2223
import triton.language as tl # @manual=//triton:triton
2324
import tritonparse.context_manager
25+
import tritonparse.reproducer.orchestrator
2426
import tritonparse.structured_logging
2527
import tritonparse.utils
2628
from triton import knobs # @manual=//triton:triton
@@ -138,6 +140,26 @@ def clear_all_caches(*kernels):
138140
class TestTritonparseCPU(unittest.TestCase):
139141
"""CPU-only tests (no CUDA required)"""
140142

143+
def _get_test_ndjson_file(self):
144+
"""Get the test NDJSON file path."""
145+
gz_file = (
146+
Path(__file__).parent
147+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
148+
)
149+
self.assertTrue(gz_file.exists(), f"Test file not found: {gz_file}")
150+
return gz_file
151+
152+
def setup_temp_reproduce_dir(self):
153+
"""Setup temporary directory for reproduce tests."""
154+
temp_dir = tempfile.mkdtemp()
155+
out_dir = os.path.join(temp_dir, "repro_output")
156+
return temp_dir, out_dir
157+
158+
def cleanup_temp_reproduce_dir(self, temp_dir):
159+
"""Cleanup temporary directory for reproduce tests."""
160+
if not TEST_KEEP_OUTPUT:
161+
shutil.rmtree(temp_dir, ignore_errors=True)
162+
141163
def test_callsite_parsing(self):
142164
"""Test parsing of callsite locations in TTIR/TTGIR"""
143165
from tritonparse.ir_parser import extract_loc_definitions
@@ -482,6 +504,116 @@ def test_find_launch_index_out_of_range(self):
482504
self.assertIn("--launch-id 10", error_msg)
483505
self.assertIn("Valid range: 0 to 3", error_msg)
484506

507+
def test_reproduce_mutual_exclusivity(self):
508+
"""Test that --line and --kernel/--launch-id are mutually exclusive."""
509+
import argparse
510+
511+
from tritonparse.reproducer.cli import _add_reproducer_args
512+
513+
parser = argparse.ArgumentParser()
514+
_add_reproducer_args(parser)
515+
516+
# Test: both --line and --kernel provided should raise error
517+
# Create a mock parser with error method
518+
mock_parser = argparse.ArgumentParser()
519+
_add_reproducer_args(mock_parser)
520+
args = mock_parser.parse_args(
521+
["test.ndjson", "--line", "5", "--kernel", "matmul_kernel"]
522+
)
523+
524+
# The mutual exclusivity check happens in cli.py main()
525+
# We test that args are parsed correctly, and the check will happen there
526+
self.assertEqual(args.kernel, "matmul_kernel")
527+
self.assertEqual(args.line, 5)
528+
529+
# Test: only --kernel should work (line defaults to 0, which is allowed)
530+
args = parser.parse_args(["test.ndjson", "--kernel", "matmul_kernel"])
531+
self.assertEqual(args.kernel, "matmul_kernel")
532+
self.assertEqual(args.line, 0) # default value, allowed with --kernel
533+
534+
# Test: only --line should work
535+
args = parser.parse_args(["test.ndjson", "--line", "5"])
536+
self.assertEqual(args.line, 5)
537+
self.assertIsNone(args.kernel)
538+
539+
def test_reproduce_kernel_launch_id(self):
540+
"""End-to-end test: reproduce using --kernel and --launch-id."""
541+
gz_file = self._get_test_ndjson_file()
542+
temp_dir, out_dir = self.setup_temp_reproduce_dir()
543+
544+
try:
545+
# Test reproducing fused_op_kernel launch_id=0
546+
result = tritonparse.reproducer.orchestrator.reproduce(
547+
input_path=str(gz_file),
548+
line_index=0, # Placeholder, will be recalculated from kernel_name
549+
out_dir=out_dir,
550+
template="example",
551+
kernel_name="fused_op_kernel",
552+
launch_id=0,
553+
)
554+
555+
# Verify output structure
556+
self.assertIn("kernel", result)
557+
self.assertIn("repro_script", result)
558+
self.assertIn("repro_context", result)
559+
self.assertTrue(os.path.exists(result["repro_script"]))
560+
self.assertTrue(os.path.exists(result["repro_context"]))
561+
562+
# Verify the script contains kernel name
563+
script_content = Path(result["repro_script"]).read_text()
564+
self.assertIn("fused_op_kernel", script_content)
565+
566+
finally:
567+
self.cleanup_temp_reproduce_dir(temp_dir)
568+
569+
def test_reproduce_kernel_not_found(self):
570+
"""Test that proper error is raised when kernel not found."""
571+
gz_file = self._get_test_ndjson_file()
572+
temp_dir, out_dir = self.setup_temp_reproduce_dir()
573+
574+
try:
575+
with self.assertRaises(ValueError) as cm:
576+
tritonparse.reproducer.orchestrator.reproduce(
577+
input_path=str(gz_file),
578+
line_index=0, # Placeholder, will be recalculated from kernel_name
579+
out_dir=out_dir,
580+
template="example",
581+
kernel_name="nonexistent_kernel",
582+
launch_id=0,
583+
)
584+
585+
error_msg = str(cm.exception)
586+
self.assertIn("not found", error_msg)
587+
self.assertIn("nonexistent_kernel", error_msg)
588+
589+
finally:
590+
self.cleanup_temp_reproduce_dir(temp_dir)
591+
592+
def test_reproduce_launch_id_out_of_range(self):
593+
"""Test that proper error is raised when launch_id is out of range."""
594+
gz_file = self._get_test_ndjson_file()
595+
temp_dir, out_dir = self.setup_temp_reproduce_dir()
596+
597+
try:
598+
# fused_op_kernel has only 4 launches (0-3), test with launch_id=10
599+
with self.assertRaises(ValueError) as cm:
600+
tritonparse.reproducer.orchestrator.reproduce(
601+
input_path=str(gz_file),
602+
line_index=0, # Placeholder, will be recalculated from kernel_name
603+
out_dir=out_dir,
604+
template="example",
605+
kernel_name="fused_op_kernel",
606+
launch_id=10,
607+
)
608+
609+
error_msg = str(cm.exception)
610+
self.assertIn("has only 4 launches", error_msg)
611+
self.assertIn("--launch-id 10", error_msg)
612+
self.assertIn("Valid range: 0 to 3", error_msg)
613+
614+
finally:
615+
self.cleanup_temp_reproduce_dir(temp_dir)
616+
485617

486618
class TestTritonparseCUDA(unittest.TestCase):
487619
"""CUDA tests (require GPU)"""

tritonparse/cli.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def main():
6868
}
6969
unified_parse(**parse_args)
7070
elif args.func == "reproduce":
71+
# Check mutual exclusivity between --line and --kernel/--launch-id
72+
if args.kernel and args.line != 0:
73+
repro_parser.error("--line and --kernel/--launch-id are mutually exclusive")
74+
7175
replacer = None
7276
if args.use_fbcode:
7377
from tritonparse.fb.reproducer.replacer import FBCodePlaceholderReplacer
@@ -77,9 +81,11 @@ def main():
7781

7882
reproduce(
7983
input_path=args.input,
80-
line_index=args.line,
84+
line_index=args.line if not args.kernel else 0,
8185
out_dir=args.out_dir,
8286
template=args.template,
87+
kernel_name=args.kernel,
88+
launch_id=args.launch_id if args.kernel else 0,
8389
kernel_import=args.kernel_import,
8490
replacer=replacer,
8591
)

tritonparse/reproducer/cli.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,26 @@ def _add_reproducer_args(parser: argparse.ArgumentParser) -> None:
1414
default=0,
1515
help=(
1616
"The line index (0-based) of the launch event in the input file to reproduce. "
17-
"Defaults to 0 (first launch event)."
17+
"Defaults to 0 (first launch event). Mutually exclusive with --kernel/--launch-id."
18+
),
19+
)
20+
parser.add_argument(
21+
"--kernel",
22+
type=str,
23+
default=None,
24+
help=(
25+
"Kernel name (exact match, case-sensitive) to reproduce. "
26+
"Use with --launch-id to specify which launch of the kernel. "
27+
"Mutually exclusive with --line."
28+
),
29+
)
30+
parser.add_argument(
31+
"--launch-id",
32+
type=int,
33+
default=0,
34+
help=(
35+
"0-based launch index for the kernel specified by --kernel. "
36+
"Defaults to 0 (first launch). Only used when --kernel is provided."
1837
),
1938
)
2039
parser.add_argument(

tritonparse/reproducer/orchestrator.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Optional
55

6+
from tritonparse.info.kernel_query import find_launch_index_by_kernel
67
from tritonparse.reproducer.ingestion.ndjson import build_context_bundle
78
from tritonparse.reproducer.placeholder_replacer import (
89
DefaultPlaceholderReplacer,
@@ -20,24 +21,44 @@ def reproduce(
2021
line_index: int,
2122
out_dir: str,
2223
template: str,
24+
kernel_name: Optional[str] = None,
25+
launch_id: int = 0,
2326
replacer: Optional[PlaceholderReplacer] = None,
2427
kernel_import: KernelImportMode = KernelImportMode.DEFAULT,
2528
) -> dict[str, str]:
2629
"""
2730
Generate a reproducer script from NDJSON trace file.
2831
32+
Must provide either line_index OR (kernel_name + launch_id), not both.
33+
If kernel_name is provided, the line_index parameter will be ignored and
34+
recalculated from the kernel lookup.
35+
2936
Args:
30-
input_path: Path to the NDJSON trace file.
31-
line_index: 0-based index of the launch event to reproduce in the events list.
37+
input_path: Path to ndjson file. Supports uncompressed (.ndjson),
38+
gzip compressed (.ndjson.gz), and gzip member concatenation (.bin.ndjson) formats.
39+
line_index: 0-based index in events list. Ignored if kernel_name is provided.
3240
out_dir: Output directory for reproducer files.
3341
template: Template name to use for the reproducer.
42+
kernel_name: Exact kernel name to match (case-sensitive). If provided, line_index will be recalculated.
43+
launch_id: 0-based launch index for the kernel (default: 0, first launch).
3444
replacer: Optional custom PlaceholderReplacer instance. If None, uses DefaultPlaceholderReplacer.
3545
kernel_import: Kernel import mode (DEFAULT or COPY).
3646
"""
37-
logger.debug(f"Building bundle from {input_path} at line {line_index}")
3847
events = load_ndjson(Path(input_path))
3948
logger.debug(f"Loaded {len(events)} events")
4049

50+
# If kernel_name is provided, lookup the actual line_index (overrides the parameter)
51+
if kernel_name is not None:
52+
logger.debug(
53+
f"Looking up kernel '{kernel_name}' launch_id={launch_id} in {input_path}"
54+
)
55+
line_index = find_launch_index_by_kernel(events, kernel_name, launch_id)
56+
logger.debug(
57+
f"Found kernel '{kernel_name}' launch_id={launch_id} at line {line_index}"
58+
)
59+
60+
logger.debug(f"Building bundle from {input_path} at line {line_index}")
61+
4162
# Build context bundle from the specified launch event
4263
context_bundle = build_context_bundle(events, line_index)
4364
logger.debug(

0 commit comments

Comments
 (0)