diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 76674ba..7a2cd4a 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -368,7 +368,10 @@ __global__ void __launch_bounds__(512, 1) const size_t chanOffset = nPeer * blockIdx.x; // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); - const size_t nInt4PerRank = nInt4 / worldSize; + size_t nInt4PerRank = nInt4 / worldSize; + if (nInt4 % worldSize) + nInt4PerRank = nInt4PerRank + 1; + auto smChans = smChannels + chanOffset; auto smOutChans = smOutChannels + chanOffset;