misc/msccl: force use of mscclpp (#1581)
This commit is contained in:
@@ -598,6 +598,7 @@ struct ncclComm {
|
||||
bool mscclppCompatible;
|
||||
struct mscclppComm* mscclpp_comm;
|
||||
size_t mscclpp_threshold;
|
||||
bool mscclppForceEnable;
|
||||
#endif
|
||||
|
||||
// Whether this comm is compatible with MSCCL
|
||||
|
||||
@@ -121,6 +121,7 @@ static constexpr int64_t defaultEnableMscclpp = 0;
|
||||
#endif
|
||||
|
||||
RCCL_PARAM(MscclppEnabled, "MSCCLPP_ENABLE", defaultEnableMscclpp);
|
||||
RCCL_PARAM(MscclppForceEnabled, "MSCCLPP_FORCE_ENABLE", 0);
|
||||
|
||||
// GDRCOPY support: Off by default
|
||||
NCCL_PARAM(GdrCopyEnable, "GDRCOPY_ENABLE", 1);
|
||||
@@ -1956,6 +1957,11 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
TRACE_CALL("mscclpp_ncclCommInitRank (*comm=%p, nranks=%d, commId=hash:0x%llx, myrank=%d)", comm->mscclpp_comm, job->nranks, mscclppUniqueIdHash, job->myrank);
|
||||
mscclpp_commToUniqueIdMap[comm->mscclpp_comm] = mscclppUniqueId;
|
||||
ncclCommToUniqueIdMap[comm] = job->commId;
|
||||
if (rcclParamMscclppForceEnabled()) {
|
||||
comm->mscclppForceEnable = true;
|
||||
} else {
|
||||
comm->mscclppForceEnable = false;
|
||||
}
|
||||
} else {
|
||||
WARN("MSCCL++: Cannot enable MSCCL++ on %s architecture", devProp.gcnArchName);
|
||||
}
|
||||
|
||||
@@ -528,22 +528,24 @@ ncclResult_t mscclEnqueueCheck(
|
||||
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
|
||||
|
||||
/* check if one rank per GPU and graph mode is enabled */
|
||||
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
if ((graphMode || buffsRegistered || comm->mscclppForceEnable) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
bool isManagedBuffer = false;
|
||||
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
|
||||
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
|
||||
|
||||
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
|
||||
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
|
||||
graphMode, buffsRegistered);
|
||||
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
}
|
||||
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d",
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
|
||||
graphMode, buffsRegistered);
|
||||
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
@@ -572,22 +574,24 @@ ncclResult_t mscclEnqueueCheck(
|
||||
const bool buffsRegistered = sendBuffRegistered && recvBuffRegistered;
|
||||
|
||||
/* check if one rank per GPU and graph mode is enabled */
|
||||
if ((graphMode || buffsRegistered) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
if ((graphMode || buffsRegistered || comm->mscclppForceEnable) && comm->mscclCompatible && nBytes > 0 && (nBytes & 31) == 0) {
|
||||
bool isManagedBuffer = false;
|
||||
if (sendBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(sendBuff)));
|
||||
if (!isManagedBuffer && recvBuff) CUDACHECK(hipPointerGetAttribute(&isManagedBuffer, HIP_POINTER_ATTRIBUTE_IS_MANAGED, const_cast<void*>(recvBuff)));
|
||||
|
||||
if (isManagedBuffer) { /* MSCCL++ not enabled for managed memory buffers */ }
|
||||
else if (func == mscclFuncAllReduce && nBytes <= comm->mscclpp_threshold && isMscclppAllReduceSupported(dataType, op)) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d",
|
||||
"mscclpp_ncclAllReduce", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
|
||||
graphMode, buffsRegistered);
|
||||
NCCLCHECK(mscclpp_ncclAllReduce(sendBuff, recvBuff, count, dataType, op, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
}
|
||||
else if (func == mscclFuncAllGather && nBytes * comm->nRanks <= comm->mscclpp_threshold) {
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d" ,
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream, graphMode);
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p graphMode %d ubr %d" ,
|
||||
"mscclpp_ncclAllGather", comm->opCount, sendBuff, recvBuff, count, dataType, op, root, comm, comm->nRanks, stream,
|
||||
graphMode, buffsRegistered);
|
||||
NCCLCHECK(mscclpp_ncclAllGather(sendBuff, recvBuff, count, dataType, comm->mscclpp_comm, stream));
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user