Skip to content

Commit ea72b85

Browse files
authored
docs: Clarify and enhance BYO Kernel guide (#85)
1 parent cea26dc commit ea72b85

File tree

1 file changed

+51
-52
lines changed

1 file changed

+51
-52
lines changed

docs/tutorials/bring_your_own_kernel.mdx

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Bring Your Own Kernel to FlashInfer-Bench
22

3-
This guide gives instructions on how to add Definitions, Solutions, capture Workloads, and record Evaluations by walking through each **component of the Trace**, with an end-to-end apply at runtime flow.
3+
This guide gives instructions on how to add Definitions, Solutions, capture Workloads, and record Evaluations by walking through each **component of the Trace**, with an end-to-end "apply at runtime" flow.
44

5-
A **Trace** is an atomic, immutable record of a single benchmark run. It links a specific `Solution` to a specific `Definition`, fixes the exact `workload` (shapes + data), and stores the complete `evaluation`. A folder of Definitions, Solutions, and Traces is your benchmark database.
5+
A **Trace** is an atomic, immutable record of a single benchmark run. It links a specific `Solution` to a specific `Definition`, fixes the exact `workload` (input shapes + input data), and stores the complete `evaluation`. A folder of Definitions, Solutions, and Traces is your benchmark database.
66

77
## Trace Schema (top level)
88

@@ -13,17 +13,19 @@ A **Trace** is an atomic, immutable record of a single benchmark run. It links a
1313
| `workload` | object | Yes | Concrete shapes and input data used for this run. |
1414
| `evaluation` | object | Yes | Results, logs, and environment snapshot. |
1515

16+
More details about schema are in [FlashInfer Trace Schema](https://bench.flashinfer.ai/docs/flashinfer_trace/flashinfer_trace).
17+
1618
## Component 1: `definition`
1719

18-
**What it is.** The operator’s contract: axes (const/var), inputs/outputs, constraints, and a correct (not necessarily fast) `reference`.
20+
**What it is:** The operator’s contract: axes (const/var), inputs/outputs, constraints, and a correct (not necessarily fast) `reference`.
1921

20-
**Identity rule.** Two kernels are the **same Definition** iff:
22+
**Identity rule:** Two kernels are under the same Definition iff:
2123

22-
* They have the **same axes**,
23-
* Each axis has the **same role** (`const` vs `var`),
24-
* All `const` axes have the **same values**.
24+
* They have the same axes,
25+
* Each axis has the same role (`const` vs `var`),
26+
* All `const` axes have the same values.
2527

26-
**How to add a new kernel Definition.**
28+
**How to add a new kernel Definition:**
2729

2830
1. Refer to schema, choose a `name` (`<type>_<stage>_<axis tokens>`) and `type`; write a clear `description` and helpful `tags`.
2931
2. Specify `axes` with `type: const|var` (+ `value` for const).
@@ -34,58 +36,58 @@ A **Trace** is an atomic, immutable record of a single benchmark run. It links a
3436

3537
## Component 2: `solution`
3638

37-
**What it is.** A concrete implementation of a Definition’s interface (Triton/CUDA/CUTLASS/PyTorch, etc.) plus metadata including target archs, libraries, author (human or LLM).
39+
**What it is:** A concrete implementation of a Definition’s interface (Triton/CUDA/CUTLASS/PyTorch, etc.) plus metadata including target archs, libraries, author (human or LLM).
3840

39-
**Interface.** Your function must take the Definition’s `inputs` and **return** the tuple of `outputs`.
41+
**Interface:** Your function must take the Definition’s `inputs` and return the tuple of `outputs`.
4042

41-
**How to add a Solution.**
43+
**How to add a Solution:**
4244

4345
1. Add the implementation of the kernel (matching signature).
4446
2. Provide metadata co-located with the code, according to schema.
45-
3. Add unit tests vs `reference` across a representative shapes.
46-
47-
See agent.md (to be added) for our methods to generate Solutions with LLMs.
47+
3. Add unit tests vs `reference` across representative shapes.
4848

4949
## Component 3: `workload`
5050

51-
**What it is.** The concrete axes + input data that instantiate a Definition for one run.
51+
**What it is:** The concrete axes + input data that instantiate a Definition for one run.
5252

5353
| Field | Description |
5454
| -------- | --------------------------------------------- |
55-
| `axes` | Map of **var** axis → concrete int value. |
56-
| `inputs` | Map of **input name****actual input**. |
55+
| `axes` | Map of var axis → concrete int value. |
56+
| `inputs` | Map of input name → actual input. |
5757

58-
**How to capture workloads.**
58+
**How to capture workloads:**
5959

60-
### **Env-vars (zero-code):**
60+
### Env-vars (zero-code)
6161

62-
1. **Choose an output dataset root** (optional):
62+
1. Choose an output dataset root (optional):
6363

6464
```bash
65-
export FLASHINFER_BENCH_DATASET_PATH=/root/flashinfer-trace
66-
# defaults to ./flashinfer-trace if unset
65+
export FIB_DATASET_PATH=/root/flashinfer-trace
66+
# defaults to `~/.cache/flashinfer_bench/dataset` if unset
6767
```
6868

69-
2. **Enable tracing and run your engine or script:**
69+
2. Enable tracing and run your engine or script:
7070

7171
```bash
72-
export FLASHINFER_BENCH_ENABLE_TRACING=1
72+
export FIB_ENABLE_TRACING=1
7373
python run_engine.py # your serving or batch script
7474
```
7575

76-
By default, all kernels with a matching **Definition** are traced.
76+
By default, all kernels specified with its [tracing config](https://github.com/flashinfer-ai/flashinfer-bench/blob/main/flashinfer_bench/tracing/builtin/configs.py) with a matching Definition are traced.
7777

78-
3. **What gets saved & where (default layout):**
78+
3. What gets saved & where (default layout):
7979
```
80-
$FLASHINFER_BENCH_DATASET_PATH/
81-
└── workloads/
82-
├── *.jsonl # workload records (FlashInfer Trace format)
83-
└── safetensors/ # tensor payloads (when dumped)
80+
$FIB_DATASET_PATH/
81+
├── workloads/
82+
│ └── <op_type>/
83+
│ └── <definition_name>.jsonl # workload records (FlashInfer Trace format)
84+
└── blob/
85+
└── workloads/ # tensor payloads (safetensors, when dumped)
8486
```
8587

86-
Writing tensors to file is **async** (background thread) to reduce runtime overhead.
88+
Writing tensors to file is async (background thread) to reduce runtime overhead.
8789

88-
### **Tracing in code (fine-grained control)**
90+
### Tracing in code (fine-grained control)
8991

9092
If you want to target a subset of kernels / customize policies:
9193

@@ -95,14 +97,14 @@ import flashinfer_bench as fib
9597
# 1) Pick which kernels to trace and how
9698
from flashinfer_bench import TracingConfig
9799

98-
gqa_tracing = TracingConfig(
99-
tensor_dump_policy="dump_non_float", # keep scalar and int tensors; skip large float payloads
100+
gqa_paged_prefill_config = TracingConfig(
101+
input_dump_policy="dump_non_float", # keep scalar and int tensors; skip large float payloads
100102
filter_policy="shape_only", # save first occurrence per input-shape signature
101103
)
102104

103105
configs = {
104-
"gqa_paged_decode_h32_kv4_d128_ps1": gqa_tracing,
105-
# more kernel definitions...
106+
"gqa_paged_prefill_causal_h32_kv4_d128_ps1": gqa_paged_prefill_config,
107+
# more tracing config mappings...
106108
}
107109

108110
# 2) Enable, run, then finalize
@@ -112,47 +114,44 @@ with fib.enable_tracing(dataset_path="/root/flashinfer-trace", tracing_configs=c
112114

113115
**Policies you can use right away:**
114116

115-
* `tensor_dump_policy`: `"dump_all"`, `"dump_none"`, `"dump_non_float"`, or a list of input names to dump.
116-
* `filter_policy`: `"keep_all"`, `"shape_only"`, `"keep_first_k"` (e.g., first k calls), or a custom callable `Workload -> key`.
117-
These reduce disk/time while keeping representative samples.
117+
* `input_dump_policy`: `"dump_all"`, `"dump_none"`, `"dump_int32"`, or a list of input names to dump, like `input_dump_policy=["qo_indptr", "kv_indptr", "kv_indices", "sm_scale"]`.
118+
* `filter_policy`: `"keep_all"`, `"keep_first"` (e.g., first k calls), `"keep_first_by_axes"`, `"keep_none"`, or a custom callable `Workload -> key`. These reduce disk/time while keeping representative samples.
118119

119120
## Component 4: `evaluation`
120121

121-
**What it is.** The result bundle for one `(definition, solution, workload)` run.
122+
**What it is:** The result bundle for one `(definition, solution, workload)` run.
122123

123-
**How to benchmark to produce Evaluations.**
124+
**How to benchmark to produce Evaluations:**
124125
Run the benchmarker over your `(definition, solution, workload)` triples in the dataset:
125126

126-
CLI:
127+
Using CLI:
127128
```bash
128-
flashinfer-bench run --local ./flashinfer-trace --warmup-runs 10 --iterations 50 --save-results
129+
flashinfer-bench run --local /path/to/flashinfer-trace
129130
```
130131

131-
Use Python API:
132-
### Prepare a `TraceSet` and Run the benchmark
132+
Using Python API:
133133

134134
```python
135135
from flashinfer_bench.data import TraceSet
136136
from flashinfer_bench.bench import Benchmark
137-
from flashinfer_bench.bench import BenchmarkConfig
138137

139138
# 1) Build TraceSet (definitions, solutions, workloads)
140139
trace_set = TraceSet(root="./flashinfer-trace") # scans for definitions, solutions, workloads
141140

142141
# 2) Run the benchmark
143-
bench = Benchmark(trace_set)
144-
bench.run_all(dump_traces=True) # executes reference + solutions in parallel
142+
benchmark = Benchmark(trace_set, config)
143+
benchmark.run_all(save_results=True)
145144
```
146145

147-
* **Device pool.** One `MultiProcessRunner` is created per CUDA device.
148-
* **Concurrency.** For each definition and workload, the benchmark:
146+
* **Device pool:** One `MultiProcessRunner` is created per CUDA device.
147+
* **Concurrency:** For each definition and workload, the benchmark:
149148

150149
* Picks up to `K = min(#devices, #solutions)` runners (round-robin).
151150
* **Reference phase:** in parallel, calls `runner.run_ref(defn, wl, config)` to build a baseline on each selected runner.
152151

153152
* If a runner fails during reference, it is removed from the pool and the workload on that runner is skipped.
154153
* **Solutions phase:** distributes solutions round-robin across the runners that succeeded in the reference phase, calling `runner.run_solution(sol, baseline_handle, config)` in parallel.
155-
* **Status mapping.**
154+
* **Status mapping:**
156155

157156
* Successful run with numerics in tolerance → `PASSED`.
158157
* Output shape/dtype mismatch → `INCORRECT_SHAPE` / `INCORRECT_DTYPE`.
@@ -213,7 +212,7 @@ def gemm(A, B):
213212
**Turn on runtime substitution:**
214213

215214
```bash
216-
export FLASHINFER_BENCH_ENABLE_APPLY=1
215+
export FIB_ENABLE_APPLY=1
217216
python serve_or_benchmark.py
218217
```
219218

0 commit comments

Comments
 (0)