Skip to content

Conversation

@damien-lejeune
Copy link
Contributor

Proposed changes

Implementation of multi reduce ops, both in a threadwise or multiblock (aka blockwise) fashion. It migrate most of the feature already present in old CK with a few exception:

  • The alpha and beta terms have not been added. The examples in the old CK are using with values 1 and 0, respectively. As this feature is very easy to implement I suggest we wait the need to have them brought in.

Also some notable limitations are to be noted:

  • The multi block version makes use of atomic operation when performing the inter-block reduction/update. While it works for the atomic add for a reasonable collection of types and thread tile size, the other atomic operations (such as MAX) is pretty limited (e.g. no fp16, tile size of 1). While these limitations could be improved on, I suggest it to be part of another PR, if necessary
  • Unit testing. While some tests and examples are present it is worth noting that the GPU and CPU reference outputs experience a discrepancy making it difficult to match the two. The nature of the reduction (stochastic order, atomic operation behavior, down casting) make small error to accumulate on the GPU side compared to the deterministic, sequential execution of the CPU reference. Testing large reduction size makes it very apparent. Unfortunately I haven't been able to mitigate much these issues: a generous error tolerance (absolute tolerance of 0.1, relative tolerance of 0.01) has been used, along with small to moderate input sizes.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

@aosewski aosewski requested a review from Copilot November 4, 2025 09:50
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for multi-reduction kernels to the CK_TILE library, enabling multiple reduction operations to be performed simultaneously on tensors. The implementation includes both threadwise and multiblock reduction variants, with supporting infrastructure for code generation, testing, and examples.

  • Implements MultiReduceThreadWise and MultiReduceMultiblock kernels for GPU reduction operations
  • Adds a CMake-based code generation system that creates test instances from JSON configurations
  • Provides comprehensive test coverage with both threadwise and multiblock implementations

Reviewed Changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tile_engine/ops/reduce/reduce_instance_builder.py Code generator for test instances using configuration-driven approach
tile_engine/ops/reduce/reduce_config.py Configuration loader for reduction kernel parameters
tile_engine/ops/reduce/CMakeLists.txt Build system integration with Python code generation
tile_engine/ops/CMakeLists.txt Added reduce subdirectory to build
test/ck_tile/reduce/test_multi_reduce2d_* Test infrastructure for both threadwise and multiblock kernels
include/ck_tile/ops/reduce/kernel/multi_reduce2d_* Core kernel implementations
include/ck_tile/host/reference/reference_reduce.hpp Reference implementations for validation
include/ck_tile/ops/reduce.hpp Updated API surface with new kernel includes
example/ck_tile/05_reduce/multiple_reduce_*.cpp Example applications demonstrating usage

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a bit more offline discussion about the overall design.

Comment on lines -168 to -169
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What problems did you have with that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the multi-reduce case we load the input tile then immediately cast the tile to the compute type, because we do an elementwise operation on it right after. Then the XDistributedTensor is not of XDataType anymore.

Comment on lines +449 to +460
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
{
y = ck_tile::type_convert<raw_t<Y>>(x);
}
/* otherwise (r-value or const) → do nothing */
}

template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = ck_tile::type_convert<raw_t<Y>>(x);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing in the context of above overload for universal references.... Is the above version correct at all ??? @ThruptiRajLakshmanaGowda @ThomasNing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy you make a comment on this :-) I added this overload to support this base/simple case that "feels" like it should be supported: take in only l-value for Y as input. The main issues experiences here is an interface mismatch where the other elementwise operations handles the base case naturally but not the passthrough. I think a good compromise could be to keep this overload and to use it in the universal reference one in place of y = ck_tile::type_convert<raw_t<Y>>(x);. What do you think?

}

CK_TILE_HOST_DEVICE static void CalculateBlockGroupParams(const int reduce_total_length,
[[maybe_unused]] int K_BlockTileSize,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If unused why you add it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apologies, it should have been removed prior to creating the PR. This has now been fixed. Thanks for noticing this

Comment on lines 56 to 59
if(num_block_tile_iterations == 0)
{
num_block_tile_iterations = 1;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The (num_block_tile_iterations would equal to zero only if reduce_total_length is zero. Then you could set the block_group_size to 0 and return - it would be equal to zero anyway from below expression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been handled with a warning in the examples and an early exit in the kernel.

Comment on lines 69 to 71
static constexpr index_t CalculateInputVectorSize()
{
using S = typename Problem::BlockShape;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm right now wondering whether we want to expose through Problem and Shape the ThreadTile_N which might be used as vector size.... Maybe we could deduce the ThreadTile_N - that is number of elements per thread in N dim based on other parameters.


if(pass_op)
{
std::cout << "" << std::endl;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some more information ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<128, 128>;
using WarpTile = ck_tile::sequence<32, 128>;
using Vector = ck_tile::sequence<8, 8>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ThreadTile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines 230 to 242
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);

// Apply the elementwise operation before the reduction
auto x_compute = cast_tile<ComputeDataType>(x);

tile_elementwise_inout(elementwise_ops.get(number<i>{}), x_compute, x_compute);

block_reduce2d(x_compute, y_compute, reduce_ops.get(number<i>{}));

move_tile_window(x_window, {0, S::Block_N});
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to run multiple reductions at the same time? In current version I see you're loading data multiple times.
I assume that for multiple reductions running in parallel you might need a multiple of all resources. Maybe we could have some heuristic with max number of reductions, after which we fallback to sequential execution (as is right now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. That would makes sense to parallelize the reduction too, indeed. However I think it would add a bit of complexity. From what I see we would need threads to be specialize in handling a primary specific operation. I guess for this part we would need to identify the threads id and the thread tile (i.e. data position) and apply some modulo on the number of operations (or max number operation to run in parallel), then running sequentially on the operations if needed. I think this could work, in principle: a bit of arithmetic to have the threads figure out which operations to run. However I think it could take a bit of time to test it (and get it right). But I think it's a nice idea!

Comment on lines 260 to 265
// 3. Atomically operation between the register tile and DRAM
auto atomic_ops =
interblock_reduce_ops.get(number<i>{})
.template GetAtomic<YDataType, y_thread_buf.N>(); // TODO: check if we
// need YDataType
atomic_ops(p_y_tile, y_thread_buf);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use update_tile API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I missed that API. It actually made things much simpler. Fixed in 814ab44

return false;
}

if(input_strides.at(number<input_strides.size() - 1>{}) != 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why's that? You don't have to do vectorized reads only on rightmost (innermost) dimension .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we create the input tensor descriptor we guarantee the last dimensions to have stride one:

auto desc = make_naive_tensor_descriptor(
            input_shape, input_strides, number<x_tensor_vector_size>{}, number<1>{});

Do you mean, we should not offer this guarantee or do you mean we should relax it to something else?

@damien-lejeune damien-lejeune marked this pull request as ready for review November 20, 2025 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants