diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index 9dc9a7cd44..d136768397 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -805,6 +805,12 @@ static ncclResult_t scheduleCollTasksToPlan( } proxyOp->channelId = c; proxyOp->opCount = proxyOpId; + proxyOp->connIndex = 0; + if (task->protocol == NCCL_PROTO_SIMPLE && task->algorithm == NCCL_ALGO_RING) { + if (comm->useIntraNet && nBytes > rcclParamIntraNetThreshold()) { + proxyOp->connIndex = NCCL_CONN_IDX_P2P_NET; + } + } addWorkBatchToPlan(comm, plan, c, workNode->workType, task->devFuncId, plan->workBytes); NCCLCHECK(addProxyOpIfNeeded(comm, plan, proxyOp)); } @@ -1992,13 +1998,6 @@ static ncclResult_t calcCollChunking( } } - proxyOp->connIndex = 0; - if (info->protocol == NCCL_PROTO_SIMPLE && info->algorithm == NCCL_ALGO_RING) { - if (comm->useIntraNet && nBytes > rcclParamIntraNetThreshold()) { - proxyOp->connIndex = NCCL_CONN_IDX_P2P_NET; - } - } - *outChunkSize = chunkSize; return ncclSuccess; } diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index f990efc04b..c4b0c06e3d 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -1631,6 +1631,17 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p } NCCLCHECKGOTO(ncclTransportRingConnect(comm), ret, fail); + // Connect NET for intranode use + if (comm->graphs[NCCL_ALGO_RING].nIntraChannels && rcclParamP2pNetDisable() == 0) { + comm->useIntraNet = 1; + for (int c = 0; c < comm->nChannels; c++) { + struct ncclChannel* channel = comm->channels+c; + if (comm->nRanks == 1) continue; + NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->ring.prev, 1, &channel->ring.next, NCCL_CONN_IDX_P2P_NET), ret, fail); + } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &comm->graphs[NCCL_ALGO_RING], NCCL_CONN_IDX_P2P_NET), ret, fail); + } + // Connect Trees NCCLCHECKGOTO(ncclTransportTreeConnect(comm), ret, fail);