fixing the ar_with_bias test issue when running rccl-tests (#1912)
* fixing the AR_With_Bias issue when running rccl-tests
This commit is contained in:
@@ -61,7 +61,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &ring->next as an array. However, due to the use of
|
||||
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, RCCLMetadata, Pipeline> prims
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, RCCLMetadata, Pipeline, USE_ACC> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -252,7 +252,7 @@ namespace {
|
||||
#endif
|
||||
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC> prims
|
||||
(tid, nthreads, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -301,7 +301,7 @@ namespace {
|
||||
}
|
||||
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC> prims
|
||||
(tid, nthreads, &tree->up, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -463,7 +463,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &tree->up as an array. However, due to the use of
|
||||
// FanAsymmetric<n, 1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline>
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -508,7 +508,7 @@ namespace {
|
||||
// Coverity reports that the callee treats &tree->up as an array. However, due to the use of
|
||||
// FanAsymmetric<1, n>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline>
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, work->sendbuff, work->recvbuff,
|
||||
work->redOpArg, 1*Proto::MaxGroupWidth, 0, 0, work);
|
||||
|
||||
|
||||
@@ -603,6 +603,8 @@ for name in name_to_funcs.keys():
|
||||
for fn in fns:
|
||||
(coll, algo, proto, redop, ty, acc, pipeline, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, acc, pipeline, unroll)
|
||||
if coll == "AllReduceWithBias":
|
||||
coll = "AllReduce"
|
||||
if proto == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out(
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, useAcc>> {
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int Metadata, int Pipeline, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pipeline, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pipeline, useAcc>> {
|
||||
|
||||
// In the case of Fan::MaxRecv == 0, we need to force MaxRecv to 1 for this to compile
|
||||
// This is because of a recv buffer which is allocated to MaxRecv length in send-only cases
|
||||
|
||||
@@ -20,9 +20,9 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, useAcc>> {
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int Metadata, int Pipeline, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, Metadata, Pipeline, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, Metadata, Pipeline, useAcc>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1, Acc=2;;
|
||||
|
||||
@@ -304,7 +304,7 @@ private:
|
||||
}
|
||||
#endif
|
||||
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
|
||||
reduceCopy<Unroll, useAcc && Dst, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
|
||||
(tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false,
|
||||
1, ncclShmem.groups[group].srcs,
|
||||
fan.nsend(), ncclShmem.groups[group].dsts+1,
|
||||
@@ -340,7 +340,7 @@ private:
|
||||
}
|
||||
#endif
|
||||
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
|
||||
reduceCopy<Unroll, useAcc && Dst, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp,
|
||||
Recv, ncclShmem.groups[group].srcs,
|
||||
Dst, ncclShmem.groups[group].dsts,
|
||||
@@ -378,7 +378,7 @@ private:
|
||||
DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1;
|
||||
if (Send && Dst && ncclShmem.groups[group].dsts[1] == nullptr) {
|
||||
// this case should only be directCopySend() with registered buffers and send to net peer
|
||||
reduceCopy<Unroll, useAcc, RedOp, T,
|
||||
reduceCopy<Unroll, useAcc && Dst, RedOp, T,
|
||||
0, Recv + Src, Recv * MaxRecv + Src,
|
||||
0, 1, 1, PreOpSrcs, Pipeline>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
@@ -386,7 +386,7 @@ private:
|
||||
1, ncclShmem.groups[group].dsts,
|
||||
workSize);
|
||||
} else {
|
||||
reduceCopy<Unroll, useAcc, RedOp, T,
|
||||
reduceCopy<Unroll, useAcc && Dst, RedOp, T,
|
||||
MultimemSrcs, Recv + Src, Recv * MaxRecv + Src,
|
||||
MultimemDsts, Send + Dst, Send * MaxSend + Dst, PreOpSrcs, Pipeline>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
|
||||
+10
-1
@@ -500,7 +500,14 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool
|
||||
}
|
||||
|
||||
NCCLCHECK(getAlgoInfo(comm, &agg, collNetSupport, nvlsSupport, nTasksPerChannel, simInfo));
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline);
|
||||
if(agg.func==ncclFuncAllReduce && agg.acc != nullptr)
|
||||
{
|
||||
agg.devFuncId = ncclDevFuncId(ncclFuncAllReduceWithBias, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline);
|
||||
}
|
||||
else
|
||||
{
|
||||
agg.devFuncId = ncclDevFuncId(agg.func, agg.opDev.op, agg.datatype, agg.algorithm, agg.protocol, agg.pipeline);
|
||||
}
|
||||
if (agg.devFuncId < 0) {
|
||||
WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__);
|
||||
return ncclInvalidUsage;
|
||||
@@ -523,6 +530,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool
|
||||
struct ncclTaskColl* next = aggBeg->next;
|
||||
aggBeg->algorithm = agg.algorithm;
|
||||
aggBeg->protocol = agg.protocol;
|
||||
aggBeg->acc = agg.acc;
|
||||
aggBeg->pipeline = agg.pipeline;
|
||||
if (aggBeg->protocol == NCCL_PROTO_LL) aggBeg->trafficBytes *= 4;
|
||||
aggBeg->nMaxChannels = agg.nMaxChannels;
|
||||
@@ -574,6 +582,7 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool
|
||||
struct ncclDevWorkColl devWork = {};
|
||||
devWork.sendbuff = (void*)task->sendbuff;
|
||||
devWork.recvbuff = (void*)task->recvbuff;
|
||||
devWork.acc = (void*)task->acc;
|
||||
devWork.sendbuffOffset = task->sendbuffOffset;
|
||||
devWork.recvbuffOffset = task->recvbuffOffset;
|
||||
devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs;
|
||||
|
||||
Reference in New Issue
Block a user