diff --git a/src/device/all_reduce.h b/src/device/all_reduce.h index bfdf29f619..45b06be119 100644 --- a/src/device/all_reduce.h +++ b/src/device/all_reduce.h @@ -61,7 +61,7 @@ namespace { // Coverity reports that the callee treats &ring->next as an array. However, due to the use of // FanSymmetric<1>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] - Primitives, 0, Proto, 0, false, RCCLMetadata, Pipeline> prims + Primitives, 0, Proto, 0, false, RCCLMetadata, Pipeline, USE_ACC> prims (tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex, work); #if defined(ENABLE_NPKIT) @@ -252,7 +252,7 @@ namespace { #endif { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims + Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC> prims (tid, nthreads, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work); #if defined(ENABLE_NPKIT) @@ -301,7 +301,7 @@ namespace { } { // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims + Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC> prims (tid, nthreads, &tree->up, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work); #if defined(ENABLE_NPKIT) @@ -463,7 +463,7 @@ namespace { // Coverity reports that the callee treats &tree->up as an array. However, due to the use of // FanAsymmetric, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] - Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline> + Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC> prims(tid, nthreadsSplit, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, work); #if defined(ENABLE_NPKIT) @@ -508,7 +508,7 @@ namespace { // Coverity reports that the callee treats &tree->up as an array. However, due to the use of // FanAsymmetric<1, n>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] - Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline> + Primitives, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC> prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 1*Proto::MaxGroupWidth, 0, 0, work); diff --git a/src/device/generate.py b/src/device/generate.py index 3ab07b4a4d..ce3e679606 100755 --- a/src/device/generate.py +++ b/src/device/generate.py @@ -603,6 +603,8 @@ for name in name_to_funcs.keys(): for fn in fns: (coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll) + if coll == "AllReduceWithBias": + coll = "AllReduce" if proto == "LL128": out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n") out( diff --git a/src/device/prims_ll.h b/src/device/prims_ll.h index 2cf1045c84..620cd9e75f 100644 --- a/src/device/prims_ll.h +++ b/src/device/prims_ll.h @@ -10,9 +10,9 @@ #include "npkit/npkit.h" #endif -template -class Primitives: - public PrimitivesWithoutDirect> { +template +class Primitives: + public PrimitivesWithoutDirect> { // In the case of Fan::MaxRecv == 0, we need to force MaxRecv to 1 for this to compile // This is because of a recv buffer which is allocated to MaxRecv length in send-only cases diff --git a/src/device/prims_ll128.h b/src/device/prims_ll128.h index f36964c888..3d26fc58e5 100644 --- a/src/device/prims_ll128.h +++ b/src/device/prims_ll128.h @@ -20,9 +20,9 @@ #endif #endif -template -class Primitives: - public PrimitivesWithoutDirect> { +template +class Primitives: + public PrimitivesWithoutDirect> { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1, Acc=2;; diff --git a/src/device/prims_simple.h b/src/device/prims_simple.h index 10fa7cebca..f8187bdb28 100644 --- a/src/device/prims_simple.h +++ b/src/device/prims_simple.h @@ -304,7 +304,7 @@ private: } #endif - reduceCopy + reduceCopy (tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false, 1, ncclShmem.groups[group].srcs, fan.nsend(), ncclShmem.groups[group].dsts+1, @@ -340,7 +340,7 @@ private: } #endif - reduceCopy + reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp, Recv, ncclShmem.groups[group].srcs, Dst, ncclShmem.groups[group].dsts, @@ -378,7 +378,7 @@ private: DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; if (Send && Dst && ncclShmem.groups[group].dsts[1] == nullptr) { // this case should only be directCopySend() with registered buffers and send to net peer - reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, @@ -386,7 +386,7 @@ private: 1, ncclShmem.groups[group].dsts, workSize); } else { - reduceCopy (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, diff --git a/src/enqueue.cc b/src/enqueue.cc index 4f720cfb99..f94c11ea21 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -500,7 +500,14 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool } NCCLCHECK(getAlgoInfo(comm, &agg, collNetSupport, nvlsSupport, nTasksPerChannel, simInfo)); - agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline); + if(agg.func==ncclFuncAllReduce && agg.acc != nullptr) + { + agg.devFuncId = ncclDevFuncId(ncclFuncAllReduceWithBias, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline); + } + else + { + agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline); + } if (agg.devFuncId < 0) { WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); return ncclInvalidUsage; @@ -523,6 +530,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool struct ncclTaskColl* next = aggBeg->next; aggBeg->algorithm = agg.algorithm; aggBeg->protocol = agg.protocol; + aggBeg->acc = agg.acc; aggBeg->pipeline = agg.pipeline; if (aggBeg->protocol == NCCL_PROTO_LL) aggBeg->trafficBytes *= 4; aggBeg->nMaxChannels = agg.nMaxChannels; @@ -574,6 +582,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool struct ncclDevWorkColl devWork = {}; devWork.sendbuff = (void*)task->sendbuff; devWork.recvbuff = (void*)task->recvbuff; + devWork.acc = (void*)task->acc; devWork.sendbuffOffset = task->sendbuffOffset; devWork.recvbuffOffset = task->recvbuffOffset; devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs;