Skip to content

Commit fa236b6

Browse files
authored
refactor: Builder System (#120)
This PR aims to simplify the logic of all builders and provide a unified abstraction. It also provides a simpler abstraction for Runnable. It also: 1. Renames CUDABuilder to TorchBuilder and adds CPP support to it 2. Enhances the CI unit test to reduce execution time 3. Removes get_builder_registry and provides class-level singleton BuilderRegistry.get_instance() instead 4. Provides a hash function for Solution. 5. Provides utilities for tests Signed-off-by: Ubospica <[email protected]> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added TorchBuilder for compiling CUDA solutions via PyTorch extensions. * Introduced Solution hashing for deterministic caching and entry-point extraction utilities. * **Bug Fixes** * Improved runtime rank validation and error handling with better error messages. * Removed legacy CUDA builder; replaced with TorchBuilder. * **Improvements** * Simplified builder registry access pattern. * Enhanced metadata tracking and cleanup workflows. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Ubospica <[email protected]>
1 parent 2a7154c commit fa236b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2070
-1416
lines changed

examples/fi_gqa_e2e_example.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import flashinfer
2+
import torch
3+
14
from flashinfer_bench.apply import ApplyConfig, enable_apply
25
from flashinfer_bench.data import (
36
BuildSpec,
@@ -16,11 +19,11 @@
1619
Workload,
1720
)
1821

19-
ts = TraceSet.from_path("../flashinfer-trace") # path to your flashinfer-trace
22+
trace_set = TraceSet.from_path("../flashinfer-trace") # path to your flashinfer-trace
2023

2124
# Add a pseudo solution with an unrealistically large speedup so it gets selected
2225
def_name = "gqa_paged_prefill_causal_h32_kv8_d128_ps1"
23-
if def_name in ts.definitions:
26+
if def_name in trace_set.definitions:
2427
pseudo_code = (
2528
"import torch\n\n"
2629
"def run(q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale):\n"
@@ -36,7 +39,7 @@
3639
definition=def_name,
3740
author="pseudo",
3841
spec=BuildSpec(
39-
language=SupportedLanguages.PYTHON, target_hardware=["gpu"], entry_point="main.py::run"
42+
language=SupportedLanguages.PYTHON, target_hardware=["cuda"], entry_point="main.py::run"
4043
),
4144
sources=[SourceFile(path="main.py", content=pseudo_code)],
4245
description="Pseudo solution that prints greeting and returns zero tensors.",
@@ -77,17 +80,13 @@
7780
),
7881
)
7982

80-
ts.solutions.setdefault(def_name, []).append(pseudo_sol)
81-
ts.traces.setdefault(def_name, []).append(pseudo_trace)
83+
trace_set.solutions.setdefault(def_name, []).append(pseudo_sol)
84+
trace_set.traces.setdefault(def_name, []).append(pseudo_trace)
8285

8386
# Enable apply against the in-memory augmented trace set
84-
enable_apply(ts, ApplyConfig(on_miss_policy="use_def_best"))
85-
86-
import flashinfer
87+
enable_apply(trace_set, ApplyConfig(on_miss_policy="use_def_best"))
8788

8889
# FlashInfer official example
89-
import torch
90-
9190
num_layers = 32
9291
num_qo_heads = 32
9392
num_kv_heads = 8

