SPLITCOMM design fix in src/misc/msccl (#1715)

* Fix TOC-TOU in mcclInit

* Improving vector resize thread safety

* Initial commit rank to comm change

* Removing unwanted include header changes

* Updated CHANGELOG.md

* Update CHANGELOG.md

Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>

---------

Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>

[ROCm/rccl commit: e94b360246]
Этот коммит содержится в:
Avinash
2025-06-01 21:00:38 -05:00
коммит произвёл GitHub
родитель 4277b5aa88
Коммит a50ff2c3d3
13 изменённых файлов: 134 добавлений и 111 удалений
+3 -1
Просмотреть файл
@@ -8,6 +8,8 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
* Resolved an issue when using more than 64 channels when multiple collectives are used in the same `ncclGroup()` call.
* Fixed unit test failures in tests ending with `ManagedMem` and `ManagedMemGraph` suffixes.
* Fixed the known issue "When splitting a communicator using `ncclCommSplit` in some GPU configurations, MSCCL initialization can cause a segmentation fault." with a design change to use `comm` instead of `rank` for `mscclStatus`. The Global map for `comm` to `mscclStatus` is still not thread safe but should be explicitly handled by mutexes for read writes. This is tested for correctness, but there is a plan to use a thread-safe map data structure in upcoming changes.
### Added
@@ -304,4 +306,4 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
### Changed
- Switched to hip-clang as default compiler
### Deprecated
- Deprecated hcc build
- Deprecated hcc build
+11 -11
Просмотреть файл
@@ -96,7 +96,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
NCCLCHECK(Recorder::instance().record(rrAllGather, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
@@ -124,7 +124,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
NCCLCHECK(Recorder::instance().record(rrAllReduce, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
@@ -149,7 +149,7 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
NCCLCHECK(Recorder::instance().record(rrAllToAll, sendbuff, recvbuff, count, datatype, comm, stream));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
@@ -195,7 +195,7 @@ ncclResult_t ncclAllToAllv_impl(const void *sendbuff, const size_t sendcounts[],
NCCLCHECK(Recorder::instance().record(rrAllToAllv, sendbuff, recvbuff, 0, datatype, comm, stream, -1, sendcounts, sdispls, recvcounts, rdispls));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream);
@@ -241,7 +241,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou
NCCLCHECK(Recorder::instance().record(rrBroadcast, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
@@ -271,7 +271,7 @@ ncclResult_t ncclGather_impl(const void* sendbuff, void* recvbuff, size_t sendco
NCCLCHECK(Recorder::instance().record(rrGather, sendbuff, recvbuff, sendcount, datatype, comm, stream, root));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream);
@@ -310,7 +310,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count,
NCCLCHECK(Recorder::instance().record(rrReduce, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
@@ -337,7 +337,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
NCCLCHECK(Recorder::instance().record(rrReduceScatter, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
@@ -360,7 +360,7 @@ ncclResult_t ncclScatter_impl(const void* sendbuff, void* recvbuff, size_t recvc
NCCLCHECK(Recorder::instance().record(rrScatter, sendbuff, recvbuff, recvcount, datatype, comm, stream, root));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream);
@@ -400,7 +400,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da
NCCLCHECK(Recorder::instance().record(rrSend, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
@@ -426,7 +426,7 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype
NCCLCHECK(Recorder::instance().record(rrRecv, info));
}
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
if (mscclAvailable(comm) && !mscclIsCaller()) {
return mscclEnqueueCheck(
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
+1 -1
Просмотреть файл
@@ -139,7 +139,7 @@ typedef ncclResult_t (*ncclMemAlloc_fn_t)(void** ptr, size_t size);
typedef ncclResult_t (*ncclMemFree_fn_t)(void* ptr);
typedef ncclResult_t (*mscclLoadAlgo_fn_t)(const char* mscclAlgoFilePath,
mscclAlgoHandle_t* mscclAlgoHandle, int rank);
mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm);
typedef ncclResult_t (*mscclRunAlgo_fn_t)(
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
+10 -2
Просмотреть файл
@@ -17,7 +17,15 @@ void mscclSetIsCallerFlag();
void mscclClearIsCallerFlag();
bool mscclIsCaller();
bool mscclAvailable(int rank = -1);
/**
* @brief mscclAvailable() is used to determine if msccl functionality is avaliable
* @param comm is an optional rccl communicator, if provided uses the mscclStatus
* from a global map<comm -> mscclStatus> to determine if msccl is available. If not available
* in the map, this invocations inserts a new key value pair in the global map.
* If comm == nullptr, on the first invocation it initializes a static thread local variable
* mscclStatus and uses the same object in subsequent calls from same thread if comm is null ptr
*/
bool mscclAvailable(const ncclComm_t comm = nullptr);
ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired);
@@ -33,7 +41,7 @@ ncclResult_t mscclEnqueueCheck(
ncclResult_t mscclGroupEnd();
ncclResult_t mscclTeardown(int rank);
ncclResult_t mscclTeardown(const ncclComm_t comm);
size_t mscclKernMaxLocalSize();
+3 -3
Просмотреть файл
@@ -11,13 +11,13 @@
#include "comm.h"
#include "msccl/msccl_struct.h"
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream);
ncclResult_t mscclGetCaptureStatus(const ncclComm_t comm, hipStream_t stream);
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream);
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream);
ncclResult_t mscclSetupSyncFlags(const ncclComm_t comm, hipStream_t stream);
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm);
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo,const ncclComm_t comm);
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType);
+5 -5
Просмотреть файл
@@ -8,15 +8,15 @@
#include "msccl/msccl_struct.h"
bool mscclInitialized(int rank);
bool mscclInitialized(const ncclComm_t comm);
void mscclSetInitialized(int rank, bool initialized = true);
void mscclSetInitialized(const ncclComm_t comm, bool initialized = true);
void mscclRemoveRank(int rank);
void mscclRemoveRank(const ncclComm_t comm);
mscclStatus& mscclGetStatus(int rank);
mscclStatus& mscclGetStatus(const ncclComm_t comm);
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank);
mscclSavedProxyArgs& mscclGetSavedProxyArgs(const ncclComm_t comm);
mscclThreadLocalStatus& mscclGetThreadLocalStatus();
+2 -2
Просмотреть файл
@@ -1745,7 +1745,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
NCCLCHECK(mscclInit(comm));
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
status.needsProxy |= mscclNeedsProxy;
}
@@ -2530,7 +2530,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) {
NCCLCHECK(ncclTunerPluginUnload(comm));
}
if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) {
NCCLCHECK(mscclTeardown(comm->rank));
NCCLCHECK(mscclTeardown(comm));
}
NCCLCHECK(commFree(comm));
+4 -4
Просмотреть файл
@@ -135,7 +135,7 @@ ncclMemFree_impl(void* ptr);
ncclResult_t
mscclLoadAlgo_impl(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle,
int rank);
const ncclComm_t comm);
ncclResult_t
mscclRunAlgo_impl(const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
@@ -380,7 +380,7 @@ NCCL_API(ncclResult_t, ncclMemAlloc, void** ptr, size_t size);
NCCL_API(ncclResult_t, ncclMemFree, void* ptr);
NCCL_API(ncclResult_t, mscclLoadAlgo, const char* mscclAlgoFilePath,
mscclAlgoHandle_t* mscclAlgoHandle, int rank);
mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm);
NCCL_API(ncclResult_t, mscclRunAlgo, const void* sendBuff, const size_t sendCounts[],
const size_t sDisPls[], void* recvBuff, const size_t recvCounts[],
@@ -620,10 +620,10 @@ ncclMemFree(void* ptr)
}
ncclResult_t
mscclLoadAlgo(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, int rank)
mscclLoadAlgo(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm)
{
return ::rccl::RcclGetFunctionTable()->mscclLoadAlgo_fn(mscclAlgoFilePath,
mscclAlgoHandle, rank);
mscclAlgoHandle, comm);
}
ncclResult_t
+47 -38
Просмотреть файл
@@ -22,6 +22,7 @@
#include "msccl/msccl_setup.h"
#include "msccl/msccl_status.h"
#include "rccl/rccl.h"
#ifdef ENABLE_MSCCLPP
#include "mscclpp/mscclpp_nccl.h"
#endif
@@ -59,8 +60,8 @@ bool mscclIsCaller() {
return mscclGetThreadLocalStatus().mscclIsCallerFlag;
}
bool mscclAvailable(int rank) {
return mscclEnabled() && mscclInitialized(rank);
bool mscclAvailable(const ncclComm_t comm) {
return mscclEnabled() && mscclInitialized(comm);
}
static bool allProcessHostsUnique(ncclComm_t comm) {
@@ -118,7 +119,7 @@ static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-tes
static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
static thread_local bool mscclAlgoMetaLoaded = false;
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
int maxNchannels = *numChannelsRequired;
*numChannelsRequired = 0;
@@ -213,7 +214,7 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
return ncclSuccess;
}
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
bool useInternalScheduler = false;
const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv);
@@ -243,12 +244,14 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
return ncclSuccess;
}
ncclResult_t mscclInit(ncclComm_t comm) {
ncclResult_t mscclInit(const ncclComm_t comm) {
{
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
// freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection
if (!mscclInitialized(comm->rank)) {
if (!mscclInitialized(comm)) {
static std::mutex initMutex;
std::lock_guard<std::mutex> lock(initMutex);
status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS);
for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) {
status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1;
@@ -264,13 +267,16 @@ ncclResult_t mscclInit(ncclComm_t comm) {
auto &m = status.algoMetas[i];
if (m.nRanks == comm->nRanks) {
// Load algorithms
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) {
mscclAlgoHandle_t mscclAlgoHandle;
{
static std::mutex loadAlgoMutex;
std::lock_guard<std::mutex> lock(loadAlgoMutex);
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank));
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()){
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm));
}
// Connect algorithms
mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank];
}
// Connect algorithms
mscclAlgoHandle_t mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank];
if (status.connectedAlgos[comm].find(mscclAlgoHandle) == status.connectedAlgos[comm].end()) {
NCCLCHECK(mscclSetupConnections(status.hostAlgos[mscclAlgoHandle], comm));
status.connectedAlgos[comm].insert(mscclAlgoHandle);
@@ -278,17 +284,20 @@ ncclResult_t mscclInit(ncclComm_t comm) {
}
}
}
{
static std::mutex mscclInitMutex;
std::lock_guard<std::mutex> lock(mscclInitMutex);
if (mscclInitialized(comm)){
return ncclSuccess;
}
if (mscclInitialized(comm->rank)) {
return ncclSuccess;
status.workIndex = 1;
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
status.lastStream = nullptr;
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
mscclSetInitialized(comm);
}
status.workIndex = 1;
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
status.lastStream = nullptr;
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
mscclSetInitialized(comm->rank);
}
INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize());
@@ -304,8 +313,8 @@ ncclResult_t mscclGroupStart() {
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) {
mscclStatus& status = mscclGetStatus(rank);
static ncclResult_t mscclInternalSchedulerSelectAlgo(const ncclComm_t comm, struct mscclSchedulerParam* param) {
mscclStatus& status = mscclGetStatus(comm);
param->scheduled = false;
// Current MSCCL doesn't support pre/post op
@@ -350,13 +359,13 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSched
}
static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) {
mscclStatus& status = mscclGetStatus(param->comm->rank);
mscclStatus& status = mscclGetStatus(param->comm);
if (status.mscclSchedulerPtr) {
NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p)));
} else {
// Disable MSCCL algorithms if machine type is not matching
if (param->comm->topo->mscclEnabled || mscclForceEnabled()) {
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p)));
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm, &(param->p)));
} else {
param->p.scheduled = false;
}
@@ -520,7 +529,7 @@ ncclResult_t mscclEnqueueCheck(
#ifdef ENABLE_MSCCLPP
if (comm->mscclppCompatible) {
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
@@ -566,7 +575,7 @@ ncclResult_t mscclEnqueueCheck(
#ifdef ENABLE_MSCCLPP
if (comm->mscclppCompatible) {
INFO(NCCL_COLL, "MSCCL++: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
@@ -631,8 +640,8 @@ ncclResult_t mscclGroupEnd() {
return ncclSuccess;
}
static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) {
mscclStatus& status = mscclGetStatus(rank);
static ncclResult_t mscclInternalUnloadAlgo(const ncclComm_t comm, mscclAlgoHandle_t mscclAlgoHandle) {
mscclStatus& status = mscclGetStatus(comm);
free(status.hostAlgos[mscclAlgoHandle]);
status.hostAlgos.erase(mscclAlgoHandle);
@@ -649,12 +658,12 @@ static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlg
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
static ncclResult_t mscclInternalSchedulerTeardown(const ncclComm_t comm) {
ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess;
mscclStatus& status = mscclGetStatus(rank);
mscclStatus& status = mscclGetStatus(comm);
for (auto &m : status.rankToAlgoHandles) {
for (auto &p : m) {
tmpRet = mscclInternalUnloadAlgo(rank, p.second);
tmpRet = mscclInternalUnloadAlgo(comm, p.second);
if (ret == ncclSuccess) {
ret = tmpRet;
}
@@ -665,13 +674,13 @@ static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
return ret;
}
ncclResult_t mscclTeardown(int rank) {
ncclResult_t mscclTeardown(const ncclComm_t comm) {
{
if (!mscclInitialized(rank)) {
mscclRemoveRank(rank);
if (!mscclInitialized(comm)) {
mscclRemoveRank(comm);
return ncclSuccess;
}
mscclStatus& status = mscclGetStatus(rank);
mscclStatus& status = mscclGetStatus(comm);
for (auto &p : status.hostAlgos) {
free(p.second);
status.freeAlgoHandles.push_back(p.first);
@@ -694,14 +703,14 @@ ncclResult_t mscclTeardown(int rank) {
dlclose(status.mscclSchedulerLib);
status.mscclSchedulerLib = nullptr;
} else {
NCCLCHECK(mscclInternalSchedulerTeardown(rank));
NCCLCHECK(mscclInternalSchedulerTeardown(comm));
}
NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus)));
for (auto &p : status.graphWorkFifoStatus) {
NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second)));
}
mscclSetInitialized(rank, false);
mscclRemoveRank(rank);
mscclSetInitialized(comm, false);
mscclRemoveRank(comm);
}
INFO(NCCL_INIT, "MSCCL: Teardown finished");
+13 -13
Просмотреть файл
@@ -22,10 +22,10 @@ static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nC
return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop;
}
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(rank);
ncclResult_t mscclGetCaptureStatus(const ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(comm);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(rank);
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm);
cudaStreamCaptureStatus captureStatus;
unsigned long long captureId;
CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &captureStatus, &captureId, &threadLocalStatus.graph, nullptr, nullptr));
@@ -47,7 +47,7 @@ ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) {
}
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1;
status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1;
@@ -72,8 +72,8 @@ ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
return ncclSuccess;
}
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(rank);
ncclResult_t mscclSetupSyncFlags(const ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(comm);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
if (threadLocalStatus.captureStatus == mscclNewCapture ||
status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) {
@@ -85,7 +85,7 @@ ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) {
}
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
// Check whether there are enough channels
if (hostAlgo->nChannels > comm->nChannels) {
@@ -122,7 +122,7 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm)
}
static ncclResult_t mscclSetupProxyImpl(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
struct ncclProxyOp proxyOp = {};
proxyOp.connIndex = 0;
@@ -183,12 +183,12 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) {
}
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm->rank);
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm);
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
INFO(NCCL_NET, "mscclSetupProxy: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
}
if (threadLocalStatus.captureStatus == mscclNoCapture) {
INFO(NCCL_NET,"mscclSetupProxy: no capture\n");
@@ -412,7 +412,7 @@ RCCL_PARAM(MscclForceFullOps, "MSCCL_FORCE_FULLOPS", 0);
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
if (status.lastStream != stream && status.lastStream != nullptr) {
@@ -481,7 +481,7 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
INFO(NCCL_NET, "MSCCL: reading capture status");
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
}
mscclWorkFifoStatus* workFifoStatus = nullptr;
if (threadLocalStatus.captureStatus == mscclNoCapture) {
+25 -20
Просмотреть файл
@@ -7,61 +7,66 @@
#include "msccl/msccl_struct.h"
#include "debug.h"
#include "comm.h"
#include <memory>
#include <mutex>
#include <unordered_map>
using namespace std;
struct mscclRankState {
int rank;
bool initialized;
mscclStatus status;
mscclSavedProxyArgs savedProxyArgs;
mscclRankState() : rank(-1), initialized(false), status(), savedProxyArgs() {}
mscclRankState() : initialized(false), status(), savedProxyArgs() {}
explicit mscclRankState(const mscclRankState&) = default;
};
static mutex rankStatesMutex;
static unordered_map<int, shared_ptr<mscclRankState>> rankStates;
/*
* @brief rankStates is intended to hold mscclRankState for each communicator in a rccl process.
* "rankStates" is not threadsafe, hence read/writes on this data strcutures need to be handled explicitly by
* block of code that is accessing the elements in this map using a lock guard or any mutual exclusion device.
*/
static unordered_map<ncclComm_t, shared_ptr<mscclRankState>> rankStates;
static inline mscclRankState& mscclGetRankState(int rank) {
// In the unlikely case of negative rank, return a per-thread state
if (rank < 0) {
static inline mscclRankState& mscclGetRankState(const ncclComm_t comm) {
//the following condition comm == nullptr evaluates true when mscclAvailable() called with default params
if (comm == nullptr) {
static thread_local shared_ptr<mscclRankState> threadRankState(new mscclRankState());
return *threadRankState;
}
lock_guard<mutex> lock(rankStatesMutex);
auto rankStateIt = rankStates.find(rank);
auto rankStateIt = rankStates.find(comm);
if (rankStateIt == rankStates.end()) {
// Create a per rank threadRankState rather than per thread
shared_ptr<mscclRankState> newthreadRankState(new mscclRankState());
newthreadRankState->rank = rank;
rankStateIt = rankStates.insert(make_pair(rank, newthreadRankState)).first;
// newthreadRankState->rank = rank;
rankStateIt = rankStates.insert(make_pair(comm, newthreadRankState)).first;
}
return *(rankStateIt->second);
}
bool mscclInitialized(int rank) {
return mscclGetRankState(rank).initialized;
bool mscclInitialized(const ncclComm_t comm) {
return mscclGetRankState(comm).initialized;
}
void mscclSetInitialized(int rank, bool initialized) {
auto& state = mscclGetRankState(rank);
void mscclSetInitialized(const ncclComm_t comm, bool initialized) {
auto& state = mscclGetRankState(comm);
assert(!initialized || !state.initialized);
state.initialized = initialized;
}
void mscclRemoveRank(int rank) {
void mscclRemoveRank(const ncclComm_t comm) {
lock_guard<mutex> lock(rankStatesMutex);
rankStates.erase(rank);
rankStates.erase(comm);
}
mscclStatus& mscclGetStatus(int rank) {
return mscclGetRankState(rank).status;
mscclStatus& mscclGetStatus(const ncclComm_t comm) {
return mscclGetRankState(comm).status;
}
mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
@@ -69,6 +74,6 @@ mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
return threadLocalStatus;
}
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank) {
return mscclGetRankState(rank).savedProxyArgs;
mscclSavedProxyArgs& mscclGetSavedProxyArgs(const ncclComm_t comm) {
return mscclGetRankState(comm).savedProxyArgs;
}
+7 -7
Просмотреть файл
@@ -14,10 +14,10 @@
using namespace rccl;
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) {
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm);
ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm) {
Recorder::instance().record("mscclLoadAlgo");
mscclStatus& status = mscclGetStatus(rank);
mscclStatus& status = mscclGetStatus(comm);
if (status.freeAlgoHandles.size() == 0) {
WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS);
@@ -28,7 +28,7 @@ ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t
struct mscclAlgo* hostAlgo;
NCCLCHECK(ncclCalloc(&hostAlgo, 1));
NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, rank));
NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, comm->rank));
status.hostAlgos[*mscclAlgoHandle] = hostAlgo;
struct mscclAlgo* devAlgo;
@@ -53,7 +53,7 @@ ncclResult_t mscclRunAlgo_impl(
NVTX3_FUNC_WITH_PARAMS(MSCCL, NcclNvtxParamsMSCCL,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(dataType), op, dataType));
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm);
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
@@ -65,13 +65,13 @@ ncclResult_t mscclRunAlgo_impl(
CUDACHECK(hipSetDevice(comm->cudaDev));
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));
NCCLCHECK(mscclSetupScratch(hostAlgo, stream));
NCCLCHECK(mscclSetupSyncFlags(comm->rank, stream));
NCCLCHECK(mscclSetupSyncFlags(comm, stream));
NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream));
+3 -4
Просмотреть файл
@@ -5,7 +5,6 @@
*
* See LICENSE.txt for license information
************************************************************************/
#ifndef NCCL_H_
#define NCCL_H_
@@ -764,10 +763,10 @@ typedef int mscclAlgoHandle_t;
@param[in] mscclAlgoFilePath Path to MSCCL algorithm file
@param[out] mscclAlgoHandle Returned handle to MSCCL algorithm
@param[in] rank Current rank */
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
@param[in] comm Current rank's communicator */
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm);
/*! @cond include_hidden */
ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm);
/*! @endcond */
/*! @brief MSCCL Run Algorithm