Skip to content

Commit d11e917

Browse files
committed
finish
Signed-off-by: Ubospica <[email protected]>
1 parent 2a7154c commit d11e917

29 files changed

+1135
-754
lines changed

examples/win_at_p.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
Notes:
2727
- If multiple runs exist for the same author within a group, we take the MIN latency for that author in that group.
2828
- By default, the baseline author ('flashinfer') is EXCLUDED from output curves; use --include-baseline to include it.
29-
3029
"""
3130

3231
import argparse

flashinfer_bench/apply/key.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class ApplyKey:
1313
axes: Tuple[Tuple[str, int], ...] = field(default_factory=tuple)
1414
# Features extracted from input tensors
15-
feats: Tuple[Tuple[str, Union[int, Union[float, bool]]], ...] = field(default_factory=tuple)
15+
feats: Tuple[Tuple[str, Union[int, float, bool]], ...] = field(default_factory=tuple)
1616

1717
def encode(self) -> str:
1818
return json.dumps(

flashinfer_bench/apply/runtime.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from typing import Any, Callable, Dict, Mapping, Optional
1212

13-
from flashinfer_bench.compile import get_builder_registry
13+
from flashinfer_bench.compile import BuilderRegistry
1414
from flashinfer_bench.data import TraceSet
1515
from flashinfer_bench.env import get_fib_dataset_path, get_fib_enable_apply
1616
from flashinfer_bench.logging import get_logger
@@ -175,15 +175,15 @@ def dispatch(
175175
if sol_name:
176176
sol = self._trace_set.get_solution(sol_name)
177177
if sol:
178-
runnable = get_builder_registry().build(defn, sol)
178+
runnable = BuilderRegistry.get_instance().build(defn, sol)
179179

180180
# Miss policy
181181
if runnable is None:
182182
if self._apply_config.on_miss_policy == "use_def_best":
183183
best_sol_name = self._table.def_best.get(def_name)
184184
sol = self._trace_set.get_solution(best_sol_name)
185185
if defn and sol:
186-
runnable = get_builder_registry().build(defn, sol)
186+
runnable = BuilderRegistry.get_instance().build(defn, sol)
187187
if runnable is not None:
188188
return runnable(**runtime_kwargs)
189189
if fallback is None:

flashinfer_bench/apply/table.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from typing import Any, Dict, List, Optional, Tuple
99

10-
from flashinfer_bench.compile import Runnable, get_builder_registry
10+
from flashinfer_bench.compile import BuilderRegistry, Runnable
1111
from flashinfer_bench.data import Trace, TraceSet
1212
from flashinfer_bench.env import get_fib_cache_path
1313

@@ -115,7 +115,7 @@ def load_or_build(cls, trace_set: TraceSet, apply_config: ApplyConfig) -> "Apply
115115
index[def_name] = bucket
116116

117117
def_best: Dict[str, str] = {}
118-
reg = get_builder_registry()
118+
reg = BuilderRegistry.get_instance()
119119

120120
for def_name, sol_name in raw["def_best"].items():
121121
defn = trace_set.definitions.get(def_name)
@@ -169,7 +169,7 @@ def _build(cls, trace_set: TraceSet, apply_config: ApplyConfig) -> "ApplyTable":
169169
The newly built apply table.
170170
"""
171171
digest = cls._digest(trace_set, apply_config)
172-
reg = get_builder_registry()
172+
reg = BuilderRegistry.get_instance()
173173

