Skip to content

Commit f2b3a7b

Browse files
FindHaometa-codesync[bot]
authored andcommitted
Create info/ Module (Core Query Layer) (#208)
Summary: This PR creates the `info/` module as the core query layer for kernel information from NDJSON trace files. This is an internal infrastructure PR that will be used by PR4 (reproduce `--kernel`/`--launch-id`) and PR5 (info CLI). ## Changes - **`tritonparse/info/__init__.py`**: Module initialization, exports core functions - **`tritonparse/info/kernel_query.py`**: Core query functions: - `KernelSummary` dataclass: kernel name, hash, total launches - `LaunchInfo` dataclass: launch ID, line index, grid (for PR5) - `list_kernels(events)`: List all kernels with their launch counts - `find_launch_index_by_kernel(events, kernel_name, launch_id)`: Find 0-based line index for a kernel's N-th launch - **`tests/test_tritonparse.py`**: Added 6 unit tests in `TestTritonparseCPU`: - `test_list_kernels_empty()`: Empty events list - `test_list_kernels_single()`: Single kernel with multiple launches - `test_list_kernels_multiple()`: Multiple different kernels - `test_find_launch_index_valid()`: Valid kernel name and launch_id - `test_find_launch_index_kernel_not_found()`: Raises ValueError when kernel not found - `test_find_launch_index_out_of_range()`: Raises ValueError with valid range hint ## Testing Tests use real data from `tests/example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz` when possible, mock data for edge cases (empty list, non-existent kernel, out-of-range launch_id). ## Notes - This PR does NOT create CLI. It's purely internal infrastructure. - All indices are 0-based for consistency with Python conventions. - Kernel name matching is case-sensitive (exact match only). Pull Request resolved: #208 Reviewed By: wychi Differential Revision: D88171102 Pulled By: FindHao fbshipit-source-id: 7c2ac5d74551a1d2b0667dd8b4734972298d7a62
1 parent 24e002e commit f2b3a7b

File tree

3 files changed

+277
-0
lines changed

3 files changed

+277
-0
lines changed

tests/test_tritonparse.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,152 @@ def test_load_ndjson_gzip_support(self):
336336

337337
print(f"✓ Successfully loaded {len(events)} events from .ndjson.gz file")
338338

339+
def test_list_kernels_empty(self):
340+
"""Test listing kernels from empty events list."""
341+
from tritonparse.info.kernel_query import list_kernels
342+
343+
events = []
344+
result = list_kernels(events)
345+
self.assertEqual(result, [])
346+
347+
def test_list_kernels_single(self):
348+
"""Test listing kernels with single kernel and multiple launches."""
349+
from pathlib import Path
350+
351+
from tritonparse.info.kernel_query import list_kernels
352+
from tritonparse.tools.prettify_ndjson import load_ndjson
353+
354+
# Load real test data
355+
gz_file = (
356+
Path(__file__).parent
357+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
358+
)
359+
events = load_ndjson(gz_file)
360+
361+
# Filter to only fused_op_kernel launches (4 launches)
362+
filtered_events = []
363+
for event in events:
364+
if event.get("event_type") == "launch":
365+
kernel_name = event.get("compilation_metadata", {}).get("name")
366+
if kernel_name == "fused_op_kernel":
367+
filtered_events.append(event)
368+
else:
369+
# Keep non-launch events to test filtering
370+
filtered_events.append(event)
371+
372+
result = list_kernels(filtered_events)
373+
self.assertEqual(len(result), 1)
374+
self.assertEqual(result[0].name, "fused_op_kernel")
375+
self.assertEqual(result[0].total_launches, 4)
376+
377+
def test_list_kernels_multiple(self):
378+
"""Test listing kernels with multiple different kernels."""
379+
from pathlib import Path
380+
381+
from tritonparse.info.kernel_query import list_kernels
382+
from tritonparse.tools.prettify_ndjson import load_ndjson
383+
384+
# Load real test data
385+
gz_file = (
386+
Path(__file__).parent
387+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
388+
)
389+
events = load_ndjson(gz_file)
390+
391+
result = list_kernels(events)
392+
self.assertEqual(len(result), 2)
393+
394+
# Check that results are sorted by name
395+
names = [k.name for k in result]
396+
self.assertEqual(names, ["fused_op_kernel", "matmul_kernel"])
397+
398+
# Check launch counts
399+
kernel_dict = {k.name: k for k in result}
400+
self.assertEqual(kernel_dict["matmul_kernel"].total_launches, 1553)
401+
self.assertEqual(kernel_dict["fused_op_kernel"].total_launches, 4)
402+
403+
def test_find_launch_index_valid(self):
404+
"""Test finding valid kernel name and launch_id."""
405+
from pathlib import Path
406+
407+
from tritonparse.info.kernel_query import find_launch_index_by_kernel
408+
from tritonparse.tools.prettify_ndjson import load_ndjson
409+
410+
# Load real test data
411+
gz_file = (
412+
Path(__file__).parent
413+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
414+
)
415+
events = load_ndjson(gz_file)
416+
417+
# Test first launch of fused_op_kernel (launch_id=0)
418+
index = find_launch_index_by_kernel(events, "fused_op_kernel", 0)
419+
self.assertEqual(events[index].get("event_type"), "launch")
420+
self.assertEqual(
421+
events[index].get("compilation_metadata", {}).get("name"),
422+
"fused_op_kernel",
423+
)
424+
425+
# Test second launch of fused_op_kernel (launch_id=1)
426+
index = find_launch_index_by_kernel(events, "fused_op_kernel", 1)
427+
self.assertEqual(events[index].get("event_type"), "launch")
428+
self.assertEqual(
429+
events[index].get("compilation_metadata", {}).get("name"),
430+
"fused_op_kernel",
431+
)
432+
433+
# Test first launch of matmul_kernel (launch_id=0)
434+
index = find_launch_index_by_kernel(events, "matmul_kernel", 0)
435+
self.assertEqual(events[index].get("event_type"), "launch")
436+
self.assertEqual(
437+
events[index].get("compilation_metadata", {}).get("name"),
438+
"matmul_kernel",
439+
)
440+
441+
def test_find_launch_index_kernel_not_found(self):
442+
"""Test that ValueError is raised when kernel not found."""
443+
from pathlib import Path
444+
445+
from tritonparse.info.kernel_query import find_launch_index_by_kernel
446+
from tritonparse.tools.prettify_ndjson import load_ndjson
447+
448+
# Load real test data
449+
gz_file = (
450+
Path(__file__).parent
451+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
452+
)
453+
events = load_ndjson(gz_file)
454+
455+
with self.assertRaises(ValueError) as cm:
456+
find_launch_index_by_kernel(events, "nonexistent_kernel", 0)
457+
458+
error_msg = str(cm.exception)
459+
self.assertIn("not found", error_msg)
460+
self.assertIn("nonexistent_kernel", error_msg)
461+
462+
def test_find_launch_index_out_of_range(self):
463+
"""Test that ValueError is raised when launch_id is out of range."""
464+
from pathlib import Path
465+
466+
from tritonparse.info.kernel_query import find_launch_index_by_kernel
467+
from tritonparse.tools.prettify_ndjson import load_ndjson
468+
469+
# Load real test data
470+
gz_file = (
471+
Path(__file__).parent
472+
/ "example_output/parsed_output_complex/dedicated_log_triton_trace_findhao__mapped.ndjson.gz"
473+
)
474+
events = load_ndjson(gz_file)
475+
476+
# fused_op_kernel has only 4 launches (0-3), test with launch_id=10
477+
with self.assertRaises(ValueError) as cm:
478+
find_launch_index_by_kernel(events, "fused_op_kernel", 10)
479+
480+
error_msg = str(cm.exception)
481+
self.assertIn("has only 4 launches", error_msg)
482+
self.assertIn("--launch-id 10", error_msg)
483+
self.assertIn("Valid range: 0 to 3", error_msg)
484+
339485

340486
class TestTritonparseCUDA(unittest.TestCase):
341487
"""CUDA tests (require GPU)"""

tritonparse/info/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
"""
4+
Info module for querying kernel information from NDJSON trace files.
5+
6+
This module provides core query functions for kernel information:
7+
- Listing all kernels with their launch counts
8+
- Finding launch events by kernel name and launch ID
9+
- Querying launch information for specific kernels
10+
"""
11+
12+
from tritonparse.info.kernel_query import (
13+
find_launch_index_by_kernel,
14+
KernelSummary,
15+
LaunchInfo,
16+
list_kernels,
17+
)
18+
19+
__all__ = [
20+
"KernelSummary",
21+
"LaunchInfo",
22+
"list_kernels",
23+
"find_launch_index_by_kernel",
24+
]

tritonparse/info/kernel_query.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
"""
4+
Core query functions for kernel information from NDJSON trace files.
5+
6+
This module provides functions to query kernel launch information from parsed
7+
event lists. It supports both raw log files and parsed ndjson files (with launch_diff events).
8+
"""
9+
10+
from collections import defaultdict
11+
from dataclasses import dataclass
12+
from typing import Any, Dict, List
13+
14+
15+
@dataclass
16+
class KernelSummary:
17+
"""Summary information about a kernel."""
18+
19+
name: str
20+
hash: str
21+
total_launches: int
22+
23+
24+
@dataclass
25+
class LaunchInfo:
26+
"""Information about a specific kernel launch."""
27+
28+
launch_id: int # 0-based
29+
line_index: int # 0-based (index in events list)
30+
grid: List[int]
31+
32+
33+
def list_kernels(events: List[Dict[str, Any]]) -> List[KernelSummary]:
34+
"""
35+
List all kernels with their launch counts.
36+
37+
Args:
38+
events: List of parsed event dictionaries from NDJSON file
39+
40+
Returns:
41+
List of KernelSummary objects, sorted by kernel name
42+
"""
43+
# Count launches per kernel
44+
kernel_counts: Dict[str, Dict[str, Any]] = defaultdict(
45+
lambda: {"hash": "", "count": 0}
46+
)
47+
48+
for event in events:
49+
if event.get("event_type") != "launch":
50+
continue
51+
52+
comp_meta = event.get("compilation_metadata", {})
53+
kernel_name = comp_meta.get("name")
54+
kernel_hash = comp_meta.get("hash", "")
55+
56+
if kernel_name:
57+
kernel_counts[kernel_name]["hash"] = kernel_hash
58+
kernel_counts[kernel_name]["count"] += 1
59+
60+
# Convert to KernelSummary list
61+
summaries = [
62+
KernelSummary(name=name, hash=info["hash"], total_launches=info["count"])
63+
for name, info in kernel_counts.items()
64+
]
65+
66+
# Sort by kernel name for consistent output
67+
summaries.sort(key=lambda x: x.name)
68+
69+
return summaries
70+
71+
72+
def find_launch_index_by_kernel(
73+
events: List[Dict[str, Any]], kernel_name: str, launch_id: int
74+
) -> int:
75+
"""
76+
Find the 0-based line index for a kernel's N-th launch.
77+
78+
Args:
79+
events: List of parsed event dictionaries
80+
kernel_name: Exact kernel name to match (case-sensitive)
81+
launch_id: 0-based launch index for the kernel
82+
83+
Returns:
84+
0-based line index (index in events list)
85+
86+
Raises:
87+
ValueError: If kernel not found or launch_id out of range
88+
"""
89+
count = 0
90+
for i, event in enumerate(events):
91+
if event.get("event_type") != "launch":
92+
continue
93+
94+
comp_meta = event.get("compilation_metadata", {})
95+
name = comp_meta.get("name")
96+
if name == kernel_name:
97+
if count == launch_id:
98+
return i
99+
count += 1
100+
101+
if count == 0:
102+
raise ValueError(f"Kernel '{kernel_name}' not found")
103+
else:
104+
raise ValueError(
105+
f"Kernel '{kernel_name}' has only {count} launches, "
106+
f"but --launch-id {launch_id} was requested. Valid range: 0 to {count - 1}"
107+
)

0 commit comments

Comments
 (0)