Files
isaki001 de76d7f649 Add Compilation Flag for enabling/disabling clipping, and tune number of blocks for mscclpp allreduce8 (#1607)
* mscclpp patch apply clip patch and set allreduce8 blocks from 512 to 1024

* add compilation flag for enabling/disabling clipping in mscclpp

* change flag name for consistency, set flag to OFF

* add compilation flag in rccl for enabling clipping in mscclpp

* set 1024 threads for mscclpp allreduce8 only for bfloat16

* fix improper description for ENABLE_MSCCLPP_CLIP flag

* Revert "Merge branch 'clip-patch' of https://github.com/isaki001/rccl into clip-patch"

This reverts commit 6e31857a9db98314b8a748eb024f2c3699ebe2d5, reversing
changes made to 193f4caa8ffa78b4e056893212fd8344aa14e937.

* update clip remove-clip.patch for rebase

[ROCm/rccl commit: 8145c4f3b8]
2025-04-30 16:42:28 -05:00

55 行
2.0 KiB
Diff

diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index fac105a..9ef93ce 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -71,17 +71,29 @@ __forceinline__ __device__ __bfloat162 clip(__bfloat162 val) {
template <typename T>
__forceinline__ __device__ T add_elements(T a, T b) {
- return clip(a + b);
+ #ifdef MSCCLPP_CLIP_ENABLED
+ return clip(a + b);
+ #else
+ return a + b;
+ #endif
}
template <>
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
- return clip(__hadd2(a, b));
+ #ifdef MSCCLPP_CLIP_ENABLED
+ return clip(__hadd2(a, b));
+ #else
+ return __hadd2(a, b);
+ #endif
}
template <>
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
- return clip(__hadd2(a, b));
+ #ifdef MSCCLPP_CLIP_ENABLED
+ return clip(__hadd2(a, b));
+ #else
+ return __hadd2(a, b);
+ #endif
}
template <typename T>
@@ -558,7 +570,7 @@ __global__ void __launch_bounds__(512, 1)
template <typename T>
-__global__ void __launch_bounds__(512, 1)
+__global__ void __launch_bounds__(1024, 1)
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
@@ -1045,6 +1057,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
allreduce8Mod<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems);
} else {
+ nThreadsPerBlock = std::is_same<T,__bfloat16>::value ? 1024 : nThreadsPerBlock;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);