Comhaid
rocm-systems/src/misc/msccl/msccl_lifecycle.cc
T
corey-derochie-amd 853a0586b4 Moved mscclpp_ncclGetUniqueId call into ncclCommInitRankFunc (#1332)
* Moved call to `mscclpp_ncclGetUniqueId` into `ncclCommInitRankFunc` to avoid setting up transport early in environments where MSCCL++ isn't valid.

* Checking `mscclEnabled` for the process and the topology to gate MSCCL++.

* Allowed `mscclForceEnable` to enable MSCCL++.
2024-09-16 16:41:40 -06:00

692 línte
27 KiB
C++

/*************************************************************************
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
************************************************************************/
#include <atomic>
#include <map>
#include <mutex>
#include <set>
#include <dirent.h>
#include <dlfcn.h>
#include <error.h>
#include <link.h>
#include "alloc.h"
#include "checks.h"
#include "graph/topo.h"
#include "msccl/msccl_lifecycle.h"
#include "msccl/msccl_parser.h"
#include "msccl/msccl_setup.h"
#include "msccl/msccl_status.h"
#ifdef ENABLE_MSCCLPP
#include "mscclpp/mscclpp_nccl.h"
#endif
RCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 1);
RCCL_PARAM(MscclForceEnabled, "MSCCL_FORCE_ENABLE", 0);
RCCL_PARAM(MscclEnableSingleProcess, "MSCCL_ENABLE_SINGLE_PROCESS", 0);
static const char* mscclAlgoFilePathEnv = "MSCCL_ALGO_FILE_PATH";
bool mscclEnabled() {
#ifdef COMPILE_MSCCL_KERNEL
return rcclParamMscclEnabled();
#else
return false;
#endif
}
bool mscclForceEnabled() {
#ifdef COMPILE_MSCCL_KERNEL
return rcclParamMscclForceEnabled();
#else
return false;
#endif
}
void mscclSetIsCallerFlag() {
mscclGetThreadLocalStatus().mscclIsCallerFlag = true;
}
void mscclClearIsCallerFlag() {
mscclGetThreadLocalStatus().mscclIsCallerFlag = false;
}
bool mscclIsCaller() {
return mscclGetThreadLocalStatus().mscclIsCallerFlag;
}
bool mscclAvailable(int rank) {
return mscclEnabled() && mscclInitialized(rank);
}
static bool mscclCommCompatible(ncclComm_t comm) {
if (rcclParamMscclEnableSingleProcess()) {
// Single process usage enabled. No need to guard against multi-thread.
return true;
}
std::map<uint64_t, std::set<uint64_t>> hostHashToPidHashes;
for (int i = 0; i < comm->nRanks; i++) {
uint64_t hostHash = comm->peerInfo[i].hostHash;
uint64_t pidHash = comm->peerInfo[i].pidHash;
if (hostHashToPidHashes.find(hostHash) != hostHashToPidHashes.end()) {
auto& pidHashSet = hostHashToPidHashes[hostHash];
if (pidHashSet.find(pidHash) != pidHashSet.end()) {
return false;
}
}
hostHashToPidHashes[hostHash].insert(pidHash);
}
return true;
}
#ifdef ENABLE_MSCCLPP
bool mscclppCommCompatible(ncclComm_t comm) {
return mscclCommCompatible(comm);
}
#endif
const char *mscclFuncNames[] = {
"mscclFuncReduce",
"mscclFuncBroadcast",
"mscclFuncAllReduce",
"mscclFuncReduceScatter",
"mscclFuncAllGather",
"mscclFuncSend",
"mscclFuncRecv",
"mscclFuncGather",
"mscclFuncScatter",
"mscclFuncAllToAll",
"mscclFuncAllToAllv",
};
static const char* mscclSchedulerPathEnv = "MSCCL_SCHEDULER";
static const char* mscclSchedulerDefaultPath = "libmsccl-scheduler.so";
static const char* mscclAlgoDirEnv = "MSCCL_ALGO_DIR";
static const char* mscclAlgoDefaultDir = "msccl-algorithms";
extern "C" bool mscclUnitTestMode() __attribute__((__weak__));
static const char* mscclUnitTestAlgoDefaultDir = "msccl-unit-test-algorithms";
static const char* mscclAlgoShareDirPath = "../share/rccl/msccl-algorithms";
static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-test-algorithms";
static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
static thread_local bool mscclAlgoMetaLoaded = false;
mscclStatus& status = mscclGetStatus(comm->rank);
int maxNchannels = *numChannelsRequired;
*numChannelsRequired = 0;
// Query numChannelsRequired from loaded algorithm metas
if (mscclAlgoMetaLoaded) {
for (auto& m : status.algoMetas) {
if (comm->nRanks == m.nRanks) {
if(m.nChannels <= maxNchannels) {
*numChannelsRequired = std::max(*numChannelsRequired, m.nChannels);
} else {
WARN("NCCL_MAX_NCHANNELS:%d is lesser than number of channels required by MSCCL:%d, so disabling MSCCL for %s between minBytes:%ld and maxBytes:%ld from file: %s", \
maxNchannels, m.nChannels, mscclFuncNames[m.func], m.minBytes, m.maxBytes, m.filePath.c_str());
}
}
}
return ncclSuccess;
}
const char* mscclAlgoDir = getenv(mscclAlgoDirEnv);
const char* mscclAlgoShareDir = nullptr;
std::string mscclAlgoDirStr;
std::string mscclAlgoShareDirStr;
const char *fullDirPath = nullptr;
if (mscclAlgoDir == nullptr) {
// Try to find default algorithm directory based on librccl.so path
Dl_info dl_info;
struct link_map *link_map_ptr = nullptr;
if (!dladdr1((void *)mscclInternalSchedulerInit, &dl_info, (void **)&link_map_ptr, RTLD_DL_LINKMAP)) {
WARN("MSCCL Internal Scheduler: dladdr1 failed");
return ncclInvalidUsage;
}
std::string selfLibPath = link_map_ptr->l_name;
mscclAlgoDirStr = selfLibPath.substr(0, selfLibPath.find_last_of("/\\") + 1);
mscclAlgoDirStr += (mscclUnitTestMode && mscclUnitTestMode()) ? mscclUnitTestAlgoDefaultDir : mscclAlgoDefaultDir;
mscclAlgoDir = mscclAlgoDirStr.c_str();
// Get share Directory Paths
mscclAlgoShareDirStr = selfLibPath.substr(0, selfLibPath.find_last_of("/\\") + 1);
mscclAlgoShareDirStr += (mscclUnitTestMode && mscclUnitTestMode()) ? mscclUnitTestAlgoShareDirPath : mscclAlgoShareDirPath;
mscclAlgoShareDir = mscclAlgoShareDirStr.c_str();
}
struct dirent *entry = nullptr;
DIR *dp = nullptr;
dp = opendir(mscclAlgoDir);
if (dp == nullptr) {
//Try to find the algorithm directory under share folder based on librccl.so path
dp = opendir(mscclAlgoShareDir);
if (dp == nullptr) {
WARN("MSCCL Internal Scheduler: Unable to find algorithm files in %s. Please ensure that RCCL has been installed properly" , mscclAlgoShareDir);
return ncclInvalidUsage;
}
fullDirPath = mscclAlgoShareDir;
} else {
fullDirPath = mscclAlgoDir;
}
INFO(NCCL_INIT, "Using MSCCL files from %s", fullDirPath);
while ((entry = readdir(dp))) {
if (entry->d_type != DT_LNK && entry->d_type != DT_REG) {
continue;
}
status.algoMetas.emplace_back();
std::string fullPath = fullDirPath;
fullPath += "/";
fullPath += entry->d_name;
NCCLCHECK(mscclGetAlgoMetaFromXmlFile(fullPath.c_str(), &(status.algoMetas.back())));
if (status.algoMetas.back().nRanks == comm->nRanks) {
if(status.algoMetas.back().nChannels <= maxNchannels) {
*numChannelsRequired = std::max(*numChannelsRequired, status.algoMetas.back().nChannels);
} else {
WARN("NCCL_MAX_NCHANNELS:%d is lesser than number of channels required by MSCCL:%d, so disabling MSCCL for %s between minBytes:%ld and maxBytes:%ld from file: %s", \
maxNchannels, status.algoMetas.back().nChannels, mscclFuncNames[status.algoMetas.back().func], status.algoMetas.back().minBytes, status.algoMetas.back().maxBytes, \
status.algoMetas.back().filePath.c_str());
status.algoMetas.pop_back();
}
}
}
if (closedir(dp)) {
WARN("MSCCL Internal Scheduler: closedir failed, error %d", errno);
return ncclInvalidUsage;
}
status.rankToAlgoHandles.resize(status.algoMetas.size());
mscclAlgoMetaLoaded = true;
return ncclSuccess;
}
ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
comm->mscclCompatible = mscclCommCompatible(comm);
if (!comm->mscclCompatible) {
return ncclSuccess;
}
mscclStatus& status = mscclGetStatus(comm->rank);
bool useInternalScheduler = false;
const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv);
if (mscclSchedulerPath) {
status.mscclSchedulerLib = dlopen(mscclSchedulerPath, RTLD_NOW | RTLD_LOCAL);
} else {
status.mscclSchedulerLib = dlopen(mscclSchedulerDefaultPath, RTLD_NOW | RTLD_LOCAL);
}
if (status.mscclSchedulerLib == nullptr) {
INFO(NCCL_INIT, "MSCCL: No external scheduler found, using internal implementation");
useInternalScheduler = true;
} else {
status.mscclSchedulerPtr = (mscclSchedulerInterface *)dlsym(status.mscclSchedulerLib, "mscclScheduler");
if (status.mscclSchedulerPtr == nullptr) {
INFO(NCCL_INIT, "MSCCL: Failed to find mscclScheduler symbol, using internal implementation");
useInternalScheduler = true;
}
}
if (useInternalScheduler) {
NCCLCHECK(mscclInternalSchedulerInit(comm, numChannelsRequired));
} else {
NCCLCHECK(status.mscclSchedulerPtr->init());
*numChannelsRequired = MAXCHANNELS;
}
return ncclSuccess;
}
ncclResult_t mscclInit(ncclComm_t comm) {
{
mscclStatus& status = mscclGetStatus(comm->rank);
// freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection
if (!mscclInitialized(comm->rank)) {
status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS);
for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) {
status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1;
}
status.needsProxy = false;
}
// Pre-process all algorithms for internal scheduler and for different comms.
// This is a temp fix to bypass the issue that stream cannot be synchronized during HIP graph capturing,
// should use dynamic loading approach after the issue is fixed.
if (comm->mscclCompatible && !status.mscclSchedulerPtr) {
for (size_t i = 0; i < status.algoMetas.size(); i++) {
auto &m = status.algoMetas[i];
if (m.nRanks == comm->nRanks) {
// Load algorithms
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) {
static std::mutex loadAlgoMutex;
std::lock_guard<std::mutex> lock(loadAlgoMutex);
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank));
}
// Connect algorithms
mscclAlgoHandle_t mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank];
if (status.connectedAlgos[comm].find(mscclAlgoHandle) == status.connectedAlgos[comm].end()) {
NCCLCHECK(mscclSetupConnections(status.hostAlgos[mscclAlgoHandle], comm));
status.connectedAlgos[comm].insert(mscclAlgoHandle);
}
}
}
}
if (mscclInitialized(comm->rank)) {
return ncclSuccess;
}
status.workIndex = 1;
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
status.lastStream = nullptr;
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
mscclSetInitialized(comm->rank);
}
INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize());
return ncclSuccess;
}
ncclResult_t mscclGroupStart() {
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
threadLocalStatus.groupDepth++;
if (threadLocalStatus.groupStatus == mscclNoGroup) {
threadLocalStatus.groupStatus = mscclGroupSupportedOp;
}
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) {
mscclStatus& status = mscclGetStatus(rank);
param->scheduled = false;
// Current MSCCL doesn't support pre/post op
if (param->op >= ncclAvg) {
return ncclSuccess;
}
// Whether the algorithm is in-place
bool isInPlace = false;
if (param->func == mscclFuncReduce ||
param->func == mscclFuncBroadcast ||
param->func == mscclFuncAllReduce ||
param->func == mscclFuncAllToAll ||
param->func == mscclFuncAllToAllv) {
isInPlace = param->sendBuff == param->recvBuff;
} else if (param->func == mscclFuncAllGather ||
param->func == mscclFuncGather) {
isInPlace = (char*)param->sendBuff == (char*)param->recvBuff + param->rank * param->count * ncclTypeSize(param->dataType);
} else if (param->func == mscclFuncReduceScatter ||
param->func == mscclFuncScatter) {
isInPlace = (char*)param->recvBuff == (char*)param->sendBuff + param->rank * param->count * ncclTypeSize(param->dataType);
}
// Search suitable algorithms
for (size_t i = 0; i < status.algoMetas.size(); i++) {
auto &m = status.algoMetas[i];
size_t nBytes = param->count * ncclTypeSize(param->dataType) * m.sizeMultiplier;
bool msgSizeIsValid =
param->count > 0 && (param->count % m.nChunksPerLoop) == 0 &&
nBytes >= m.minBytes && (m.maxBytes == 0 || nBytes <= m.maxBytes);
if (msgSizeIsValid &&
m.nRanks == param->nRanks &&
m.func == param->func &&
(isInPlace ? m.inPlace : m.outOfPlace)) {
param->handle = status.rankToAlgoHandles[i][param->rank];
param->scheduled = true;
return ncclSuccess;
}
}
return ncclSuccess;
}
static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) {
mscclStatus& status = mscclGetStatus(param->comm->rank);
if (status.mscclSchedulerPtr) {
NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p)));
} else {
// Disable MSCCL algorithms if machine type is not matching
if (param->comm->topo->mscclEnabled || mscclForceEnabled()) {
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p)));
} else {
param->p.scheduled = false;
}
}
return ncclSuccess;
}
static ncclResult_t mscclSetSavedSchedulerParam(
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
mscclFunc_t func, ncclComm_t comm, hipStream_t stream,
struct mscclSavedSchedulerParam* param) {
param->p.sendBuff = sendBuff;
param->p.sendCounts = sendCounts;
param->p.sDisPls = sDisPls;
param->p.recvBuff = recvBuff;
param->p.recvCounts = recvCounts;
param->p.rDisPls = rDisPls;
param->p.count = count;
param->p.dataType = dataType;
param->p.root = root;
param->p.peer = peer;
param->p.op = op;
param->p.func = func;
param->p.rank = comm->rank;
param->p.nRanks = comm->nRanks;
param->comm = comm;
param->stream = stream;
param->p.opCount = comm->opCount;
return ncclSuccess;
}
static ncclResult_t mscclSaveCountsAndDispls(struct mscclSavedSchedulerParam* param) {
if (param->p.sendCounts) {
param->savedSendCounts.assign(param->p.sendCounts, param->p.sendCounts + param->p.nRanks);
param->p.sendCounts = param->savedSendCounts.data();
param->savedSDisPls.assign(param->p.sDisPls, param->p.sDisPls + param->p.nRanks);
param->p.sDisPls = param->savedSDisPls.data();
param->savedRecvCounts.assign(param->p.recvCounts, param->p.recvCounts + param->p.nRanks);
param->p.recvCounts = param->savedRecvCounts.data();
param->savedRDisPls.assign(param->p.rDisPls, param->p.rDisPls + param->p.nRanks);
param->p.rDisPls = param->savedRDisPls.data();
}
return ncclSuccess;
}
static ncclResult_t mscclRunSavedParams() {
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
for (auto& param : threadLocalStatus.savedSchedulerParams) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d",
mscclFuncNames[param.p.func], param.p.opCount, param.p.sendBuff, param.p.recvBuff, param.p.count,
param.p.dataType, param.p.op, param.p.root, param.comm, param.p.nRanks, param.stream, param.comm->tasks.nTasksP2p + param.comm->tasks.nTasksColl, param.comm->localRankToRank[param.comm->localRank]);
NCCLCHECK(mscclRunAlgo(
param.p.sendBuff, param.p.sendCounts, param.p.sDisPls,
param.p.recvBuff, param.p.recvCounts, param.p.rDisPls,
param.p.count, param.p.dataType, param.p.root, param.p.peer, param.p.op, param.p.handle, param.comm, param.stream));
}
threadLocalStatus.savedSchedulerParams.clear();
return ncclSuccess;
}
static ncclResult_t mscclFallBackSavedParams() {
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
mscclSetIsCallerFlag();
for (auto& param : threadLocalStatus.savedSchedulerParams) {
switch (param.p.func) {
case mscclFuncReduce:
NCCLCHECK(ncclReduce(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.p.op, param.p.root, param.comm, param.stream));
break;
case mscclFuncBroadcast:
NCCLCHECK(ncclBroadcast(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.p.root, param.comm, param.stream));
break;
case mscclFuncAllReduce:
NCCLCHECK(ncclAllReduce(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.p.op, param.comm, param.stream));
break;
case mscclFuncReduceScatter:
NCCLCHECK(ncclReduceScatter(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.p.op, param.comm, param.stream));
break;
case mscclFuncAllGather:
NCCLCHECK(ncclAllGather(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.comm, param.stream));
break;
case mscclFuncSend:
NCCLCHECK(ncclSend(param.p.sendBuff, param.p.count, param.p.dataType,
param.p.peer, param.comm, param.stream));
break;
case mscclFuncRecv:
NCCLCHECK(ncclRecv(param.p.recvBuff, param.p.count, param.p.dataType,
param.p.peer, param.comm, param.stream));
break;
case mscclFuncGather:
NCCLCHECK(ncclGather(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.p.root, param.comm, param.stream));
break;
case mscclFuncScatter:
NCCLCHECK(ncclScatter(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.p.root, param.comm, param.stream));
break;
case mscclFuncAllToAll:
NCCLCHECK(ncclAllToAll(param.p.sendBuff, param.p.recvBuff, param.p.count, param.p.dataType,
param.comm, param.stream));
break;
case mscclFuncAllToAllv:
NCCLCHECK(ncclAllToAllv(
param.p.sendBuff, param.p.sendCounts, param.p.sDisPls,
param.p.recvBuff, param.p.recvCounts, param.p.rDisPls,
param.p.dataType, param.comm, param.stream));
break;
default:
WARN("Invalid MSCCL function type in saved parameter");
return ncclInvalidUsage;
}
}
mscclClearIsCallerFlag();
threadLocalStatus.savedSchedulerParams.clear();
return ncclSuccess;
}
#ifdef ENABLE_MSCCLPP
static inline bool isMscclppAllReduceSupported(ncclDataType_t dataType, ncclRedOp_t op) {
switch (dataType) {
case ncclFloat16:
case ncclInt32:
case ncclUint32:
case ncclFloat32:
#ifdef RCCL_BFLOAT16
case ncclBfloat16:
#endif
break;
default:
return false;
}
return (op == ncclSum);
}
#endif
ncclResult_t mscclEnqueueCheck(
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
void* recvBuff, const size_t recvCounts[], const size_t rDisPls[],
size_t count, ncclDataType_t dataType, int root, int peer, ncclRedOp_t op,
mscclFunc_t func, ncclComm_t comm, hipStream_t stream) {
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
threadLocalStatus.savedSchedulerParams.push_back({});
NCCLCHECK(mscclSetSavedSchedulerParam(
sendBuff, sendCounts, sDisPls, recvBuff, recvCounts, rDisPls,
count, dataType, root, peer, op, func, comm, stream,
&threadLocalStatus.savedSchedulerParams.back()));
size_t nBytes = count * ncclTypeSize(dataType);
switch (threadLocalStatus.groupStatus) {
case mscclNoGroup:
#ifdef ENABLE_MSCCLPP
if (comm->mscclppCompatible) {
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
}
/* check if one rank per GPU and graph mode is enabled */
if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
}
}
#endif
if (comm->mscclCompatible) {
NCCLCHECK(mscclSchedulerSelectAlgo(&threadLocalStatus.savedSchedulerParams.back()));
if (threadLocalStatus.savedSchedulerParams.back().p.scheduled) {
NCCLCHECK(mscclRunSavedParams());
break;
}
}
NCCLCHECK(mscclFallBackSavedParams());
break;
case mscclGroupSupportedOp:
#ifdef ENABLE_MSCCLPP
if (comm->mscclppCompatible) {
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
}
/* check if one rank per GPU and graph mode is enabled */
if ((threadLocalStatus.captureStatus != mscclNoCapture) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream);
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
}
}
#endif
if (comm->mscclCompatible) {
NCCLCHECK(mscclSchedulerSelectAlgo(&threadLocalStatus.savedSchedulerParams.back()));
if (threadLocalStatus.savedSchedulerParams.back().p.scheduled) {
// Only save counts and displs when there is suitable MSCCL algorithm for this
NCCLCHECK(mscclSaveCountsAndDispls(&threadLocalStatus.savedSchedulerParams.back()));
break;
}
}
threadLocalStatus.groupStatus = mscclGroupUnsupportedOp;
NCCLCHECK(mscclFallBackSavedParams());
case mscclGroupUnsupportedOp:
NCCLCHECK(mscclFallBackSavedParams());
break;
default:
return ncclInvalidUsage;
}
return ncclSuccess;
}
ncclResult_t mscclGroupEnd() {
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
threadLocalStatus.groupDepth--;
if (threadLocalStatus.groupDepth == 0) {
if (threadLocalStatus.groupStatus == mscclGroupSupportedOp) {
NCCLCHECK(mscclRunSavedParams());
}
threadLocalStatus.groupStatus = mscclNoGroup;
}
return ncclSuccess;
}
static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) {
mscclStatus& status = mscclGetStatus(rank);
free(status.hostAlgos[mscclAlgoHandle]);
status.hostAlgos.erase(mscclAlgoHandle);
NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle]));
status.devAlgos.erase(mscclAlgoHandle);
status.freeAlgoHandles.push_back(mscclAlgoHandle);
for (auto &s : status.connectedAlgos) {
s.second.erase(mscclAlgoHandle);
}
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess;
mscclStatus& status = mscclGetStatus(rank);
for (auto &m : status.rankToAlgoHandles) {
for (auto &p : m) {
tmpRet = mscclInternalUnloadAlgo(rank, p.second);
if (ret == ncclSuccess) {
ret = tmpRet;
}
}
}
status.algoMetas.clear();
status.rankToAlgoHandles.clear();
return ret;
}
ncclResult_t mscclTeardown(int rank) {
{
if (!mscclInitialized(rank)) {
mscclRemoveRank(rank);
return ncclSuccess;
}
mscclStatus& status = mscclGetStatus(rank);
for (auto &p : status.hostAlgos) {
free(p.second);
status.freeAlgoHandles.push_back(p.first);
}
for (auto &p : status.devAlgos) {
CUDACHECK(hipFree(p.second));
}
CUDACHECK(hipFree(status.syncFlags));
status.hostAlgos.clear();
status.devAlgos.clear();
status.freeAlgoHandles.clear();
for (auto &p : status.scratchBuffers) {
CUDACHECK(hipFree(p.second));
}
status.scratchBuffers.clear();
status.connectedAlgos.clear();
if (status.mscclSchedulerPtr) {
NCCLCHECK(status.mscclSchedulerPtr->teardown());
status.mscclSchedulerPtr = nullptr;
dlclose(status.mscclSchedulerLib);
status.mscclSchedulerLib = nullptr;
} else {
NCCLCHECK(mscclInternalSchedulerTeardown(rank));
}
NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus)));
for (auto &p : status.graphWorkFifoStatus) {
NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second)));
}
mscclSetInitialized(rank, false);
mscclRemoveRank(rank);
}
INFO(NCCL_INIT, "MSCCL: Teardown finished");
return ncclSuccess;
}