Comhaid
rocm-systems/src/misc/msccl/msccl_setup.cc
T

311 línte
11 KiB
C++

/*************************************************************************
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
************************************************************************/
#include "checks.h"
#include "collectives.h"
#include "proxy.h"
#include "transport.h"
#include "msccl/msccl_lifecycle.h"
#include "msccl/msccl_kernel.h"
#include "msccl/msccl_setup.h"
#include "msccl/msccl_status.h"
#ifndef HIP_EVENT_DISABLE_FENCE
RCCL_PARAM(MscclEnableDoneEvent, "MSCCL_ENABLE_DONE_EVENT", 1);
#endif
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
mscclStatus& status = mscclGetStatus();
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1;
status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1;
status.chunkSize = status.stepSize * status.chunkSteps;
status.chunkEffectiveSize = status.chunkSize;
if (hostAlgo->protocol == NCCL_PROTO_LL) status.chunkEffectiveSize /= 2;
if (hostAlgo->protocol == NCCL_PROTO_LL128) status.chunkEffectiveSize = (status.chunkSize / NCCL_LL128_LINEELEMS) * NCCL_LL128_DATAELEMS;
status.dataType = dataType;
status.nBytes = count * ncclTypeSize(status.dataType) * hostAlgo->sizeMultiplier;
status.maxAllowedCount = std::max((uint32_t)1, (uint32_t)(status.chunkEffectiveSize / DIVUP(status.nBytes, (size_t)(hostAlgo->nChunksPerLoop))));
if (status.maxAllowedCount == 0){
WARN("MSCCL: something went wrong. Max allowed count is 0\n");
return ncclInternalError;
}
if (status.maxAllowedCount >= MSCCL_MAX_COUNT) {
status.maxAllowedCount = MSCCL_MAX_COUNT - 1;
}
return ncclSuccess;
}
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
size_t sizeNeeded = (status.nBytes * (size_t)(hostAlgo->nScratchChunks)) / (size_t)(hostAlgo->nChunksPerLoop);
if (sizeNeeded > status.scratchBufferSize){
CUDACHECK(hipStreamSynchronize(stream));
CUDACHECK(hipFree(status.scratchBuffer));
NCCLCHECK(ncclCudaCalloc((char**)&status.scratchBuffer, sizeNeeded));
status.scratchBufferSize = sizeNeeded;
}
return ncclSuccess;
}
ncclResult_t mscclSetupSyncFlags(hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
if (status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) {
CUDACHECK(hipMemsetAsync(status.syncFlags, 0, sizeof(struct mscclFlag) * MSCCL_MAX_NUM_THREAD_BLOCKS, stream));
status.workIndex = 1; // setting the workIndex back to 1 for next iterations
}
return ncclSuccess;
}
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
// Check whether there is enough channels
if (hostAlgo->nChannels > comm->nChannels) {
WARN("MSCCL: number of channels available (%d) less than required (%d)", comm->nChannels, hostAlgo->nChannels);
return ncclInvalidUsage;
}
// Flag MSCCL connections
for (int i = 0; i < hostAlgo->nChannels; i++) {
struct mscclChannelInfo* mCh = hostAlgo->mscclChannels + i;
int sendPeers[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
for (int p = 0; p < mCh->nSendPeers; p++) {
sendPeers[p] = mCh->sendPeerInfo[p].peer;
}
int recvPeers[MSCCL_MAX_NUM_THREAD_BLOCKS_PER_CHANNEL];
for (int p = 0; p < mCh->nRecvPeers; p++) {
recvPeers[p] = mCh->recvPeerInfo[p].peer;
}
NCCLCHECK(ncclTransportP2pConnect(comm, i, mCh->nRecvPeers, recvPeers, mCh->nSendPeers, sendPeers, 0 /*connIndex*/));
}
// Connect MSCCL connections
mscclSetIsCallerFlag();
NCCLCHECK(ncclTransportP2pSetup(comm, NULL, 0));
mscclClearIsCallerFlag();
INFO(NCCL_INIT, "MSCCL: Setup connections finished, used %ld", allocTracker[comm->cudaDev].totalAllocSize);
return ncclSuccess;
}
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
mscclStatus& status = mscclGetStatus();
struct ncclProxyOp proxyOp = {};
proxyOp.connIndex = 0;
proxyOp.sliceSteps = status.sliceSteps;
proxyOp.chunkSteps = status.chunkSteps;
proxyOp.chunkSize = status.chunkSize;
proxyOp.protocol = hostAlgo->protocol;
proxyOp.dtype = status.dataType;
proxyOp.redOp = 0;
proxyOp.pattern = 0;
proxyOp.root = 0;
proxyOp.nbytes = status.stepSize*proxyOp.sliceSteps;
proxyOp.opCount = comm->sharedRes->collOpCount;
int nLoops = (int)(DIVUP(status.nBytes, (size_t)((size_t)hostAlgo->nChunksPerLoop*(size_t)status.chunkEffectiveSize)));
int nLoopsChunkSteps = nLoops * status.chunkSteps;
for (int ch = 0; ch < hostAlgo->nChannels; ch++) {
proxyOp.channelId = ch;
struct mscclChannelInfo* mscclChannel = hostAlgo->mscclChannels + ch;
struct ncclChannel* ncclChannel = comm->channels + ch;
for (int i = 0; i < mscclChannel->nRecvPeers; i++){
struct mscclChannelPeerInfo* recvPeer = mscclChannel->recvPeerInfo + i;
int nRecvs = 0;
for (int j = 0; j < recvPeer->nExistingCounts; j++){
int c = recvPeer->existingCounts[j];
int nStepsInCount = DIVUP(c+1, status.maxAllowedCount);
nRecvs += recvPeer->nTransmissionsOfCount[c] * nStepsInCount;
}
proxyOp.nsteps = nLoopsChunkSteps * nRecvs;
if (proxyOp.nsteps > 0) {
NCCLCHECK(mscclSaveProxy(comm, ncclChannel, proxyRecv, recvPeer->peer, &proxyOp, 0));
}
}
for (int i=0; i<mscclChannel->nSendPeers; i++){
struct mscclChannelPeerInfo* sendPeer = &mscclChannel->sendPeerInfo[i];
int nSends = 0;
for (int j = 0; j < sendPeer->nExistingCounts; j++){
int c = sendPeer->existingCounts[j];
int nStepsInCount = DIVUP(c+1, status.maxAllowedCount);
nSends += sendPeer->nTransmissionsOfCount[c] * nStepsInCount;
}
proxyOp.nsteps = nLoopsChunkSteps * nSends;
if (proxyOp.nsteps > 0) {
NCCLCHECK(mscclSaveProxy(comm, ncclChannel, proxySend, sendPeer->peer, &proxyOp, 0));
}
}
}
NCCLCHECK(ncclProxyStart(comm));
comm->sharedRes->collOpCount++;
return ncclSuccess;
}
static ncclResult_t hostToDevRedOp(
ncclDevRedOpFull *opFull, ncclRedOp_t op, ncclDataType_t datatype, ncclComm *comm
) {
union {
int8_t i8;
uint8_t u8;
int32_t i32;
uint32_t u32;
int64_t i64;
uint64_t u64;
half f16;
#if defined(RCCL_BFLOAT16)
rccl_bfloat16 bf16;
#endif
float f32;
double f64;
void *ptr;
};
u64 = 0;
opFull->scalarArgIsPtr = false;
switch (int(op)) {
case ncclSum: opFull->op = ncclDevSum; break;
case ncclProd: opFull->op = ncclDevProd; break;
case ncclMax: opFull->op = ncclDevMax; break;
case ncclMin: opFull->op = ncclDevMin; break;
case ncclAvg:
switch ((int)datatype) {
case ncclInt8: case ncclInt32: case ncclInt64:
case ncclUint8: case ncclUint32: case ncclUint64:
opFull->op = ncclDevSumPostDiv;
u64 = comm->nRanks;
break;
case ncclFloat16:
opFull->op = ncclDevPreMulSum;
f16 = __float2half(float(1.0/comm->nRanks)); // __double2half not supported pre CUDA 11.x
break;
#if defined(RCCL_BFLOAT16)
case ncclBfloat16:
opFull->op = ncclDevPreMulSum;
bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks));
break;
#endif
case ncclFloat32:
opFull->op = ncclDevPreMulSum;
f32 = float(1.0/comm->nRanks);
break;
case ncclFloat64:
opFull->op = ncclDevPreMulSum;
f64 = 1.0/comm->nRanks;
break;
}
opFull->scalarArgIsPtr = false;
opFull->scalarArg = u64;
break;
default: // user created
int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps);
ncclUserRedOp *user = &comm->userRedOps[ix];
if (datatype != user->datatype) {
WARN("Data type supplied to user-created ncclRedOp_t does not match type "
"given to reduction operation");
return ncclInvalidArgument;
}
*opFull = user->opFull;
break;
}
return ncclSuccess;
}
#define MSCCL_KERNEL_ENTRY_DEVREDOP_NULL() \
nullptr, \
nullptr, \
nullptr
#define MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, type) \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL), \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128), \
(void *)MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple)
#define MSCCL_KERNEL_ENTRY_DEVREDOP(devredop) \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int8_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint8_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int32_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint32_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, int64_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, uint64_t), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, half), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, float), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, double), \
MSCCL_KERNEL_ENTRY_DEVREDOP_TYPE(devredop, rccl_bfloat16)
#define MSCCL_KERNEL_ENTRY() \
MSCCL_KERNEL_ENTRY_DEVREDOP(Sum), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Prod), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Max), \
MSCCL_KERNEL_ENTRY_DEVREDOP(Min)
// Except for ncclDevPreMulSum and ncclDevSumPostDiv required by ncclAvg
void* mscclKernelEntries[(ncclNumDevRedOps - 2) * ncclNumTypes * NCCL_NUM_PROTOCOLS] = {
MSCCL_KERNEL_ENTRY()
};
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
bool enableDoneEvent =
#ifndef HIP_EVENT_DISABLE_FENCE
(rcclParamMscclEnableDoneEvent() == 1);
#else
true;
#endif
if (enableDoneEvent && (status.lastStream != stream && status.lastStream != nullptr)) {
CUDACHECK(hipStreamWaitEvent(stream, comm->doneEvent, 0));
}
dim3 grid = {(uint32_t)hostAlgo->nBlocks, 1, 1};
dim3 block = {NCCL_MAX_NTHREADS, 1, 1};
ncclDevRedOpFull opFull;
NCCLCHECK(hostToDevRedOp(&opFull, op, dataType, comm));
mscclWork work;
work.syncFlags = status.syncFlags;
work.scratchBuffer = status.scratchBuffer;
work.sendBuff = sendBuff;
work.recvBuff = recvBuff;
work.count = count * hostAlgo->sizeMultiplier; // count is sum of all ranks in MSCCL kernel
work.redOpArg = opFull.scalarArg;
work.workIndex = status.workIndex;
work.nChunksPerLoop = hostAlgo->nChunksPerLoop;
work.maxAllowedCount = status.maxAllowedCount;
work.hasReduce = hostAlgo->hasReduce;
work.redOpArgIsPtr = opFull.scalarArgIsPtr;
void *args[3] = {&comm->devComm, &devAlgo, &work};
void *func = mscclKernelEntries[(opFull.op * ncclNumTypes + dataType) * NCCL_NUM_PROTOCOLS + hostAlgo->protocol];
if (enableDoneEvent) {
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, comm->doneEvent, 0));
} else {
CUDACHECK(hipExtLaunchKernel(func, grid, block, args, 0, stream, NULL, NULL, 0));
}
status.workIndex++;
status.lastStream = stream;
return ncclSuccess;
}
// Determine the maximum kernel stack size of all MSCCL kernels
size_t mscclKernMaxLocalSize() {
ncclResult_t res = ncclSuccess;
int numMscclKerns = sizeof(mscclKernelEntries)/sizeof(void *);
hipFuncAttributes attr = {0};
size_t max = 0;
for (int i = 0; i < numMscclKerns; i++) {
if (mscclKernelEntries[i] != nullptr) {
CUDACHECKGOTO(hipFuncGetAttributes(&attr, reinterpret_cast<const void*>(mscclKernelEntries[i])), res, error);
if (attr.localSizeBytes > max) max = attr.localSizeBytes;
}
}
error:
return (res != ncclSuccess) ? 0 : max;
}