55 γραμμές
2.0 KiB
Diff
55 γραμμές
2.0 KiB
Diff
|
|
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
|
||
|
|
index fac105a..9ef93ce 100644
|
||
|
|
--- a/apps/nccl/src/allreduce.hpp
|
||
|
|
+++ b/apps/nccl/src/allreduce.hpp
|
||
|
|
@@ -71,17 +71,29 @@ __forceinline__ __device__ __bfloat162 clip(__bfloat162 val) {
|
||
|
|
|
||
|
|
template <typename T>
|
||
|
|
__forceinline__ __device__ T add_elements(T a, T b) {
|
||
|
|
- return clip(a + b);
|
||
|
|
+ #ifdef MSCCLPP_CLIP_ENABLED
|
||
|
|
+ return clip(a + b);
|
||
|
|
+ #else
|
||
|
|
+ return a + b;
|
||
|
|
+ #endif
|
||
|
|
}
|
||
|
|
|
||
|
|
template <>
|
||
|
|
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
||
|
|
- return clip(__hadd2(a, b));
|
||
|
|
+ #ifdef MSCCLPP_CLIP_ENABLED
|
||
|
|
+ return clip(__hadd2(a, b));
|
||
|
|
+ #else
|
||
|
|
+ return __hadd2(a, b);
|
||
|
|
+ #endif
|
||
|
|
}
|
||
|
|
|
||
|
|
template <>
|
||
|
|
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
|
||
|
|
- return clip(__hadd2(a, b));
|
||
|
|
+ #ifdef MSCCLPP_CLIP_ENABLED
|
||
|
|
+ return clip(__hadd2(a, b));
|
||
|
|
+ #else
|
||
|
|
+ return __hadd2(a, b);
|
||
|
|
+ #endif
|
||
|
|
}
|
||
|
|
|
||
|
|
template <typename T>
|
||
|
|
@@ -558,7 +570,7 @@ __global__ void __launch_bounds__(512, 1)
|
||
|
|
|
||
|
|
|
||
|
|
template <typename T>
|
||
|
|
-__global__ void __launch_bounds__(512, 1)
|
||
|
|
+__global__ void __launch_bounds__(1024, 1)
|
||
|
|
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
|
||
|
|
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
|
||
|
|
size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
||
|
|
@@ -1045,6 +1057,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
|
||
|
|
allreduce8Mod<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
|
||
|
|
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems);
|
||
|
|
} else {
|
||
|
|
+ nThreadsPerBlock = std::is_same<T,__bfloat16>::value ? 1024 : nThreadsPerBlock;
|
||
|
|
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels,
|
||
|
|
smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
|
||
|
|
worldSize, nelems);
|