You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/bring_your_own_kernel.mdx
+51-52Lines changed: 51 additions & 52 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,8 +1,8 @@
1
1
# Bring Your Own Kernel to FlashInfer-Bench
2
2
3
-
This guide gives instructions on how to add Definitions, Solutions, capture Workloads, and record Evaluations by walking through each **component of the Trace**, with an end-to-end “apply at runtime” flow.
3
+
This guide gives instructions on how to add Definitions, Solutions, capture Workloads, and record Evaluations by walking through each **component of the Trace**, with an end-to-end "apply at runtime" flow.
4
4
5
-
A **Trace** is an atomic, immutable record of a single benchmark run. It links a specific `Solution` to a specific `Definition`, fixes the exact `workload` (shapes + data), and stores the complete `evaluation`. A folder of Definitions, Solutions, and Traces is your benchmark database.
5
+
A **Trace** is an atomic, immutable record of a single benchmark run. It links a specific `Solution` to a specific `Definition`, fixes the exact `workload` (input shapes + input data), and stores the complete `evaluation`. A folder of Definitions, Solutions, and Traces is your benchmark database.
6
6
7
7
## Trace Schema (top level)
8
8
@@ -13,17 +13,19 @@ A **Trace** is an atomic, immutable record of a single benchmark run. It links a
13
13
|`workload`| object | Yes | Concrete shapes and input data used for this run. |
More details about schema are in [FlashInfer Trace Schema](https://bench.flashinfer.ai/docs/flashinfer_trace/flashinfer_trace).
17
+
16
18
## Component 1: `definition`
17
19
18
-
**What it is.** The operator’s contract: axes (const/var), inputs/outputs, constraints, and a correct (not necessarily fast) `reference`.
20
+
**What it is:** The operator’s contract: axes (const/var), inputs/outputs, constraints, and a correct (not necessarily fast) `reference`.
19
21
20
-
**Identity rule.** Two kernels are the **same Definition** iff:
22
+
**Identity rule:** Two kernels are under the same Definition iff:
21
23
22
-
* They have the **same axes**,
23
-
* Each axis has the **same role** (`const` vs `var`),
24
-
* All `const` axes have the **same values**.
24
+
* They have the same axes,
25
+
* Each axis has the same role (`const` vs `var`),
26
+
* All `const` axes have the same values.
25
27
26
-
**How to add a new kernel Definition.**
28
+
**How to add a new kernel Definition:**
27
29
28
30
1. Refer to schema, choose a `name` (`<type>_<stage>_<axis tokens>`) and `type`; write a clear `description` and helpful `tags`.
29
31
2. Specify `axes` with `type: const|var` (+ `value` for const).
@@ -34,58 +36,58 @@ A **Trace** is an atomic, immutable record of a single benchmark run. It links a
34
36
35
37
## Component 2: `solution`
36
38
37
-
**What it is.** A concrete implementation of a Definition’s interface (Triton/CUDA/CUTLASS/PyTorch, etc.) plus metadata including target archs, libraries, author (human or LLM).
39
+
**What it is:** A concrete implementation of a Definition’s interface (Triton/CUDA/CUTLASS/PyTorch, etc.) plus metadata including target archs, libraries, author (human or LLM).
38
40
39
-
**Interface.** Your function must take the Definition’s `inputs` and **return** the tuple of `outputs`.
41
+
**Interface:** Your function must take the Definition’s `inputs` and return the tuple of `outputs`.
40
42
41
-
**How to add a Solution.**
43
+
**How to add a Solution:**
42
44
43
45
1. Add the implementation of the kernel (matching signature).
44
46
2. Provide metadata co-located with the code, according to schema.
45
-
3. Add unit tests vs `reference` across a representative shapes.
46
-
47
-
See agent.md (to be added) for our methods to generate Solutions with LLMs.
47
+
3. Add unit tests vs `reference` across representative shapes.
48
48
49
49
## Component 3: `workload`
50
50
51
-
**What it is.** The concrete axes + input data that instantiate a Definition for one run.
51
+
**What it is:** The concrete axes + input data that instantiate a Definition for one run.
# defaults to `~/.cache/flashinfer_bench/dataset` if unset
67
67
```
68
68
69
-
2.**Enable tracing and run your engine or script:**
69
+
2. Enable tracing and run your engine or script:
70
70
71
71
```bash
72
-
exportFLASHINFER_BENCH_ENABLE_TRACING=1
72
+
exportFIB_ENABLE_TRACING=1
73
73
python run_engine.py # your serving or batch script
74
74
```
75
75
76
-
By default, all kernels with a matching **Definition** are traced.
76
+
By default, all kernels specified with its [tracing config](https://github.com/flashinfer-ai/flashinfer-bench/blob/main/flashinfer_bench/tracing/builtin/configs.py) with a matching Definition are traced.
77
77
78
-
3.**What gets saved & where (default layout):**
78
+
3. What gets saved & where (default layout):
79
79
```
80
-
$FLASHINFER_BENCH_DATASET_PATH/
81
-
└── workloads/
82
-
├── *.jsonl # workload records (FlashInfer Trace format)
83
-
└── safetensors/ # tensor payloads (when dumped)
80
+
$FIB_DATASET_PATH/
81
+
├── workloads/
82
+
│ └── <op_type>/
83
+
│ └── <definition_name>.jsonl # workload records (FlashInfer Trace format)
84
+
└── blob/
85
+
└── workloads/ # tensor payloads (safetensors, when dumped)
84
86
```
85
87
86
-
Writing tensors to file is **async** (background thread) to reduce runtime overhead.
88
+
Writing tensors to file is async (background thread) to reduce runtime overhead.
87
89
88
-
### **Tracing in code (fine-grained control)**
90
+
### Tracing in code (fine-grained control)
89
91
90
92
If you want to target a subset of kernels / customize policies:
91
93
@@ -95,14 +97,14 @@ import flashinfer_bench as fib
95
97
# 1) Pick which kernels to trace and how
96
98
from flashinfer_bench import TracingConfig
97
99
98
-
gqa_tracing= TracingConfig(
99
-
tensor_dump_policy="dump_non_float", # keep scalar and int tensors; skip large float payloads
100
+
gqa_paged_prefill_config= TracingConfig(
101
+
input_dump_policy="dump_non_float", # keep scalar and int tensors; skip large float payloads
100
102
filter_policy="shape_only", # save first occurrence per input-shape signature
@@ -112,47 +114,44 @@ with fib.enable_tracing(dataset_path="/root/flashinfer-trace", tracing_configs=c
112
114
113
115
**Policies you can use right away:**
114
116
115
-
*`tensor_dump_policy`: `"dump_all"`, `"dump_none"`, `"dump_non_float"`, or a list of input names to dump.
116
-
*`filter_policy`: `"keep_all"`, `"shape_only"`, `"keep_first_k"` (e.g., first k calls), or a custom callable `Workload -> key`.
117
-
These reduce disk/time while keeping representative samples.
117
+
*`input_dump_policy`: `"dump_all"`, `"dump_none"`, `"dump_int32"`, or a list of input names to dump, like `input_dump_policy=["qo_indptr", "kv_indptr", "kv_indices", "sm_scale"]`.
118
+
*`filter_policy`: `"keep_all"`, `"keep_first"` (e.g., first k calls), `"keep_first_by_axes"`, `"keep_none"`, or a custom callable `Workload -> key`. These reduce disk/time while keeping representative samples.
118
119
119
120
## Component 4: `evaluation`
120
121
121
-
**What it is.** The result bundle for one `(definition, solution, workload)` run.
122
+
**What it is:** The result bundle for one `(definition, solution, workload)` run.
122
123
123
-
**How to benchmark to produce Evaluations.**
124
+
**How to benchmark to produce Evaluations:**
124
125
Run the benchmarker over your `(definition, solution, workload)` triples in the dataset:
125
126
126
-
CLI:
127
+
Using CLI:
127
128
```bash
128
-
flashinfer-bench run --local ./flashinfer-trace --warmup-runs 10 --iterations 50 --save-results
129
+
flashinfer-bench run --local /path/to/flashinfer-trace
129
130
```
130
131
131
-
Use Python API:
132
-
### Prepare a `TraceSet` and Run the benchmark
132
+
Using Python API:
133
133
134
134
```python
135
135
from flashinfer_bench.data import TraceSet
136
136
from flashinfer_bench.bench import Benchmark
137
-
from flashinfer_bench.bench import BenchmarkConfig
trace_set = TraceSet(root="./flashinfer-trace") # scans for definitions, solutions, workloads
141
140
142
141
# 2) Run the benchmark
143
-
bench= Benchmark(trace_set)
144
-
bench.run_all(dump_traces=True)# executes reference + solutions in parallel
142
+
benchmark= Benchmark(trace_set, config)
143
+
benchmark.run_all(save_results=True)
145
144
```
146
145
147
-
***Device pool.** One `MultiProcessRunner` is created per CUDA device.
148
-
***Concurrency.** For each definition and workload, the benchmark:
146
+
***Device pool:** One `MultiProcessRunner` is created per CUDA device.
147
+
***Concurrency:** For each definition and workload, the benchmark:
149
148
150
149
* Picks up to `K = min(#devices, #solutions)` runners (round-robin).
151
150
***Reference phase:** in parallel, calls `runner.run_ref(defn, wl, config)` to build a baseline on each selected runner.
152
151
153
152
* If a runner fails during reference, it is removed from the pool and the workload on that runner is skipped.
154
153
***Solutions phase:** distributes solutions round-robin across the runners that succeeded in the reference phase, calling `runner.run_solution(sol, baseline_handle, config)` in parallel.
155
-
***Status mapping.**
154
+
***Status mapping:**
156
155
157
156
* Successful run with numerics in tolerance → `PASSED`.
0 commit comments