Files
rocm-systems/projects/rccl/ext-src/remove-clip.patch
T

Αγνόηση των αναθεωρήσεων στο .git-blame-ignore-revs. Πατήστε εδώ για να το παρακάμψετε και να δείτε την κανονική προβολή ευθυνών.

55 γραμμές
2.0 KiB
Diff

diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index fac105a..9ef93ce 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -71,17 +71,29 @@ __forceinline__ __device__ __bfloat162 clip(__bfloat162 val) {
template <typename T>
__forceinline__ __device__ T add_elements(T a, T b) {
- return clip(a + b);
+ #ifdef MSCCLPP_CLIP_ENABLED
+ return clip(a + b);
+ #else
+ return a + b;
+ #endif
}
template <>
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
- return clip(__hadd2(a, b));
+ #ifdef MSCCLPP_CLIP_ENABLED
+ return clip(__hadd2(a, b));
+ #else
+ return __hadd2(a, b);
+ #endif
}
template <>
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
- return clip(__hadd2(a, b));
+ #ifdef MSCCLPP_CLIP_ENABLED
+ return clip(__hadd2(a, b));
+ #else
+ return __hadd2(a, b);
+ #endif
}
template <typename T>
@@ -558,7 +570,7 @@ __global__ void __launch_bounds__(512, 1)
template <typename T>
-__global__ void __launch_bounds__(512, 1)
+__global__ void __launch_bounds__(1024, 1)
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
@@ -1045,6 +1057,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
allreduce8Mod<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems);
} else {
+ nThreadsPerBlock = std::is_same<T,__bfloat16>::value ? 1024 : nThreadsPerBlock;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);