diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index fac105a..9ef93ce 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -71,17 +71,29 @@ __forceinline__ __device__ __bfloat162 clip(__bfloat162 val) { template __forceinline__ __device__ T add_elements(T a, T b) { - return clip(a + b); + #ifdef MSCCLPP_CLIP_ENABLED + return clip(a + b); + #else + return a + b; + #endif } template <> __forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) { - return clip(__hadd2(a, b)); + #ifdef MSCCLPP_CLIP_ENABLED + return clip(__hadd2(a, b)); + #else + return __hadd2(a, b); + #endif } template <> __forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) { - return clip(__hadd2(a, b)); + #ifdef MSCCLPP_CLIP_ENABLED + return clip(__hadd2(a, b)); + #else + return __hadd2(a, b); + #endif } template @@ -558,7 +570,7 @@ __global__ void __launch_bounds__(512, 1) template -__global__ void __launch_bounds__(512, 1) +__global__ void __launch_bounds__(1024, 1) allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, mscclpp::DeviceHandle* smOutChannels, size_t channelOutDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) { @@ -1045,6 +1057,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< allreduce8Mod<<>>(buff, scratch, resultBuff, smScrChannels, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems); } else { + nThreadsPerBlock = std::is_same::value ? 1024 : nThreadsPerBlock; allreduce8<<>>(buff, scratch, resultBuff, smScrChannels, smOutChannels, channelOutOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems);