diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index b1ae9369a2..932acb72fd 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -6,8 +6,19 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) - target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + set(EXE_NAME tile_example_gemm_quant) + add_executable(${EXE_NAME} EXCLUDE_FROM_ALL + gemm_quant.cpp + gemm_aquant_quantgrouped.cpp + gemm_bquant_quantgrouped_prefill_bf8i4.cpp + gemm_bquant_quantgrouped_prefill_fp8i4.cpp + gemm_bquant_quantgrouped_prefill_bf8.cpp + gemm_bquant_quantgrouped_prefill_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_prefill.cpp + gemm_quant_rowcol.cpp + gemm_quant_tensor.cpp + ) + target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 496697ca32..64ecebd15a 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic` ## example ``` args: - -b batch size (default:1) - -m m dimension (default:1024) - -n n dimension (default:2048) - -k k dimension (default:64) - -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: C) - -c_layout Tensor C data layout (default: R) - -stride_a Tensor A stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) - -e Absolute error tolerance (default:1e-5) - -prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8) - -warmup number of iterations before benchmark the kernel (default:10) - -repeat number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -quant_mode Which quant method to use (aquant, bquant, tensor, rowcol) + -h Print help message (default:false) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -a_layout A tensor data layout - Row or Column (default:R) + -b_layout B tensor data layout - Row or Column (default:C) + -bq_layout Bq tensor data layout - Row or Column (default:C) + -c_layout C tensor data layout - Row or Column (default:R) + -stride_a Tensor A stride (default:0) + -stride_q Tensor AQ stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8) + -warmup Number of iterations before benchmarking the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:1000) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k SplitK value (default:1) + -device Device id that will be used to run the kernel (default:0) + -init 0:random, 1:linear, 2:constant(1) (default:0) + -flush_cache Flush cache before running the kernel (default:true) +-rotating_count Rotating count (default:1000) + -quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant) + -preshuffleb Enable preshuffle of tensor B (default:false) + -group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128) ``` User need to select correct mapping of config for each quant mode: diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp new file mode 100644 index 0000000000..3786230ff0 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuant; + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser& + arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::AQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp new file mode 100644 index 0000000000..cb9f8b62cf --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp new file mode 100644 index 0000000000..33ae3bc4a9 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_bf8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp new file mode 100644 index 0000000000..526c35b081 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = + decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp new file mode 100644 index 0000000000..4b2a8efb14 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_prefill_fp8i4.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigBQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); +#ifndef CK_GFX950_SUPPORT + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +#endif + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp new file mode 100644 index 0000000000..d9591bb588 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_prefill.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill; + +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut) +{ + using QuantGroupSize = ck_tile::QuantGroupShape>; + lut[hash_multiple_strings( + {"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp new file mode 100644 index 0000000000..a35f867f5d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "gemm_utils.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("h", "false", "Print help message") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row or Column") + .insert("b_layout", "C", "B tensor data layout - Row or Column") + .insert("bq_layout", "C", "Bq tensor data layout - Row or Column") + .insert("c_layout", "R", "C tensor data layout - Row or Column") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_q", "0", "Tensor AQ stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0: No validation, 1: Validation on CPU, 2: Validation on GPU") + .insert("prec", + "fp8", + "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " + "or bf8i4") + .insert("warmup", "50", "Number of iterations before benchmarking the kernel") + .insert("repeat", "1000", "Number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "SplitK value") + .insert("device", "0", "Device id that will be used to run the kernel") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "Flush cache before running the kernel") + .insert("rotating_count", "1000", "Rotating count") + .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") + .insert("preshuffleb", "false", "Enable preshuffle of tensor B") + .insert("group_size", + "1x1x128", + "Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +auto gen_lut_key(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string quant_mode = arg_parser.get_str("quant_mode"); + + std::vector params = {data_type, quant_mode}; + + if(quant_mode == "bquant") + { + std::string preshuffleb = + arg_parser.get_bool("preshuffleb") ? "preshuffleb" : "non-preshuffleb"; + params.push_back(preshuffleb); + } + if(quant_mode != "rowcol" && quant_mode != "tensor") + { + // NOTE: rowcol and tensor pipeline do not use group size + std::string group_size_str = arg_parser.get_str("group_size"); + params.push_back(group_size_str); + } + + return hash_multiple_strings(params); +} + +void aquant_quantgrouped_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_fp8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_bf8i4_instance_factory( + std::unordered_map>& lut); +void bquant_quantgrouped_preshuffleb_instance_factory( + std::unordered_map>& lut); +void quant_rowcol_instance_factory( + std::unordered_map>& lut); +void quant_tensor_instance_factory( + std::unordered_map>& lut); + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result || arg_parser.get_bool("h")) + { + arg_parser.print(); + return -1; + } + + auto device_id = arg_parser.get_int("device"); + std::cout << "Device ID: " << device_id << std::endl; + ck_tile::hip_check_error(hipSetDevice(device_id)); + + std::unordered_map> lut; + aquant_quantgrouped_instance_factory(lut); + bquant_quantgrouped_fp8_instance_factory(lut); + bquant_quantgrouped_bf8_instance_factory(lut); + bquant_quantgrouped_fp8i4_instance_factory(lut); + bquant_quantgrouped_bf8i4_instance_factory(lut); + bquant_quantgrouped_preshuffleb_instance_factory(lut); + quant_rowcol_instance_factory(lut); + quant_tensor_instance_factory(lut); + + auto key = gen_lut_key(arg_parser); + + if(lut.find(key) != lut.end()) + { + return lut[key](arg_parser); + } + else + { + std::cerr + << "Error: Combination of prec, quant_mode, preshuffleb, and group_size not supported." + << std::endl; + return -1; + } +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp deleted file mode 100644 index d605a2b780..0000000000 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ /dev/null @@ -1,428 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -// This example demonstrates 2D block scale quantization (N×K) for BQuant -// using non-preshuffled configuration. -// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example -// This is currently done separately to avoid too verbose dispatching. - -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" -#include "gemm_utils.hpp" - -template -float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - static_assert(std::is_same_v); - using ComputeDataType = std::conditional_t; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using GemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - // This example only supports BQuant (no AQuant) - // For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3 - using BaseGemmPipeline = std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>; - - const ck_tile::index_t K_split = - (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr bool transpose_c = false; - - // row-col and tensor quants use the regular pipeline, A/B quants use their own - using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmRowColTensorQuantPipelineProblem, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>>; - - using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant, - ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrMem, // memory pipeline hardcoded - // for aquant - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - GemmConfig::TiledMMAPermuteN>>; - using Kernel = - ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << PipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - float ave_time = 0; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper - rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString( - hipMemsetAsync(args.c_ptr, - 0, - args.M * args.N * sizeof(typename TypeConfig::CDataType), - s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -#include "run_gemm_quant_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if((QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) && - GemmConfig::PreshuffleB) - { - throw std::runtime_error( - "Preshuffling weight matrix is not supported for AQuant or RowColQuant"); - } - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return 0; -} - -// Forward declaration for dispatch function -template