SPLITCOMM design fix in src/misc/msccl (#1715)
* Fix TOC-TOU in mcclInit
* Improving vector resize thread safety
* Initial commit rank to comm change
* Removing unwanted include header changes
* Updated CHANGELOG.md
* Update CHANGELOG.md
Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
---------
Co-authored-by: Jeffrey Novotny <jnovotny@amd.com>
[ROCm/rccl commit: e94b360246]
Этот коммит содержится в:
@@ -8,6 +8,8 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
|
||||
|
||||
* Resolved an issue when using more than 64 channels when multiple collectives are used in the same `ncclGroup()` call.
|
||||
* Fixed unit test failures in tests ending with `ManagedMem` and `ManagedMemGraph` suffixes.
|
||||
* Fixed the known issue "When splitting a communicator using `ncclCommSplit` in some GPU configurations, MSCCL initialization can cause a segmentation fault." with a design change to use `comm` instead of `rank` for `mscclStatus`. The Global map for `comm` to `mscclStatus` is still not thread safe but should be explicitly handled by mutexes for read writes. This is tested for correctness, but there is a plan to use a thread-safe map data structure in upcoming changes.
|
||||
|
||||
|
||||
### Added
|
||||
|
||||
@@ -304,4 +306,4 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
|
||||
### Changed
|
||||
- Switched to hip-clang as default compiler
|
||||
### Deprecated
|
||||
- Deprecated hcc build
|
||||
- Deprecated hcc build
|
||||
@@ -96,7 +96,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
|
||||
NCCLCHECK(Recorder::instance().record(rrAllGather, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
|
||||
@@ -124,7 +124,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
|
||||
NCCLCHECK(Recorder::instance().record(rrAllReduce, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
|
||||
@@ -149,7 +149,7 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
|
||||
NCCLCHECK(Recorder::instance().record(rrAllToAll, sendbuff, recvbuff, count, datatype, comm, stream));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
|
||||
@@ -195,7 +195,7 @@ ncclResult_t ncclAllToAllv_impl(const void *sendbuff, const size_t sendcounts[],
|
||||
NCCLCHECK(Recorder::instance().record(rrAllToAllv, sendbuff, recvbuff, 0, datatype, comm, stream, -1, sendcounts, sdispls, recvcounts, rdispls));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
|
||||
0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream);
|
||||
@@ -241,7 +241,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou
|
||||
NCCLCHECK(Recorder::instance().record(rrBroadcast, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
|
||||
@@ -271,7 +271,7 @@ ncclResult_t ncclGather_impl(const void* sendbuff, void* recvbuff, size_t sendco
|
||||
NCCLCHECK(Recorder::instance().record(rrGather, sendbuff, recvbuff, sendcount, datatype, comm, stream, root));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream);
|
||||
@@ -310,7 +310,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count,
|
||||
NCCLCHECK(Recorder::instance().record(rrReduce, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
|
||||
@@ -337,7 +337,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
|
||||
NCCLCHECK(Recorder::instance().record(rrReduceScatter, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
|
||||
@@ -360,7 +360,7 @@ ncclResult_t ncclScatter_impl(const void* sendbuff, void* recvbuff, size_t recvc
|
||||
NCCLCHECK(Recorder::instance().record(rrScatter, sendbuff, recvbuff, recvcount, datatype, comm, stream, root));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream);
|
||||
@@ -400,7 +400,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da
|
||||
NCCLCHECK(Recorder::instance().record(rrSend, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
|
||||
@@ -426,7 +426,7 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype
|
||||
NCCLCHECK(Recorder::instance().record(rrRecv, info));
|
||||
}
|
||||
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
|
||||
|
||||
@@ -139,7 +139,7 @@ typedef ncclResult_t (*ncclMemAlloc_fn_t)(void** ptr, size_t size);
|
||||
typedef ncclResult_t (*ncclMemFree_fn_t)(void* ptr);
|
||||
|
||||
typedef ncclResult_t (*mscclLoadAlgo_fn_t)(const char* mscclAlgoFilePath,
|
||||
mscclAlgoHandle_t* mscclAlgoHandle, int rank);
|
||||
mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm);
|
||||
|
||||
typedef ncclResult_t (*mscclRunAlgo_fn_t)(
|
||||
const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
|
||||
@@ -17,7 +17,15 @@ void mscclSetIsCallerFlag();
|
||||
void mscclClearIsCallerFlag();
|
||||
bool mscclIsCaller();
|
||||
|
||||
bool mscclAvailable(int rank = -1);
|
||||
/**
|
||||
* @brief mscclAvailable() is used to determine if msccl functionality is avaliable
|
||||
* @param comm is an optional rccl communicator, if provided uses the mscclStatus
|
||||
* from a global map<comm -> mscclStatus> to determine if msccl is available. If not available
|
||||
* in the map, this invocations inserts a new key value pair in the global map.
|
||||
* If comm == nullptr, on the first invocation it initializes a static thread local variable
|
||||
* mscclStatus and uses the same object in subsequent calls from same thread if comm is null ptr
|
||||
*/
|
||||
bool mscclAvailable(const ncclComm_t comm = nullptr);
|
||||
|
||||
ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired);
|
||||
|
||||
@@ -33,7 +41,7 @@ ncclResult_t mscclEnqueueCheck(
|
||||
|
||||
ncclResult_t mscclGroupEnd();
|
||||
|
||||
ncclResult_t mscclTeardown(int rank);
|
||||
ncclResult_t mscclTeardown(const ncclComm_t comm);
|
||||
|
||||
size_t mscclKernMaxLocalSize();
|
||||
|
||||
|
||||
@@ -11,13 +11,13 @@
|
||||
#include "comm.h"
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream);
|
||||
ncclResult_t mscclGetCaptureStatus(const ncclComm_t comm, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream);
|
||||
ncclResult_t mscclSetupSyncFlags(const ncclComm_t comm, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm);
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo,const ncclComm_t comm);
|
||||
|
||||
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType);
|
||||
|
||||
|
||||
@@ -8,15 +8,15 @@
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
bool mscclInitialized(int rank);
|
||||
bool mscclInitialized(const ncclComm_t comm);
|
||||
|
||||
void mscclSetInitialized(int rank, bool initialized = true);
|
||||
void mscclSetInitialized(const ncclComm_t comm, bool initialized = true);
|
||||
|
||||
void mscclRemoveRank(int rank);
|
||||
void mscclRemoveRank(const ncclComm_t comm);
|
||||
|
||||
mscclStatus& mscclGetStatus(int rank);
|
||||
mscclStatus& mscclGetStatus(const ncclComm_t comm);
|
||||
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank);
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs(const ncclComm_t comm);
|
||||
|
||||
mscclThreadLocalStatus& mscclGetThreadLocalStatus();
|
||||
|
||||
|
||||
@@ -1745,7 +1745,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
|
||||
|
||||
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
|
||||
NCCLCHECK(mscclInit(comm));
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
status.needsProxy |= mscclNeedsProxy;
|
||||
}
|
||||
|
||||
@@ -2530,7 +2530,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) {
|
||||
NCCLCHECK(ncclTunerPluginUnload(comm));
|
||||
}
|
||||
if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) {
|
||||
NCCLCHECK(mscclTeardown(comm->rank));
|
||||
NCCLCHECK(mscclTeardown(comm));
|
||||
}
|
||||
NCCLCHECK(commFree(comm));
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ ncclMemFree_impl(void* ptr);
|
||||
|
||||
ncclResult_t
|
||||
mscclLoadAlgo_impl(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle,
|
||||
int rank);
|
||||
const ncclComm_t comm);
|
||||
|
||||
ncclResult_t
|
||||
mscclRunAlgo_impl(const void* sendBuff, const size_t sendCounts[], const size_t sDisPls[],
|
||||
@@ -380,7 +380,7 @@ NCCL_API(ncclResult_t, ncclMemAlloc, void** ptr, size_t size);
|
||||
NCCL_API(ncclResult_t, ncclMemFree, void* ptr);
|
||||
|
||||
NCCL_API(ncclResult_t, mscclLoadAlgo, const char* mscclAlgoFilePath,
|
||||
mscclAlgoHandle_t* mscclAlgoHandle, int rank);
|
||||
mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm);
|
||||
|
||||
NCCL_API(ncclResult_t, mscclRunAlgo, const void* sendBuff, const size_t sendCounts[],
|
||||
const size_t sDisPls[], void* recvBuff, const size_t recvCounts[],
|
||||
@@ -620,10 +620,10 @@ ncclMemFree(void* ptr)
|
||||
}
|
||||
|
||||
ncclResult_t
|
||||
mscclLoadAlgo(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, int rank)
|
||||
mscclLoadAlgo(const char* mscclAlgoFilePath, mscclAlgoHandle_t* mscclAlgoHandle, const ncclComm_t comm)
|
||||
{
|
||||
return ::rccl::RcclGetFunctionTable()->mscclLoadAlgo_fn(mscclAlgoFilePath,
|
||||
mscclAlgoHandle, rank);
|
||||
mscclAlgoHandle, comm);
|
||||
}
|
||||
|
||||
ncclResult_t
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "msccl/msccl_setup.h"
|
||||
#include "msccl/msccl_status.h"
|
||||
|
||||
#include "rccl/rccl.h"
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
#include "mscclpp/mscclpp_nccl.h"
|
||||
#endif
|
||||
@@ -59,8 +60,8 @@ bool mscclIsCaller() {
|
||||
return mscclGetThreadLocalStatus().mscclIsCallerFlag;
|
||||
}
|
||||
|
||||
bool mscclAvailable(int rank) {
|
||||
return mscclEnabled() && mscclInitialized(rank);
|
||||
bool mscclAvailable(const ncclComm_t comm) {
|
||||
return mscclEnabled() && mscclInitialized(comm);
|
||||
}
|
||||
|
||||
static bool allProcessHostsUnique(ncclComm_t comm) {
|
||||
@@ -118,7 +119,7 @@ static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-tes
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
|
||||
static thread_local bool mscclAlgoMetaLoaded = false;
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
|
||||
int maxNchannels = *numChannelsRequired;
|
||||
*numChannelsRequired = 0;
|
||||
@@ -213,7 +214,7 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
bool useInternalScheduler = false;
|
||||
|
||||
const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv);
|
||||
@@ -243,12 +244,14 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
ncclResult_t mscclInit(const ncclComm_t comm) {
|
||||
{
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
|
||||
// freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection
|
||||
if (!mscclInitialized(comm->rank)) {
|
||||
if (!mscclInitialized(comm)) {
|
||||
static std::mutex initMutex;
|
||||
std::lock_guard<std::mutex> lock(initMutex);
|
||||
status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS);
|
||||
for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) {
|
||||
status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1;
|
||||
@@ -264,13 +267,16 @@ ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
auto &m = status.algoMetas[i];
|
||||
if (m.nRanks == comm->nRanks) {
|
||||
// Load algorithms
|
||||
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) {
|
||||
mscclAlgoHandle_t mscclAlgoHandle;
|
||||
{
|
||||
static std::mutex loadAlgoMutex;
|
||||
std::lock_guard<std::mutex> lock(loadAlgoMutex);
|
||||
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank));
|
||||
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()){
|
||||
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm));
|
||||
}
|
||||
// Connect algorithms
|
||||
mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank];
|
||||
}
|
||||
// Connect algorithms
|
||||
mscclAlgoHandle_t mscclAlgoHandle = status.rankToAlgoHandles[i][comm->rank];
|
||||
if (status.connectedAlgos[comm].find(mscclAlgoHandle) == status.connectedAlgos[comm].end()) {
|
||||
NCCLCHECK(mscclSetupConnections(status.hostAlgos[mscclAlgoHandle], comm));
|
||||
status.connectedAlgos[comm].insert(mscclAlgoHandle);
|
||||
@@ -278,17 +284,20 @@ ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
static std::mutex mscclInitMutex;
|
||||
std::lock_guard<std::mutex> lock(mscclInitMutex);
|
||||
if (mscclInitialized(comm)){
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
if (mscclInitialized(comm->rank)) {
|
||||
return ncclSuccess;
|
||||
status.workIndex = 1;
|
||||
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
|
||||
status.lastStream = nullptr;
|
||||
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
|
||||
|
||||
mscclSetInitialized(comm);
|
||||
}
|
||||
|
||||
status.workIndex = 1;
|
||||
NCCLCHECK(ncclCudaCalloc(&status.syncFlags, MSCCL_MAX_NUM_THREAD_BLOCKS));
|
||||
status.lastStream = nullptr;
|
||||
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
|
||||
|
||||
mscclSetInitialized(comm->rank);
|
||||
}
|
||||
|
||||
INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize());
|
||||
@@ -304,8 +313,8 @@ ncclResult_t mscclGroupStart() {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
static ncclResult_t mscclInternalSchedulerSelectAlgo(const ncclComm_t comm, struct mscclSchedulerParam* param) {
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
param->scheduled = false;
|
||||
|
||||
// Current MSCCL doesn't support pre/post op
|
||||
@@ -350,13 +359,13 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSched
|
||||
}
|
||||
|
||||
static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) {
|
||||
mscclStatus& status = mscclGetStatus(param->comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(param->comm);
|
||||
if (status.mscclSchedulerPtr) {
|
||||
NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p)));
|
||||
} else {
|
||||
// Disable MSCCL algorithms if machine type is not matching
|
||||
if (param->comm->topo->mscclEnabled || mscclForceEnabled()) {
|
||||
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p)));
|
||||
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm, &(param->p)));
|
||||
} else {
|
||||
param->p.scheduled = false;
|
||||
}
|
||||
@@ -520,7 +529,7 @@ ncclResult_t mscclEnqueueCheck(
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (comm->mscclppCompatible) {
|
||||
INFO(NCCL_COLL, "MSCCL++: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
|
||||
|
||||
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
|
||||
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
|
||||
@@ -566,7 +575,7 @@ ncclResult_t mscclEnqueueCheck(
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (comm->mscclppCompatible) {
|
||||
INFO(NCCL_COLL, "MSCCL++: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
|
||||
|
||||
const bool sendBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, sendBuff);
|
||||
const bool recvBuffRegistered = mscclpp_BuffIsRegistered(comm->mscclpp_comm, recvBuff);
|
||||
@@ -631,8 +640,8 @@ ncclResult_t mscclGroupEnd() {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
static ncclResult_t mscclInternalUnloadAlgo(const ncclComm_t comm, mscclAlgoHandle_t mscclAlgoHandle) {
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
|
||||
free(status.hostAlgos[mscclAlgoHandle]);
|
||||
status.hostAlgos.erase(mscclAlgoHandle);
|
||||
@@ -649,12 +658,12 @@ static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlg
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
|
||||
static ncclResult_t mscclInternalSchedulerTeardown(const ncclComm_t comm) {
|
||||
ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess;
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
for (auto &m : status.rankToAlgoHandles) {
|
||||
for (auto &p : m) {
|
||||
tmpRet = mscclInternalUnloadAlgo(rank, p.second);
|
||||
tmpRet = mscclInternalUnloadAlgo(comm, p.second);
|
||||
if (ret == ncclSuccess) {
|
||||
ret = tmpRet;
|
||||
}
|
||||
@@ -665,13 +674,13 @@ static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ncclResult_t mscclTeardown(int rank) {
|
||||
ncclResult_t mscclTeardown(const ncclComm_t comm) {
|
||||
{
|
||||
if (!mscclInitialized(rank)) {
|
||||
mscclRemoveRank(rank);
|
||||
if (!mscclInitialized(comm)) {
|
||||
mscclRemoveRank(comm);
|
||||
return ncclSuccess;
|
||||
}
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
for (auto &p : status.hostAlgos) {
|
||||
free(p.second);
|
||||
status.freeAlgoHandles.push_back(p.first);
|
||||
@@ -694,14 +703,14 @@ ncclResult_t mscclTeardown(int rank) {
|
||||
dlclose(status.mscclSchedulerLib);
|
||||
status.mscclSchedulerLib = nullptr;
|
||||
} else {
|
||||
NCCLCHECK(mscclInternalSchedulerTeardown(rank));
|
||||
NCCLCHECK(mscclInternalSchedulerTeardown(comm));
|
||||
}
|
||||
NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus)));
|
||||
for (auto &p : status.graphWorkFifoStatus) {
|
||||
NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second)));
|
||||
}
|
||||
mscclSetInitialized(rank, false);
|
||||
mscclRemoveRank(rank);
|
||||
mscclSetInitialized(comm, false);
|
||||
mscclRemoveRank(comm);
|
||||
}
|
||||
|
||||
INFO(NCCL_INIT, "MSCCL: Teardown finished");
|
||||
|
||||
@@ -22,10 +22,10 @@ static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nC
|
||||
return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
ncclResult_t mscclGetCaptureStatus(const ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(rank);
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm);
|
||||
cudaStreamCaptureStatus captureStatus;
|
||||
unsigned long long captureId;
|
||||
CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &captureStatus, &captureId, &threadLocalStatus.graph, nullptr, nullptr));
|
||||
@@ -47,7 +47,7 @@ ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) {
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
|
||||
status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1;
|
||||
status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1;
|
||||
@@ -72,8 +72,8 @@ ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
ncclResult_t mscclSetupSyncFlags(const ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
if (threadLocalStatus.captureStatus == mscclNewCapture ||
|
||||
status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) {
|
||||
@@ -85,7 +85,7 @@ ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) {
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
|
||||
// Check whether there are enough channels
|
||||
if (hostAlgo->nChannels > comm->nChannels) {
|
||||
@@ -122,7 +122,7 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm)
|
||||
}
|
||||
|
||||
static ncclResult_t mscclSetupProxyImpl(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
struct ncclProxyOp proxyOp = {};
|
||||
proxyOp.connIndex = 0;
|
||||
@@ -183,12 +183,12 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) {
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm->rank);
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm);
|
||||
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
|
||||
INFO(NCCL_NET, "mscclSetupProxy: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
|
||||
}
|
||||
if (threadLocalStatus.captureStatus == mscclNoCapture) {
|
||||
INFO(NCCL_NET,"mscclSetupProxy: no capture\n");
|
||||
@@ -412,7 +412,7 @@ RCCL_PARAM(MscclForceFullOps, "MSCCL_FORCE_FULLOPS", 0);
|
||||
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
|
||||
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
|
||||
ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
|
||||
if (status.lastStream != stream && status.lastStream != nullptr) {
|
||||
@@ -481,7 +481,7 @@ ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count
|
||||
|
||||
if (threadLocalStatus.captureStatus == mscclUnknownCaptureStatus) {
|
||||
INFO(NCCL_NET, "MSCCL: reading capture status");
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
|
||||
}
|
||||
mscclWorkFifoStatus* workFifoStatus = nullptr;
|
||||
if (threadLocalStatus.captureStatus == mscclNoCapture) {
|
||||
|
||||
@@ -7,61 +7,66 @@
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
#include "comm.h"
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
using namespace std;
|
||||
|
||||
struct mscclRankState {
|
||||
int rank;
|
||||
bool initialized;
|
||||
mscclStatus status;
|
||||
mscclSavedProxyArgs savedProxyArgs;
|
||||
|
||||
mscclRankState() : rank(-1), initialized(false), status(), savedProxyArgs() {}
|
||||
mscclRankState() : initialized(false), status(), savedProxyArgs() {}
|
||||
explicit mscclRankState(const mscclRankState&) = default;
|
||||
};
|
||||
|
||||
static mutex rankStatesMutex;
|
||||
static unordered_map<int, shared_ptr<mscclRankState>> rankStates;
|
||||
/*
|
||||
* @brief rankStates is intended to hold mscclRankState for each communicator in a rccl process.
|
||||
* "rankStates" is not threadsafe, hence read/writes on this data strcutures need to be handled explicitly by
|
||||
* block of code that is accessing the elements in this map using a lock guard or any mutual exclusion device.
|
||||
*/
|
||||
static unordered_map<ncclComm_t, shared_ptr<mscclRankState>> rankStates;
|
||||
|
||||
static inline mscclRankState& mscclGetRankState(int rank) {
|
||||
// In the unlikely case of negative rank, return a per-thread state
|
||||
if (rank < 0) {
|
||||
static inline mscclRankState& mscclGetRankState(const ncclComm_t comm) {
|
||||
//the following condition comm == nullptr evaluates true when mscclAvailable() called with default params
|
||||
if (comm == nullptr) {
|
||||
static thread_local shared_ptr<mscclRankState> threadRankState(new mscclRankState());
|
||||
return *threadRankState;
|
||||
}
|
||||
|
||||
lock_guard<mutex> lock(rankStatesMutex);
|
||||
|
||||
auto rankStateIt = rankStates.find(rank);
|
||||
auto rankStateIt = rankStates.find(comm);
|
||||
if (rankStateIt == rankStates.end()) {
|
||||
// Create a per rank threadRankState rather than per thread
|
||||
shared_ptr<mscclRankState> newthreadRankState(new mscclRankState());
|
||||
newthreadRankState->rank = rank;
|
||||
rankStateIt = rankStates.insert(make_pair(rank, newthreadRankState)).first;
|
||||
// newthreadRankState->rank = rank;
|
||||
rankStateIt = rankStates.insert(make_pair(comm, newthreadRankState)).first;
|
||||
}
|
||||
return *(rankStateIt->second);
|
||||
}
|
||||
|
||||
bool mscclInitialized(int rank) {
|
||||
return mscclGetRankState(rank).initialized;
|
||||
bool mscclInitialized(const ncclComm_t comm) {
|
||||
return mscclGetRankState(comm).initialized;
|
||||
}
|
||||
|
||||
void mscclSetInitialized(int rank, bool initialized) {
|
||||
auto& state = mscclGetRankState(rank);
|
||||
void mscclSetInitialized(const ncclComm_t comm, bool initialized) {
|
||||
auto& state = mscclGetRankState(comm);
|
||||
assert(!initialized || !state.initialized);
|
||||
state.initialized = initialized;
|
||||
}
|
||||
|
||||
void mscclRemoveRank(int rank) {
|
||||
void mscclRemoveRank(const ncclComm_t comm) {
|
||||
lock_guard<mutex> lock(rankStatesMutex);
|
||||
rankStates.erase(rank);
|
||||
rankStates.erase(comm);
|
||||
}
|
||||
|
||||
mscclStatus& mscclGetStatus(int rank) {
|
||||
return mscclGetRankState(rank).status;
|
||||
mscclStatus& mscclGetStatus(const ncclComm_t comm) {
|
||||
return mscclGetRankState(comm).status;
|
||||
}
|
||||
|
||||
mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
|
||||
@@ -69,6 +74,6 @@ mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
|
||||
return threadLocalStatus;
|
||||
}
|
||||
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank) {
|
||||
return mscclGetRankState(rank).savedProxyArgs;
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs(const ncclComm_t comm) {
|
||||
return mscclGetRankState(comm).savedProxyArgs;
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
|
||||
using namespace rccl;
|
||||
|
||||
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
|
||||
ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) {
|
||||
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm);
|
||||
ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm) {
|
||||
Recorder::instance().record("mscclLoadAlgo");
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
|
||||
if (status.freeAlgoHandles.size() == 0) {
|
||||
WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS);
|
||||
@@ -28,7 +28,7 @@ ncclResult_t mscclLoadAlgo_impl(const char *mscclAlgoFilePath, mscclAlgoHandle_t
|
||||
|
||||
struct mscclAlgo* hostAlgo;
|
||||
NCCLCHECK(ncclCalloc(&hostAlgo, 1));
|
||||
NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, rank));
|
||||
NCCLCHECK(mscclGetAlgoFromXmlFile(mscclAlgoFilePath, hostAlgo, comm->rank));
|
||||
status.hostAlgos[*mscclAlgoHandle] = hostAlgo;
|
||||
|
||||
struct mscclAlgo* devAlgo;
|
||||
@@ -53,7 +53,7 @@ ncclResult_t mscclRunAlgo_impl(
|
||||
NVTX3_FUNC_WITH_PARAMS(MSCCL, NcclNvtxParamsMSCCL,
|
||||
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(dataType), op, dataType));
|
||||
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclStatus& status = mscclGetStatus(comm);
|
||||
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
|
||||
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
|
||||
|
||||
@@ -65,13 +65,13 @@ ncclResult_t mscclRunAlgo_impl(
|
||||
|
||||
CUDACHECK(hipSetDevice(comm->cudaDev));
|
||||
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));
|
||||
|
||||
NCCLCHECK(mscclSetupScratch(hostAlgo, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupSyncFlags(comm->rank, stream));
|
||||
NCCLCHECK(mscclSetupSyncFlags(comm, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream));
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#ifndef NCCL_H_
|
||||
#define NCCL_H_
|
||||
|
||||
@@ -764,10 +763,10 @@ typedef int mscclAlgoHandle_t;
|
||||
|
||||
@param[in] mscclAlgoFilePath Path to MSCCL algorithm file
|
||||
@param[out] mscclAlgoHandle Returned handle to MSCCL algorithm
|
||||
@param[in] rank Current rank */
|
||||
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
|
||||
@param[in] comm Current rank's communicator */
|
||||
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm);
|
||||
/*! @cond include_hidden */
|
||||
ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
|
||||
ncclResult_t pmscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, const ncclComm_t comm);
|
||||
/*! @endcond */
|
||||
|
||||
/*! @brief MSCCL Run Algorithm
|
||||
|
||||
Ссылка в новой задаче
Block a user