Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e6a6c71
WIP
Oct 13, 2025
5828e0b
Add Unit tests for the Multi Reduction Kernel
Oct 14, 2025
9cc3192
clang format
Oct 14, 2025
22d602e
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Oct 14, 2025
352f8e9
Rename multiblock to threadwise
Oct 15, 2025
0e44893
Multiblock WIP
Oct 22, 2025
4eaceb5
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Oct 22, 2025
c752be3
Fix multi reduce multi block unit tests
Oct 24, 2025
9a59853
Multi Reduce Tile Engine: WIP
Oct 24, 2025
3ac37c2
refactoring + try addressing precision error
Nov 2, 2025
31dec00
Fix multiops examples
Nov 3, 2025
bfa229b
Cleanup
Nov 3, 2025
3a05ab6
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Nov 3, 2025
444b7b8
Clean up tile engine's reduce op
Nov 3, 2025
839498c
Update changelog
Nov 3, 2025
3a23e5a
Fix remod/clang
Nov 4, 2025
1a55a95
Fix dates
Nov 4, 2025
2603089
Fix documentation & missing file
Nov 4, 2025
ab0d475
Fix comments
Nov 4, 2025
816e82b
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
aosewski Nov 6, 2025
814ab44
Use the update_tile api in the multi-block kernel
Nov 18, 2025
d8467b2
Unify threadwise/multiblock into a single kernel + default multiblock…
Nov 20, 2025
54a24ec
Add TileParitioner
Nov 20, 2025
2b4b305
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Nov 20, 2025
3e63723
Cleanup
Nov 20, 2025
b7aede3
Add warning when no data to process, in the example
Nov 20, 2025
24df254
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
aosewski Nov 26, 2025
55cfaa6
Refactoring Reduce kernel Tile Partioner + cleanup
Dec 1, 2025
ad470fe
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Dec 1, 2025
fb332e5
Merge remote-tracking branch 'origin/dlejeune/ck_tile_2d_multiple_red…
Dec 1, 2025
defbaa4
Move the tile partioner to its own file
Dec 8, 2025
6877153
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Dec 8, 2025
25cf5ae
Add missing includes
Dec 8, 2025
ed80a47
Merge branch 'develop' into dlejeune/ck_tile_2d_multiple_reductions
Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added pooling kernel in CK_TILE
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.
* Added reduce and multi reduction kernels

### Changed

Expand Down
18 changes: 17 additions & 1 deletion example/ck_tile/05_reduce/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,24 @@ list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo

target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})

# Multi Reduce Threadwise Example
set(EXAMPLE_MULTI_REDUCE "tile_example_multi_reduce_threadwise")
add_executable(${EXAMPLE_MULTI_REDUCE} EXCLUDE_FROM_ALL multiple_reduce_threadwise.cpp)
target_include_directories(${EXAMPLE_MULTI_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_MULTI_REDUCE} PRIVATE ${EXAMPLE_MULTI_REDUCE_COMPILE_OPTIONS})

# Multi Reduce Blockwise Example
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE "tile_example_multi_reduce_multiblock")
add_executable(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} EXCLUDE_FROM_ALL multiple_reduce_multiblock.cpp)
target_include_directories(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set(EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS)
list(APPEND EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_MULTI_REDUCE_BLOCKWISE} PRIVATE ${EXAMPLE_MULTI_REDUCE_BLOCKWISE_COMPILE_OPTIONS})

# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
251 changes: 251 additions & 0 deletions example/ck_tile/05_reduce/multiple_reduce_multiblock.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>

template <typename T>
struct DataTypeTraits;

template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};

