misc/msccl: force use of mscclpp (#1581)

This commit is contained in:
Nusrat Islam
2025-03-04 12:48:59 -06:00
committed by GitHub
parent d88cca3098
commit ac823818aa
3 changed files with 21 additions and 10 deletions
+1
View File
@@ -598,6 +598,7 @@ struct ncclComm {
bool mscclppCompatible;
struct mscclppComm* mscclpp_comm;
size_t mscclpp_threshold;
bool mscclppForceEnable;
#endif
// Whether this comm is compatible with MSCCL
+6
View File
@@ -121,6 +121,7 @@ static constexpr int64_t defaultEnableMscclpp = 0;
#endif
RCCL_PARAM(MscclppEnabled, "MSCCLPP_ENABLE", defaultEnableMscclpp);
RCCL_PARAM(MscclppForceEnabled, "MSCCLPP_FORCE_ENABLE", 0);
// GDRCOPY support: Off by default
NCCL_PARAM(GdrCopyEnable, "GDRCOPY_ENABLE", 1);
@@ -1956,6 +1957,11 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank);
mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId;
ncclCommToUniqueIdMap[comm] = job->commId;
if (rcclParamMscclppForceEnabled()) {
comm->mscclppForceEnable = true;
} else {
comm->mscclppForceEnable = false;
}
} else {
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
}
+14 -10
View File
@@ -528,22 +528,24 @@ ncclResult_t mscclEnqueueCheck(
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
/* check if one rank per GPU and graph mode is enabled */
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
if ((graphMode || buffsRegistered || comm->mscclppForceEnable) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
graphMode, buffsRegistered);
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d",
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
graphMode, buffsRegistered);
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
@@ -572,22 +574,24 @@ ncclResult_t mscclEnqueueCheck(
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
/* check if one rank per GPU and graph mode is enabled */
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
if ((graphMode || buffsRegistered || comm->mscclppForceEnable) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
bool isManagedBuffer = false;
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d",
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
graphMode, buffsRegistered);
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;
}
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d" ,
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d" ,
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
graphMode, buffsRegistered);
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
threadLocalStatus.savedSchedulerParams.clear();
break;