de76d7f649
* mscclpp patch apply clip patch and set allreduce8 blocks from 512 to 1024
* add compilation flag for enabling/disabling clipping in mscclpp
* change flag name for consistency, set flag to OFF
* add compilation flag in rccl for enabling clipping in mscclpp
* set 1024 threads for mscclpp allreduce8 only for bfloat16
* fix improper description for ENABLE_MSCCLPP_CLIP flag
* Revert "Merge branch 'clip-patch' of https://github.com/isaki001/rccl into clip-patch"
This reverts commit 6e31857a9db98314b8a748eb024f2c3699ebe2d5, reversing
changes made to 193f4caa8ffa78b4e056893212fd8344aa14e937.
* update clip remove-clip.patch for rebase
[ROCm/rccl commit: 8145c4f3b8]
55 linhas
2.0 KiB
Diff
55 linhas
2.0 KiB
Diff
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
|
|
index fac105a..9ef93ce 100644
|
|
--- a/apps/nccl/src/allreduce.hpp
|
|
+++ b/apps/nccl/src/allreduce.hpp
|
|
@@ -71,17 +71,29 @@ __forceinline__ __device__ __bfloat162 clip(__bfloat162 val) {
|
|
|
|
template <typename T>
|
|
__forceinline__ __device__ T add_elements(T a, T b) {
|
|
- return clip(a + b);
|
|
+ #ifdef MSCCLPP_CLIP_ENABLED
|
|
+ return clip(a + b);
|
|
+ #else
|
|
+ return a + b;
|
|
+ #endif
|
|
}
|
|
|
|
template <>
|
|
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
|
- return clip(__hadd2(a, b));
|
|
+ #ifdef MSCCLPP_CLIP_ENABLED
|
|
+ return clip(__hadd2(a, b));
|
|
+ #else
|
|
+ return __hadd2(a, b);
|
|
+ #endif
|
|
}
|
|
|
|
template <>
|
|
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
|
|
- return clip(__hadd2(a, b));
|
|
+ #ifdef MSCCLPP_CLIP_ENABLED
|
|
+ return clip(__hadd2(a, b));
|
|
+ #else
|
|
+ return __hadd2(a, b);
|
|
+ #endif
|
|
}
|
|
|
|
template <typename T>
|
|
@@ -558,7 +570,7 @@ __global__ void __launch_bounds__(512, 1)
|
|
|
|
|
|
template <typename T>
|
|
-__global__ void __launch_bounds__(512, 1)
|
|
+__global__ void __launch_bounds__(1024, 1)
|
|
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
|
|
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
|
|
size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
|
@@ -1045,6 +1057,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
|
|
allreduce8Mod<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
|
|
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems);
|
|
} else {
|
|
+ nThreadsPerBlock = std::is_same<T,__bfloat16>::value ? 1024 : nThreadsPerBlock;
|
|
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
|
|
smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
|
worldSize, nelems);
|