From 72ef10005026e4ea1f480b8495a4bc20efbaee07 Mon Sep 17 00:00:00 2001 From: Wenkai Du Date: Mon, 31 Oct 2022 08:54:34 -0700 Subject: [PATCH] Fix P2P scheduling --- src/enqueue.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/enqueue.cc b/src/enqueue.cc index 235842d6df..4fc77a3175 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -324,7 +324,7 @@ static ncclResult_t addP2pToPlan( struct ncclWorkElemP2p elem = {0}; elem.proto = info.protocol; - elem.peer = peer; + elem.peer = addr == nullptr ? -1 : peer; elem.nWarps = NCCL_MAX_NTHREADS/comm->WarpSize; elem.p2pType = isSendNotRecv ? ncclWorkP2pTypeSend : ncclWorkP2pTypeRecv; elem.buffLo32 = uint32_t(reinterpret_cast(addr)); @@ -342,7 +342,7 @@ static ncclResult_t addP2pToPlan( // Calculate the opCount after appendWorkElemP2p since it will always return // with channel->nWork equal to one plus the work index this p2p settled in. proxyOp.opCount = uint64_t(plan->channels[channelId].nWork)<<1 | 1; - NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp)); + if (addr != nullptr) NCCLCHECK(addProxyOpIfNeeded(comm, plan, &proxyOp)); return ncclSuccess; } @@ -646,6 +646,8 @@ static ncclResult_t scheduleP2pTasksToPlan( ncclIntruQueueDequeue(&peers[recvPeer].recvQueue); tasks->nTasksP2p -= 1; } + } else { + NCCLCHECK(addP2pToPlan(comm, plan, nWorkBudget, /*isSendNotRecv=*/false, recvPeer, 0, nullptr, 0, recvIdx)); } if (sendChunkBytes != 0) { if (sendChunkBytes == -1) sendChunkBytes = 0; @@ -659,6 +661,8 @@ static ncclResult_t scheduleP2pTasksToPlan( ncclIntruQueueDequeue(&peers[sendPeer].sendQueue); tasks->nTasksP2p -= 1; } + } else { + NCCLCHECK(addP2pToPlan(comm, plan, nWorkBudget, /*isSendNotRecv=*/true, sendPeer, 0, nullptr, 0, sendIdx)); } } while (sendBytes != 0 || recvBytes != 0); }