examples/win_at_p.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
Notes:
2727
- If multiple runs exist for the same author within a group, we take the MIN latency for that author in that group.
2828
- By default, the baseline author ('flashinfer') is EXCLUDED from output curves; use --include-baseline to include it.
29-
3029
"""
3130

3231
import argparse

flashinfer_bench/apply/key.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class ApplyKey:
1313
axes: Tuple[Tuple[str, int], ...] = field(default_factory=tuple)
1414
# Features extracted from input tensors
15-
feats: Tuple[Tuple[str, Union[int, Union[float, bool]]], ...] = field(default_factory=tuple)
15+
feats: Tuple[Tuple[str, Union[int, float, bool]], ...] = field(default_factory=tuple)
1616

1717
def encode(self) -> str:
1818
return json.dumps(
@@ -39,10 +39,10 @@ def __eq__(self, other: Any) -> bool:
3939

4040

4141
class ApplyKeyBuilder(ABC):
42-
def __init__(self, defn: Definition) -> None:
43-
self.defn = defn
42+
def __init__(self, definition: Definition) -> None:
43+
self.definition = definition
4444
# axis -> (input, dim_idx)
45-
self._axis_proj: Dict[str, Tuple[str, int]] = self._collect_var_axis_projections(defn)
45+
self._axis_proj: Dict[str, Tuple[str, int]] = self._collect_var_axis_projections(definition)
4646

4747
@abstractmethod
4848
def build_from_runtime(self, runtime_kwargs: Dict[str, Any]) -> ApplyKey:
@@ -59,14 +59,14 @@ def features(self, runtime_kwargs: Dict[str, Any]) -> Tuple[Tuple[str, Any], ...
5959
"""Lightweight feature extraction"""
6060
...
6161

62-
def _collect_var_axis_projections(self, defn: Definition) -> Dict[str, Tuple[str, int]]:
62+
def _collect_var_axis_projections(self, definition: Definition) -> Dict[str, Tuple[str, int]]:
6363
"""
6464
Iterate over the shape of inputs, find the first occurrence of each var axis:
6565
axis_name -> (input_name, dim_idx)
6666
"""
6767
proj: Dict[str, Tuple[str, int]] = {}
68-
axis_defs = defn.axes
69-
inputs = defn.inputs
68+
axis_defs = definition.axes
69+
inputs = definition.inputs
7070

7171
for inp_name, spec in inputs.items():
7272
shape = spec.shape
@@ -82,15 +82,15 @@ def _collect_var_axis_projections(self, defn: Definition) -> Dict[str, Tuple[str
8282
var_axes = [k for k, v in axis_defs.items() if isinstance(v, AxisVar)]
8383
missing = [ax for ax in var_axes if ax not in proj]
8484
if missing:
85-
raise ValueError(f"Cannot locate var axes {missing} from inputs of '{defn.name}'")
85+
raise ValueError(f"Cannot locate var axes {missing} from inputs of '{definition.name}'")
8686
return proj
8787

8888
def _materialize_axes(self, runtime_kwargs: Dict[str, Any]) -> Dict[str, int]:
8989
axes: Dict[str, int] = {}
9090
for axis, (inp, dim_idx) in self._axis_proj.items():
9191
if inp not in runtime_kwargs:
9292
raise KeyError(
93-
f"Missing runtime input '{inp}' for axis '{axis}' in '{self.defn.name}'"
93+
f"Missing runtime input '{inp}' for axis '{axis}' in '{self.definition.name}'"
9494
)
9595
val = runtime_kwargs[inp]
9696
shape = val.shape
@@ -144,9 +144,9 @@ def for_type(cls, type_name: str) -> Type[ApplyKeyBuilder]:
144144
return cls._REGISTRY.get(type_name, AxesOnlyKeyBuilder)
145145

146146
@classmethod
147-
def specialize(cls, defn: Definition) -> ApplyKeyBuilder:
148-
builder_cls = cls.for_type(defn.op_type)
149-
return builder_cls(defn)
147+
def specialize(cls, definition: Definition) -> ApplyKeyBuilder:
148+
builder_cls = cls.for_type(definition.op_type)
149+
return builder_cls(definition)
150150

151151

152152
ApplyKeyFactory.register("gemm", GEMMKeyBuilder)

flashinfer_bench/apply/runtime.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from typing import Any, Callable, Dict, Mapping, Optional
1212

13-
from flashinfer_bench.compile import get_builder_registry
13+
from flashinfer_bench.compile import BuilderRegistry
1414
from flashinfer_bench.data import TraceSet
1515
from flashinfer_bench.env import get_fib_dataset_path, get_fib_enable_apply
1616
from flashinfer_bench.logging import get_logger
@@ -156,34 +156,34 @@ def dispatch(
156156
pass
157157

158158
# Then try to run apply logic
159-
defn = self._trace_set.definitions.get(def_name)
160-
if defn is None:
159+
definition = self._trace_set.definitions.get(def_name)
160+
if definition is None:
161161
if fallback is None:
162162
raise RuntimeError(f"Definition '{def_name}' not found and no fallback provided")
163163
return fallback(**runtime_kwargs)
164164

165165
# Build key
166-
builder = self._key_builders.get(defn.name)
166+
builder = self._key_builders.get(definition.name)
167167
if builder is None:
168-
builder = ApplyKeyFactory.specialize(defn)
169-
self._key_builders[defn.name] = builder
168+
builder = ApplyKeyFactory.specialize(definition)
169+
self._key_builders[definition.name] = builder
170170
key = builder.build_from_runtime(runtime_kwargs)
171171

172172
# Lookup solution
173173
sol_name = self._table.match_solution(def_name, key)
174174
runnable = None
175175
if sol_name:
176-
sol = self._trace_set.get_solution(sol_name)
177-
if sol:
178-
runnable = get_builder_registry().build(defn, sol)
176+
solution = self._trace_set.get_solution(sol_name)
177+
if solution:
178+
runnable = BuilderRegistry.get_instance().build(definition, solution)
179179

180180
# Miss policy
181181
if runnable is None:
182182
if self._apply_config.on_miss_policy == "use_def_best":
183183
best_sol_name = self._table.def_best.get(def_name)
184-
sol = self._trace_set.get_solution(best_sol_name)
185-
if defn and sol:
186-
runnable = get_builder_registry().build(defn, sol)
184+
solution = self._trace_set.get_solution(best_sol_name)
185+
if definition and solution:
186+
runnable = BuilderRegistry.get_instance().build(definition, solution)
187187
if runnable is not None:
188188
return runnable(**runtime_kwargs)
189189
if fallback is None:

flashinfer_bench/apply/table.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from typing import Any, Dict, List, Optional, Tuple
99

10-
from flashinfer_bench.compile import Runnable, get_builder_registry
10+
from flashinfer_bench.compile import BuilderRegistry, Runnable
1111
from flashinfer_bench.data import Trace, TraceSet
1212
from flashinfer_bench.env import get_fib_cache_path
1313

@@ -115,13 +115,13 @@ def load_or_build(cls, trace_set: TraceSet, apply_config: ApplyConfig) -> "Apply
115115
index[def_name] = bucket
116116

117117
def_best: Dict[str, str] = {}
118-
reg = get_builder_registry()
118+
reg = BuilderRegistry.get_instance()
119119

120120
for def_name, sol_name in raw["def_best"].items():
121-
defn = trace_set.definitions.get(def_name)
122-
sol = trace_set.get_solution(sol_name)
123-
if defn and sol:
124-
reg.build(defn, sol)
121+
definition = trace_set.definitions.get(def_name)
122+
solution = trace_set.get_solution(sol_name)
123+
if definition and solution:
124+
reg.build(definition, solution)
125125
def_best[def_name] = sol_name
126126

127127
table = cls(digest=digest, index=index, def_best=def_best)
@@ -169,12 +169,12 @@ def _build(cls, trace_set: TraceSet, apply_config: ApplyConfig) -> "ApplyTable":
169169
The newly built apply table.
170170
"""
171171
digest = cls._digest(trace_set, apply_config)
172-
reg = get_builder_registry()
172+
reg = BuilderRegistry.get_instance()
173173

