Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:

def save_configs(
configs: Dict[int, BenchmarkConfig],
filename: str,
) -> None:
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")


def get_filename(
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
Expand Down Expand Up @@ -404,10 +413,7 @@ def save_configs(
per_channel_quant,
)

print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
return filename


def main(args: argparse.Namespace):
Expand Down Expand Up @@ -541,7 +547,22 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")

filename = get_filename(
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
Comment on lines +551 to +562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To align with the suggested changes to the get_filename function signature, the unused arguments hidden_size and topk should be removed from this call.

        filename = get_filename(
            E,
            shard_intermediate_size,
            dtype,
            use_fp8_w8a8,
            use_int8_w8a8,
            use_int8_w8a16,
            per_channel_quant,
            block_shape,
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to make this change if desired, I just kept it close to the original

print(
f"Start tuning over {len(search_space)} configurations to create {filename}..."
)

start = time.perf_counter()
configs = _distribute(
Expand Down Expand Up @@ -569,16 +590,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
}
save_configs(
best_configs,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
filename,
)
end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds")
Expand Down
Loading