Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 40 additions & 51 deletions flashinfer_bench/apply/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from flashinfer_bench.data import TraceSet
from flashinfer_bench.env import get_fib_dataset_path, get_fib_enable_apply
from flashinfer_bench.logging import get_logger
from flashinfer_bench.tracing import get_tracing_runtime

from .config import ApplyConfig
from .key import ApplyKeyBuilder, ApplyKeyFactory
Expand All @@ -23,46 +22,6 @@
logger = get_logger("ApplyRuntime")


def _init_apply_runtime_from_env() -> Optional["ApplyRuntime"]:
"""Initialize the global runtime from environment variables if configured."""
fib_enable_apply = get_fib_enable_apply()
if not fib_enable_apply:
return
fib_dataset_path = get_fib_dataset_path()
trace_set = TraceSet.from_path(fib_dataset_path)
apply_config = ApplyConfig()
return ApplyRuntime(trace_set, apply_config, None)


_global_apply_runtime: Optional["ApplyRuntime"] = _init_apply_runtime_from_env()


def get_apply_runtime() -> Optional["ApplyRuntime"]:
"""Get the global ApplyRuntime instance.

Returns the singleton runtime instance, initializing it from environment
variables if it hasn't been created yet.

Returns
-------
Optional[ApplyRuntime]
The global runtime instance, or None if not initialized.
"""
return _global_apply_runtime


def set_apply_runtime(rt: Optional["ApplyRuntime"]) -> None:
"""Set the global ApplyRuntime instance.

Parameters
----------
rt : Optional[ApplyRuntime]
The runtime instance to set, or None to clear the global runtime.
"""
global _global_apply_runtime
_global_apply_runtime = rt


class ApplyRuntime:
"""Runtime system for dispatching optimized implementations based on trace data.

Expand Down Expand Up @@ -146,16 +105,6 @@ def dispatch(
If the definition is not found and no fallback is provided, or if
no suitable implementation is available and no fallback is provided.
"""
# First try to run tracing logic in case tracing is enabled
tracing_runtime = get_tracing_runtime()
if tracing_runtime is not None:
try:
tracing_runtime.collect(def_name, runtime_kwargs)
except Exception as e:
logger.error(f"Error collecting trace for {def_name}: {e}")
pass

# Then try to run apply logic
defn = self._trace_set.definitions.get(def_name)
if defn is None:
if fallback is None:
Expand Down Expand Up @@ -198,3 +147,43 @@ def __enter__(self) -> None:
def __exit__(self, exc_type, exc, tb) -> bool:
set_apply_runtime(self._prev_runtime)
return False


def _init_apply_runtime_from_env() -> Optional["ApplyRuntime"]:
"""Initialize the global runtime from environment variables if configured."""
fib_enable_apply = get_fib_enable_apply()
if not fib_enable_apply:
return
fib_dataset_path = get_fib_dataset_path()
trace_set = TraceSet.from_path(fib_dataset_path)
apply_config = ApplyConfig()
return ApplyRuntime(trace_set, apply_config, None)


_global_apply_runtime: Optional["ApplyRuntime"] = _init_apply_runtime_from_env()


def get_apply_runtime() -> Optional["ApplyRuntime"]:
"""Get the global ApplyRuntime instance.

Returns the singleton runtime instance, initializing it from environment
variables if it hasn't been created yet.

Returns
-------
Optional[ApplyRuntime]
The global runtime instance, or None if not initialized.
"""
return _global_apply_runtime


