From b3d8e214658a644c9fb80eee348e040856972bb7 Mon Sep 17 00:00:00 2001 From: YiyanZhai Date: Mon, 1 Dec 2025 13:45:21 -0500 Subject: [PATCH 1/4] reduce apply overhead --- flashinfer_bench/apply/runtime.py | 80 ++++++++++++++-------------- flashinfer_bench/compile/builder.py | 11 +--- flashinfer_bench/compile/registry.py | 9 +++- flashinfer_bench/data/trace_set.py | 13 ++--- 4 files changed, 54 insertions(+), 59 deletions(-) diff --git a/flashinfer_bench/apply/runtime.py b/flashinfer_bench/apply/runtime.py index a027dd6d..38963acb 100644 --- a/flashinfer_bench/apply/runtime.py +++ b/flashinfer_bench/apply/runtime.py @@ -23,46 +23,6 @@ logger = get_logger("ApplyRuntime") -def _init_apply_runtime_from_env() -> Optional["ApplyRuntime"]: - """Initialize the global runtime from environment variables if configured.""" - fib_enable_apply = get_fib_enable_apply() - if not fib_enable_apply: - return - fib_dataset_path = get_fib_dataset_path() - trace_set = TraceSet.from_path(fib_dataset_path) - apply_config = ApplyConfig() - return ApplyRuntime(trace_set, apply_config, None) - - -_global_apply_runtime: Optional["ApplyRuntime"] = _init_apply_runtime_from_env() - - -def get_apply_runtime() -> Optional["ApplyRuntime"]: - """Get the global ApplyRuntime instance. - - Returns the singleton runtime instance, initializing it from environment - variables if it hasn't been created yet. - - Returns - ------- - Optional[ApplyRuntime] - The global runtime instance, or None if not initialized. - """ - return _global_apply_runtime - - -def set_apply_runtime(rt: Optional["ApplyRuntime"]) -> None: - """Set the global ApplyRuntime instance. - - Parameters - ---------- - rt : Optional[ApplyRuntime] - The runtime instance to set, or None to clear the global runtime. - """ - global _global_apply_runtime - _global_apply_runtime = rt - - class ApplyRuntime: """Runtime system for dispatching optimized implementations based on trace data. @@ -198,3 +158,43 @@ def __enter__(self) -> None: def __exit__(self, exc_type, exc, tb) -> bool: set_apply_runtime(self._prev_runtime) return False + + +def _init_apply_runtime_from_env() -> Optional["ApplyRuntime"]: + """Initialize the global runtime from environment variables if configured.""" + fib_enable_apply = get_fib_enable_apply() + if not fib_enable_apply: + return + fib_dataset_path = get_fib_dataset_path() + trace_set = TraceSet.from_path(fib_dataset_path) + apply_config = ApplyConfig() + return ApplyRuntime(trace_set, apply_config, None) + + +_global_apply_runtime: Optional["ApplyRuntime"] = _init_apply_runtime_from_env() + + +def get_apply_runtime() -> Optional["ApplyRuntime"]: + """Get the global ApplyRuntime instance. + + Returns the singleton runtime instance, initializing it from environment + variables if it hasn't been created yet. + + Returns + ------- + Optional[ApplyRuntime] + The global runtime instance, or None if not initialized. + """ + return _global_apply_runtime + + +def set_apply_runtime(rt: Optional["ApplyRuntime"]) -> None: + """Set the global ApplyRuntime instance. + + Parameters + ---------- + rt : Optional[ApplyRuntime] + The runtime instance to set, or None to clear the global runtime. + """ + global _global_apply_runtime + _global_apply_runtime = rt diff --git a/flashinfer_bench/compile/builder.py b/flashinfer_bench/compile/builder.py index 5f4ebc40..fd956a42 100644 --- a/flashinfer_bench/compile/builder.py +++ b/flashinfer_bench/compile/builder.py @@ -1,6 +1,5 @@ from __future__ import annotations -import hashlib import os import re import tempfile @@ -31,18 +30,10 @@ def write_sources_to_temp(base: str, sources: list[SourceFile], pkg: Optional[st def create_pkg_name(sol: Solution, prefix: str = "") -> str: - # Normalize the solution name s = re.sub(r"[^0-9a-zA-Z_]", "_", sol.name) if not s or s[0].isdigit(): s = "_" + s - - # Hash the sources - h = hashlib.sha1() - for src in sol.sources: - h.update(src.path.encode()) - h.update(src.content.encode()) - - return prefix + s + "_" + h.hexdigest()[:6] + return prefix + s class BuildError(RuntimeError): diff --git a/flashinfer_bench/compile/registry.py b/flashinfer_bench/compile/registry.py index 0e108e77..bbd22c99 100644 --- a/flashinfer_bench/compile/registry.py +++ b/flashinfer_bench/compile/registry.py @@ -15,6 +15,8 @@ def __init__(self, builders: Tuple[Builder, ...]) -> None: if not builders: raise ValueError("BuilderRegistry requires at least one builder") self._builders: Tuple[Builder, ...] = builders + # Cache: solution_name -> builder to avoid repeated can_build checks + self._solution_to_builder: dict[str, Builder] = {} def clear(self) -> None: for b in self._builders: @@ -22,11 +24,16 @@ def clear(self) -> None: b.clear_cache() except Exception: pass + self._solution_to_builder.clear() def build(self, defn: Definition, sol: Solution) -> Runnable: + builder = self._solution_to_builder.get(sol.name) + if builder is not None: + return builder.build(defn, sol) + for builder in self._builders: - # Choose the first if builder.can_build(sol): + self._solution_to_builder[sol.name] = builder return builder.build(defn, sol) raise BuildError(f"No registered builder can build solution '{sol.name}'") diff --git a/flashinfer_bench/data/trace_set.py b/flashinfer_bench/data/trace_set.py index 0da875c5..ee91da53 100644 --- a/flashinfer_bench/data/trace_set.py +++ b/flashinfer_bench/data/trace_set.py @@ -44,6 +44,8 @@ class TraceSet: definition.""" traces: Dict[str, List[Trace]] = field(default_factory=dict) """The traces in the database. Map from definition name to all traces for that definition.""" + _solution_by_name: Dict[str, Solution] = field(default_factory=dict, init=False, repr=False) + """Fast lookup index: solution name -> Solution object. Automatically maintained.""" @property def definitions_path(self) -> Path: @@ -131,6 +133,7 @@ def from_path(cls: type["TraceSet"], path: Optional[str] = None) -> "TraceSet": raise ValueError(f"Duplicate solution name: {s.name}") seen_solutions.add(s.name) trace_set.solutions.setdefault(s.definition, []).append(s) + trace_set._solution_by_name[s.name] = s for p in sorted((trace_set.workloads_path.rglob("*.jsonl"))): for t in load_jsonl_file(Trace, p): @@ -174,9 +177,7 @@ def to_dict(self) -> Dict[str, Any]: def get_solution(self, name: str) -> Optional[Solution]: """Get a solution by name from all loaded solutions. - Searches across all solutions in the TraceSet to find one with the specified name. - Since solution names are unique across the entire dataset, this returns at most - one solution. + Uses an O(1) index lookup for fast retrieval. Parameters ---------- @@ -188,11 +189,7 @@ def get_solution(self, name: str) -> Optional[Solution]: Optional[Solution] The solution with the given name, or None if not found. """ - for solution_list in self.solutions.values(): - for solution in solution_list: - if solution.name == name: - return solution - return None + return self._solution_by_name.get(name) def filter_traces(self, def_name: str, atol: float = 1e-2, rtol: float = 1e-2) -> List[Trace]: """Filter traces for a definition based on error bounds. From 1c0d75b58cf601bb36222efa2dfedfee40329cf8 Mon Sep 17 00:00:00 2001 From: YiyanZhai Date: Mon, 1 Dec 2025 14:04:31 -0500 Subject: [PATCH 2/4] debug --- flashinfer_bench/apply/runtime.py | 11 ----------- flashinfer_bench/compile/builders/cuda_builder.py | 2 +- flashinfer_bench/data/trace_set.py | 9 ++++++++- pyproject.toml | 2 +- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/flashinfer_bench/apply/runtime.py b/flashinfer_bench/apply/runtime.py index 38963acb..4d2f5d8d 100644 --- a/flashinfer_bench/apply/runtime.py +++ b/flashinfer_bench/apply/runtime.py @@ -14,7 +14,6 @@ from flashinfer_bench.data import TraceSet from flashinfer_bench.env import get_fib_dataset_path, get_fib_enable_apply from flashinfer_bench.logging import get_logger -from flashinfer_bench.tracing import get_tracing_runtime from .config import ApplyConfig from .key import ApplyKeyBuilder, ApplyKeyFactory @@ -106,16 +105,6 @@ def dispatch( If the definition is not found and no fallback is provided, or if no suitable implementation is available and no fallback is provided. """ - # First try to run tracing logic in case tracing is enabled - tracing_runtime = get_tracing_runtime() - if tracing_runtime is not None: - try: - tracing_runtime.collect(def_name, runtime_kwargs) - except Exception as e: - logger.error(f"Error collecting trace for {def_name}: {e}") - pass - - # Then try to run apply logic defn = self._trace_set.definitions.get(def_name) if defn is None: if fallback is None: diff --git a/flashinfer_bench/compile/builders/cuda_builder.py b/flashinfer_bench/compile/builders/cuda_builder.py index 84bc0dab..0189dc5a 100644 --- a/flashinfer_bench/compile/builders/cuda_builder.py +++ b/flashinfer_bench/compile/builders/cuda_builder.py @@ -70,7 +70,7 @@ def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None): CUDA_DEPS = { "cublas": ("nvidia.cublas", ["cublas", "cublasLt"]), "cudnn": ("nvidia.cudnn", ["cudnn"]), - "cutlass": ("flashinfer_bench._deps.cutlass", None), # Header-only + "cutlass": ("flashinfer_bench.thirdparty.cutlass", None), # Header-only } diff --git a/flashinfer_bench/data/trace_set.py b/flashinfer_bench/data/trace_set.py index ee91da53..af78679f 100644 --- a/flashinfer_bench/data/trace_set.py +++ b/flashinfer_bench/data/trace_set.py @@ -47,6 +47,12 @@ class TraceSet: _solution_by_name: Dict[str, Solution] = field(default_factory=dict, init=False, repr=False) """Fast lookup index: solution name -> Solution object. Automatically maintained.""" + def __post_init__(self): + """Initialize the _solution_by_name index from existing solutions.""" + for solutions_list in self.solutions.values(): + for solution in solutions_list: + self._solution_by_name[solution.name] = solution + @property def definitions_path(self) -> Path: if self.root is None: @@ -177,7 +183,8 @@ def to_dict(self) -> Dict[str, Any]: def get_solution(self, name: str) -> Optional[Solution]: """Get a solution by name from all loaded solutions. - Uses an O(1) index lookup for fast retrieval. + Uses an O(1) index lookup for fast retrieval. Since solution names are unique + across the entire dataset, this returns at most one solution. Parameters ---------- diff --git a/pyproject.toml b/pyproject.toml index b68fe093..88c2257e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ where = ["."] include = ["flashinfer_bench*"] [tool.setuptools.package-data] "flashinfer_bench" = ["py.typed"] -"flashinfer_bench._deps.cutlass" = ["include/**"] +"flashinfer_bench.thirdparty.cutlass" = ["include/**"] [tool.black] line-length = 100 From d100efa2050d631b18a0789b0082504497bf3962 Mon Sep 17 00:00:00 2001 From: Yiyan Zhai <98248913+YiyanZhai@users.noreply.github.com> Date: Mon, 1 Dec 2025 14:23:39 -0500 Subject: [PATCH 3/4] Update flashinfer_bench/data/trace_set.py a check for duplicates is added Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- flashinfer_bench/data/trace_set.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flashinfer_bench/data/trace_set.py b/flashinfer_bench/data/trace_set.py index af78679f..21759da7 100644 --- a/flashinfer_bench/data/trace_set.py +++ b/flashinfer_bench/data/trace_set.py @@ -51,6 +51,8 @@ def __post_init__(self): """Initialize the _solution_by_name index from existing solutions.""" for solutions_list in self.solutions.values(): for solution in solutions_list: + if solution.name in self._solution_by_name: + raise ValueError(f"Duplicate solution name found: {solution.name}") self._solution_by_name[solution.name] = solution @property From e5cf5216a78be80e3a8c6979116b10c5695e872d Mon Sep 17 00:00:00 2001 From: YiyanZhai Date: Mon, 1 Dec 2025 14:24:53 -0500 Subject: [PATCH 4/4] linting --- flashinfer_bench/compile/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer_bench/compile/registry.py b/flashinfer_bench/compile/registry.py index bbd22c99..d2fb4d7a 100644 --- a/flashinfer_bench/compile/registry.py +++ b/flashinfer_bench/compile/registry.py @@ -30,7 +30,7 @@ def build(self, defn: Definition, sol: Solution) -> Runnable: builder = self._solution_to_builder.get(sol.name) if builder is not None: return builder.build(defn, sol) - + for builder in self._builders: if builder.can_build(sol): self._solution_to_builder[sol.name] = builder