Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b8a888d
Add FP8 support for allreduce
seagater Oct 8, 2025
d94fca5
Add fp8 for execution kernel
seagater Oct 8, 2025
e18bdc1
Add missing float16 type in mscclpp-test
seagater Oct 8, 2025
ac8247d
Update the count calculation for different data types in mscclpp-test
seagater Oct 8, 2025
40317c9
Fix issue in mpirun with fp8 for allreduce
seagater Oct 9, 2025
359b320
Add missing add_vectors and add_verctors_helper for fp8; Fix issue of…
seagater Oct 10, 2025
6b611b6
Use compatible types __fp8_e4m4 and __fp8_e5m2 to replace __nv_fp8_*
seagater Oct 15, 2025
aa15aab
Use __hadd, __hadd2, __hmin and __hmin2 operations for FP8
seagater Oct 15, 2025
7083845
Add type __fp8x4_e4m3 and __fp8x4_e5m2; use vectorized operator for 4…
seagater Oct 15, 2025
45a2d86
Update vectorized operators fp8 in execution_kernel
seagater Oct 17, 2025
6c760d6
Update operations for fp8 on HIP platform
seagater Oct 18, 2025
183fd11
Add missing UseClip = true; Move the definition of add_elements for _…
seagater Oct 18, 2025
fa1b8e7
Revert the changes for fp8 support in mscclpp-test
seagater Oct 20, 2025
6e10921
Remove duplicate definition
seagater Oct 20, 2025
eff0e3d
Add helper function add_fp8x4_hip to handle different scalar type __f…
seagater Oct 21, 2025
2f8bb59
Remove the macro __HIP_FP8_TYPES_EXIST__
seagater Oct 22, 2025
cad9434
Replace __CUDA_FP8_TYPES_EXIST__ with __FP8_TYPES_EXIST__
seagater Oct 22, 2025
91557c2
Throw warning for AllReduce with FP8 for data > 64K; Solve clang-form…
seagater Oct 22, 2025
e389df9
Merge branch 'main' into qinghuazhou/allreduce_fp8
Binyang2014 Oct 23, 2025
cdbdb90
update
Binyang2014 Oct 24, 2025
cde8ac8
fix example
Binyang2014 Oct 24, 2025
461311a
Merge branch 'main' into qinghuazhou/allreduce_fp8
Binyang2014 Oct 24, 2025
7fac820
add comment
Binyang2014 Oct 24, 2025
6e7ab36
Replace float2_t with float2
seagater Oct 24, 2025
d9acf0f
Update the syntax of value retrival for float2; skip the check for __…
seagater Oct 24, 2025
5248fc5
Update clip function for fp8_e5m2 to prevent infinites
seagater Oct 24, 2025
3a18ff2
Update clang-format
seagater Oct 24, 2025
2075915
Use HIP_VERSION_MAJOR to check the FP8 support
seagater Oct 25, 2025
8bb5dd3
Remove macros in nccl.h; Update data type names ncclFloat8e4m3 and nc…
seagater Oct 25, 2025
729635a
Merge branch 'main' into qinghuazhou/allreduce_fp8
seagater Oct 27, 2025
237c0b0
Merge branch 'main' into qinghuazhou/allreduce_fp8
seagater Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apps/nccl/include/mscclpp/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ typedef enum {
ncclFloat = 7,
ncclFloat64 = 8,
ncclDouble = 8,
#if defined(__CUDA_BF16_TYPES_EXIST__) && defined(__CUDA_FP8_TYPES_EXIST__)
#if defined(__CUDA_BF16_TYPES_EXIST__) && defined(__FP8_TYPES_EXIST__)
ncclBfloat16 = 9,
ncclFp8E4M3 = 10,
ncclFp8E5M2 = 11,
Expand Down Expand Up @@ -282,11 +282,11 @@ static inline size_t ncclTypeSize(ncclDataType_t type) {
case ncclBfloat16:
return 2;
#endif // defined(__CUDA_BF16_TYPES_EXIST__)
#if defined(__CUDA_FP8_TYPES_EXIST__)
#if defined(__FP8_TYPES_EXIST__)
case ncclFp8E4M3:
case ncclFp8E5M2:
return 1;
#endif // defined(__CUDA_FP8_TYPES_EXIST__)
#endif // defined(__FP8_TYPES_EXIST__)
case ncclNumTypes:
return 0;
}
Expand Down
68 changes: 47 additions & 21 deletions apps/nccl/src/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,20 @@ struct NvlsAdapter {
mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t, int rank, int nRanksPerNode, int, size_t nelems,
cudaStream_t stream, uint32_t*, uint32_t*, uint32_t*, uint32_t) {
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
int nBlocks = nRanksPerNode;
int nThreadsPerBlock = 1024;
allreduce9<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>((ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels,
channelInOffset, channelOutOffset, nelems * sizeof(T), rank,
nRanksPerNode);
return cudaGetLastError();
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
return cudaErrorNotSupported;
} else
#endif
{
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
int nBlocks = nRanksPerNode;
int nThreadsPerBlock = 1024;
allreduce9<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>((ChannelType*)memoryChannels, nvlsChannels,
nvlsOutChannels, channelInOffset, channelOutOffset,
nelems * sizeof(T), rank, nRanksPerNode);
return cudaGetLastError();
}
}
};