174174
index: Dict[str, Dict[ApplyKey, str]] = {}
175175
def_best: Dict[str, Runnable] = {}
@@ -267,7 +267,7 @@ def _prewarm_aot(cls, trace_set: TraceSet, config: ApplyConfig, table: "ApplyTab
267267
"""
268268
if not (config.aot_ratio and config.aot_ratio > 0.0):
269269
return
270-
reg = get_builder_registry()
270+
reg = BuilderRegistry.get_instance()
271271

272272
for def_name, bucket in table.index.items():
273273
if not bucket:

flashinfer_bench/bench/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from typing import List
55

6-
from flashinfer_bench.compile import get_builder_registry
6+
from flashinfer_bench.compile import BuilderRegistry
77
from flashinfer_bench.data import EvaluationStatus, Trace, TraceSet
88
from flashinfer_bench.logging import get_logger
99

@@ -46,7 +46,7 @@ def __init__(self, trace_set: TraceSet, config: BenchmarkConfig = None) -> None:
4646
self._runner = PersistentRunner(logger, self._config.log_dir)
4747

4848
# Setup registry
49-
self._registry = get_builder_registry()
49+
self._registry = BuilderRegistry.get_instance()
5050

5151
def get_trace_set(self) -> TraceSet:
5252
"""Get the TraceSet associated with this benchmark.

flashinfer_bench/bench/evaluators/default.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
make_eval,
1717
normalize_outputs,
1818
)
19-
from flashinfer_bench.compile.registry import get_builder_registry
20-
from flashinfer_bench.compile.runnable import Runnable
19+
from flashinfer_bench.compile import BuilderRegistry, Runnable
2120
from flashinfer_bench.data.definition import Definition
2221
from flashinfer_bench.data.trace import (
2322
Correctness,
@@ -46,7 +45,7 @@ def build_baseline(
4645
traceset_root: Optional[Path] = None,
4746
) -> DeviceBaseline:
4847
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
49-
ref_runnable = get_builder_registry().build_reference(defn)
48+
ref_runnable = BuilderRegistry.get_instance().build_reference(defn)
5049
loaded_stensors = (
5150
load_safetensors(defn, workload, traceset_root)
5251
if any(d.type == "safetensors" for d in workload.inputs.values())

flashinfer_bench/bench/evaluators/sampling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
make_eval,
1919
normalize_outputs,
2020
)
21-
from flashinfer_bench.compile.registry import get_builder_registry
22-
from flashinfer_bench.compile.runnable import Runnable
21+
from flashinfer_bench.compile import BuilderRegistry, Runnable
2322
from flashinfer_bench.data.definition import Definition
2423
from flashinfer_bench.data.trace import Correctness, Evaluation, EvaluationStatus, Workload
2524
from flashinfer_bench.utils import dtype_str_to_torch_dtype
@@ -42,7 +41,7 @@ def build_baseline(
4241
device: str,
4342
traceset_root: Optional[Path] = None,
4443
) -> DeviceBaseline:
45-
ref_runnable = get_builder_registry().build_reference(defn)
44+
ref_runnable = BuilderRegistry.get_instance().build_reference(defn)
4645
loaded_stensors = (
4746
load_safetensors(defn, workload, traceset_root)
4847
if any(d.type == "safetensors" for d in workload.inputs.values())

flashinfer_bench/bench/runner/isolated_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from flashinfer_bench.bench.config import BenchmarkConfig
1515
from flashinfer_bench.bench.evaluators import resolve_evaluator
1616
from flashinfer_bench.bench.utils import make_eval
17-
from flashinfer_bench.compile import Runnable, get_builder_registry
17+
from flashinfer_bench.compile import BuilderRegistry, Runnable
1818
from flashinfer_bench.data import Definition, Evaluation, EvaluationStatus, Solution, Workload
1919
from flashinfer_bench.logging import get_logger
2020
from flashinfer_bench.utils import redirect_stdio_to_file
@@ -230,7 +230,7 @@ def _solution_worker_main(
230230
redirect_stdio_to_file(log_path)
231231
try:
232232
torch.cuda.set_device(int(device.split(":")[1]))
233-
registry = get_builder_registry()
233+
registry = BuilderRegistry.get_instance()
234234

235235
# Handshake
236236
conn.send({"cmd": "READY"})

flashinfer_bench/bench/runner/persistent_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from flashinfer_bench.bench.config import BenchmarkConfig
1717
from flashinfer_bench.bench.evaluators import resolve_evaluator
1818
from flashinfer_bench.bench.utils import make_eval
19-
from flashinfer_bench.compile import BuildError, get_builder_registry
19+
from flashinfer_bench.compile import BuilderRegistry, BuildError
2020
from flashinfer_bench.data import Definition, Evaluation, EvaluationStatus, Solution, Workload
2121
from flashinfer_bench.logging import get_logger
2222
from flashinfer_bench.utils import redirect_stdio_to_file
@@ -65,7 +65,7 @@ def __init__(self, device: str, log_dir: str = "/tmp/flashinfer_bench") -> None:
6565
self._device = device
6666
self._log_dir = log_dir
6767
self._baselines: Dict[BaselineHandle, DeviceBaseline] = {}
68-
self._registry = get_builder_registry()
68+
self._registry = BuilderRegistry.get_instance()
6969

7070
# Solution failure tracking
7171
self._failure_records: Dict[str, SolutionFailureRecord] = {}
@@ -648,7 +648,7 @@ def _persistent_worker_main(conn: mp.connection.Connection, device: str, log_dir
648648
"""
649649
try:
650650
torch.cuda.set_device(int(device.split(":")[1]))
651-
registry = get_builder_registry()
651+
registry = BuilderRegistry.get_instance()
652652

653653
conn.send({"cmd": WorkerResponse.READY.value})
654654

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
"""Compiler subsystem package.
22
3-
Exports common builder types for convenience.
3+
This package provides the infrastructure for building solutions into executable runnables.
4+
It includes:
5+
- Builder: Abstract base class for different language/build system implementations
6+
- BuilderRegistry: Central registry for managing and dispatching builders
7+
- Runnable: Executable wrapper around compiled solutions
8+
- RunnableMetadata: Metadata about build process and source
9+
10+
The typical workflow is:
11+
1. Get the singleton registry: registry = BuilderRegistry.get_instance()
12+
2. Build a solution: runnable = registry.build(definition, solution)
13+
3. Execute: result = runnable(**inputs)
414
"""
515

616
from .builder import Builder, BuildError
7-
from .registry import BuilderRegistry, get_builder_registry
8-
from .runnable import Runnable
17+
from .registry import BuilderRegistry
18+
from .runnable import Runnable, RunnableMetadata
919

10-
__all__ = ["Builder", "BuildError", "BuilderRegistry", "Runnable", "get_builder_registry"]
20+
__all__ = ["Builder", "BuildError", "BuilderRegistry", "Runnable", "RunnableMetadata"]

0 commit comments

Comments
 (0)