174174
index: Dict[str, Dict[ApplyKey, str]] = {}
175175
def_best: Dict[str, Runnable] = {}
176176

177-
for def_name, defn in trace_set.definitions.items():
177+
for def_name, definition in trace_set.definitions.items():
178178
per_key, ranked = cls._sweep_def(
179179
trace_set, def_name, apply_config.max_atol, apply_config.max_rtol
180180
)
@@ -189,11 +189,11 @@ def _build(cls, trace_set: TraceSet, apply_config: ApplyConfig) -> "ApplyTable":
189189
# Build def_best
190190
if ranked:
191191
best_sol_name = ranked[0][0]
192-
sol = trace_set.get_solution(best_sol_name)
193-
if sol:
192+
solution = trace_set.get_solution(best_sol_name)
193+
if solution:
194194
if apply_config.on_miss_policy == "use_def_best":
195195
# Only AOT if on_miss_policy is use_def_best
196-
reg.build(defn, sol)
196+
reg.build(definition, solution)
197197
def_best[def_name] = best_sol_name
198198

199199
return cls(digest=digest, index=index, def_best=def_best)
@@ -267,7 +267,7 @@ def _prewarm_aot(cls, trace_set: TraceSet, config: ApplyConfig, table: "ApplyTab
267267
"""
268268
if not (config.aot_ratio and config.aot_ratio > 0.0):
269269
return
270-
reg = get_builder_registry()
270+
reg = BuilderRegistry.get_instance()
271271

272272
for def_name, bucket in table.index.items():
273273
if not bucket:
@@ -277,20 +277,20 @@ def _prewarm_aot(cls, trace_set: TraceSet, config: ApplyConfig, table: "ApplyTab
277277
ranked = sorted(win_counts.items(), key=lambda kv: kv[1], reverse=True)
278278
cutoff = max(1, int(len(ranked) * config.aot_ratio))
279279

280-
defn = trace_set.definitions.get(def_name)
281-
if not defn:
280+
definition = trace_set.definitions.get(def_name)
281+
if not definition:
282282
continue
283283
for sol_name, _ in ranked[:cutoff]:
284-
sol = trace_set.get_solution(sol_name)
285-
if sol:
286-
reg.build(defn, sol)
284+
solution = trace_set.get_solution(sol_name)
285+
if solution:
286+
reg.build(definition, solution)
287287

288288
if config.on_miss_policy == "use_def_best":
289289
for def_name, sol_name in table.def_best.items():
290-
defn = trace_set.definitions.get(def_name)
291-
sol = trace_set.get_solution(sol_name)
292-
if defn and sol:
293-
reg.build(defn, sol)
290+
definition = trace_set.definitions.get(def_name)
291+
solution = trace_set.get_solution(sol_name)
292+
if definition and solution:
293+
reg.build(definition, solution)
294294

295295
@classmethod
296296
def _digest(cls, trace_set: TraceSet, config: ApplyConfig) -> str:
@@ -313,23 +313,23 @@ def _digest(cls, trace_set: TraceSet, config: ApplyConfig) -> str:
313313
SHA256 hash digest as a hexadecimal string.
314314
"""
315315
d = trace_set.to_dict()
316-
for defn in d["definitions"].values():
316+
for definition in d["definitions"].values():
317317
for drop in ("description", "tags", "reference", "constraints"):
318-
defn.pop(drop, None)
318+
definition.pop(drop, None)
319319
for sol_list in d["solutions"].values():
320-
for sol in sol_list:
321-
spec = sol.get("spec", {}) or {}
320+
for solution in sol_list:
321+
spec = solution.get("spec", {}) or {}
322322
deps = spec.get("dependencies") or []
323323
spec["dependencies"] = sorted(deps)
324324
new_sources = []
325-
for sf in sol.get("sources") or []:
325+
for sf in solution.get("sources") or []:
326326
new_sources.append(
327327
{
328328
"path": sf["path"],
329329
"sha1": hashlib.sha1(sf["content"].encode("utf-8")).hexdigest(),
330330
}
331331
)
332-
sol["sources"] = new_sources
332+
solution["sources"] = new_sources
333333
kept_traces: List[Dict[str, Any]] = []
334334
for traces in d["traces"].values():
335335
for trace in traces:

flashinfer_bench/bench/benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from typing import List
55

6-
from flashinfer_bench.compile import get_builder_registry
6+
from flashinfer_bench.compile import BuilderRegistry
77
from flashinfer_bench.data import EvaluationStatus, Trace, TraceSet
88
from flashinfer_bench.logging import get_logger
99

@@ -46,7 +46,7 @@ def __init__(self, trace_set: TraceSet, config: BenchmarkConfig = None) -> None:
4646
self._runner = PersistentRunner(logger, self._config.log_dir)
4747

4848
# Setup registry
49-
self._registry = get_builder_registry()
49+
self._registry = BuilderRegistry.get_instance()
5050

5151
def get_trace_set(self) -> TraceSet:
5252
"""Get the TraceSet associated with this benchmark.
@@ -79,8 +79,8 @@ def run_all(self, dump_traces: bool = True, resume: bool = False) -> TraceSet:
7979
definitions_to_run = self._trace_set.definitions.items()
8080
if self._config.definitions is not None:
8181
definitions_to_run = [
82-
(name, defn)
83-
for name, defn in definitions_to_run
82+
(name, definition)
83+
for name, definition in definitions_to_run
8484
if name in self._config.definitions
8585
]
8686
provided_defs = set(self._config.definitions)
@@ -89,7 +89,7 @@ def run_all(self, dump_traces: bool = True, resume: bool = False) -> TraceSet:
8989
if missing_defs:
9090
logger.warning(f"Definitions not found in trace set: {sorted(missing_defs)}")
9191

92-
for def_name, defn in definitions_to_run:
92+
for def_name, definition in definitions_to_run:
9393
sols = self._trace_set.solutions.get(def_name, [])
9494
if not sols:
9595
logger.warning(f"No solutions found for def={def_name}, skipping definition")
@@ -128,7 +128,7 @@ def run_all(self, dump_traces: bool = True, resume: bool = False) -> TraceSet:
128128

129129
try:
130130
results = self._runner.run_workload(
131-
defn, wl, sols_to_run, self._config, self._trace_set.root
131+
definition, wl, sols_to_run, self._config, self._trace_set.root
132132
)
133133
except RuntimeError as e:
134134
logger.error(f"Failed to run workload {wl.uuid}: {e}")

0 commit comments

Comments
 (0)