diff --git a/cmake/MSCCLPP.cmake b/cmake/MSCCLPP.cmake index c23ce59c80..4ddd62422c 100644 --- a/cmake/MSCCLPP.cmake +++ b/cmake/MSCCLPP.cmake @@ -53,47 +53,51 @@ if(ENABLE_MSCCLPP) ) endif() - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/non-multiple-128-fix.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) - - execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch - WORKING_DIRECTORY ${MSCCLPP_SOURCE} - ) + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/cpx.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) execute_process( - COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/no-cache.patch + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/read-allred.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mscclpp_ibv_access_relaxed_ordering.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/mem-reg.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/non-multiple-128-fix.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/bf16-tuning.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/no-cache.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + + execute_process( + COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/device-flag.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) - set(CMAKE_INHERITED_ARGS "") set(CMAKE_ARGS_LIST "CMAKE_PREFIX_PATH;CMAKE_INSTALL_RPATH_USE_LINK_PATH;HIP_COMPILER") foreach(arg IN LISTS CMAKE_ARGS_LIST) @@ -138,10 +142,15 @@ if(ENABLE_MSCCLPP) find_package(mscclpp_nccl REQUIRED) execute_process( - COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/no-cache.patch + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/device-flag.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} ) + execute_process( + COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/no-cache.patch + WORKING_DIRECTORY ${MSCCLPP_SOURCE} + ) + execute_process( COMMAND git apply --reverse ${CMAKE_CURRENT_SOURCE_DIR}/ext-src/reg-fix.patch WORKING_DIRECTORY ${MSCCLPP_SOURCE} diff --git a/ext-src/device-flag.patch b/ext-src/device-flag.patch new file mode 100644 index 0000000000..4d20c0006c --- /dev/null +++ b/ext-src/device-flag.patch @@ -0,0 +1,199 @@ +diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp +index 9f46ff9..fac105a 100644 +--- a/apps/nccl/src/allreduce.hpp ++++ b/apps/nccl/src/allreduce.hpp +@@ -199,11 +199,17 @@ template + __global__ void __launch_bounds__(32, 1) + allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, +- size_t nelems, uint32_t flag) { ++ size_t nelems, uint32_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer) { + // This version of allreduce only works for single nodes + if (worldSize != nRanksPerNode) return; + if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); + const int nPeers = nRanksPerNode - 1; ++ ++ uint32_t flag = *deviceFlag; ++ ++ size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0; ++ channelScratchOffset = scratchBaseOffset; ++ + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; + const int tid = threadIdx.x + localBlockIdx * blockDim.x; +@@ -237,13 +243,20 @@ __global__ void __launch_bounds__(32, 1) + data = add_vectors(data, src[idx]); + dst[idx] = data; + } ++ __syncthreads(); ++ ++ deviceSyncer->sync(gridDim.x); ++ ++ if (blockIdx.x == 0 && threadIdx.x == 0) { ++ *deviceFlag = *deviceFlag + 1; ++ } + } + + template + __global__ void __launch_bounds__(1024, 1) + allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* smChannels, + size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, +- size_t nelems, uint32_t flag ++ size_t nelems, uint32_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer + #if defined(ENABLE_NPKIT) + , + NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) { +@@ -290,6 +303,12 @@ __global__ void __launch_bounds__(1024, 1) + if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T); + + const int nPktsPerRank = nelemsPerRank / 2; ++ ++ uint32_t flag = *deviceFlag; ++ ++ size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0; ++ channelScratchOffset = scratchBaseOffset; ++ + // thread block & channel info + const int nBlocksPerPeer = gridDim.x / nPeers; + const int localBlockIdx = blockIdx.x % nBlocksPerPeer; +@@ -339,6 +358,8 @@ __global__ void __launch_bounds__(1024, 1) + channels[index].write(offset, packet); + } + } ++ __syncthreads(); ++ + // step 3: get data result from scratch buffer + mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset); + const int dstOffset = remoteRank * nPktsPerRank; +@@ -348,6 +369,7 @@ __global__ void __launch_bounds__(1024, 1) + result[idx].x = data.x; + result[idx].y = data.y; + } ++ __syncthreads(); + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \ + defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT) + NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer, +@@ -358,6 +380,11 @@ __global__ void __launch_bounds__(1024, 1) + #if defined(ENABLE_NPKIT) + NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head); + #endif ++ deviceSyncer->sync(gridDim.x); ++ ++ if (blockIdx.x == 0 && threadIdx.x == 0) { ++ *deviceFlag = *deviceFlag + 1; ++ } + } + + template +@@ -977,7 +1004,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + mscclpp::DeviceHandle* smScrChannels, + mscclpp::DeviceHandle* smOutChannels, size_t channelInOffset, + size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, +- size_t nelems, cudaStream_t stream) { ++ size_t nelems, cudaStream_t stream, uint32_t* deviceFlag, mscclpp::DeviceSyncer* syncer) { + static uint32_t flag = 1; + + nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode; +@@ -986,8 +1013,8 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + int nBlocks = nRanksPerNode - 1; + int nThreadsPerBlock = 32; + allreduceAllToAll<<>>(buff, scratch, resultBuff, smChannels, +- channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++); +- } else if (sizeof(T) * nelems <= (1 << 20)) { ++ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag, syncer); ++ } else if (sizeof(T) * nelems <= (1 << 18)) { + int nBlocks = 4*(nRanksPerNode - 1); + int nThreadsPerBlock = 1024; + if (nelems >= 8192) { +@@ -998,11 +1025,11 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< + size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent); + allreduce7<<>>( + buff, scratch, resultBuff, smChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, +- nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); ++ nelems, deviceFlag, syncer, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); + #else + allreduce7<<>>(buff, scratch, resultBuff, smChannels, channelInOffset, + channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, +- flag++); ++ deviceFlag, syncer); + #endif + } else { + int nBlocks = 8 * (nRanksPerNode - 1); +diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu +index b8d04d4..ac33e1c 100644 +--- a/apps/nccl/src/nccl.cu ++++ b/apps/nccl/src/nccl.cu +@@ -93,6 +93,8 @@ struct ncclComm { + + uint32_t numScratchBuff; + uint32_t buffFlag; ++ uint32_t* deviceFlag; ++ mscclpp::DeviceSyncer *syncer; + }; + + struct handleInfo { +@@ -228,7 +230,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + size_t bytes = count * ncclTypeSize(datatype); + + // Creating the channels +- if (count * ncclTypeSize(datatype) <= (1 << 20)) { ++ if (count * ncclTypeSize(datatype) <= (1 << 18)) { + auto sendIt = comm->channelScratchInfos.find(sendKey); + if (sendIt == comm->channelScratchInfos.end()) { + std::vector channels = +@@ -286,24 +288,28 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, + switch (datatype) { + case ncclFloat16: + CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, +- smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); ++ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, ++ comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, comm->syncer)); + break; + case ncclFloat32: + CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, + smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, + comm->comm->bootstrap()->getRank(), +- NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); ++ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, ++ comm->syncer)); + break; + case ncclBfloat16: + CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff, + smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, +- NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); ++ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, ++ comm->syncer)); + break; + case ncclInt32: + case ncclUint32: + CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, + smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), +- NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); ++ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, ++ comm->syncer)); + break; + default: + WARN("datatype is invalid"); +@@ -409,6 +415,13 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt + commPtr->scratchBuff = mscclpp::GpuBuffer(SCRATCH_SIZE).memory(); + commPtr->remoteScratchRegMemories = + setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); ++ ++ cudaMalloc((void**)&(commPtr->syncer), sizeof(mscclpp::DeviceSyncer)); ++ cudaMemset((void*)(commPtr->syncer), 0, sizeof(mscclpp::DeviceSyncer)); ++ ++ uint32_t initFlag = 1; ++ cudaMalloc((void**)&(commPtr->deviceFlag), sizeof(uint32_t)); ++ cudaMemcpy((void*)(commPtr->deviceFlag), &initFlag, sizeof(uint32_t), cudaMemcpyHostToDevice); + } + + NCCL_API ncclResult_t ncclGetVersion(int* version) { +@@ -506,6 +519,8 @@ NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) { + NpKit::Shutdown(); + } + #endif ++ cudaFree(comm->deviceFlag); ++ cudaFree(comm->syncer); + delete comm; + return ncclSuccess; + } diff --git a/src/init.cc b/src/init.cc index 45259a01db..839fbb517d 100644 --- a/src/init.cc +++ b/src/init.cc @@ -114,7 +114,7 @@ bool operator ==(const ncclUniqueId& a, const ncclUniqueId& b) { return memcmp(a.internal, b.internal, NCCL_UNIQUE_ID_BYTES) == 0; } -RCCL_PARAM(MscclppThreshold, "MSCCLPP_THRESHOLD", (size_t)(1024*1024)); +RCCL_PARAM(MscclppThreshold, "MSCCLPP_THRESHOLD", (size_t)(16*1024*1024)); #endif static constexpr int64_t defaultEnableMscclpp = 0;