Expand All @@ -88,21 +95,28 @@ struct NvlsWithCopyAdapter {
mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t scratchBufferSize,
int rank, int nRanksPerNode, int, size_t nelems, cudaStream_t stream, uint32_t*, uint32_t*,
uint32_t*, uint32_t) {
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
if (sizeof(T) * nelems < (1 << 24)) {
int nBlocks = nRanksPerNode * 4;
int nThreadsPerBlock = 1024;
allreduce10<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
rank, nRanksPerNode);
} else {
int nBlocks = nRanksPerNode * 5;
int nThreadsPerBlock = 1024;
allreduce11<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
rank, nRanksPerNode);
#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS
if constexpr (std::is_same_v<T, __fp8_e4m3> || std::is_same_v<T, __fp8_e5m2>) {
return cudaErrorNotSupported;
} else
#endif
{
using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>;
if (sizeof(T) * nelems < (1 << 24)) {
int nBlocks = nRanksPerNode * 4;
int nThreadsPerBlock = 1024;
allreduce10<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
rank, nRanksPerNode);
} else {
int nBlocks = nRanksPerNode * 5;
int nThreadsPerBlock = 1024;
allreduce11<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(input, scratch, output, (ChannelType*)memoryChannels,
nvlsChannels, nelems * sizeof(T), scratchBufferSize,
rank, nRanksPerNode);
}
return cudaGetLastError();
}
return cudaGetLastError();
}
};

Expand Down Expand Up @@ -154,6 +168,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
#if defined(__CUDA_BF16_TYPES_EXIST__)
} else if (dtype == ncclBfloat16) {
return Adapter<SUM, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == ncclFp8E4M3) {
return Adapter<SUM, __fp8_e4m3>::call;
} else if (dtype == ncclFp8E5M2) {
return Adapter<SUM, __fp8_e5m2>::call;
#endif
} else if (dtype == ncclInt32 || dtype == ncclUint32) {
return Adapter<SUM, int>::call;
Expand All @@ -168,6 +188,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) {
#if defined(__CUDA_BF16_TYPES_EXIST__)
} else if (dtype == ncclBfloat16) {
return Adapter<MIN, __bfloat16>::call;
#endif
#if defined(__FP8_TYPES_EXIST__)
} else if (dtype == ncclFp8E4M3) {
return Adapter<MIN, __fp8_e4m3>::call;
} else if (dtype == ncclFp8E5M2) {
return Adapter<MIN, __fp8_e5m2>::call;
#endif
} else if (dtype == ncclInt32 || dtype == ncclUint32) {
return Adapter<MIN, int>::call;
Expand Down
Loading
Loading