diff --git a/src/include/comm.h b/src/include/comm.h index 254981263e..db4df7c0cb 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -598,6 +598,7 @@ struct ncclComm { bool mscclppCompatible; struct mscclppComm* mscclpp_comm; size_t mscclpp_threshold; + bool mscclppForceEnable; #endif // Whether this comm is compatible with MSCCL diff --git a/src/init.cc b/src/init.cc index c4b0c06e3d..47df7f06a3 100644 --- a/src/init.cc +++ b/src/init.cc @@ -121,6 +121,7 @@ static constexpr int64_t defaultEnableMscclpp = 0; #endif RCCL_PARAM(MscclppEnabled, "MSCCLPP_ENABLE", defaultEnableMscclpp); +RCCL_PARAM(MscclppForceEnabled, "MSCCLPP_FORCE_ENABLE", 0); // GDRCOPY support: Off by default NCCL_PARAM(GdrCopyEnable, "GDRCOPY_ENABLE", 1); @@ -1956,6 +1957,11 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank); mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId; ncclCommToUniqueIdMap[comm] = job->commId; + if (rcclParamMscclppForceEnabled()) { + comm->mscclppForceEnable = true; + } else { + comm->mscclppForceEnable = false; + } } else { WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName); } diff --git a/src/misc/msccl/msccl_lifecycle.cc b/src/misc/msccl/msccl_lifecycle.cc index e73d68b86b..da4ad3d14c 100644 --- a/src/misc/msccl/msccl_lifecycle.cc +++ b/src/misc/msccl/msccl_lifecycle.cc @@ -528,22 +528,24 @@ ncclResult_t mscclEnqueueCheck( const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered; /* check if one rank per GPU and graph mode is enabled */ - if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { + if ((graphMode || buffsRegistered || comm->mscclppForceEnable) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { bool isManagedBuffer = false; if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(sendBuff))); if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(recvBuff))); if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ } else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d", - "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d", + "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, + graphMode, buffsRegistered); NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break; } else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d", - "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d", + "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, + graphMode, buffsRegistered); NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break; @@ -572,22 +574,24 @@ ncclResult_t mscclEnqueueCheck( const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered; /* check if one rank per GPU and graph mode is enabled */ - if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { + if ((graphMode || buffsRegistered || comm->mscclppForceEnable) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) { bool isManagedBuffer = false; if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(sendBuff))); if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast(recvBuff))); if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ } else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d", - "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d", + "mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, + graphMode, buffsRegistered); NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break; } else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) { - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d" , - "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode); + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d" , + "mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, + graphMode, buffsRegistered); NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream)); threadLocalStatus.savedSchedulerParams.clear(); break;