diff --git a/examples/kernel_generator/kernel_generator.py b/examples/kernel_generator/kernel_generator.py index 914e6b4b..6e8296ff 100644 --- a/examples/kernel_generator/kernel_generator.py +++ b/examples/kernel_generator/kernel_generator.py @@ -1,7 +1,8 @@ +import asyncio import os import random import re -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import openai from kernel_generator_prompts import get_optimization_prompt, get_prompt @@ -56,7 +57,7 @@ def __init__( if base_url is not None: client_kwargs["base_url"] = base_url - self.client = openai.OpenAI(**client_kwargs) + self.client = openai.AsyncOpenAI(**client_kwargs) def _get_supported_language(self) -> SupportedLanguages: language_map = { @@ -71,7 +72,12 @@ def _get_supported_language(self) -> SupportedLanguages: return SupportedLanguages.PYTHON def generate( - self, traceset: TraceSet, definition: Definition, max_opt_rounds: int = 10 + self, + traceset: TraceSet, + definition: Definition, + gen_rounds: int = 10, + beam: bool = False, + beam_width: int = 3, ) -> Solution: """ Generate an optimized solution through iterative improvement using flashinfer-bench feedback. @@ -79,7 +85,9 @@ def generate( Args: traceset: The TraceSet containing workloads for evaluation definition: The workload definition to implement kernel for - max_opt_rounds: Maximum number of optimization rounds (default: 10) + gen_rounds: Number of generation rounds to run (or search depth if beam=True) + beam: beam search flag, default to False as it's more expensive to run + beam_width: Number of candidates to maintain in beam search (default: 3) Returns: Solution: a solution dataclass containing the optimized kernel code @@ -94,61 +102,269 @@ def generate( print(f"Generating optimized solution for {definition.name}") print(f"Using workload {selected_workload.workload.uuid} for optimization feedback") + + if beam: + return self._beam_search_generate( + traceset, definition, selected_workload, gen_rounds, beam_width + ) + else: + return asyncio.run( + self._sequential_generate_async(traceset, definition, selected_workload, gen_rounds) + ) + + async def _sequential_generate_async( + self, traceset: TraceSet, definition: Definition, selected_workload, gen_rounds: int + ) -> Solution: prompt = get_prompt(self.language, definition, self.target_gpu) - code_result = self._generate_code_from_prompt(prompt) + code_result = await self._generate_code_from_prompt(prompt) current_code = code_result["cleaned"] current_raw_code = code_result["raw"] - for round_num in range(1, max_opt_rounds + 1): - print(f"\n=== Optimization Round {round_num}/{max_opt_rounds} ===") + passing_solutions: List[Tuple[Solution, Trace]] = [] + last_solution = None + last_trace = None + + for round_num in range(1, gen_rounds + 1): + print(f"\nGeneration Round {round_num}/{gen_rounds}") solution = self._create_solution_from_code(current_code, definition, round_num) + last_solution = solution + + traces = self._evaluate_solutions(traceset, definition, [solution], selected_workload) + trace = traces[0] if traces else None + if trace: + last_trace = trace + evaluation = trace.evaluation + print(f"Evaluation status: {evaluation.status.value}") + + if evaluation.status == EvaluationStatus.PASSED: + speedup = evaluation.performance.speedup_factor + print(f"Solution PASSED! Speedup: {speedup:.2f}x") + passing_solutions.append((solution, trace)) + else: + print(f"Solution failed with {evaluation.status.value}") + if evaluation.log: + print("Error details:") + print(evaluation.log) + + if round_num < gen_rounds: + best_trace = self._get_best_trace(passing_solutions) + opt_trace = best_trace if best_trace else last_trace + + if opt_trace: + optimization_prompt = get_optimization_prompt( + self.language, definition, opt_trace, current_raw_code, self.target_gpu + ) + else: + optimization_prompt = get_prompt(self.language, definition, self.target_gpu) + + print(f"Generating code for round {round_num + 1}...") + code_result = await self._generate_code_from_prompt(optimization_prompt) + current_code = code_result["cleaned"] + current_raw_code = code_result["raw"] - temp_traceset = TraceSet( - root=traceset.root, - definitions={definition.name: definition}, - solutions={definition.name: [solution]}, - workloads={definition.name: [selected_workload]}, - traces={definition.name: []}, + return self._select_best_solution(passing_solutions, last_solution) + + def _beam_search_generate( + self, + traceset: TraceSet, + definition: Definition, + selected_workload, + depth: int, + beam_width: int, + ) -> Solution: + print(f"Starting beam search with width={beam_width}, depth={depth}") + return asyncio.run( + self._beam_search_generate_async( + traceset, definition, selected_workload, depth, beam_width ) + ) - print(f"Evaluating solution...") - benchmark = Benchmark(temp_traceset, BenchmarkConfig()) - result_traceset = benchmark.run_all() + async def _beam_search_generate_async( + self, + traceset: TraceSet, + definition: Definition, + selected_workload, + depth: int, + beam_width: int, + ) -> Solution: + passing_solutions: List[Tuple[Solution, Trace]] = [] - traces = result_traceset.traces.get(definition.name, []) - if not traces: - print("No evaluation traces found, stopping optimization") + prompt = get_prompt(self.language, definition, self.target_gpu) + + print(f"\nBeam Level 0: Generating {beam_width} initial candidates...") + code_results = await asyncio.gather( + *[self._generate_code_from_prompt(prompt) for _ in range(beam_width)] + ) + + initial_candidates = [ + {"code": code_result["cleaned"], "raw_code": code_result["raw"], "round_num": 0} + for code_result in code_results + ] + + # Create all solutions + solutions = [ + self._create_solution_from_code(candidate["code"], definition, 0, candidate_idx=i) + for i, candidate in enumerate(initial_candidates) + ] + + print(f"Evaluating {len(solutions)} candidates...") + traces = self._evaluate_solutions(traceset, definition, solutions, selected_workload) + + beam = [] + for i, (candidate, solution, trace) in enumerate( + zip(initial_candidates, solutions, traces) + ): + if trace: + evaluation = trace.evaluation + speedup = ( + evaluation.performance.speedup_factor + if evaluation.status == EvaluationStatus.PASSED + else 0.0 + ) + print(f"Candidate {i+1}: {evaluation.status.value}, speedup={speedup:.2f}x") + + if evaluation.status == EvaluationStatus.PASSED: + passing_solutions.append((solution, trace)) + + beam.append( + { + "solution": solution, + "trace": trace, + "code": candidate["code"], + "raw_code": candidate["raw_code"], + "speedup": speedup, + "round_num": 0, + } + ) + + beam.sort(key=lambda x: x["speedup"], reverse=True) + beam = beam[:beam_width] + last_solution = beam[0]["solution"] if beam else None + + for level in range(1, depth + 1): + print(f"\nBeam Level {level}/{depth}: Expanding {len(beam)} candidates...") + + # Generate optimization prompts for all beam items + prompts = [ + get_optimization_prompt( + self.language, + definition, + beam_item["trace"], + beam_item["raw_code"], + self.target_gpu, + ) + for beam_item in beam + ] + + # Generate all candidates in parallel + code_results = await asyncio.gather( + *[self._generate_code_from_prompt(prompt) for prompt in prompts] + ) + + # Create all solutions + solutions = [ + self._create_solution_from_code( + code_result["cleaned"], definition, level, candidate_idx=i + ) + for i, code_result in enumerate(code_results) + ] + + print(f"Evaluating {len(solutions)} expanded candidates...") + traces = self._evaluate_solutions(traceset, definition, solutions, selected_workload) + + new_candidates = [] + for beam_idx, (code_result, solution, trace) in enumerate( + zip(code_results, solutions, traces) + ): + if trace: + evaluation = trace.evaluation + speedup = ( + evaluation.performance.speedup_factor + if evaluation.status == EvaluationStatus.PASSED + else 0.0 + ) + print( + f" Candidate {beam_idx+1}: {evaluation.status.value}, speedup={speedup:.2f}x" + ) + + if evaluation.status == EvaluationStatus.PASSED: + passing_solutions.append((solution, trace)) + + new_candidates.append( + { + "solution": solution, + "trace": trace, + "code": code_result["cleaned"], + "raw_code": code_result["raw"], + "speedup": speedup, + "round_num": level, + } + ) + + if new_candidates: + new_candidates.sort(key=lambda x: x["speedup"], reverse=True) + beam = new_candidates[:beam_width] + last_solution = beam[0]["solution"] + print(f"Beam level {level} complete. Top speedup: {beam[0]['speedup']:.2f}x") + else: + print(f"No valid candidates at level {level}, stopping beam search") break - trace = traces[0] # Should be only one trace - evaluation = trace.evaluation + print(f"\nBeam search complete. Found {len(passing_solutions)} passing solutions.") + return self._select_best_solution(passing_solutions, last_solution) - print(f"Evaluation status: {evaluation.status.value}") + def _evaluate_solutions( + self, + traceset: TraceSet, + definition: Definition, + solutions: List[Solution], + selected_workload, + ) -> List[Optional[Trace]]: + if not solutions: + return [] + + temp_traceset = TraceSet( + root=traceset.root, + definitions={definition.name: definition}, + solutions={definition.name: solutions}, + workloads={definition.name: [selected_workload]}, + traces={definition.name: []}, + ) - if evaluation.status == EvaluationStatus.PASSED: - print(f"Solution PASSED! Speedup: {evaluation.performance.speedup_factor:.2f}x") - return solution + benchmark = Benchmark(temp_traceset, BenchmarkConfig()) + result_traceset = benchmark.run_all() - if round_num == max_opt_rounds: - print(f"Reached maximum rounds ({max_opt_rounds}), returning current solution") - return solution + traces = result_traceset.traces.get(definition.name, []) - print( - f"Solution failed with {evaluation.status.value}, extracting feedback for next round..." - ) - if evaluation.log: - print("Error details:") - print(evaluation.log) + trace_map = {trace.solution: trace for trace in traces} + return [trace_map.get(sol.name) for sol in solutions] - optimization_prompt = get_optimization_prompt( - self.language, definition, trace, current_raw_code, self.target_gpu - ) + def _get_best_trace(self, passing_solutions: List[Tuple[Solution, Trace]]) -> Optional[Trace]: + if not passing_solutions: + return None + + best_solution_trace = max( + passing_solutions, key=lambda st: st[1].evaluation.performance.speedup_factor + ) + return best_solution_trace[1] - print(f"Generating optimized code for round {round_num + 1}...") - code_result = self._generate_code_from_prompt(optimization_prompt) - current_code = code_result["cleaned"] - current_raw_code = code_result["raw"] + def _select_best_solution( + self, passing_solutions: List[Tuple[Solution, Trace]], fallback_solution: Optional[Solution] + ) -> Solution: + if passing_solutions: + best_solution_trace = max( + passing_solutions, key=lambda st: st[1].evaluation.performance.speedup_factor + ) + best_solution = best_solution_trace[0] + best_speedup = best_solution_trace[1].evaluation.performance.speedup_factor + print(f"\nReturning best solution with speedup: {best_speedup:.2f}x") + return best_solution + elif fallback_solution: + print(f"\nNo passing solutions found, returning last generated solution") + return fallback_solution + else: + raise ValueError("No solutions generated") def _parse_xml_files(self, code: str) -> Dict[str, str]: files = {} @@ -211,15 +427,16 @@ def _clean_generated_code(self, code: str) -> str: return code - def _generate_code_from_prompt(self, prompt: str): + async def _generate_code_from_prompt(self, prompt: str): + """Generate code from prompt using async API""" try: if self.model_name.startswith("gpt-5") or self.model_name.startswith("o3"): - response = self.client.responses.create( + response = await self.client.responses.create( model=self.model_name, input=prompt, reasoning={"effort": self.reasoning_effort} ) generated_code = response.output_text.strip() - else: # We use the completions api for OpenAI SDK compatible models - response = self.client.chat.completions.create( + else: + response = await self.client.chat.completions.create( model=self.model_name, messages=[{"role": "user", "content": prompt}] ) generated_code = response.choices[0].message.content.strip() @@ -232,18 +449,16 @@ def _generate_code_from_prompt(self, prompt: str): print(f"Error while generating code: {e}") raise - def _create_solution_from_code(self, code, definition: Definition, round_num: int) -> Solution: + def _create_solution_from_code( + self, code, definition: Definition, round_num: int, candidate_idx: int = 0 + ) -> Solution: # Include reasoning effort in name and description for GPT-5 models if self.model_name.startswith("gpt-5") or self.model_name.startswith("o3"): - solution_name = f"{self.model_name}_{definition.name}_{self.language}_optimized_r{round_num}_{self.reasoning_effort}" - solution_description = f"{self.model_name} optimized kernel for {definition.name} (round {round_num}, reasoning effort: {self.reasoning_effort})" + solution_name = f"{self.model_name}_{definition.name}_{self.language}_optimized_r{round_num}_c{candidate_idx}_{self.reasoning_effort}" + solution_description = f"{self.model_name} optimized kernel for {definition.name} (round {round_num}, candidate {candidate_idx}, reasoning effort: {self.reasoning_effort})" else: - solution_name = ( - f"{self.model_name}_{definition.name}_{self.language}_optimized_r{round_num}" - ) - solution_description = ( - f"{self.model_name} optimized kernel for {definition.name} (round {round_num})" - ) + solution_name = f"{self.model_name}_{definition.name}_{self.language}_optimized_r{round_num}_c{candidate_idx}" + solution_description = f"{self.model_name} optimized kernel for {definition.name} (round {round_num}, candidate {candidate_idx})" # Handle different code formats based on language if self.language.lower() == "cuda" and isinstance(code, dict): diff --git a/examples/kernel_generator/kernel_generator_example.py b/examples/kernel_generator/kernel_generator_example.py index d3371803..78589954 100644 --- a/examples/kernel_generator/kernel_generator_example.py +++ b/examples/kernel_generator/kernel_generator_example.py @@ -19,21 +19,29 @@ def main(): """ Generate optimized solutions for all definitions in the traceset. """ - model_name = "gpt-5-2025-08-07" # Choose model here - language = "triton" - target_gpu = "B200" + # TODO: select model, language, target gpu, definition + model_name = "gpt-5-2025-08-07" # Choose author-model + language = "triton" # Target solution language + target_gpu = "B200" # Choose solution target GPU + definition = "" # Leave empty to generate solutions for all definitions # TODO: adjust local path to traceset - traceset_path = "/home/akj2/flashinfer-trace" + traceset_path = "/path/to/flashinfer-trace" print(f"Loading TraceSet from: {traceset_path}") traceset = TraceSet.from_path(traceset_path) - # all_definitions = list(traceset.definitions.keys()) - # Filter for rmsnorm definitions only - all_definitions = [name for name in traceset.definitions.keys() if "rmsnorm" in name.lower()] + all_definitions = list(traceset.definitions.keys()) - print(f"All definitions found: {len(all_definitions)}") + if definition: + if definition in all_definitions: + all_definitions = [definition] + print(f"Generating solution {definition}") + else: + print(f"Definition '{definition}' not found in traceset") + return + + print(f"Found {len(all_definitions)} definitions to generate solutions") api_key = os.getenv("LLM_API_KEY") base_url = os.getenv("BASE_URL") @@ -84,7 +92,10 @@ def main(): solution = generator.generate( traceset=traceset, definition=definition, - max_opt_rounds=10, # For our baseline, we used 10 rounds + gen_rounds=10, # For our baseline, we used 10 rounds + # TODO: uncomment bellow to use beam search + # beam=True, + # beam_width=3, ) print(f"Successfully generated solution for {definition_name}") diff --git a/examples/kernel_generator/kernel_generator_prompts.py b/examples/kernel_generator/kernel_generator_prompts.py index ab797846..79d2756e 100644 --- a/examples/kernel_generator/kernel_generator_prompts.py +++ b/examples/kernel_generator/kernel_generator_prompts.py @@ -95,7 +95,6 @@ def _format_trace_logs(trace: Trace) -> str: IMPORTANT: Use only valid Python/Triton syntax: - NO hexadecimal float literals (0x1.234p5) - use decimal equivalents - NO C/CUDA specific syntax - this is Python/Triton code -- Use math.log(2), math.pi, math.e instead of hex literals - All code must be valid Python that passes ast.parse() - Expose a "run" entry point function that can be called to execute the kernel @@ -147,7 +146,6 @@ def _format_trace_logs(trace: Trace) -> str: IMPORTANT: Use only valid Python/Triton syntax: - NO hexadecimal float literals (0x1.234p5) - use decimal equivalents - NO C/CUDA specific syntax - this is Python/Triton code -- Use math.log(2), math.pi, math.e instead of hex literals - All code must be valid Python that passes ast.parse() - Expose a "run" entry point function that can be called to execute the kernel @@ -184,7 +182,6 @@ def _format_trace_logs(trace: Trace) -> str: - Use the definition's tensor shapes, dtypes, and axes information to guide memory access patterns and optimization strategies - Optimize for {target_gpu} GPU characteristics (memory hierarchy, compute units, etc.) - For fixed axis values, optimize specifically for those constants rather than general cases -- You may use 3rd party libraries (cuBLAS, cuDNN, CUTLASS) when beneficial, but custom implementations often perform better for specialized kernels with known axis constraints IMPORTANT: Generate code in XML format with exactly 3 files with these strict names: @@ -209,10 +206,6 @@ def _format_trace_logs(trace: Trace) -> str: - Entry point function named "run" that can be called to execute the implementation - Handle both args and kwargs properly - Move CPU data to GPU, execute kernels, and return results to CPU -- Include PyTorch C++ extension bindings using PYBIND11_MODULE -- The "run" function must be exposed to Python through the binding -- Include proper tensor type conversion between PyTorch tensors and CUDA pointers -- Include all necessary PyTorch headers: #include Code Generation Guidelines: @@ -223,8 +216,6 @@ def _format_trace_logs(trace: Trace) -> str: - Implement proper error checking with cudaGetLastError() - Use appropriate grid and block dimensions for the problem size - Leverage constant memory for frequently accessed read-only data -- Use PyTorch tensor API (torch::Tensor) for all tensor arguments in the "run" function -- Convert PyTorch tensors to CUDA pointers using .data_ptr() or similar methods - Ensure proper CUDA stream synchronization and error handling Generate the implementation:""" @@ -252,7 +243,7 @@ def _format_trace_logs(trace: Trace) -> str: - Tune block sizes and grid dimensions for maximum occupancy - Utilize shared memory effectively to reduce global memory transactions - Optimize register usage and minimize divergent branches - - Consider using specialized libraries (cuBLAS, cuDNN, CUTLASS) where beneficial + - Consider using specialized libraries (such as CUTLASS) where beneficial - Leverage constant axis values for compile-time optimizations Requirements for the optimized implementation: @@ -286,10 +277,6 @@ def _format_trace_logs(trace: Trace) -> str: - Entry point function named "run" that can be called to execute the implementation - Handle both args and kwargs properly - Move CPU data to GPU, execute kernels, and return results to CPU -- MUST include PyTorch C++ extension bindings using PYBIND11_MODULE -- The "run" function must be exposed to Python through the binding -- Include proper tensor type conversion between PyTorch tensors and CUDA pointers -- Include all necessary PyTorch headers: #include Code Generation Guidelines: @@ -300,12 +287,37 @@ def _format_trace_logs(trace: Trace) -> str: - Implement proper error checking with cudaGetLastError() - Use appropriate grid and block dimensions for the problem size - Leverage constant memory for frequently accessed read-only data -- Use PyTorch tensor API (torch::Tensor) for all tensor arguments in the "run" function -- Convert PyTorch tensors to CUDA pointers using .data_ptr() or similar methods - Ensure proper CUDA stream synchronization and error handling Generate the corrected and optimized implementation:""" +TORCH_BINDINGS_PROMPT = """ +Use TORCH for your generated kernel host function and bindings + +Requirements: +- Include all necessary headers (torch/extension.h, kernel.h, etc.) +- Implement the "run" function that: + * Takes torch::Tensor arguments + * Validates tensor properties (device, dtype, shape) + * Extracts raw pointers using .data_ptr() + * Calls the CUDA kernel with appropriate launch configuration + * Returns results as torch::Tensor +- Use PYBIND11_MODULE to bind the "run" function: + * PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + * m.def("run", &run, "Kernel execution function"); + * }} +- Handle both positional args and kwargs properly +- Include proper error messages for invalid inputs + +- Use torch::Tensor for all tensor arguments +- Use .device().is_cuda() to check if tensors are on GPU +- Use .dtype() to validate tensor data types +- Use .sizes() or .size(dim) to get tensor dimensions +- Use .data_ptr() or .data_ptr() to get raw pointers +- Call cudaDeviceSynchronize() or cudaGetLastError() for error checking +- Return torch::Tensor from the run function +- Handle exceptions gracefully with proper error messages""" + def get_prompt(language: str, definition: Definition, target_gpu: str = "H100") -> str: prompts = {"triton": TRITON_PROMPT, "python": PYTHON_PROMPT, "cuda": CUDA_PROMPT}