Skip to content

Commit 583f825

Browse files
authored
Revise 'apply' documentation for FlashInfer integration (#86)
Updated the documentation for the 'apply' feature in FlashInfer, enhancing clarity and detail on usage, runtime substitution, and custom integration patterns.
1 parent ea72b85 commit 583f825

File tree

1 file changed

+79
-103
lines changed

1 file changed

+79
-103
lines changed

docs/tutorials/bring_your_own_kernel.mdx

Lines changed: 79 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -196,133 +196,109 @@ After benchmarking is done, the results can be used to rank solutions, visualize
196196
* Use runtime substitution to dispatch to the **best** ranked Solution for the current shapes.
197197

198198

199-
## End-to-end apply” (ties Trace back to serving)
199+
## End-to-end "apply"
200200

201-
**Decorator form:**
202-
203-
```python
204-
import torch, torch.nn.functional as F
205-
import flashinfer
206-
207-
@flashinfer.apply(lambda A, B: f"gemm_n_{B.shape[0]}_k_{B.shape[1]}")
208-
def gemm(A, B):
209-
return F.linear(A, B) # fallback/reference or a simple baseline
210-
```
211-
212-
**Turn on runtime substitution:**
201+
With `apply`, we can dynamically replace the kernels in the FlashInfer API with the best-performing ones from our traces. With adapters already written for FlashInfer, you can enable integration with minimal code changes.
213202

214203
```bash
215204
export FIB_ENABLE_APPLY=1
205+
export FIB_DATASET_PATH=/path/to/flashinfer-trace
216206
python serve_or_benchmark.py
217207
```
218208

219-
At call time, `apply` looks up the **Definition** (by name or via the lambda), matches the current **workload** (axes +, when required, data properties), and dispatches to the **best** `Solution` according to your recorded **Traces** (with correctness constraints and numeric tolerances enforced).
209+
At call time, `apply` looks up the Definition, matches the current workload (axes and input data properties), and dispatches to the best Solution according to our Traces (with correctness constraints and numeric tolerances enforced).
220210

221-
### Advanced Usage: Supporting kernels that dont align with the Definition
211+
### Supporting kernels that don't align with the Definition with adapters
222212

223-
Sometimes your production call site cant be decorated directly—e.g., wrappers that keep internal state across `plan()`/`run()` like `BatchPrefillWithPagedKVCacheWrapper`. In these cases the function you call at runtime doesn’t match the kernel definition’s flat signature, so the decorator can’t attach cleanly. Use the imperative form instead.
213+
Sometimes your production call site can't be decorated directly—e.g., wrappers that keep internal state across `plan()`/`run()` like `BatchPrefillWithPagedKVCacheWrapper`. FlashInfer-Bench provides built-in adapters for common FlashInfer kernels, and you can also use the imperative `apply()` API for custom integration patterns.
224214

225-
#### Imperative `apply(...)` API
215+
#### Built-in FlashInfer Integration (Recommended)
226216

227-
Use the function form of `apply` anywhere you call the kernel. It will (1) in **apply** mode: look up the best Solution for the current workload and call it; (2) in **tracing** mode: record the workload, then run the fallback; (3) otherwise: just call the fallback.
217+
FlashInfer-Bench automatically patches common FlashInfer kernels when you enable apply. No manual decoration needed:
228218

229-
```python
230-
import flashinfer
219+
**How it works:** When you call `enable_apply()`, FlashInfer-Bench automatically installs lightweight adapters that:
220+
1. Intercept FlashInfer wrapper methods (`plan` and `run`)
221+
2. Extract runtime parameters and match them to definitions
222+
3. Dispatch to the best-performing solution from your traces
223+
4. Fall back to the original FlashInfer implementation if no suitable solution exists
231224

232-
result = flashinfer.apply(
233-
name: Union[str, Callable[..., str]],
234-
fallback_function: Callable[..., Any],
235-
*args, # All arguments must follow the **kernel definition’s interface
236-
**kwargs,
237-
)
238-
```
225+
**Supported kernels:**
226+
- `flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper` (page_size=1)
227+
- `flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` (causal=True, page_size=1)
228+
- `flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper` (causal=True)
229+
- `flashinfer.norm.fused_add_rmsnorm`
239230

240-
#### Example: stateful paged-attention wrapper → imperative `apply`
231+
See `flashinfer_bench/integration/flashinfer/adapters/` for the complete list and implementation details.
241232

242-
In this example, the FlashInfer attention wrapper carries state from `plan()` into `run()`, while the FlashInfer-Bench definition exposes a single `attention(init_params, plan_params, run_params)` entry point. Bridge them with a small monkey-patch that reconstructs the original flow as the fallback:
233+
#### Imperative `apply()` API (Custom Integration)
243234

244-
```python
245-
# Original wrapper shape (state lives across plan/run)
246-
class AttentionWrapper:
247-
def __init__(self, init_params): ...
248-
def plan(self, plan_params):
249-
self.state = compute_state(plan_params)
250-
def run(self, run_params):
251-
return call_flashinfer_kernel(run_params, self.state)
252-
253-
# FlashInfer-Bench-side definition interface we want to target:
254-
def attention(init_params, plan_params, run_params):
255-
return attention_kernel(init_params, plan_params, run_params)
256-
# (covers Q, K, V, page_size, page_indptr, etc.)
257-
```
235+
For custom kernels or integration patterns not covered by the built-in adapters, use the function form of `apply`:
258236

259237
```python
260-
# Monkey patch to route run() through flashinfer.apply
261-
old_init, old_plan, old_run = (
262-
AttentionWrapper.__init__,
263-
AttentionWrapper.plan,
264-
AttentionWrapper.run,
265-
)
238+
from flashinfer_bench import apply
266239

267-
def new_init(self: AttentionWrapper, init_params):
268-
def fallback(init_params, plan_params, run_params):
269-
w = AttentionWrapper.__new__(AttentionWrapper)
270-
old_init(w, init_params)
271-
old_plan(w, plan_params)
272-
return old_run(w, run_params)
273-
self.init_params = init_params
274-
self._fallback = fallback
275-
276-
def new_plan(self: AttentionWrapper, plan_params):
277-
self.plan_params = plan_params
278-
279-
def new_run(self: AttentionWrapper, run_params):
280-
return flashinfer.apply(
281-
"attention", # or a lambda resolver if the def varies by shape
282-
self._fallback,
283-
self.init_params,
284-
self.plan_params,
285-
run_params,
286-
)
287-
288-
AttentionWrapper.__init__ = new_init
289-
AttentionWrapper.plan = new_plan
290-
AttentionWrapper.run = new_run
240+
result = apply(
241+
def_name_or_resolver: Union[str, Callable[..., str]],
242+
runtime_kwargs: Dict[str, Any], # All arguments must follow the **kernel definition's interface
243+
fallback: Optional[Callable[..., Any]] = None,
244+
)
291245
```
292246

293-
This preserves wrapper state while letting **apply** choose the best solution (and still trace workloads when enabled).
294-
295-
#### Alternative: avoid monkey-patching (shim inside the class)
296-
297-
If you can edit the wrapper, define a tiny adapter that flattens `(init, plan, run)` into the definition’s signature and call `flashinfer.apply(...)` directly inside `run()`. Same behavior, fewer moving parts.
298-
299-
#### Scope & tips
247+
**Parameters:**
248+
- `def_name_or_resolver`: The kernel definition name (e.g., `"gemm_bf16"`) or a resolver function that maps runtime arguments to a definition name.
249+
- `runtime_kwargs`: Dictionary of keyword arguments to pass to the selected kernel. Must match the kernel definition's interface.
250+
- `fallback`: Optional fallback function to invoke when no matching kernel is found in the Trace database.
300251

301-
* Make sure your adapter/fallback **matches the Definition I/O** exactly.
302-
* Group `init_params`, `plan_params`, and `run_params` so they cover the definition’s required tensors (e.g., `Q, K, V, page_size, page_indptr`).
303-
* When definitions vary by shape, pass a **`name` lambda** (e.g., derive hidden size from weights) to resolve the correct Definition at call time.
252+
#### Example: Creating custom adapters (advanced)
304253

305-
## Related customization you can enable
254+
If you want to create reusable adapters similar to the built-in FlashInfer integrations, study the real implementations:
255+
- `flashinfer_bench/integration/flashinfer/adapters/gqa_paged_decode.py`
256+
- `flashinfer_bench/integration/flashinfer/adapters/rmsnorm.py`
306257

307-
* **Apply/trace only selected kernels** via configs (context managers or code APIs), if you don’t want blanket substitution/tracing:
258+
Key pattern:
259+
1. Use `ContextStore` to preserve state across `plan()`/`run()` calls
260+
2. Extract parameters in the `plan` wrapper and store them in context
261+
3. In the `run` wrapper, retrieve stored params and call `apply()` with `runtime_kwargs`
262+
4. Provide a fallback lambda that calls the original implementation
263+
5. Register your adapter with the `PatchManager`
308264

265+
Example structure from the RMSNorm adapter:
309266
```python
310-
from flashinfer_bench import enable_apply, enable_tracing, ApplyConfig, TracingConfig
311-
312-
apply_cfgs = {
313-
"gemm_n_4096_k_14336": ApplyConfig(max_atol=1e-5, max_rtol=1e-5),
314-
"gqa_paged_decode_h32_kv4_d128_ps1": ApplyConfig(), # defaults OK
315-
}
316-
trace_cfgs = {
317-
"gqa_paged_decode_h32_kv4_d128_ps1": TracingConfig(
318-
tensor_dump_policy="dump_non_float",
319-
filter_policy="shape_only",
320-
),
321-
}
322-
323-
with enable_apply(apply_configs=apply_cfgs):
324-
with enable_tracing(tracing_configs=trace_cfgs):
325-
run_engine()
267+
from flashinfer_bench.apply import apply
268+
from flashinfer_bench.integration.patch_manager import PatchSpec
269+
from flashinfer_bench.integration.utils import ArgBinder
270+
271+
def _def_name_resolver(weight):
272+
return f"fused_add_rmsnorm_h{weight.shape[0]}"
273+
274+
class RMSNormAdapter:
275+
def targets(self):
276+
return [
277+
PatchSpec(
278+
path="flashinfer.norm.fused_add_rmsnorm",
279+
kind="function",
280+
name="fused_add_rmsnorm",
281+
ctx_key="rmsnorm",
282+
)
283+
]
284+
285+
def make_wrapper(self, spec, orig):
286+
binder = ArgBinder.from_callable(orig)
287+
288+
def wrapper(*args, **kwargs):
289+
bound = binder.bind(args, kwargs)
290+
291+
# Compatibility checks
292+
if bound["input"].dtype != torch.bfloat16:
293+
return orig(*args, **kwargs)
294+
295+
rk = {
296+
"hidden_states": bound["input"],
297+
"residual": bound["residual"],
298+
"weight": bound["weight"],
299+
}
300+
301+
return apply(_def_name_resolver, runtime_kwargs=rk, fallback=lambda **_: orig(*args, **kwargs))
302+
303+
return wrapper
326304
```
327-
328-
This limits substitution/tracing to kernels you care about and mirrors the env-var flow.

0 commit comments

Comments
 (0)