template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("n", "32", "n dimension")
.insert("h", "19", "h dimension")
.insert("w", "7", "w dimension")
.insert("c", "512", "c dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "multi_reduce_multiblock.json", "json file name to dump results");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
using ComputeDataType = float;
using YDataType = DataType;

ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t H = arg_parser.get_int("h");
ck_tile::index_t W = arg_parser.get_int("w");
ck_tile::index_t C = arg_parser.get_int("c");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");

std::vector<ck_tile::index_t> problem_shape = {N, H, W, C};
std::vector<ck_tile::index_t> strides(4);
strides[0] = H * W * C;
strides[1] = W * C;
strides[2] = C;
strides[3] = 1;

// Define reduction specification:
constexpr auto kept_dim = ck_tile::sequence<0, 3>{}; // Which dimension to keep
constexpr auto reduce_dims = ck_tile::sequence<1, 2>{}; // Which dimensions to reduce

ck_tile::HostTensor<XDataType> x_host(problem_shape, strides);
ck_tile::HostTensor<YDataType> y_host_add_ref({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_max_ref({N, C}, {C, 1});
auto y_host_ref_tuple = ck_tile::make_tuple(y_host_add_ref, y_host_max_ref);

ck_tile::HostTensor<YDataType> y_host_add_dev({N, C}, {C, 1});
ck_tile::HostTensor<YDataType> y_host_max_dev({N, C}, {C, 1});
auto y_host_dev_tuple = ck_tile::make_tuple(y_host_add_dev, y_host_max_dev);

const auto number_operations = y_host_dev_tuple.size();

std::vector<YDataType> h(number_operations * N * C);

auto y_buf_size = number_operations *
y_host_dev_tuple.at(ck_tile::number<0>{}).get_element_space_size_in_bytes();
ck_tile::DeviceMem y_buf(y_buf_size);

const auto output_tensor_offset = N * C;

const ck_tile::index_t reduce_total_length = H * W;

// Operations: one doing a sum reduction, the other computing the mean square
// In the case of mean square:
// 1. The element wise operation squares each element before reduction
// 2. The reduction operation sum the squared element
// 3. The accumulator element wise operation divides the result by the total number of reduced
// elements (intra block operation)
// 4. The partial result is updated across blocks using inter block reduction, a sum.
auto reduce_ops =
ck_tile::make_tuple(ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // reductions
auto elementwise_ops = ck_tile::make_tuple(ck_tile::element_wise::PassThrough{},
ck_tile::element_wise::UnarySquare{}); // Elementwise
// ops
auto accumulator_elementwise_ops = ck_tile::make_tuple(
ck_tile::element_wise::PassThrough{},
ck_tile::element_wise::UnaryDivide{
reduce_total_length}); // Accumulator Elementwise ops on reduction, intra block
auto inter_block_reduce_ops = ck_tile::make_tuple(
ck_tile::ReduceOp::Add{}, ck_tile::ReduceOp::Add{}); // Inter block reduction

ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host);

ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());

x_buf.ToDevice(x_host.data());

using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<128, 128>;
using WarpTile = ck_tile::sequence<32, 128>;
using Vector = ck_tile::sequence<8, 8>;

constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kept_dim_len_prod = N * C;

using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, Vector>;
using Problem = ck_tile::
Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, decltype(reduce_ops)>;

using Kernel = ck_tile::MultiReduceMultiblock<Problem>;

// Determine block group size for multi-block reduction
// block_group_size records how many blocks participate to a reduction (input data dependent)
// , for efficiency reasons this size if limited to a maximum of 128. If this is not sufficient
// to process the whole reduction, each thread will to process multiple thread tile
// a num_block_tile_iterations times
int K_BlockTileSize = BlockTile::at(1);
int num_block_tile_iterations;
int block_group_size;

Kernel::CalculateBlockGroupParams(
reduce_total_length, K_BlockTileSize, num_block_tile_iterations, block_group_size);

const ck_tile::index_t kBlockSize = Kernel::BlockSize();
ck_tile::index_t kGridSize =
((kept_dim_len_prod + Shape::Block_M - 1) / Shape::Block_M) * block_group_size;

std::cout << "Block group size: " << block_group_size
<< ", Num block tile iterations: " << num_block_tile_iterations
<< ", Reduce total length: " << reduce_total_length << std::endl;
std::cout << "grid size " << kGridSize << ", block size " << kBlockSize << std::endl;

// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
auto input_strides = ck_tile::make_tuple(strides[0], strides[1], strides[2], strides[3]);

if(!Kernel::IsSupportedArgument(
C, input_strides)) // output tensor's continuous dimension and input strides
{
throw std::runtime_error("Wrong! Arguments not supported!\n");
}

// Init the output data with identity values respective to each reduce op
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
constexpr auto op = reduce_ops.at(i);
const auto identity_val = op.template GetIdentityValue<YDataType>();
const auto output_number_elements = N * C;
std::fill(h.begin() + i * output_number_elements,
h.begin() + (i + 1) * output_number_elements,
identity_val);
});

auto clear_output_buffer = [&]() { y_buf.ToDevice(h.data()); };

float ave_time = launch_kernel_time_mask(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
clear_output_buffer,
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims,
output_tensor_offset,
elementwise_ops,
accumulator_elementwise_ops,
inter_block_reduce_ops)

);

std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;

float gb_per_sec = num_btype / 1.E6 / ave_time;

std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;

bool pass = true;

if(do_validation)
{
// reference
ck_tile::reference_multiple_reduce_multiblock<XDataType, ComputeDataType, YDataType>(
x_host,
y_host_ref_tuple,
reduce_ops,
kept_dim,
reduce_dims,
elementwise_ops,
accumulator_elementwise_ops,
inter_block_reduce_ops,
block_group_size);
std::cout << "Read " << y_buf_size / 10 << " Bytes from the device" << std::endl;

// Transfer data from device and check error for each operation
y_buf.FromDevice(h.data());
ck_tile::static_for<0, number_operations, 1>{}([&](auto i) {
std::memcpy(y_host_dev_tuple.get(ck_tile::number<i>{}).data(),
h.data() + i * output_tensor_offset,
output_tensor_offset * sizeof(YDataType));
std::cout << "Checking operation " << i << ": " << std::endl;

bool pass_op = ck_tile::check_err(y_host_dev_tuple.get(ck_tile::number<i>{}),
y_host_ref_tuple.get(ck_tile::number<i>{}));

if(pass_op)
{
std::cout << "✅" << std::endl;
}
pass &= pass_op;
});

std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}

return pass;
}

int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

const std::string data_type = arg_parser.get_str("prec");

if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
}
Loading