diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b9ecfcef4..bc7229fc81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added pooling kernel in CK_TILE * Added top-k sigmoid kernel in CK_TILE * Added the blockscale 2D support for CK_TILE GEMM. +* Added reduce and multi reduction kernels ### Changed diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt index 2f48bb85a5..f856449d78 100644 --- a/example/ck_tile/05_reduce/CMakeLists.txt +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -12,8 +12,24 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) +# Multi Reduce Threadwise Example +set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise") +add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp) +target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS}) + +# Multi Reduce Blockwise Example +set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock") +add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp) +target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS}) + # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global # TODO: consider codegen a makefile by us -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp new file mode 100644 index 0000000000..76002c34a9 --- /dev/null +++ b/example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/utility/json_dump.hpp" +#include + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("n", "32", "n dimension") + .insert("h", "19", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = float; + + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Validate input dimensions + const ck_tile::index_t kept_dim_len_prod = N * C; + const ck_tile::index_t reduce_total_length = H * W; + + if(kept_dim_len_prod == 0) + { + std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C + << ", product=" << kept_dim_len_prod << ")." << std::endl; + std::cerr << "This will result in an empty output tensor." << std::endl; + return false; + } + + if(reduce_total_length == 0) + { + std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W + << ", product=" << reduce_total_length << ")." << std::endl; + std::cerr << "This will result in an empty reduction with no data to process." << std::endl; + std::cerr << "The kernel will exit early without performing any computation." << std::endl; + return false; + } + + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_add_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_ref({N, C}, {C, 1}); + auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref); + + ck_tile::HostTensor y_host_add_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_dev({N, C}, {C, 1}); + auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev); + + const auto number_operations = y_host_dev_tuple.size(); + + std::vector h(number_operations * N * C); + + auto y_buf_size = number_operations * + y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem y_buf(y_buf_size); + + const auto output_tensor_offset = N * C; + + // Operations: one doing a sum reduction, the other computing the mean square + // In the case of mean square: + // 1. The element wise operation squares each element before reduction + // 2. The reduction operation sum the squared element + // 3. The accumulator element wise operation divides the result by the total number of reduced + // elements (intra block operation) + // 4. The partial result is updated across blocks using inter block reduction, a sum. + auto reduce_ops = + ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions + auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnarySquare{}); // Elementwise + // ops + auto accumulator_elementwise_ops = ck_tile::make_tuple( + ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnaryDivide{ + reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block + auto inter_block_reduce_ops = ck_tile::make_tuple( + ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockPerCu = 1; + + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile:: + Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceMultiblock; + + // Determine block group size for multi-block reduction + // block_group_size records how many blocks participate to a reduction (input data dependent) + // , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient + // to process the whole reduction, each thread will to process multiple thread tile + // a num_block_tile_iterations times + int num_block_tile_iterations; + int block_group_size; + + Kernel::CalculateBlockGroupParams( + reduce_total_length, num_block_tile_iterations, block_group_size); + + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + ck_tile::index_t kGridSize = + ((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size; + + std::cout << "Block group size: " << block_group_size + << ", Num block tile iterations: " << num_block_tile_iterations + << ", Reduce total length: " << reduce_total_length << std::endl; + std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl; + + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + // Init the output data with identity values respective to each reduce op + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + constexpr auto op = reduce_ops.at(i); + const auto identity_val = op.template GetIdentityValue(); + const auto output_number_elements = N * C; + std::fill(h.begin() + i * output_number_elements, + h.begin() + (i + 1) * output_number_elements, + identity_val); + }); + + auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); }; + + float ave_time = launch_kernel_time_mask( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + clear_output_buffer, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_elementwise_ops, + inter_block_reduce_ops) + + ); + + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_multiple_reduce_multiblock( + x_host, + y_host_ref_tuple, + reduce_ops, + kept_dim, + reduce_dims, + elementwise_ops, + accumulator_elementwise_ops, + inter_block_reduce_ops, + block_group_size); + std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl; + + // Transfer data from device and check error for each operation + y_buf.FromDevice(h.data()); + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(y_host_dev_tuple.get(ck_tile::number{}).data(), + h.data() + i * output_tensor_offset, + output_tensor_offset * sizeof(YDataType)); + std::cout << "Checking operation " << i << ": " << std::endl; + + bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number{}), + y_host_ref_tuple.get(ck_tile::number{})); + + if(pass_op) + { + std::cout << "✅ valid results for this operation" << std::endl; + } + pass &= pass_op; + }); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp b/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp new file mode 100644 index 0000000000..941f720272 --- /dev/null +++ b/example/ck_tile/05_reduce/multiple_reduce_threadwise.cpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/utility/json_dump.hpp" +#include + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("n", "32", "n dimension") + .insert("h", "7", "h dimension") + .insert("w", "7", "w dimension") + .insert("c", "512", "c dimension") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "multi_reduce.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t H = arg_parser.get_int("h"); + ck_tile::index_t W = arg_parser.get_int("w"); + ck_tile::index_t C = arg_parser.get_int("c"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Validate input dimensions + const ck_tile::index_t kept_dim_len_prod = N * C; + const ck_tile::index_t reduce_total_length = H * W; + + if(kept_dim_len_prod == 0) + { + std::cerr << "Warning: Product of kept dimensions is zero (N=" << N << ", C=" << C + << ", product=" << kept_dim_len_prod << ")." << std::endl; + std::cerr << "This will result in an empty output tensor." << std::endl; + return false; + } + + if(reduce_total_length == 0) + { + std::cerr << "Warning: Product of reduce dimensions is zero (H=" << H << ", W=" << W + << ", product=" << reduce_total_length << ")." << std::endl; + std::cerr << "This will result in an empty reduction with no data to process." << std::endl; + std::cerr << "The kernel will exit early without performing any computation." << std::endl; + return false; + } + + std::vector problem_shape = {N, H, W, C}; + std::vector strides(4); + strides[0] = H * W * C; + strides[1] = W * C; + strides[2] = C; + strides[3] = 1; + + // Define reduction specification: + constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce + + ck_tile::HostTensor x_host(problem_shape, strides); + ck_tile::HostTensor y_host_add_ref({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_ref({N, C}, {C, 1}); + auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref); + + ck_tile::HostTensor y_host_add_dev({N, C}, {C, 1}); + ck_tile::HostTensor y_host_max_dev({N, C}, {C, 1}); + auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev); + + const auto number_operations = y_host_dev_tuple.size(); + + // Two operations: one do a sum reduction, the other computing the mean square + auto reduce_ops = + ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions ops + auto elementwise_ops = + ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnarySquare{}); // Elementwise ops + auto accumulator_elementwise_ops = + ck_tile::make_tuple(ck_tile::element_wise::PassThrough{}, + ck_tile::element_wise::UnaryDivide{ + reduce_total_length}); // Accumulator Elementiwise ops on reduction, + + auto y_buf_size = number_operations * + y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem y_buf(y_buf_size); + + const auto output_tensor_offset = N * C; + + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::Reduce2dShape; + using Problem = ck_tile:: + Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceThreadWise; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + + // Create input tensor shape and strides + auto input_shape = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]); + + if(!Kernel::IsSupportedArgument( + C, input_strides)) // output tensor's continuous dimension and input strides + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_elementwise_ops)); + + std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + std::vector h(number_operations * N * C); + + // reference + ck_tile::reference_multiple_reduce( + x_host, + y_host_ref_tuple, + reduce_ops, + kept_dim, + reduce_dims, + elementwise_ops, + accumulator_elementwise_ops); + std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl; + + // Transfer data from device and check error for each operation + y_buf.FromDevice(h.data()); + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(y_host_dev_tuple.get(ck_tile::number{}).data(), + h.data() + i * output_tensor_offset, + output_tensor_offset * sizeof(YDataType)); + pass &= ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number{}), + y_host_ref_tuple.get(ck_tile::number{})); + }); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } +} diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp index 69449711e0..5db3dcb6f4 100644 --- a/include/ck_tile/core/utility/reduce_operator.hpp +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -34,6 +35,11 @@ struct Add return type_convert(y_ + x_); } + + CK_TILE_HOST_DEVICE static constexpr auto GetAtomic() + { + return memory_operation_enum::atomic_add; + } }; struct SquareAdd diff --git a/include/ck_tile/host/reference/reference_reduce.hpp b/include/ck_tile/host/reference/reference_reduce.hpp index 9952b7b009..632ca77fc5 100644 --- a/include/ck_tile/host/reference/reference_reduce.hpp +++ b/include/ck_tile/host/reference/reference_reduce.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/ops/elementwise.hpp" #include namespace ck_tile { @@ -108,4 +109,233 @@ CK_TILE_HOST void reference_reduce(const HostTensor& x_tensor, make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency()); } + +template containing reduce operations + typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to + // keep + typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices + // to reduce + typename ElementWiseOps, + typename AccElementWiseOps> +CK_TILE_HOST void reference_multiple_reduce(const HostTensor& x_tensor, + YRefTuple& y_tensor_tuple, + ReduceOps reduce_ops, + KeptDim kept_dim, + ReduceDims reduce_dims, + ElementWiseOps elementwise_ops, + AccElementWiseOps accumulator_ops) +{ + const auto& x_lengths = x_tensor.mDesc.get_lengths(); + + // Calculate total kept elements (product of all kept dimension lengths) + index_t total_kept_elements = 1; + static_for<0, kept_dim.size(), 1>{}( + [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; }); + + // Calculate total reduce elements (product of all reduce dimension lengths) + index_t total_reduce_elements = 1; + static_for<0, reduce_dims.size(), 1>{}( + [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; }); + + auto f = [&](auto linear_kept_idx) { + // Initialize accumulators for each reduction operation + auto v_acc_tuple = ck_tile::generate_tuple( + [&](auto i) { + return reduce_ops.template at().template GetIdentityValue(); + }, + number{}); + + // Convert linear kept index to multi-dimensional kept indices + std::vector kept_indices(kept_dim.size()); + index_t temp_kept = linear_kept_idx; + static_for<0, kept_dim.size(), 1>{}([&](auto i) { + constexpr auto dim_idx = kept_dim.size() - 1 - i; + constexpr auto dim = kept_dim.at(dim_idx); + const auto len = x_lengths[dim]; + kept_indices[dim_idx] = temp_kept % len; + temp_kept /= len; + }); + + for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx) + { + // Convert linear reduce index to multi-dimensional reduce indices + std::vector reduce_indices(reduce_dims.size()); + index_t temp_reduce = reduce_idx; + static_for<0, reduce_dims.size(), 1>{}([&](auto i) { + constexpr auto dim_idx = reduce_dims.size() - 1 - i; + constexpr auto dim = reduce_dims.at(dim_idx); + const auto len = x_lengths[dim]; + reduce_indices[dim_idx] = temp_reduce % len; + temp_reduce /= len; + }); + + // Build full input tensor indices by combining kept and reduce indices + std::vector full_indices(x_lengths.size(), 0); + static_for<0, kept_dim.size(), 1>{}( + [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; }); + static_for<0, reduce_dims.size(), 1>{}( + [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; }); + + // Access input tensor element + auto v_a = type_convert(x_tensor(full_indices)); + + // Apply each reduction operation + static_for<0, reduce_ops.size(), 1>{}([&](auto i) { + // Apply element-wise operation before reduction + elementwise_ops.at(i)(v_a, v_a); + + v_acc_tuple.template at() = + reduce_ops.template at()(v_acc_tuple.template at(), v_a); + }); + } + + static_for<0, reduce_ops.size(), 1>{}([&](auto i) { + // Apply accumulator element-wise operation after reduction + accumulator_ops.at(i)(v_acc_tuple.template at(), v_acc_tuple.template at()); + }); + + // Calculate output tensor index using kept indices + // The output tensor has the same structure as the kept dimensions + std::vector y_indices(kept_dim.size()); + static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; }); + + // Store results for each reduction operation in the output tensor + static_for<0, reduce_ops.size(), 1>{}([&](auto i) { + y_tensor_tuple.template at()(y_indices) = + type_convert(v_acc_tuple.template at()); + }); + }; + + make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency()); +} + +template containing reduce operations + typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to + // keep + typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices + // to reduce + typename ElementWiseOps, + typename AccElementWiseOps, + typename InterBlockReduceOps> +CK_TILE_HOST void reference_multiple_reduce_multiblock(const HostTensor& x_tensor, + YRefTuple& y_tensor_tuple, + ReduceOps reduce_ops, + KeptDim kept_dim, + ReduceDims reduce_dims, + ElementWiseOps elementwise_ops, + AccElementWiseOps accumulator_ops, + InterBlockReduceOps inter_block_reduce_ops, + ck_tile::index_t num_blocks) +{ + const auto& x_lengths = x_tensor.mDesc.get_lengths(); + + // Calculate total kept elements (product of all kept dimension lengths) + index_t total_kept_elements = 1; + static_for<0, kept_dim.size(), 1>{}( + [&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; }); + + // Calculate total reduce elements (product of all reduce dimension lengths) + index_t total_reduce_elements = 1; + static_for<0, reduce_dims.size(), 1>{}( + [&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; }); + + // Initialize output tensors + static_for<0, reduce_ops.size(), 1>{}([&](auto i) { + auto& y_tensor = y_tensor_tuple.template at(); + for(auto& val : y_tensor.mData) + { + val = inter_block_reduce_ops.template at().template GetIdentityValue(); + } + }); + + auto f = [&](auto linear_kept_idx) { + // Convert linear kept index to multi-dimensional kept indices + std::vector kept_indices(kept_dim.size()); + index_t temp_kept = linear_kept_idx; + static_for<0, kept_dim.size(), 1>{}([&](auto i) { + constexpr auto dim_idx = kept_dim.size() - 1 - i; + constexpr auto dim = kept_dim.at(dim_idx); + const auto len = x_lengths[dim]; + kept_indices[dim_idx] = temp_kept % len; + temp_kept /= len; + }); + + // Calculate output tensor index using kept indices + std::vector y_indices(kept_dim.size()); + static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; }); + + const auto max_element_per_block = (total_reduce_elements + num_blocks - 1) / num_blocks; + + for(index_t block_id = 0; block_id < num_blocks; ++block_id) + { + // Initialize accumulators for each reduction operation for the current block + auto v_acc_tuple = ck_tile::generate_tuple( + [&](auto i) { + return reduce_ops.template at().template GetIdentityValue(); + }, + number{}); + + const index_t element_offset = block_id * max_element_per_block; + const index_t element_end = + std::min(element_offset + max_element_per_block, total_reduce_elements); + + for(index_t linear_reduce_idx = element_offset; linear_reduce_idx < element_end; + ++linear_reduce_idx) + { + // Convert linear reduce index to multi-dimensional reduce indices + std::vector reduce_indices(reduce_dims.size()); + index_t temp_reduce = linear_reduce_idx; + static_for<0, reduce_dims.size(), 1>{}([&](auto i) { + constexpr auto dim_idx = reduce_dims.size() - 1 - i; + constexpr auto dim = reduce_dims.at(dim_idx); + const auto len = x_lengths[dim]; + reduce_indices[dim_idx] = temp_reduce % len; + temp_reduce /= len; + }); + + // Build full input tensor indices by combining kept and reduce indices + std::vector full_indices(x_lengths.size(), 0); + static_for<0, kept_dim.size(), 1>{}( + [&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; }); + static_for<0, reduce_dims.size(), 1>{}( + [&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; }); + + // Access input tensor element + const auto v_a_in = type_convert(x_tensor(full_indices)); + + // Apply each reduction operation + static_for<0, reduce_ops.size(), 1>{}([&](auto i) { + auto v_a = v_a_in; + // Apply element-wise operation before reduction + elementwise_ops.at(i)(v_a, v_a); + + v_acc_tuple.template at() = + reduce_ops.template at()(v_acc_tuple.template at(), v_a); + }); + } + + static_for<0, reduce_ops.size(), 1>{}([&](auto i) { + // Apply accumulator element-wise operation after reduction + accumulator_ops.at(i)(v_acc_tuple.template at(), v_acc_tuple.template at()); + + // Update the output tensor with the partial result from this block + auto& y_tensor = y_tensor_tuple.template at(); + auto& y_val = y_tensor(y_indices); + y_val = inter_block_reduce_ops.template at()( + y_val, type_convert(v_acc_tuple.template at())); + }); + } + }; + + make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency()); +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index f8f8059469..9402d63782 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -453,6 +453,12 @@ struct PassThrough /* otherwise (r-value or const) → do nothing */ } + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const + { + y = ck_tile::type_convert>(x); + } + template CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&...) const -> void { diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index d628e9c945..c7731c24f5 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -6,6 +6,9 @@ #include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d.hpp" #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp" +#include "ck_tile/ops/reduce/kernel/multi_reduce2d_multiblock_kernel.hpp" +#include "ck_tile/ops/reduce/kernel/multi_reduce2d_threadwise_kernel.hpp" #include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index c666608bfd..ab0a639988 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -165,8 +165,6 @@ struct BlockReduce2d template CK_TILE_DEVICE static auto MakeYBlockTile() { - static_assert(std::is_same_v, "wrong!"); - // FIXME: hard coded to reduce 2nd axis constexpr auto reduce_dims = sequence<1>{}; diff --git a/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp new file mode 100644 index 0000000000..67f923d3e1 --- /dev/null +++ b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_kernel.hpp @@ -0,0 +1,446 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" +#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" + +// Multi Reduce2d Unified Kernel: +// ======================================= +// This kernel implements multiple 2D reduction operations that reduce data along the specified +// dimensions of a matrix. It supports both single-block (threadwise) and multi-block + +namespace ck_tile { + +/// @brief TilePartitioner for 2D reduction operations +template +struct Reduce2dTilePartitioner +{ + using BlockShape = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockShape::Block_M; + static constexpr index_t NPerBlock = BlockShape::Block_N; + + CK_TILE_HOST_DEVICE Reduce2dTilePartitioner() noexcept = delete; + + /// @brief Construct partitioner with problem dimensions + /// @param M_ Output dimension size (kept dimension) + /// @param N_ Reduction dimension size + CK_TILE_HOST_DEVICE Reduce2dTilePartitioner(index_t M_, index_t N_) noexcept : M(M_), N(N_) {} + + /// @brief Get output tile index for threadwise reduction + /// @param block_idx Block index + /// @return M-dimension tile index + CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_idx) const noexcept -> index_t + { + return amd_wave_read_first_lane(block_idx); + } + + /// @brief Get output tile index and block local ID for multi-block reduction + /// @param block_idx Global block index + /// @param block_group_size Number of blocks per output tile + /// @return Tuple of (tile_index, local_block_id) + CK_TILE_DEVICE auto + GetOutputTileIndexMultiBlock(index_t block_idx, + index_t block_group_size) const noexcept -> tuple + { + const index_t tile_idx = amd_wave_read_first_lane(block_idx / block_group_size); + const index_t local_idx = amd_wave_read_first_lane(block_idx % block_group_size); + return make_tuple(tile_idx, local_idx); + } + + private: + index_t M; + index_t N; +}; + +template +struct MultiReduce2d +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + static constexpr bool ForceMultiBlock = ForceMultiBlock_; // false: threadwise, true: multiblock + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + + using TilePartitioner = Reduce2dTilePartitioner; + + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? kBlockSize / 2 : kBlockSize; + } + + CK_TILE_HOST_DEVICE static void CalculateBlockGroupParams(const int reduce_total_length, + int& num_block_tile_iterations, + int& block_group_size) + { + constexpr int max_block_group_size = + 128; // Maximum 128, as in CK. It balances between latency (i.e. limiting stalls when + // performing the atomic operation) and block parallelism. + + num_block_tile_iterations = + (reduce_total_length + (Problem::BlockShape::Block_N * max_block_group_size) - 1) / + (Problem::BlockShape::Block_N * max_block_group_size); + + // This should only happen if reduce_total_length is 0 (empty tensor) + if(num_block_tile_iterations == 0) + { +#ifndef __HIP_DEVICE_COMPILE__ + // Warning only on host side + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + printf("Warning: reduce_total_length is 0, there is no data to process\n"); + } +#endif + block_group_size = 1; + return; + } + + block_group_size = + (reduce_total_length + (Problem::BlockShape::Block_N * num_block_tile_iterations) - 1) / + (Problem::BlockShape::Block_N * num_block_tile_iterations); + } + + private: + // Helper function to calculate optimal vector size for input tensor + template + static constexpr index_t CalculateInputVectorSize() + { + using S = typename Problem::BlockShape; + constexpr index_t memory_vector_size = 16 / sizeof(XDataType); // Vectorization + constexpr index_t thread_tile_vector_size = + S::ThreadTile_N; // In the continuous dimension, within the tile + + constexpr auto innermost_reduce_dim = ReduceDims{}.at(number{}); + constexpr bool is_innermost_contiguous = (innermost_reduce_dim == InputShape{}.size() - 1); + + constexpr index_t stride_based_vector_size = + is_innermost_contiguous + ? ck_tile::min(memory_vector_size, thread_tile_vector_size) + : 1; // Move at "vectorization" steps if continuous otherwise 1 step + + return stride_based_vector_size; + } + + static constexpr index_t CalculateOutputVectorSize() + { + using S = typename Problem::BlockShape; + constexpr index_t memory_vector_size = 16 / sizeof(YDataType); + constexpr index_t thread_tile_vector_size = S::ThreadTile_M; + constexpr index_t vector_size = ck_tile::min(memory_vector_size, thread_tile_vector_size); + + return vector_size; + } + + public: + // Overload for threadwise version (no InterblockReduceOps parameter) + // This version uses the same reduce_ops for interblock reduction + template + CK_TILE_DEVICE void operator()(const XDataType* p_x, + YDataType* p_y_tuple, + InputShape input_shape, + InputStrides input_strides, + KeptDim kept_dim, + ReduceDims reduce_dims, + index_t output_tensor_offset, + ElementwiseOps elementwise_ops, + AccumulatorOps accumulator_ops) const + { + // For single-block case, use the same reduce ops for interblock reduction + // (though they won't be used since block_group_size will be 1) + auto reduce_ops = typename Problem::ReduceOp{}; + (*this)(p_x, + p_y_tuple, + input_shape, + input_strides, + kept_dim, + reduce_dims, + output_tensor_offset, + elementwise_ops, + accumulator_ops, + reduce_ops); // Use reduce_ops as interblock_reduce_ops + } + + // Main operator overload + template + CK_TILE_DEVICE void operator()(const XDataType* p_x, + YDataType* p_y_tuple, + InputShape input_shape, + InputStrides input_strides, + KeptDim kept_dim, + ReduceDims reduce_dims, + index_t output_tensor_offset, + ElementwiseOps elementwise_ops, + AccumulatorOps accumulator_ops, + InterblockReduceOps interblock_reduce_ops) const + { + static_assert( + ElementwiseOps::size() == Problem::ReduceOp::size() && + AccumulatorOps::size() == Problem::ReduceOp::size() && + InterblockReduceOps::size() == Problem::ReduceOp::size(), + "Error: All operations tuple size must match the number of reduction operations"); + + using S = typename Problem::BlockShape; + auto reduce_ops = typename Problem::ReduceOp{}; + + const auto number_operations = reduce_ops.size(); + + static_assert(number_operations > 0, + "Error: At least one reduction operation must be specified!"); + + static_assert(kept_dim.size() + reduce_dims.size() == InputShape::size(), + "Size of kept dimensions + reduced dimensions must equal input tensor rank"); + + const auto kept_lens = [&]() { + return generate_tuple([&](auto I) { return input_shape.at(number{}); }, + number{}); + }(); + const auto reduce_lens = [&]() { + return generate_tuple( + [&](auto I) { return input_shape.at(number{}); }, + number{}); + }(); + + // Calculate total reduction length + int total_reduce_len = 1; + static_for<0, reduce_lens.size(), 1>{}( + [&](auto i) { total_reduce_len *= reduce_lens.at(i); }); + + // Early exit for empty tensors (reduce_total_length == 0) + // This can happen when any dimension in reduce_lens is 0 + if(total_reduce_len == 0) + { + return; + } + + // Determine strategy: single-block or multi-block + int block_group_size = 1; + int num_n_tile_iteration = 0; + + if constexpr(ForceMultiBlock) + { + CalculateBlockGroupParams(total_reduce_len, num_n_tile_iteration, block_group_size); + } + else + { + // Single-block strategy: one block handles entire reduction + block_group_size = 1; + num_n_tile_iteration = (total_reduce_len + S::Block_N - 1) / S::Block_N; + } + + constexpr index_t output_vector_size = CalculateOutputVectorSize(); + + const auto block_global_id = get_block_id(); // Hardware block id + + // Get tile indices + index_t block_group_id, block_local_id; + if constexpr(ForceMultiBlock) + { + const auto [tile_idx, local_idx] = + TilePartitioner{total_reduce_len, total_reduce_len}.GetOutputTileIndexMultiBlock( + block_global_id, block_group_size); + block_group_id = tile_idx; + block_local_id = local_idx; + } + else + { + block_group_id = TilePartitioner{total_reduce_len, total_reduce_len}.GetOutputTileIndex( + block_global_id); + block_local_id = 0; + } + + const auto kept_merge_transform = + make_merge_transform(kept_lens); // Dimension(s) not reduced are being flattened + const auto reduce_merge_transform = + make_merge_transform(reduce_lens); // Dimension(s) to reduce are being flattened + + const auto custom_padding_values = ck_tile::apply( + [](auto... args) { + return ck_tile::make_tuple(args.template GetIdentityValue()...); + }, + reduce_ops); // Get the identity element for each operation + + constexpr auto x_tensor_vector_size = CalculateInputVectorSize(); + + auto desc = make_naive_tensor_descriptor( + input_shape, input_strides, number{}, number<1>{}); + + __shared__ char smem[Policy::template GetSmemSize()]; + + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + index_t m_offset = S::Block_M * block_group_id; + index_t n_offset = S::Block_N * num_n_tile_iteration * block_local_id; + + static_for<0, number_operations, 1>{}([&](auto i) { + auto buffer_view = make_buffer_view( + p_x, desc.get_element_space_size(), custom_padding_values.get(number{})); + + const auto x_tensor = + tensor_view{buffer_view, desc}; + const auto transformed_x_tensor = pad_tensor_view( + transform_tensor_view(x_tensor, + make_tuple(kept_merge_transform, reduce_merge_transform), + make_tuple(kept_dim, reduce_dims), + make_tuple(sequence<0>{}, sequence<1>{})), + make_tuple(number{}, number{}), + sequence<0, 1>{}); + + auto x_window = + make_tile_window(transformed_x_tensor, + make_tuple(number{}, number{}), + {m_offset, n_offset}, + Policy::template MakeXBlockTileDistribution()); + + using ComputeDataTensorType = decltype(cast_tile(load_tile(x_window))); + + auto y_compute = block_reduce2d.template MakeYBlockTile(); + + set_tile(y_compute, + reduce_ops.get(number{}).template GetIdentityValue()); + + // Reduction loop + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + auto x = load_tile(x_window); + auto x_compute = cast_tile(x); + + tile_elementwise_inout(elementwise_ops.get(number{}), x_compute, x_compute); + block_reduce2d(x_compute, y_compute, reduce_ops.get(number{})); + + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_compute, reduce_ops.get(number{})); + block_reduce2d_cross_warp_sync( + y_compute, static_cast(smem), reduce_ops.get(number{})); + + // Determine if this thread should perform the output operation + // We want threads that handle the first elements in the N (reduction) dimension + const auto tile_dist = y_compute.get_tile_distribution(); + const auto ps_idx = get_partition_index(tile_dist); + const auto rs_idx = tile_dist.calculate_rs_index_from_ps_index(ps_idx); + + // Check if this thread is responsible for the first N-dimension element + // In the tile distribution, dimension 1 corresponds to the N dimension + const bool is_first_n_thread = (rs_idx[number<1>{}] == 0); + + if(is_first_n_thread) + { + tile_elementwise_inout(accumulator_ops.get(number{}), y_compute, y_compute); + + // Single-block vs multi-block output strategy + if constexpr(!ForceMultiBlock) + { + // Single-block case: direct store without atomics + auto y_tensor_view = make_naive_tensor_view( + p_y_tuple + (i * output_tensor_offset) + (S::Block_M * block_group_id), + make_tuple(S::Block_M), + make_tuple(1), + number{}, + number<1>{}); + + auto y_window = make_tile_window(y_tensor_view, + make_tuple(number{}), + {0}, + y_compute.get_tile_distribution()); + + auto y_output = cast_tile(y_compute); + store_tile(y_window, y_output); // Direct store, no atomics + } + else + { + // Multi-block case: use atomic operations for interblock reduction + constexpr auto mem_op = interblock_reduce_ops.get(number{}).GetAtomic(); + + auto y_tensor_view = make_naive_tensor_view( + p_y_tuple + (i * output_tensor_offset) + (S::Block_M * block_group_id), + make_tuple(S::Block_M), + make_tuple(1), + number{}, + number<1>{}); + + auto y_window = make_tile_window(y_tensor_view, + make_tuple(number{}), + {0}, + y_compute.get_tile_distribution()); + + auto y_output = cast_tile(y_compute); + update_tile(y_window, y_output); // Atomic update + } + } + }); + } + + /// @brief Validates if the given arguments are supported by the 2D multi reduction kernel. + /// + /// @param y_continous_dim Size of the continuous dimension of the output tensor. + /// Must be a multiple of ThreadTile_N for proper thread mapping. + /// + /// @param input_strides The stride configuration of the input tensor. + /// The last stride must be 1 to ensure contiguous memory access + /// and enable efficient vectorized loads. + /// + /// @return true if the arguments are supported, false otherwise. + /// Error messages are logged when CK_TILE_LOGGING is enabled. + /// + /// @note Requirements: + /// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution) + /// - input_strides[-1] == 1 (for contiguous memory access) + template + CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, + InputStrides input_strides) + { + using S = typename Problem::BlockShape; + + if(y_continous_dim % S::ThreadTile_N != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Total reduction size should be a multiple of ThreadTile_N!"); + } + return false; + } + + if(input_strides.at(number{}) != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Input tensor's last stride must be 1 to support correct vector access!"); + } + return false; + } + + return true; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/kernel/multi_reduce2d_multiblock_kernel.hpp b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_multiblock_kernel.hpp new file mode 100644 index 0000000000..9da952e0ad --- /dev/null +++ b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_multiblock_kernel.hpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "multi_reduce2d_kernel.hpp" +namespace ck_tile { +template +using MultiReduceMultiblock = MultiReduce2d; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/kernel/multi_reduce2d_threadwise_kernel.hpp b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_threadwise_kernel.hpp new file mode 100644 index 0000000000..03b4024d0b --- /dev/null +++ b/include/ck_tile/ops/reduce/kernel/multi_reduce2d_threadwise_kernel.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "multi_reduce2d_kernel.hpp" +namespace ck_tile { + +template +using MultiReduceThreadWise = MultiReduce2d; + +} // namespace ck_tile diff --git a/test/ck_tile/reduce/CMakeLists.txt b/test/ck_tile/reduce/CMakeLists.txt index 0ba5974f6c..b3a77f2f38 100644 --- a/test/ck_tile/reduce/CMakeLists.txt +++ b/test/ck_tile/reduce/CMakeLists.txt @@ -1,7 +1,11 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp) + add_gtest_executable(test_ck_tile_multi_reduce2d_threadwise test_multi_reduce2d_threadwise.cpp) + add_gtest_executable(test_ck_tile_multi_reduce2d_multiblock test_multi_reduce2d_multiblock.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_reduce2d PRIVATE utility) + target_link_libraries(test_ck_tile_multi_reduce2d_threadwise PRIVATE utility) + target_link_libraries(test_ck_tile_multi_reduce2d_multiblock PRIVATE utility) endif() endif() diff --git a/test/ck_tile/reduce/test_multi_reduce2d_common.hpp b/test/ck_tile/reduce/test_multi_reduce2d_common.hpp new file mode 100644 index 0000000000..4601b91958 --- /dev/null +++ b/test/ck_tile/reduce/test_multi_reduce2d_common.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "ck_tile/ops/elementwise.hpp" +// Overload methods required for the parametrize tests + +// Overload for PassThrough (no parameter) +inline ck_tile::element_wise::PassThrough make_elementwise_op(int32_t, + ck_tile::element_wise::PassThrough) +{ + return ck_tile::element_wise::PassThrough{}; +} + +// Overload for UnaryDivide (needs parameter) +inline ck_tile::element_wise::UnaryDivide make_elementwise_op(int32_t total_reduce_elements, + ck_tile::element_wise::UnaryDivide) +{ + return ck_tile::element_wise::UnaryDivide{total_reduce_elements}; +} + +// Overload for UnarySquare (no parameter) +inline ck_tile::element_wise::UnarySquare make_elementwise_op(int32_t, + ck_tile::element_wise::UnarySquare) +{ + return ck_tile::element_wise::UnarySquare{}; +} + +template +auto make_elementwise_ops_tuple(int32_t total_reduce_elements, ck_tile::tuple) +{ + return ck_tile::make_tuple(make_elementwise_op(total_reduce_elements, Ops{})...); +} diff --git a/test/ck_tile/reduce/test_multi_reduce2d_multiblock.cpp b/test/ck_tile/reduce/test_multi_reduce2d_multiblock.cpp new file mode 100644 index 0000000000..7997895238 --- /dev/null +++ b/test/ck_tile/reduce/test_multi_reduce2d_multiblock.cpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise.hpp" + +#include "test_multi_reduce2d_multiblock_impl.hpp" + +// Shape parameters for different test configurations +using Shape1_BlockWarps = ck_tile::sequence<4, 1>; +using Shape1_BlockTile = ck_tile::sequence<128, 128>; +using Shape1_WarpTile = ck_tile::sequence<32, 128>; +using Shape1_ThreadTile = ck_tile::sequence<8, 8>; + +// Test configurations for different data types and operations +using TestConfig_F16_Add = std::tuple, + ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, + Shape1_BlockWarps, + Shape1_BlockTile, + Shape1_WarpTile, + Shape1_ThreadTile>; + +using TestConfig_F16_Add_MeanSquare = std::tuple< + ck_tile::half_t, + float, + float, // Output and multiblock reducing buffer. Using float to avoid too many accumulation + // errors + ck_tile::tuple, // Intra block reductions + ck_tile::tuple, // Elementwise + // ops + ck_tile::tuple, // Accumulator Elementiwise ops, intra block + ck_tile::tuple, // Inter block reduction + Shape1_BlockWarps, + Shape1_BlockTile, + Shape1_WarpTile, + Shape1_ThreadTile>; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestCkTileMultiReduceMultiblock, TestTypes); + +// 2D Tests - Keep dim0, reduce dim1 +TYPED_TEST(TestCkTileMultiReduceMultiblock, Test2D_KeepDim0_ReduceDim1_64x32) +{ + this->RunTest2D_KeepDim0_ReduceDim1(64, 32); +} + +TYPED_TEST(TestCkTileMultiReduceMultiblock, Test2D_KeepDim0_ReduceDim1_1024x512) +{ + this->RunTest2D_KeepDim0_ReduceDim1(1024, 512); +} + +// 3D Tests - Keep dim0, reduce dim1,2 +TYPED_TEST(TestCkTileMultiReduceMultiblock, Test3D_KeepDim0_ReduceDim12_128x128x1) +{ + this->RunTest3D_KeepDim0_ReduceDim12(128, 128, 8); +} +// 3D Tests - Keep dim0,1, reduce dim1 +TYPED_TEST(TestCkTileMultiReduceMultiblock, Test3D_KeepDim01_ReduceDim2_512x1024x16) +{ + this->RunTest3D_KeepDim01_ReduceDim2(512, 1024, 16); +} + +// 4D Tests - Keep dim0,1, reduce dim2,3 (NCHW -> NC) +TYPED_TEST(TestCkTileMultiReduceMultiblock, Test4D_KeepDim01_ReduceDim23_32x256x16x16) +{ + this->RunTest4D_KeepDim01_ReduceDim23(32, 256, 16, 16); +} +// 4D Tests - Keep dim0,3, reduce dim1,2 (NHWC -> NC) +TYPED_TEST(TestCkTileMultiReduceMultiblock, Test4D_KeepDim03_ReduceDim12_16x32x32x128) +{ + this->RunTest4D_KeepDim03_ReduceDim12(16, 32, 32, 128); +} diff --git a/test/ck_tile/reduce/test_multi_reduce2d_multiblock_impl.hpp b/test/ck_tile/reduce/test_multi_reduce2d_multiblock_impl.hpp new file mode 100644 index 0000000000..ec3545d4f8 --- /dev/null +++ b/test/ck_tile/reduce/test_multi_reduce2d_multiblock_impl.hpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +#include "test_multi_reduce2d_common.hpp" + +template +class TestCkTileMultiReduceMultiblock : public ::testing::Test +{ + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using ComputeDataType = std::tuple_element_t<1, Tuple>; + using YDataType = std::tuple_element_t<2, Tuple>; + using ReduceOpsType = std::tuple_element_t<3, Tuple>; + using ElementwiseOpsType = std::tuple_element_t<4, Tuple>; + using AccumulatorOpsType = std::tuple_element_t<5, Tuple>; + using InterBlockReduceOpsType = std::tuple_element_t<6, Tuple>; + using BlockWarps_ = std::tuple_element_t<7, Tuple>; + using BlockTile_ = std::tuple_element_t<8, Tuple>; + using WarpTile_ = std::tuple_element_t<9, Tuple>; + using ThreadTile_ = std::tuple_element_t<10, Tuple>; + + using TestReduce2dShape = + ck_tile::Reduce2dShape; + + template + void RunGenericTest(const std::vector& input_shape, + const std::vector& input_strides, + const std::vector& output_shape, + const std::vector& output_strides, + ck_tile::index_t kept_dim_len_prod, + ck_tile::index_t total_reduce_elements, + KeptDimSeq kept_dims, + ReduceDimSeq reduce_dims) + { + static_assert( + ReduceOpsType::size() == ElementwiseOpsType::size() && + ReduceOpsType::size() == AccumulatorOpsType::size() && + ReduceOpsType::size() == InterBlockReduceOpsType::size(), + "Error: All operations tuple size must match the number of reduction operations"); + + const auto number_operations = ReduceOpsType::size(); + + ck_tile::HostTensor h_x(input_shape, input_strides); + + auto h_ys = ck_tile::generate_tuple( + [&output_shape, &output_strides](auto /*i*/) { + return ck_tile::HostTensor(output_shape, output_strides); + }, + ck_tile::number{}); + + auto h_ys_ref = ck_tile::generate_tuple( + [&output_shape, &output_strides](auto /*i*/) { + return ck_tile::HostTensor(output_shape, output_strides); + }, + ck_tile::number{}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + h_ys.template at().SetZero(); + h_ys_ref.template at().SetZero(); + }); + + auto output_number_elements = [&output_shape]() { + ck_tile::index_t prod = 1; + for(auto len : output_shape) + prod *= len; + return prod; + }(); + + auto output_buffer_size = + number_operations * h_ys.get(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_y_mem(output_buffer_size); + + std::vector h(number_operations * output_number_elements); + + // Init the output data with identity values respective to each reduce op + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + constexpr auto op = ReduceOpsType{}.at(i); + const auto identity_val = op.template GetIdentityValue(); + std::fill(h.begin() + i * output_number_elements, + h.begin() + (i + 1) * output_number_elements, + identity_val); + }); + + d_x_mem.ToDevice(h_x.data()); + d_y_mem.ToDevice(h.data()); + + using Problem = ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceMultiblock; + + // Launch configuration + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + ck_tile::index_t block_group_size; + int num_block_tile_iterations; + auto elementwise_ops = + make_elementwise_ops_tuple(total_reduce_elements, ElementwiseOpsType{}); + auto accumulator_ops = + make_elementwise_ops_tuple(total_reduce_elements, AccumulatorOpsType{}); + + Kernel::CalculateBlockGroupParams( + total_reduce_elements, num_block_tile_iterations, block_group_size); + + std::cout << "Block group size: " << block_group_size + << ", Num block tile iterations: " << num_block_tile_iterations + << ", Reduce total length: " << total_reduce_elements << std::endl; + + ck_tile::index_t kGridSize = + ((kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M) * + block_group_size; + + // Generic helper to create tuple from vector based on compile-time size + auto make_shape_tuple = [](const std::vector& vec) { + return [&vec](std::index_sequence) { + return ck_tile::make_tuple(vec[I]...); + }(std::make_index_sequence{}); + }; + + auto input_shape_tuple = make_shape_tuple.template operator()(input_shape); + auto input_strides_tuple = make_shape_tuple.template operator()(input_strides); + + if(!Kernel::IsSupportedArgument( + total_reduce_elements, + input_strides_tuple)) // output tensor's continuous dimension + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_y_mem.GetDeviceBuffer()), + input_shape_tuple, + input_strides_tuple, + kept_dims, + reduce_dims, + output_number_elements, + elementwise_ops, + accumulator_ops, + InterBlockReduceOpsType{})); + + // Reference computation + ck_tile::reference_multiple_reduce_multiblock( + h_x, + h_ys_ref, + ReduceOpsType{}, + kept_dims, + reduce_dims, + elementwise_ops, + accumulator_ops, + InterBlockReduceOpsType{}, + block_group_size); + + // Calculate proper error thresholds based on data types and number of accumulations + // const auto rtol = ck_tile::get_relative_threshold( + // total_reduce_elements); + // const auto atol = ck_tile::get_absolute_threshold( + // 5.0f, total_reduce_elements); + + // Unfortunately due to the non-sequenciality, down-casting on the output buffer + // and further operations on this buffer, the error is compounding at a faster + // rate than what the host reference can support. A large tolerance is then required + const auto rtol = 1e-2; + const auto atol = 1e-1; + + // Transfer data from device and check error for each operation + std::vector h_y_tmp(output_number_elements * number_operations); + d_y_mem.FromDevice(h_y_tmp.data()); + bool result = true; + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(h_ys.get(ck_tile::number{}).data(), + h_y_tmp.data() + i * output_number_elements, + output_number_elements * sizeof(YDataType)); + std::cout << "Checking errors for operation: " << i << std::endl; + result &= ck_tile::check_err(h_ys.get(ck_tile::number{}), + h_ys_ref.get(ck_tile::number{}), + "Error: Incorrect reduce results!", + rtol, + atol); + }); + + EXPECT_TRUE(result); + } + + // Convenience functions for specific dimensional patterns + void RunTest2D_KeepDim0_ReduceDim1(ck_tile::index_t dim0, ck_tile::index_t dim1) + { + constexpr auto kept_dims = ck_tile::sequence<0>{}; + constexpr auto reduce_dims = ck_tile::sequence<1>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1}; + std::vector input_strides = {dim1, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0}; + std::vector output_strides = {1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0; + ck_tile::index_t total_reduce_elements = dim1; + + RunGenericTest<2>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest3D_KeepDim0_ReduceDim12(ck_tile::index_t dim0, + ck_tile::index_t dim1, + ck_tile::index_t dim2) + { + constexpr auto kept_dims = ck_tile::sequence<0>{}; + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1, dim2}; + std::vector input_strides = {dim1 * dim2, dim2, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0}; + std::vector output_strides = {1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0; // product of kept dimensions + ck_tile::index_t total_reduce_elements = dim1 * dim2; // product of reduced dimensions + + RunGenericTest<3>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest3D_KeepDim01_ReduceDim2(ck_tile::index_t dim0, + ck_tile::index_t dim1, + ck_tile::index_t dim2) + { + constexpr auto kept_dims = ck_tile::sequence<0, 1>{}; + constexpr auto reduce_dims = ck_tile::sequence<2>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1, dim2}; + std::vector input_strides = {dim1 * dim2, dim2, 1}; + + // Output shape and strides (keep dim0, dim1) + std::vector output_shape = {dim0, dim1}; + std::vector output_strides = {dim1, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0 * dim1; // product of kept dimensions + ck_tile::index_t total_reduce_elements = dim2; // product of reduced dimensions + + RunGenericTest<3>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest4D_KeepDim01_ReduceDim23(ck_tile::index_t N, + ck_tile::index_t C, + ck_tile::index_t H, + ck_tile::index_t W) + { + constexpr auto kept_dims = ck_tile::sequence<0, 1>{}; + constexpr auto reduce_dims = ck_tile::sequence<2, 3>{}; + + // Input shape and strides + std::vector input_shape = {N, C, H, W}; + std::vector input_strides = {C * H * W, H * W, W, 1}; + + // Output shape and strides (keep dim0, dim1) + std::vector output_shape = {N, C}; + std::vector output_strides = {C, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = N * C; // product of kept dimensions + ck_tile::index_t total_reduce_elements = H * W; // product of reduced dimensions + + RunGenericTest<4>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest4D_KeepDim03_ReduceDim12(ck_tile::index_t N, + ck_tile::index_t H, + ck_tile::index_t W, + ck_tile::index_t C) + { + constexpr auto kept_dims = ck_tile::sequence<0, 3>{}; + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; + + // Input shape and strides + std::vector input_shape = {N, H, W, C}; + std::vector input_strides = {H * W * C, W * C, C, 1}; + + // Output shape and strides (keep dim0, dim3) + std::vector output_shape = {N, C}; + std::vector output_strides = {C, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = N * C; // product of kept dimensions + ck_tile::index_t total_reduce_elements = H * W; // product of reduced dimensions + + RunGenericTest<4>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } +}; diff --git a/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp b/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp new file mode 100644 index 0000000000..aaf40343fe --- /dev/null +++ b/test/ck_tile/reduce/test_multi_reduce2d_threadwise.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +#include "test_multi_reduce2d_threadwise_impl.hpp" + +// Shape parameters for different test configurations +using Shape1_BlockWarps = ck_tile::sequence<4, 1>; +using Shape1_BlockTile = ck_tile::sequence<128, 128>; +using Shape1_WarpTile = ck_tile::sequence<32, 128>; +using Shape1_ThreadTile = ck_tile::sequence<8, 8>; + +using Shape2_BlockWarps = ck_tile::sequence<2, 2>; // Cross-warp reduction test +using Shape2_BlockTile = ck_tile::sequence<2, 1024>; +using Shape2_WarpTile = ck_tile::sequence<1, 512>; +using Shape2_ThreadTile = ck_tile::sequence<1, 8>; + +// Test configurations for different data types and operations +using TestConfig_F16_Add = std::tuple, + ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, + Shape1_BlockWarps, + Shape1_BlockTile, + Shape1_WarpTile, + Shape1_ThreadTile>; + +using TestConfig_F16_Add_Max = std::tuple< + ck_tile::half_t, + float, + ck_tile::half_t, + ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, + ck_tile::tuple, + Shape1_BlockWarps, + Shape1_BlockTile, + Shape1_WarpTile, + Shape1_ThreadTile>; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestCkTileMultiReduceThreadwise, TestTypes); + +// 2D Tests - Keep dim0, reduce dim1 +TYPED_TEST(TestCkTileMultiReduceThreadwise, Test2D_KeepDim0_ReduceDim1_64x32) +{ + this->RunTest2D_KeepDim0_ReduceDim1(64, 32); +} + +TYPED_TEST(TestCkTileMultiReduceThreadwise, Test2D_KeepDim0_ReduceDim1_1024x512) +{ + this->RunTest2D_KeepDim0_ReduceDim1(1024, 512); +} + +// 3D Tests - Keep dim0, reduce dim1,2 +TYPED_TEST(TestCkTileMultiReduceThreadwise, Test3D_KeepDim0_ReduceDim12_128x128x1) +{ + this->RunTest3D_KeepDim0_ReduceDim12(128, 128, 8); +} +// 3D Tests - Keep dim0,1, reduce dim1 +TYPED_TEST(TestCkTileMultiReduceThreadwise, Test3D_KeepDim01_ReduceDim2_512x1024x16) +{ + this->RunTest3D_KeepDim01_ReduceDim2(512, 512, 16); +} + +// 4D Tests - Keep dim0,1, reduce dim2,3 (NCHW -> NC) +TYPED_TEST(TestCkTileMultiReduceThreadwise, Test4D_KeepDim01_ReduceDim23_32x256x16x16) +{ + this->RunTest4D_KeepDim01_ReduceDim23(32, 256, 16, 16); +} +// 4D Tests - Keep dim0,3, reduce dim1,2 (NHWC -> NC) +TYPED_TEST(TestCkTileMultiReduceThreadwise, Test4D_KeepDim03_ReduceDim12_16x32x32x128) +{ + this->RunTest4D_KeepDim03_ReduceDim12(16, 32, 32, 128); +} diff --git a/test/ck_tile/reduce/test_multi_reduce2d_threadwise_impl.hpp b/test/ck_tile/reduce/test_multi_reduce2d_threadwise_impl.hpp new file mode 100644 index 0000000000..e50e60396e --- /dev/null +++ b/test/ck_tile/reduce/test_multi_reduce2d_threadwise_impl.hpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +#include "test_multi_reduce2d_common.hpp" + +template +class TestCkTileMultiReduceThreadwise : public ::testing::Test +{ + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using ComputeDataType = std::tuple_element_t<1, Tuple>; + using YDataType = std::tuple_element_t<2, Tuple>; + using ReduceOpsType = std::tuple_element_t<3, Tuple>; + using ElementwiseOpsType = std::tuple_element_t<4, Tuple>; + using AccumulatorOpsType = std::tuple_element_t<5, Tuple>; + using InterBlockReduceOpsType = std::tuple_element_t<6, Tuple>; + using BlockWarps_ = std::tuple_element_t<7, Tuple>; + using BlockTile_ = std::tuple_element_t<8, Tuple>; + using WarpTile_ = std::tuple_element_t<9, Tuple>; + using ThreadTile_ = std::tuple_element_t<10, Tuple>; + + using TestReduce2dShape = + ck_tile::Reduce2dShape; + + template + void RunGenericTest(const std::vector& input_shape, + const std::vector& input_strides, + const std::vector& output_shape, + const std::vector& output_strides, + ck_tile::index_t kept_dim_len_prod, + ck_tile::index_t total_reduce_elements, + KeptDimSeq kept_dims, + ReduceDimSeq reduce_dims) + { + const auto number_operations = ReduceOpsType::size(); + + ck_tile::HostTensor h_x(input_shape, input_strides); + + auto h_ys = ck_tile::generate_tuple( + [&output_shape, &output_strides](auto /*i*/) { + return ck_tile::HostTensor(output_shape, output_strides); + }, + ck_tile::number{}); + + auto h_ys_ref = ck_tile::generate_tuple( + [&output_shape, &output_strides](auto /*i*/) { + return ck_tile::HostTensor(output_shape, output_strides); + }, + ck_tile::number{}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + h_ys.template at().SetZero(); + h_ys_ref.template at().SetZero(); + }); + + auto output_number_elements = [&output_shape]() { + ck_tile::index_t prod = 1; + for(auto len : output_shape) + prod *= len; + return prod; + }(); + + auto output_buffer_size = + number_operations * h_ys.get(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_y_mem(output_buffer_size); + + d_x_mem.ToDevice(h_x.data()); + + // Problem and kernel setup + using Problem = ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::MultiReduceThreadWise; + + // Launch configuration + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = + (kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M; + + // Generic helper to create tuple from vector based on compile-time size + auto make_shape_tuple = [](const std::vector& vec) { + return [&vec](std::index_sequence) { + return ck_tile::make_tuple(vec[I]...); + }(std::make_index_sequence{}); + }; + + auto input_shape_tuple = make_shape_tuple.template operator()(input_shape); + auto input_strides_tuple = make_shape_tuple.template operator()(input_strides); + + if(!Kernel::IsSupportedArgument( + total_reduce_elements, + input_strides_tuple)) // output tensor's continuous dimension + { + throw std::runtime_error("Wrong! Arguments not supported!\n"); + } + + auto elementwise_ops = + make_elementwise_ops_tuple(total_reduce_elements, ElementwiseOpsType{}); + auto accumulator_ops = + make_elementwise_ops_tuple(total_reduce_elements, AccumulatorOpsType{}); + + ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, false, 0}, + ck_tile::make_kernel(Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_y_mem.GetDeviceBuffer()), + input_shape_tuple, + input_strides_tuple, + kept_dims, + reduce_dims, + output_number_elements, + elementwise_ops, + accumulator_ops)); + + // Reference computation + ck_tile::reference_multiple_reduce(h_x, + h_ys_ref, + ReduceOpsType{}, + kept_dims, + reduce_dims, + elementwise_ops, + accumulator_ops); + + // Calculate proper error thresholds based on data types and number of accumulations + // const auto rtol = ck_tile::get_relative_threshold( + // total_reduce_elements); + // const auto atol = ck_tile::get_absolute_threshold( + // 5.0f, total_reduce_elements); + + // Unfortunately due to the non-sequenciality, down-casting on the output buffer + // and further operations on this buffer, the error is compounding at a faster + // rate than what the host reference can support. A large tolerance is then required + const auto rtol = 1e-2; + const auto atol = 1e-1; + + // Transfer data from device and check error for each operation + std::vector h_y_tmp(output_number_elements * number_operations); + d_y_mem.FromDevice(h_y_tmp.data()); + bool result = true; + ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + std::memcpy(h_ys.get(ck_tile::number{}).data(), + h_y_tmp.data() + i * output_number_elements, + output_number_elements * sizeof(YDataType)); + result &= ck_tile::check_err(h_ys.get(ck_tile::number{}), + h_ys_ref.get(ck_tile::number{}), + "Error: Incorrect reduce results!", + rtol, + atol); + }); + + EXPECT_TRUE(result); + } + + // Convenience functions for specific dimensional patterns + void RunTest2D_KeepDim0_ReduceDim1(ck_tile::index_t dim0, ck_tile::index_t dim1) + { + constexpr auto kept_dims = ck_tile::sequence<0>{}; + constexpr auto reduce_dims = ck_tile::sequence<1>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1}; + std::vector input_strides = {dim1, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0}; + std::vector output_strides = {1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0; + ck_tile::index_t total_reduce_elements = dim1; + + RunGenericTest<2>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest3D_KeepDim0_ReduceDim12(ck_tile::index_t dim0, + ck_tile::index_t dim1, + ck_tile::index_t dim2) + { + constexpr auto kept_dims = ck_tile::sequence<0>{}; + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1, dim2}; + std::vector input_strides = {dim1 * dim2, dim2, 1}; + + // Output shape and strides (keep dim0) + std::vector output_shape = {dim0}; + std::vector output_strides = {1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0; // product of kept dimensions + ck_tile::index_t total_reduce_elements = dim1 * dim2; // product of reduced dimensions + + RunGenericTest<3>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest3D_KeepDim01_ReduceDim2(ck_tile::index_t dim0, + ck_tile::index_t dim1, + ck_tile::index_t dim2) + { + constexpr auto kept_dims = ck_tile::sequence<0, 1>{}; + constexpr auto reduce_dims = ck_tile::sequence<2>{}; + + // Input shape and strides + std::vector input_shape = {dim0, dim1, dim2}; + std::vector input_strides = {dim1 * dim2, dim2, 1}; + + // Output shape and strides (keep dim0, dim1) + std::vector output_shape = {dim0, dim1}; + std::vector output_strides = {dim1, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = dim0 * dim1; // product of kept dimensions + ck_tile::index_t total_reduce_elements = dim2; // product of reduced dimensions + + RunGenericTest<3>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest4D_KeepDim01_ReduceDim23(ck_tile::index_t N, + ck_tile::index_t C, + ck_tile::index_t H, + ck_tile::index_t W) + { + constexpr auto kept_dims = ck_tile::sequence<0, 1>{}; + constexpr auto reduce_dims = ck_tile::sequence<2, 3>{}; + + // Input shape and strides + std::vector input_shape = {N, C, H, W}; + std::vector input_strides = {C * H * W, H * W, W, 1}; + + // Output shape and strides (keep dim0, dim1) + std::vector output_shape = {N, C}; + std::vector output_strides = {C, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = N * C; // product of kept dimensions + ck_tile::index_t total_reduce_elements = H * W; // product of reduced dimensions + + RunGenericTest<4>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } + + void RunTest4D_KeepDim03_ReduceDim12(ck_tile::index_t N, + ck_tile::index_t H, + ck_tile::index_t W, + ck_tile::index_t C) + { + constexpr auto kept_dims = ck_tile::sequence<0, 3>{}; + constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; + + // Input shape and strides + std::vector input_shape = {N, H, W, C}; + std::vector input_strides = {H * W * C, W * C, C, 1}; + + // Output shape and strides (keep dim0, dim3) + std::vector output_shape = {N, C}; + std::vector output_strides = {C, 1}; + + // Calculate products + ck_tile::index_t kept_dim_len_prod = N * C; // product of kept dimensions + ck_tile::index_t total_reduce_elements = H * W; // product of reduced dimensions + + RunGenericTest<4>(input_shape, + input_strides, + output_shape, + output_strides, + kept_dim_len_prod, + total_reduce_elements, + kept_dims, + reduce_dims); + } +}; diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt index db100553f3..08a98fe47c 100644 --- a/tile_engine/ops/CMakeLists.txt +++ b/tile_engine/ops/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(gemm) add_subdirectory(gemm_multi_d) -add_subdirectory(gemm_preshuffle) \ No newline at end of file +add_subdirectory(gemm_preshuffle) +add_subdirectory(reduce) \ No newline at end of file diff --git a/tile_engine/ops/reduce/CMakeLists.txt b/tile_engine/ops/reduce/CMakeLists.txt new file mode 100644 index 0000000000..60304c1e73 --- /dev/null +++ b/tile_engine/ops/reduce/CMakeLists.txt @@ -0,0 +1,123 @@ +# cmake_minimum_required(VERSION 4.2) + +# enable_testing() + +set(MULTI_REDUCE_DATATYPE "fp16" CACHE STRING "List of datatypes Multi Reduce (semicolon-separated)") +set(MULTI_REDUCE_VARIANTS "multiops_multiblock;multiops_threadwise" CACHE STRING "List of variants for Multi Reduce (semicolon-separated)") + +function(build_multi_reduce_for_datatype datatype variant) + # Filter GPU targets to only gfx942, and gfx950 + set(GPU_TARGETS "") + set(DESIRED_TARGETS "gfx942;gfx950") + set(VALID_VARIANTS "multiops_multiblock;multiops_threadwise") + + foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GPU_TARGETS ${target}) + endif() + endforeach() + + # Skip compilation if no matching targets found + if(NOT GPU_TARGETS) + message(WARNING "Skipping Tile Engine for Multi Reduction Kernel: No supported GPU targets (gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() + endif() + + message(STATUS "Building Reduction for GPU targets: ${GPU_TARGETS}") + + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${variant}") + file(MAKE_DIRECTORY "${working_path}") + + # Comment this if-else block when using user_provided_config + if(variant IN_LIST VALID_VARIANTS) + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_multi_reduce_config.json") + else() + # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json") + message(WARNING "Unknown Multi Reduce variant: ${variant}.") + return() + endif() + + # uncomment this if you want to use user_provided_config.json + # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") + + # Generate kernel list + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/reduce_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --variant ${variant} + --config_json ${json_blob} + --list_blobs + --gpu_target "${GPU_TARGETS}" + RESULT_VARIABLE ret + ) + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${variant}: ${ret}") + endif() + + file(STRINGS "${working_path}/reduce_${variant}_blobs_list.txt" codegen_blobs) + + # Generate the blobs + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/reduce_instance_builder.py + --working_path "${working_path}" + --datatype ${datatype} + --config_json "${json_blob}" + --variant "${variant}" + --gen_blobs + --gpu_target "${GPU_TARGETS}" + RESULT_VARIABLE ret + ) + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate kernels for ${datatype} ${variant}: ${ret}") + endif() + + message(STATUS "Generated ${datatype} ${variant} reduction kernel blobs at: ${working_path}") + + # # Add test executables for each generated test + file(STRINGS "${working_path}/reduce_${variant}_blobs_list.txt" test_basenames) + + foreach(test_base IN LISTS test_basenames) + string(PREPEND test_base "test_") + set(test_src "${working_path}/${test_base}.cpp") + set(test_target "${test_base}") + + add_executable(${test_target} ${test_src}) + target_include_directories(${test_target} PRIVATE + "${CMAKE_SOURCE_DIR}/test/ck_tile/reduce/" + ${working_path} + ) + + target_compile_options(${test_target} PRIVATE -Wno-global-constructors -Wno-dev) + target_link_libraries(${test_target} PRIVATE gtest gtest_main) + + add_test(NAME ${test_target} COMMAND ${test_target}) + set_tests_properties(${test_target} PROPERTIES LABELS "multi_reduce") + endforeach() + add_custom_target(test_reduce_${variant}_${datatype} DEPENDS ${codegen_blobs}) + + # # Generating a single binary from all the tests (debug-only) + # set(test_srcs) + # foreach(test_base IN LISTS test_basenames) + # list(APPEND test_srcs "${working_path}/test_${test_base}.cpp") + # endforeach() + + # if(test_srcs) + # set(test_target "test_reduce_${variant}_${datatype}") + # add_executable(${test_target} ${test_srcs}) + # target_include_directories(${test_target} PRIVATE + # ${working_path} + # "${CMAKE_SOURCE_DIR}/test/ck_tile/reduce/" + # ) + # target_compile_options(${test_target} PRIVATE -Wno-global-constructors -Wno-dev) + # target_link_libraries(${test_target} PRIVATE gtest gtest_main) + # endif() + +endfunction() + +# Process each datatype in isolation +foreach(dt IN LISTS MULTI_REDUCE_DATATYPE) + foreach(l IN LISTS MULTI_REDUCE_VARIANTS) + build_multi_reduce_for_datatype(${dt} ${l}) + endforeach() +endforeach() \ No newline at end of file diff --git a/tile_engine/ops/reduce/configs/default_multi_reduce_config.json b/tile_engine/ops/reduce/configs/default_multi_reduce_config.json new file mode 100644 index 0000000000..01d29333f4 --- /dev/null +++ b/tile_engine/ops/reduce/configs/default_multi_reduce_config.json @@ -0,0 +1,51 @@ +{ + "problem" : { + }, + + "problem_size" : { + "input_shape" : [ + [128, 64, 2], + [32, 8, 64, 16] + ] + }, + + "tile_config" : { + "fixed": [ + {"tile_m": 128, "tile_n": 128, "warp_per_block_m": 4, "warp_per_block_n": 1, "warp_tile_m": 32, "warp_tile_n": 128, "thread_tile_m": 8, "thread_tile_n": 8} + ], + "combination": { + "tile_m" : { + "values" : [ + ] + }, + "tile_n" : { + "values": [ + ] + }, + "warp_per_block_m" : { + "values" : [ + ] + }, + "warp_per_block_n" : { + "values" : [ + ] + }, + "warp_tile_m" : { + "values" : [ + ] + }, + "warp_tile_n" : { + "values" : [ + ] + }, + "thread_tile_m" : { + "values" : [ + ] + }, + "thread_tile_n" : { + "values" : [ + ] + } + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/reduce/reduce_config.py b/tile_engine/ops/reduce/reduce_config.py new file mode 100644 index 0000000000..5df2b46d12 --- /dev/null +++ b/tile_engine/ops/reduce/reduce_config.py @@ -0,0 +1,8 @@ +import json + + +class ReduceConfig: + def __init__(self, config_json_path: str): + self.config_json_path = config_json_path + with open(config_json_path, "r") as f: + self.config_dict = json.load(f) diff --git a/tile_engine/ops/reduce/reduce_instance_builder.py b/tile_engine/ops/reduce/reduce_instance_builder.py new file mode 100644 index 0000000000..5a008bcb27 --- /dev/null +++ b/tile_engine/ops/reduce/reduce_instance_builder.py @@ -0,0 +1,168 @@ +import argparse +from pathlib import Path + +from reduce_config import ReduceConfig +from reduce_parameter import get_parameter_combinations, TYPE_MAP + + +class MultiReduceBase: + def __init__(self, working_path, gpu_target, datatype, config_json=None): + self.working_path = Path(working_path) + self.gpu_target = gpu_target + self.datatype = datatype + self.output_type = self.datatype + self.config = ReduceConfig(config_json) if config_json else None + self.name = "multiops_base" + + self.signature_test = { + 3: "Test3D_KeepDim0_ReduceDim12", + 4: "Test4D_KeepDim01_ReduceDim23", + } + self.header = "test_multi_reduce2d_multiblock_impl.hpp" + self.test_type = "TestCkTileMultiReduce2D" + + def _generate_instances(self): + if not self.config: + raise ValueError("Configuration not provided.") + + instances = [] + for params in get_parameter_combinations(self.config.config_dict): + instance = self._create_instance(params) + instances.append((instance, params)) + return instances + + def _create_instance(self, parameters): + generated_test = self._get_test(parameters) + + return generated_test + + def do_list_blobs(self): + with open( + self.working_path / Path(f"reduce_{self.name}_blobs_list.txt"), "w" + ) as f: + combos_str = [ + f"{self.name}_{params}" + for params in get_parameter_combinations(self.config.config_dict) + ] + f.write("\n".join(combos_str)) + f.write("\n") + + def do_generate_blobs(self): + instances = self._generate_instances() + for instance_code, params in instances: + blob_filename = self.working_path / Path(f"test_{self.name}_{params}.cpp") + with open(blob_filename, "w") as f: + f.write(instance_code) + + def _get_test(self, params): + dimension = len(params.input_shape) + signature = self.signature_test.get(dimension, None) + + if not signature: + raise ValueError( + f"No test signature found for input shape dimension: {dimension}" + ) + + shape_str = [str(i) for i in params.input_shape] + input_shape_arg_str = ",".join(shape_str) + input_shape_str = "x".join(shape_str) + + t = f"""#include "{self.header}" + +using Shape_BlockWarps = ck_tile::sequence<{params.warp_per_block_m}, {params.warp_per_block_n}>; +using Shape_BlockTile = ck_tile::sequence<{params.tile_m}, {params.tile_n}>; +using Shape_WarpTile = ck_tile::sequence<{params.warp_m}, {params.warp_n}>; +using Shape_ThreadTile = ck_tile::sequence<{params.thread_tile_m}, {params.thread_tile_n}>; + +using TestConfig = + std::tuple<{TYPE_MAP[self.datatype]}, + float, + {TYPE_MAP[self.output_type]}, + ck_tile::tuple, // Intra block reductions + ck_tile::tuple, // Elementwise ops + ck_tile::tuple, // Accumulator Elementiwise ops, intra block + ck_tile::tuple, // Inter block reduction + Shape_BlockWarps, + Shape_BlockTile, + Shape_WarpTile, + Shape_ThreadTile>; + +// Register the type(s) for the typed test suite +typedef ::testing::Types TestTypes; +TYPED_TEST_SUITE({self.test_type}, TestTypes); + +TYPED_TEST({self.test_type}, {signature}_{input_shape_str}) +{{ + this->Run{signature}({input_shape_arg_str}); +}} +""" + + return t + + +class MultiReduceThreadwiseKernelBuilder(MultiReduceBase): + def __init__(self, working_path, gpu_target, datatype, config_json=None): + super().__init__(working_path, gpu_target, datatype, config_json) + + self.name = "multiops_threadwise" + + self.header = "test_multi_reduce2d_threadwise_impl.hpp" + self.test_type = "TestCkTileMultiReduceThreadwise" + + +class MultiReduceMultiBlockKernelBuilder(MultiReduceBase): + def __init__(self, working_path, gpu_target, datatype, config_json=None): + super().__init__(working_path, gpu_target, datatype, config_json) + + self.name = "multiops_multiblock" + + self.output_type = ( + "float" # Force float to be used as the output is also used as accumulator + ) + + self.header = "test_multi_reduce2d_multiblock_impl.hpp" + self.test_type = "TestCkTileMultiReduceMultiblock" + + +def main(args): + variants = { + "multiops_threadwise": {"class": MultiReduceThreadwiseKernelBuilder}, + "multiops_multiblock": {"class": MultiReduceMultiBlockKernelBuilder}, + } + if not (args.list_blobs or args.gen_blobs): + raise ValueError("Please provide a list or generate blobs.") + + builder = variants.get(args.variant) + builder_instance = builder["class"]( + working_path=args.working_path, + gpu_target=args.gpu_target, + datatype=args.datatype, + config_json=args.config_json, + ) + + if args.list_blobs: + builder_instance.do_list_blobs() + if args.gen_blobs: + builder_instance.do_generate_blobs() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Reduce Instance Builder") + + parser.add_argument( + "--working_path", type=str, required=True, help="Working directory path" + ) + parser.add_argument("--datatype", type=str, required=True, help="Data type") + parser.add_argument( + "--variant", type=str, required=True, help="Variant: multiblock or threadwise" + ) + parser.add_argument( + "--config_json", type=str, required=True, help="Path to config JSON blob" + ) + parser.add_argument("--list_blobs", action="store_true", help="List blobs") + parser.add_argument("--gen_blobs", action="store_true", help="Generate blobs") + parser.add_argument("--gpu_target", type=str, required=True, help="GPU target") + + args = parser.parse_args() + + main(args) diff --git a/tile_engine/ops/reduce/reduce_parameter.py b/tile_engine/ops/reduce/reduce_parameter.py new file mode 100644 index 0000000000..05e799034b --- /dev/null +++ b/tile_engine/ops/reduce/reduce_parameter.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from itertools import product + +from pyparsing import List + +TYPE_MAP = {"fp16": "ck_tile::half_t", "float": "float"} + + +@dataclass +class ParametersBlockwise: + tile_m: int + tile_n: int + warp_per_block_m: int + warp_per_block_n: int + warp_m: int + warp_n: int + thread_tile_m: int + thread_tile_n: int + input_shape: List[int] + + def __str__(self): + tile_size = "x".join(str(i) for i in [self.tile_m, self.tile_n]) + warp_per_block = "x".join( + str(i) for i in [self.warp_per_block_m, self.warp_per_block_n] + ) + warp_size = "x".join(str(i) for i in [self.warp_m, self.warp_n]) + thread_tile_size = "x".join( + str(i) for i in [self.thread_tile_m, self.thread_tile_n] + ) + input_shape = "x".join(str(i) for i in self.input_shape) + + return "_".join( + [tile_size, warp_per_block, warp_size, thread_tile_size, input_shape] + ) + + +def get_parameter_combinations( + config_dict: dict, +) -> List[ParametersBlockwise]: + input_shape_configs = config_dict["problem_size"]["input_shape"] + + fixed_configs = config_dict["tile_config"].get("fixed", None) + + seen_config = set() + + if fixed_configs is not None: + for fixed in fixed_configs: + tile_m_values = fixed["tile_m"] + tile_n_values = fixed["tile_n"] + warp_per_block_m_values = fixed["warp_per_block_m"] + warp_per_block_n_values = fixed["warp_per_block_n"] + warp_m_values = fixed["warp_tile_m"] + warp_n_values = fixed["warp_tile_n"] + thread_tile_m_values = fixed["thread_tile_m"] + thread_tile_n_values = fixed["thread_tile_n"] + for combo in product( + [tile_m_values], + [tile_n_values], + [warp_per_block_m_values], + [warp_per_block_n_values], + [warp_m_values], + [warp_n_values], + [thread_tile_m_values], + [thread_tile_n_values], + input_shape_configs, + ): + p = ParametersBlockwise(*combo) + if is_valid_combination(p): + hashable_combo = (tuple(combo[-1]),) + combo[0:-1] + seen_config.add(hashable_combo) + yield p + + combo_config = config_dict["tile_config"].get("combination", None) + if combo_config is None: + tile_m_values = combo_config["tile_m"]["values"] + tile_n_values = combo_config["tile_n"]["values"] + warp_per_block_m_values = combo_config["warp_per_block_m"]["values"] + warp_per_block_n_values = combo_config["warp_per_block_n"]["values"] + warp_m_values = combo_config["warp_tile_m"]["values"] + warp_n_values = combo_config["warp_tile_n"]["values"] + thread_tile_m_values = combo_config["thread_tile_m"]["values"] + thread_tile_n_values = combo_config["tile_config"]["thread_tile_n"]["values"] + + for combo in product( + tile_m_values, + tile_n_values, + warp_per_block_m_values, + warp_per_block_n_values, + warp_m_values, + warp_n_values, + thread_tile_m_values, + thread_tile_n_values, + input_shape_configs, + ): + if combo: + p = ParametersBlockwise(*combo) + hashable_combo = (tuple(combo[-1]),) + combo[0:-1] + if is_valid_combination(p) and hashable_combo not in seen_config: + yield p + + +def is_valid_combination(p: ParametersBlockwise) -> bool: + # Thread tile must be at least 1 + if p.thread_tile_m < 1 or p.thread_tile_n < 1: + return False + + # Alignment check + if p.tile_m % (p.warp_per_block_m * p.warp_m) != 0: + return False + if p.tile_n % (p.warp_per_block_n * p.warp_n) != 0: + return False + + # Reduction dimension size must be divisible by tile size + if len(p.input_shape) == 4 and ( + p.input_shape[2] * p.input_shape[3] % p.thread_tile_n != 0 + ): + return False + + if len(p.input_shape) == 3 and ( + p.input_shape[1] * p.input_shape[2] % p.thread_tile_n != 0 + ): + return False + + return True