From eeea3b693b66cc42e7f80b1826a050c548412c87 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Thu, 16 May 2024 10:11:12 -0700 Subject: [PATCH] Report error when collective is not enabled in build (#1177) * Report error when collective is not enabled in build * Fix typo --- src/enqueue.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/enqueue.cc b/src/enqueue.cc index a77f5b1e01..5544c02d63 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -171,11 +171,15 @@ static void finishWork(struct ncclWork* work, int WarpSize) { } } -static void appendWorkElemP2p( +static ncclResult_t appendWorkElemP2p( struct ncclComm* comm, struct ncclKernelPlan* plan, int channelId, struct ncclWorkElemP2p const *elem, bool fuseOk ) { int funcIndex = ncclDevFuncId_P2p(); + if (funcIndex < 0) { + WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); + return ncclInvalidUsage; + } struct ncclKernelPlan::Channel* chan = &plan->channels[channelId]; struct ncclWorkList* q = ncclIntruQueueTail(&chan->workQueue); if (q && funcIndex == q->work.header.funcIndex) { @@ -190,7 +194,7 @@ static void appendWorkElemP2p( int e = chan->p2pTailElem[elem->p2pType-1]; q->work.p2pElems[e] = *elem; // C++ struct assignment chan->p2pTailElem[elem->p2pType-1] += 2; - return; + return ncclSuccess; } NewWork: finishWorkP2p(&q->work, comm->WarpSize); @@ -204,6 +208,7 @@ static void appendWorkElemP2p( chan->p2pTailElem[elem->p2pType-1] += 2; chan->nWork += 1; ncclIntruQueueEnqueue(&chan->workQueue, q); + return ncclSuccess; } static ncclResult_t addProxyOpIfNeeded(struct ncclComm* comm, struct ncclKernelPlan* plan, struct ncclProxyOp* op) { @@ -1699,6 +1704,10 @@ RCCL_PARAM(IntraNetThreshold, "INTRANET_THRESHOLD", 8388608); static ncclResult_t computeCollWorkFunc(struct ncclInfo* collInfo) { collInfo->workFuncIndex = ncclDevFuncId(collInfo->coll, collInfo->opFull.op, collInfo->datatype, collInfo->algorithm, collInfo->protocol); + if (collInfo->workFuncIndex < 0) { + WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); + return ncclInvalidUsage; + } return ncclSuccess; }