Files
rocm-systems/projects/rccl/ext-src/device-flag.patch
T
Nusrat Islam 691e98940c Fix MSCCLPP accuracy issue for allreduce7 (#1634)
* ext-src: fix a graph-mode bug in allreduce7

* change MSCCLPP threshold to 16MB

* ext-src: change message size threshold for allreduce7

* ext-src: address review comments

[ROCm/rccl commit: f20c33effd]
2025-04-18 08:54:32 -05:00

200 rivejä
9.5 KiB
Diff

diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index 9f46ff9..fac105a 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -199,11 +199,17 @@ template <typename T>
__global__ void __launch_bounds__(32, 1)
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
- size_t nelems, uint32_t flag) {
+ size_t nelems, uint32_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
const int nPeers = nRanksPerNode - 1;
+
+ uint32_t flag = *deviceFlag;
+
+ size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0;
+ channelScratchOffset = scratchBaseOffset;
+
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
@@ -237,13 +243,20 @@ __global__ void __launch_bounds__(32, 1)
data = add_vectors<T>(data, src[idx]);
dst[idx] = data;
}
+ __syncthreads();
+
+ deviceSyncer->sync(gridDim.x);
+
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
+ *deviceFlag = *deviceFlag + 1;
+ }
}
template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
- size_t nelems, uint32_t flag
+ size_t nelems, uint32_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer
#if defined(ENABLE_NPKIT)
,
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
@@ -290,6 +303,12 @@ __global__ void __launch_bounds__(1024, 1)
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);
const int nPktsPerRank = nelemsPerRank / 2;
+
+ uint32_t flag = *deviceFlag;
+
+ size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0;
+ channelScratchOffset = scratchBaseOffset;
+
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
@@ -339,6 +358,8 @@ __global__ void __launch_bounds__(1024, 1)
channels[index].write(offset, packet);
}
}
+ __syncthreads();
+
// step 3: get data result from scratch buffer
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
@@ -348,6 +369,7 @@ __global__ void __launch_bounds__(1024, 1)
result[idx].x = data.x;
result[idx].y = data.y;
}
+ __syncthreads();
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
@@ -358,6 +380,11 @@ __global__ void __launch_bounds__(1024, 1)
#if defined(ENABLE_NPKIT)
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
#endif
+ deviceSyncer->sync(gridDim.x);
+
+ if (blockIdx.x == 0 && threadIdx.x == 0) {
+ *deviceFlag = *deviceFlag + 1;
+ }
}
template <typename T>
@@ -977,7 +1004,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
- size_t nelems, cudaStream_t stream) {
+ size_t nelems, cudaStream_t stream, uint32_t* deviceFlag, mscclpp::DeviceSyncer* syncer) {
static uint32_t flag = 1;
nRanksPerNode = (worldSize < nRanksPerNode) ? worldSize : nRanksPerNode;
@@ -986,8 +1013,8 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
int nBlocks = nRanksPerNode - 1;
int nThreadsPerBlock = 32;
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels,
- channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, flag++);
- } else if (sizeof(T) * nelems <= (1 << 20)) {
+ channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag, syncer);
+ } else if (sizeof(T) * nelems <= (1 << 18)) {
int nBlocks = 4*(nRanksPerNode - 1);
int nThreadsPerBlock = 1024;
if (nelems >= 8192) {
@@ -998,11 +1025,11 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
allreduce7<<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
buff, scratch, resultBuff, smChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize,
- nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
+ nelems, deviceFlag, syncer, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems,
- flag++);
+ deviceFlag, syncer);
#endif
} else {
int nBlocks = 8 * (nRanksPerNode - 1);
diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu
index b8d04d4..ac33e1c 100644
--- a/apps/nccl/src/nccl.cu
+++ b/apps/nccl/src/nccl.cu
@@ -93,6 +93,8 @@ struct ncclComm {
uint32_t numScratchBuff;
uint32_t buffFlag;
+ uint32_t* deviceFlag;
+ mscclpp::DeviceSyncer *syncer;
};
struct handleInfo {
@@ -228,7 +230,7 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
size_t bytes = count * ncclTypeSize(datatype);
// Creating the channels
- if (count * ncclTypeSize(datatype) <= (1 << 20)) {
+ if (count * ncclTypeSize(datatype) <= (1 << 18)) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
@@ -286,24 +288,28 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels,
- smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
+ smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
+ comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag, comm->syncer));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch,
comm->comm->bootstrap()->getRank(),
- NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag,
+ comm->syncer));
break;
case ncclBfloat16:
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank,
- NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag,
+ comm->syncer));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
- NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
+ NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, comm->deviceFlag,
+ comm->syncer));
break;
default:
WARN("datatype is invalid");
@@ -409,6 +415,13 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
commPtr->scratchBuff = mscclpp::GpuBuffer(SCRATCH_SIZE).memory();
commPtr->remoteScratchRegMemories =
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
+
+ cudaMalloc((void**)&(commPtr->syncer), sizeof(mscclpp::DeviceSyncer));
+ cudaMemset((void*)(commPtr->syncer), 0, sizeof(mscclpp::DeviceSyncer));
+
+ uint32_t initFlag = 1;
+ cudaMalloc((void**)&(commPtr->deviceFlag), sizeof(uint32_t));
+ cudaMemcpy((void*)(commPtr->deviceFlag), &initFlag, sizeof(uint32_t), cudaMemcpyHostToDevice);
}
NCCL_API ncclResult_t ncclGetVersion(int* version) {
@@ -506,6 +519,8 @@ NCCL_API ncclResult_t ncclCommDestroy(ncclComm_t comm) {
NpKit::Shutdown();
}
#endif
+ cudaFree(comm->deviceFlag);
+ cudaFree(comm->syncer);
delete comm;
return ncclSuccess;
}