diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 9f46ff9..fac105a 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -199,11 +199,17 @@ template __global__ void __launch_bounds__(32, 1) allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, uint32_t flag) { + size_t nelems, uint32_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); const int nPeers = nRanksPerNode - 1; + + uint32_t flag = *deviceFlag; + + size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0; + channelScratchOffset = scratchBaseOffset; + const int nBlocksPerPeer = gridDim.x / nPeers; const int localBlockIdx = blockIdx.x % nBlocksPerPeer; const int tid = threadIdx.x + localBlockIdx * blockDim.x; @@ -237,13 +243,20 @@ __global__ void __launch_bounds__(32, 1) data = add_vectors(data, src[idx]); dst[idx] = data; } + __syncthreads(); + + deviceSyncer->sync(gridDim.x); + + if (blockIdx.x == 0 && threadIdx.x == 0) { + *deviceFlag = *deviceFlag + 1; + } } template __global__ void __launch_bounds__(1024, 1) allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, uint32_t flag + size_t nelems, uint32_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer #if defined(ENABLE_NPKIT) , NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) { @@ -290,6 +303,12 @@ __global__ void __launch_bounds__(1024, 1) if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T); const int nPktsPerRank = nelemsPerRank / 2; + + uint32_t flag = *deviceFlag; + + size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0; + channelScratchOffset = scratchBaseOffset; + // thread block & channel info const int nBlocksPerPeer = gridDim.x / nPeers; const int localBlockIdx = blockIdx.x % nBlocksPerPeer; @@ -339,6 +358,8 @@ __global__ void __launch_bounds__(1024, 1) channels[index].write(offset, packet); } } + __syncthreads(); + // step 3: get data result from scratch buffer mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset); const int dstOffset = remoteRank * nPktsPerRank; @@ -348,6 +369,7 @@ __global__ void __launch_bounds__(1024, 1) result[idx].x = data.x; result[idx].y = data.y; } + __syncthreads(); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \ defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT) NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer, @@ -358,6 +380,11 @@ __global__ void __launch_bounds__(1024, 1) #if defined(ENABLE_NPKIT) NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head); #endif + deviceSyncer->sync(gridDim.x); + + if (blockIdx.x == 0 && threadIdx.x == 0) { + *deviceFlag = *deviceFlag + 1; + } } template @@ -977,7 +1004,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< mscclpp::DeviceHandle* smScrChannels, mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, cudaStream_t stream) { + size_t nelems, cudaStream_t stream, uint32_t* deviceFlag, mscclpp::DeviceSyncer* syncer) { static uint32_t flag = 1; nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode; @@ -986,8 +1013,8 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< int nBlocks = nRanksPerNode - 1; int nThreadsPerBlock = 32; allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, - channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++); - } else if (sizeof(T) * nelems <= (1 << 20)) { + channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag, syncer); + } else if (sizeof(T) * nelems <= (1 << 18)) { int nBlocks = 4*(nRanksPerNode - 1); int nThreadsPerBlock = 1024; if (nelems >= 8192) { @@ -998,11 +1025,11 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent); allreduce7<<>>( buff, scratch, resultBuff, smChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, - nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); + nelems, deviceFlag, syncer, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else allreduce7<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, - flag++); + deviceFlag, syncer); #endif } else { int nBlocks = 8 * (nRanksPerNode - 1); diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index b8d04d4..ac33e1c 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -93,6 +93,8 @@ struct ncclComm { uint32_t numScratchBuff; uint32_t buffFlag; + uint32_t* deviceFlag; + mscclpp::DeviceSyncer *syncer; }; struct handleInfo { @@ -228,7 +230,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t bytes = count * ncclTypeSize(datatype); // Creating the channels - if (count * ncclTypeSize(datatype) <= (1 << 20)) { + if (count * ncclTypeSize(datatype) <= (1 << 18)) { auto sendIt = comm->channelScratchInfos.find(sendKey); if (sendIt == comm->channelScratchInfos.end()) { std::vector channels = @@ -286,24 +288,28 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, switch (datatype) { case ncclFloat16: CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, - smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, + comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, comm->syncer)); break; case ncclFloat32: CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), - NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, + comm->syncer)); break; case ncclBfloat16: CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff, smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, - NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, + comm->syncer)); break; case ncclInt32: case ncclUint32: CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), - NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, + comm->syncer)); break; default: WARN("datatype is invalid"); @@ -409,6 +415,13 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt commPtr->scratchBuff = mscclpp::GpuBuffer(SCRATCH_SIZE).memory(); commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); + + cudaMalloc((void**)&(commPtr->syncer), sizeof(mscclpp::DeviceSyncer)); + cudaMemset((void*)(commPtr->syncer), 0, sizeof(mscclpp::DeviceSyncer)); + + uint32_t initFlag = 1; + cudaMalloc((void**)&(commPtr->deviceFlag), sizeof(uint32_t)); + cudaMemcpy((void*)(commPtr->deviceFlag), &initFlag, sizeof(uint32_t), cudaMemcpyHostToDevice); } NCCL_API ncclResult_t ncclGetVersion(int* version) { @@ -506,6 +519,8 @@ NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) { NpKit::Shutdown(); } #endif + cudaFree(comm->deviceFlag); + cudaFree(comm->syncer); delete comm; return ncclSuccess; }