def set_apply_runtime(rt: Optional["ApplyRuntime"]) -> None:
"""Set the global ApplyRuntime instance.

Parameters
----------
rt : Optional[ApplyRuntime]
The runtime instance to set, or None to clear the global runtime.
"""
global _global_apply_runtime
_global_apply_runtime = rt
11 changes: 1 addition & 10 deletions flashinfer_bench/compile/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import hashlib
import os
import re
import tempfile
Expand Down Expand Up @@ -31,18 +30,10 @@ def write_sources_to_temp(base: str, sources: list[SourceFile], pkg: Optional[st


def create_pkg_name(sol: Solution, prefix: str = "") -> str:
# Normalize the solution name
s = re.sub(r"[^0-9a-zA-Z_]", "_", sol.name)
if not s or s[0].isdigit():
s = "_" + s

# Hash the sources
h = hashlib.sha1()
for src in sol.sources:
h.update(src.path.encode())
h.update(src.content.encode())

return prefix + s + "_" + h.hexdigest()[:6]
return prefix + s


class BuildError(RuntimeError):
Expand Down
2 changes: 1 addition & 1 deletion flashinfer_bench/compile/builders/cuda_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None):
CUDA_DEPS = {
"cublas": ("nvidia.cublas", ["cublas", "cublasLt"]),
"cudnn": ("nvidia.cudnn", ["cudnn"]),
"cutlass": ("flashinfer_bench._deps.cutlass", None), # Header-only
"cutlass": ("flashinfer_bench.thirdparty.cutlass", None), # Header-only
}


Expand Down
9 changes: 8 additions & 1 deletion flashinfer_bench/compile/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@ def __init__(self, builders: Tuple[Builder, ...]) -> None:
if not builders:
raise ValueError("BuilderRegistry requires at least one builder")
self._builders: Tuple[Builder, ...] = builders
# Cache: solution_name -> builder to avoid repeated can_build checks
self._solution_to_builder: dict[str, Builder] = {}

def clear(self) -> None:
for b in self._builders:
try:
b.clear_cache()
except Exception:
pass
self._solution_to_builder.clear()

def build(self, defn: Definition, sol: Solution) -> Runnable:
builder = self._solution_to_builder.get(sol.name)
if builder is not None:
return builder.build(defn, sol)

for builder in self._builders:
# Choose the first
if builder.can_build(sol):
self._solution_to_builder[sol.name] = builder
return builder.build(defn, sol)
raise BuildError(f"No registered builder can build solution '{sol.name}'")

Expand Down
22 changes: 14 additions & 8 deletions flashinfer_bench/data/trace_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class TraceSet:
definition."""
traces: Dict[str, List[Trace]] = field(default_factory=dict)
"""The traces in the database. Map from definition name to all traces for that definition."""
_solution_by_name: Dict[str, Solution] = field(default_factory=dict, init=False, repr=False)
"""Fast lookup index: solution name -> Solution object. Automatically maintained."""

def __post_init__(self):
"""Initialize the _solution_by_name index from existing solutions."""
for solutions_list in self.solutions.values():
for solution in solutions_list:
if solution.name in self._solution_by_name:
raise ValueError(f"Duplicate solution name found: {solution.name}")
self._solution_by_name[solution.name] = solution

@property
def definitions_path(self) -> Path:
Expand Down Expand Up @@ -131,6 +141,7 @@ def from_path(cls: type["TraceSet"], path: Optional[str] = None) -> "TraceSet":
raise ValueError(f"Duplicate solution name: {s.name}")
seen_solutions.add(s.name)
trace_set.solutions.setdefault(s.definition, []).append(s)
trace_set._solution_by_name[s.name] = s

for p in sorted((trace_set.workloads_path.rglob("*.jsonl"))):
for t in load_jsonl_file(Trace, p):
Expand Down Expand Up @@ -174,9 +185,8 @@ def to_dict(self) -> Dict[str, Any]:
def get_solution(self, name: str) -> Optional[Solution]:
"""Get a solution by name from all loaded solutions.

Searches across all solutions in the TraceSet to find one with the specified name.
Since solution names are unique across the entire dataset, this returns at most
one solution.
Uses an O(1) index lookup for fast retrieval. Since solution names are unique
across the entire dataset, this returns at most one solution.

Parameters
----------
Expand All @@ -188,11 +198,7 @@ def get_solution(self, name: str) -> Optional[Solution]:
Optional[Solution]
The solution with the given name, or None if not found.
"""
for solution_list in self.solutions.values():
for solution in solution_list:
if solution.name == name:
return solution
return None
return self._solution_by_name.get(name)

def filter_traces(self, def_name: str, atol: float = 1e-2, rtol: float = 1e-2) -> List[Trace]:
"""Filter traces for a definition based on error bounds.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ where = ["."]
include = ["flashinfer_bench*"]
[tool.setuptools.package-data]
"flashinfer_bench" = ["py.typed"]
"flashinfer_bench._deps.cutlass" = ["include/**"]
"flashinfer_bench.thirdparty.cutlass" = ["include/**"]

[tool.black]
line-length = 100
Expand Down