Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,19 @@ endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set(EXE_NAME tile_example_gemm_quant)
add_executable(${EXE_NAME} EXCLUDE_FROM_ALL
gemm_quant.cpp
gemm_aquant_quantgrouped.cpp
gemm_bquant_quantgrouped_prefill_bf8i4.cpp
gemm_bquant_quantgrouped_prefill_fp8i4.cpp
gemm_bquant_quantgrouped_prefill_bf8.cpp
gemm_bquant_quantgrouped_prefill_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_prefill.cpp
gemm_quant_rowcol.cpp
gemm_quant_tensor.cpp
)
target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()
42 changes: 25 additions & 17 deletions example/ck_tile/38_block_scale_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,31 @@ This will result in an executable `build/bin/tile_example_gemm_quant_basic`
## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: C)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8 (default:fp8)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
-h Print help message (default:false)
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-a_layout A tensor data layout - Row or Column (default:R)
-b_layout B tensor data layout - Row or Column (default:C)
-bq_layout Bq tensor data layout - Row or Column (default:C)
-c_layout C tensor data layout - Row or Column (default:R)
-stride_a Tensor A stride (default:0)
-stride_q Tensor AQ stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k SplitK value (default:1)
-device Device id that will be used to run the kernel (default:0)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache Flush cache before running the kernel (default:true)
-rotating_count Rotating count (default:1000)
-quant_mode Choose aquant, bquant, tensor or rowcol (default:bquant)
-preshuffleb Enable preshuffle of tensor B (default:false)
-group_size Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128 (default:1x1x128)
```

User need to select correct mapping of config for each quant mode:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigQuant<T>;

void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings({"fp8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8", "aquant", "1x1x128"})] = [](const ck_tile::ArgParser&
arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "aquant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "aquant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::AQuantGrouped>(arg_parser);
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_bf8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_bf8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"bf8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

void bquant_quantgrouped_fp8i4_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings({"fp8i4", "bquant", "non-preshuffleb", "1x64x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigPreshuffleB_Bquant_prefill<T>;

void bquant_quantgrouped_preshuffleb_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
lut[hash_multiple_strings(
{"fp8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings(
{"bf8", "bquant", "preshuffleb", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8i4", "bquant", "preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"bf8i4", "bquant", "preshuffleb", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
}
Loading