Added useAcc as a template parameter to address the performance regression (#1856)
* Added useAcc as a template parameter to address the 2% performance regression in allreduceWithBias
---------
Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>
[ROCm/rccl commit: c61152baa4]
Этот коммит содержится в:
@@ -21,13 +21,13 @@
|
||||
HIP_FILE=$1
|
||||
|
||||
if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then
|
||||
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int COLL_UNROLL\2>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, COLL_UNROLL>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(runRing<T.*?)((, (true|false))?>\()/\1, COLL_UNROLL\2/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeUpDown<T.*?)>\(/\1, COLL_UNROLL>(/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeSplit<T.*?)>\(/\1, COLL_UNROLL>(/' "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkColl<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
|
||||
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL\2>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, USE_ACC, COLL_UNROLL>/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(runRing<T.*?)((, (true|false))?>\()/\1, USE_ACC, COLL_UNROLL\2/g' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeUpDown<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
|
||||
perl -pi -e 's/(runTreeSplit<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkColl<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
|
||||
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
|
||||
echo "Added COLL_UNROLL and USE_ACC template arguments to $HIP_FILE"
|
||||
|
||||
echo "Added COLL_UNROLL template argument to $HIP_FILE"
|
||||
fi
|
||||
@@ -162,7 +162,7 @@ namespace {
|
||||
} else if (inputBuf != outputBuf + ringRanks[0] * count) {
|
||||
inputBuf = inputBuf + partOffset;
|
||||
outputBuf = outputBuf + partOffset + ringRanks[0] * count;
|
||||
reduceCopy<COLL_UNROLL, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
|
||||
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
|
||||
(tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, partCount);
|
||||
}
|
||||
#if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__)
|
||||
@@ -303,7 +303,7 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
if (!work->regUsed) {
|
||||
if (tid < tidEndGather) {
|
||||
// Gather
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -329,7 +329,7 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
} else {
|
||||
/* direct allgather */
|
||||
if (tid < tidEndGather) {
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsGather, nvls->up, nvls->up, NULL, NULL,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -409,7 +409,7 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCCL_P
|
||||
ssize_t userOneBeg = rank*countPerRank + railOneOffset;
|
||||
int outIsDst = (inPlace && rank == ncclShmem.comm.rank) ? 0 : 1;
|
||||
if (nSrcs != 0 && outIsDst+nDsts != 0) {
|
||||
reduceCopy<ncclCollUnroll(), RedOp, T,
|
||||
reduceCopy<ncclCollUnroll(), USE_ACC, RedOp, T,
|
||||
/*MultimemSrcs,MinSrcs,MaxSrcs=*/0,1,1,
|
||||
/*MultimemDsts=*/0, 0+MinDsts, 1+MaxDsts,
|
||||
/*PreOpSrcs=*/0>
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto, int RCCLMetadata, int COLL_UNROLL>
|
||||
template<typename T, typename RedOp, typename Proto, int RCCLMetadata, int USE_ACC, int COLL_UNROLL>
|
||||
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__)
|
||||
__device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
#else
|
||||
@@ -420,7 +420,7 @@ namespace {
|
||||
|
||||
if (tree->up == -1) {
|
||||
// Reduce and broadcast. Max number of recv is 2, max number of send is 2
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto,USE_ACC >
|
||||
prims(tid, nthreads, tree->down, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
|
||||
|
||||
#if defined(ENABLE_NPKIT)
|
||||
@@ -566,7 +566,7 @@ namespace {
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS_SINGLE_NODE, ALLREDUCE_SLICESTEPS_SINGLE_NODE>; \
|
||||
if(work->regUsed || work->netRegUsed || work->gfx942CheapFenceOff){ \
|
||||
runRing<T, RedOp, Proto, RCCL_METADATA_EMPTY>(tid, nthreads, work); \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
runRing<T, RedOp, Proto, RCCL_ONE_NODE_RING_SIMPLE>(tid, nthreads, work); \
|
||||
} \
|
||||
@@ -772,7 +772,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
|
||||
if (tid < tidEndScatter) {
|
||||
// Scatter
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -784,7 +784,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
}
|
||||
} else if (tid < tidEndGather) {
|
||||
// Gather
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
|
||||
work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -819,7 +819,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
|
||||
if (tid < tidEndScatter) {
|
||||
// Scatter
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -831,7 +831,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
// coverity[overrun-call] => Coverity think prims.index can be greater than 1
|
||||
} else if (tid < tidEndGather) {
|
||||
// Gather
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
|
||||
work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -842,7 +842,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
}
|
||||
} else if (tid < tidEndReduce && nvls->headRank != -1) {
|
||||
// Reduce, send to network
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
|
||||
// Coverity complains about a possible overrun inside the class below, but that's actually
|
||||
// a false positive.
|
||||
// coverity[identity_transfer:FALSE]
|
||||
@@ -857,7 +857,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
|
||||
}
|
||||
} else if (tid < tidEndBcast && nvls->headRank != -1) {
|
||||
// Recv from network, broadcast
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 0, 1>;
|
||||
// Coverity complains about a possible overrun inside the class below, but that's actually
|
||||
// a false positive.
|
||||
// coverity[identity_transfer:FALSE]
|
||||
@@ -907,7 +907,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
|
||||
|
||||
if (tid < tidEndScatter) {
|
||||
// Scatter
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -919,7 +919,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
|
||||
}
|
||||
} else if (tid < tidEndGather) {
|
||||
// Gather
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
|
||||
work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -932,7 +932,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
|
||||
} else if (tid < tidEndReduce && nvls->headRank != -1) {
|
||||
if (!hasUp) {
|
||||
// Reduce and Broadcast
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 1>;
|
||||
Primitives<T, RedOp, FanSymmetric<3>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid - tidEndGather, nThreadsReduce, treeDown, treeDown, NULL, NULL,
|
||||
work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, work);
|
||||
@@ -946,7 +946,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
|
||||
}
|
||||
} else {
|
||||
// Reduce, send to network
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
|
||||
// Coverity reports that the callee treats &treeUp as an array. However, due to the use of
|
||||
// FanAsymmetric<3, 1>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
@@ -964,7 +964,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
|
||||
}
|
||||
} else if (tid < tidEndBcast && nvls->headRank != -1) {
|
||||
// Recv from network, broadcast
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 0, 1>;
|
||||
// Coverity reports that the callee treats &treeUp as an array. However, due to the use of
|
||||
// FanAsymmetric<1, 3>, only the first element is ever accessed, so it's fine.
|
||||
// coverity[callee_ptr_arith:FALSE]
|
||||
|
||||
@@ -52,7 +52,7 @@ namespace {
|
||||
if (num_hops == 0 && work->sendbuff != work->recvbuff) {
|
||||
const T* sendbuff = (const T*)work->sendbuff + send_offset;
|
||||
T* recvbuff = (T *)work->recvbuff + recv_offset;
|
||||
reduceCopy<COLL_UNROLL, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
|
||||
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
|
||||
tid, nthreads, 0, nullptr, false, 1, (void **)&sendbuff, 1, (void **)&recvbuff, send_recv_size);
|
||||
} else {
|
||||
for (ssize_t prims_offset = 0; prims_offset < send_recv_size; prims_offset += prims_size) {
|
||||
|
||||
@@ -91,7 +91,7 @@ namespace {
|
||||
} else if (inputBuf != outputBuf && rank == root) {
|
||||
inputBuf = inputBuf + gridOffset;
|
||||
outputBuf = outputBuf + gridOffset;
|
||||
reduceCopy<COLL_UNROLL, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
|
||||
reduceCopy<COLL_UNROLL, 0, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
|
||||
(tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, channelCount);
|
||||
}
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_EXIT)
|
||||
@@ -126,4 +126,4 @@ struct RunWorkColl<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128
|
||||
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
runRing<T, RedOp, ProtoLL128>(tid, nthreads, work);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -392,14 +392,14 @@ __device__ __forceinline__ void loadWorkBatchToShmem(
|
||||
}
|
||||
}
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
|
||||
struct RunWorkColl {
|
||||
__device__ void run(int tid, int tn, struct ncclDevWorkColl* work) {
|
||||
// Put NOT IMPLEMENTED behavior here.
|
||||
}
|
||||
};
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
|
||||
struct RunWorkBatch;
|
||||
|
||||
// Specialized for P2p in sendrecv.h
|
||||
@@ -407,7 +407,7 @@ template<typename T, typename RedOp>
|
||||
struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE>;
|
||||
|
||||
// Specialized here for non-P2p (Coll and CollReg)
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
|
||||
struct RunWorkBatch {
|
||||
// This __forceinline__ is necessary. The compiler was inserting a function call
|
||||
// here from the LL ncclKernel.
|
||||
@@ -437,7 +437,7 @@ struct RunWorkBatch {
|
||||
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
|
||||
// However, the code ensures that the participation is on a per-warp basis.
|
||||
// coverity[device_thread_diverged:FALSE]
|
||||
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, COLL_UNROLL>().run(tid, subtn, work);
|
||||
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, USE_ACC, COLL_UNROLL>().run(tid, subtn, work);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -672,14 +672,14 @@ __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONST
|
||||
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
|
||||
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
|
||||
__device__ void ncclDevFunc_##suffix() { \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, unroll>().run(); \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
|
||||
}
|
||||
#else
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \
|
||||
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
|
||||
__device__ __attribute__((noinline)) void ncclDevFunc_##suffix() { \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, unroll>().run(); \
|
||||
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -618,7 +618,7 @@ __device__ __attribute__((noinline)) void reduceCopyPacksWithBias(
|
||||
thread = warp*WARP_SIZE + lane;
|
||||
}
|
||||
|
||||
template<int Unroll, typename RedFn, typename T,
|
||||
template<int Unroll, int useAcc, typename RedFn, typename T,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
|
||||
@@ -641,7 +641,7 @@ __device__ __forceinline__ void reduceCopy(
|
||||
|
||||
IntBytes nBytesBehind = 0;
|
||||
IntBytes nBytesAhead = nElts*sizeof(T);
|
||||
bool useAcc = accPtrFn() != nullptr;
|
||||
//bool useAcc = accPtrFn() != nullptr;
|
||||
|
||||
#if __cpp_if_constexpr
|
||||
if constexpr (BigPackSize > sizeof(T)) {
|
||||
@@ -763,7 +763,7 @@ __device__ __forceinline__ void reduceCopy(
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
}
|
||||
|
||||
template<int Unroll, typename RedFn, typename T,
|
||||
template<int Unroll, int useAcc, typename RedFn, typename T,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes>
|
||||
@@ -773,7 +773,7 @@ __device__ __forceinline__ void reduceCopy(
|
||||
int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs,
|
||||
IntBytes nElts, void *accPtr = nullptr
|
||||
) {
|
||||
reduceCopy<Unroll, RedFn, T,
|
||||
reduceCopy<Unroll, useAcc, RedFn, T,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs,
|
||||
MultimemDsts, MinDsts, MaxDsts, PreOpSrcs, IntBytes>
|
||||
(thread, nThreads, redArg, preOpArgs, postOp,
|
||||
|
||||
@@ -4,14 +4,15 @@ import sys
|
||||
import subprocess
|
||||
|
||||
# Order of redops, tys, protos, algos must match src/include/device.h
|
||||
all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"]
|
||||
all_colls = ["AllGather","AllReduce","AllReduceWithBias","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"]
|
||||
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
|
||||
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
|
||||
all_protos = ["LL","LL128","SIMPLE"]
|
||||
all_algos = ["TREE","RING", "PAT"]
|
||||
all_unroll = ["1", "2", "4"]
|
||||
use_acc = ["0", "1"]
|
||||
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, all_unroll]
|
||||
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_unroll]
|
||||
|
||||
################################################################################
|
||||
# The first command line argument is the path to the directory to generate and
|
||||
@@ -76,43 +77,47 @@ func_pattern = sys.argv[6:7]
|
||||
if func_pattern and func_pattern[0]:
|
||||
func_pattern = func_pattern[0]
|
||||
else:
|
||||
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
func_pattern = "AllGather|AllReduce|AllReduceWithBias|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
|
||||
|
||||
################################################################################
|
||||
|
||||
algos_of_coll = {
|
||||
"AllGather": ["RING", "PAT"],
|
||||
"AllReduce": ["RING", "TREE"],
|
||||
"AllToAllPivot": ["RING"],
|
||||
"Broadcast": ["RING"],
|
||||
"Reduce": ["RING"],
|
||||
"ReduceScatter": ["RING", "PAT"],
|
||||
"SendRecv": ["RING"]
|
||||
"AllGather": ["RING", "PAT"],
|
||||
"AllReduce": ["RING", "TREE"],
|
||||
"AllReduceWithBias": ["RING", "TREE"],
|
||||
"AllToAllPivot": ["RING"],
|
||||
"Broadcast": ["RING"],
|
||||
"Reduce": ["RING"],
|
||||
"ReduceScatter": ["RING", "PAT"],
|
||||
"SendRecv": ["RING"]
|
||||
}
|
||||
|
||||
protos_of_coll = {
|
||||
"AllGather": all_protos,
|
||||
"AllReduce": all_protos,
|
||||
"AllToAllPivot": ["SIMPLE"],
|
||||
"Broadcast": all_protos,
|
||||
"Reduce": all_protos,
|
||||
"ReduceScatter": all_protos,
|
||||
"SendRecv": ["SIMPLE"]
|
||||
"AllGather": all_protos,
|
||||
"AllReduce": all_protos,
|
||||
"AllReduceWithBias": all_protos,
|
||||
"AllToAllPivot": ["SIMPLE"],
|
||||
"Broadcast": all_protos,
|
||||
"Reduce": all_protos,
|
||||
"ReduceScatter": all_protos,
|
||||
"SendRecv": ["SIMPLE"]
|
||||
}
|
||||
|
||||
redops_of_coll = {
|
||||
"AllGather": ["Sum"],
|
||||
"AllReduce": all_redops,
|
||||
"AllToAllPivot": ["Sum"],
|
||||
"Broadcast": ["Sum"],
|
||||
"Reduce": all_redops,
|
||||
"ReduceScatter": all_redops,
|
||||
"SendRecv": ["Sum"]
|
||||
"AllGather": ["Sum"],
|
||||
"AllReduce": all_redops,
|
||||
"AllReduceWithBias": all_redops,
|
||||
"AllToAllPivot": ["Sum"],
|
||||
"Broadcast": ["Sum"],
|
||||
"Reduce": all_redops,
|
||||
"ReduceScatter": all_redops,
|
||||
"SendRecv": ["Sum"]
|
||||
}
|
||||
|
||||
tys_of_coll = {
|
||||
"AllGather": ["i8"],
|
||||
"AllReduce": all_tys,
|
||||
"AllReduceWithBias": all_tys,
|
||||
"AllToAllPivot": ["i8"],
|
||||
"Broadcast": ["i8"],
|
||||
"Reduce": all_tys,
|
||||
@@ -121,11 +126,12 @@ tys_of_coll = {
|
||||
}
|
||||
|
||||
coll_camel_to_lower = {
|
||||
"AllGather": "all_gather",
|
||||
"AllReduce": "all_reduce",
|
||||
"AllToAllPivot": "alltoall_pivot",
|
||||
"Broadcast": "broadcast",
|
||||
"Reduce": "reduce",
|
||||
"AllGather": "all_gather",
|
||||
"AllReduce": "all_reduce",
|
||||
"AllReduceWithBias": "allreduce_with_bias",
|
||||
"AllToAllPivot": "alltoall_pivot",
|
||||
"Broadcast": "broadcast",
|
||||
"Reduce": "reduce",
|
||||
"ReduceScatter": "reduce_scatter",
|
||||
"SendRecv": "sendrecv"
|
||||
}
|
||||
@@ -174,7 +180,11 @@ def calc_unroll_for_local_arch():
|
||||
return all_unroll
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
def func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
if acc == "1" and coll != "AllReduceWithBias":
|
||||
return False
|
||||
if acc == "0" and coll == "AllReduceWithBias":
|
||||
return False
|
||||
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
|
||||
@@ -222,10 +232,10 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
# For each loop layer remove the last element in item_list
|
||||
item_list.pop()
|
||||
else:
|
||||
coll, algo, proto, redop, ty, unroll = item_list
|
||||
coll, algo, proto, redop, ty, acc, unroll = item_list
|
||||
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
yield(coll, algo, proto, redop, ty, unroll)
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
yield(coll, algo, proto, redop, ty, acc, unroll)
|
||||
|
||||
# Parse ONLY_FUNCS input and feed it to func_filter
|
||||
def parse_input(func_pattern):
|
||||
@@ -245,33 +255,35 @@ def parse_input(func_pattern):
|
||||
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
|
||||
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, acc, unroll):
|
||||
if coll in ("AllReduce", "AllReduceWithBias", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
||||
ty = "u"+ty[1:]
|
||||
# map signed integer min/max to unsigned for non-NVLS
|
||||
elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
|
||||
ty = "u"+ty[1:]
|
||||
return (coll, algo, proto, redop, ty, unroll)
|
||||
return (coll, algo, proto, redop, ty, acc, unroll)
|
||||
|
||||
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
|
||||
def enumerate_func_rows():
|
||||
for unroll in all_unroll:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
yield (coll, algo, proto, redop, ty, unroll)
|
||||
for acc in use_acc:
|
||||
for unroll in all_unroll:
|
||||
for coll in all_colls:
|
||||
for algo in all_algos:
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty, acc, unroll):
|
||||
yield (coll, algo, proto, redop, ty, acc, unroll)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
use_acc.index(acc),
|
||||
all_colls.index(coll),
|
||||
all_algos.index(algo),
|
||||
all_protos.index(proto),
|
||||
@@ -320,7 +332,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
|
||||
index1 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
if unroll != "1": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
@@ -337,7 +349,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
|
||||
index2 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
if unroll != "2": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
@@ -354,7 +366,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index4 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
coll, algo, proto, redop, ty, acc, unroll = fn
|
||||
if unroll != "4": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
@@ -469,7 +481,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# Maps to .cu filename which implements this func. The only constraint is that
|
||||
# "coll" is reflected in the name: formally that no two funcs having different
|
||||
# coll's map to the same filename.
|
||||
def impl_filename(coll, algo, proto, redop, ty, unroll):
|
||||
def impl_filename(coll, algo, proto, redop, ty, acc, unroll):
|
||||
return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
|
||||
|
||||
# Partition the functions and kernels to the .cu filenames. The partition is
|
||||
@@ -518,6 +530,8 @@ for name in name_to_funcs.keys():
|
||||
print("-- Generating %s" % os.path.join(gensrc, name))
|
||||
|
||||
out = f.write
|
||||
if coll == "AllReduceWithBias":
|
||||
coll = "AllReduce"
|
||||
out(
|
||||
'#include "common.h"\n'
|
||||
'#include "{lower_coll}.h"\n'
|
||||
@@ -525,14 +539,14 @@ for name in name_to_funcs.keys():
|
||||
)
|
||||
|
||||
for fn in fns:
|
||||
(coll, algo, proto, redop, ty, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, unroll)
|
||||
(coll, algo, proto, redop, ty, acc, unroll) = fn
|
||||
sym = paste("_", coll, algo, proto, redop, ty, acc, unroll)
|
||||
if proto == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out(
|
||||
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {unroll})\n"
|
||||
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {unroll})\n"
|
||||
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"), unroll=unroll)
|
||||
algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, unroll=unroll)
|
||||
)
|
||||
if proto == "LL128":
|
||||
out("#endif\n")
|
||||
|
||||
@@ -376,7 +376,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct n
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128, fullOps>(comm, algo, work); \
|
||||
} \
|
||||
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS, 2>, fullOps>(comm, algo, work); \
|
||||
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS, 0, 2>, fullOps>(comm, algo, work); \
|
||||
}
|
||||
|
||||
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
|
||||
|
||||
@@ -49,7 +49,7 @@ namespace {
|
||||
redOpArg = *reinterpret_cast<uint64_t*>(redOpArg);
|
||||
}
|
||||
}
|
||||
reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/1>
|
||||
reduceCopy<COLL_UNROLL,0, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/1>
|
||||
(tid, tn, redOpArg, &redOpArg, true, 1, &src, 1, &dst, i1-i0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
* to how that protocol operates with a consistent interface so that our
|
||||
* algorithm code can operate protocol parametrically.
|
||||
*/
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
|
||||
template<int SlicePerChunk_1, int StepPerSlice_1,int useAcc, int Unroll_1, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
|
||||
struct ProtoSimple {
|
||||
static constexpr int Id = NCCL_PROTO_SIMPLE;
|
||||
static constexpr int SlicePerChunk = SlicePerChunk_1;
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>> {
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, useAcc>> {
|
||||
|
||||
// In the case of Fan::MaxRecv == 0, we need to force MaxRecv to 1 for this to compile
|
||||
// This is because of a recv buffer which is allocated to MaxRecv length in send-only cases
|
||||
@@ -437,7 +437,7 @@ private:
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
T *accElts = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
|
||||
T *accElts = (DstBuf == -1 || !useAcc) ? nullptr : userBufs[Acc] + dstIx;
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
|
||||
@@ -20,9 +20,9 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>> {
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, useAcc>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1, Acc=2;;
|
||||
@@ -366,7 +366,7 @@ private:
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
T *accPtr = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
|
||||
T *accPtr = (DstBuf == -1 || !useAcc) ? nullptr : userBufs[Acc] + dstIx;
|
||||
int wireOffset = WireWordPerSlice*warp + 2*wid;
|
||||
const int nwarps = nthreads/WARP_SIZE;
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
|
||||
@@ -23,9 +23,9 @@ enum primsMode {
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp, typename Fan, int Direct,
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata>
|
||||
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata, int useAcc>
|
||||
class Primitives<
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll, MultimemSrcs, MultimemDsts>, P2p, isNetOffload, Metadata
|
||||
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, useAcc, Unroll, MultimemSrcs, MultimemDsts>, P2p, isNetOffload, Metadata
|
||||
> {
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
@@ -299,7 +299,7 @@ private:
|
||||
}
|
||||
#endif
|
||||
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
|
||||
(tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false,
|
||||
1, ncclShmem.groups[group].srcs,
|
||||
fan.nsend(), ncclShmem.groups[group].dsts+1,
|
||||
@@ -335,7 +335,7 @@ private:
|
||||
}
|
||||
#endif
|
||||
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp,
|
||||
Recv, ncclShmem.groups[group].srcs,
|
||||
Dst, ncclShmem.groups[group].dsts,
|
||||
@@ -373,7 +373,7 @@ private:
|
||||
DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1;
|
||||
if (Send && Dst && ncclShmem.groups[group].dsts[1] == nullptr) {
|
||||
// this case should only be directCopySend() with registered buffers and send to net peer
|
||||
reduceCopy<Unroll, RedOp, T,
|
||||
reduceCopy<Unroll, useAcc, RedOp, T,
|
||||
0, Recv + Src, Recv * MaxRecv + Src,
|
||||
0, 1, 1, PreOpSrcs>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
@@ -381,7 +381,7 @@ private:
|
||||
1, ncclShmem.groups[group].dsts,
|
||||
workSize);
|
||||
} else {
|
||||
reduceCopy<Unroll, RedOp, T,
|
||||
reduceCopy<Unroll, useAcc, RedOp, T,
|
||||
MultimemSrcs, Recv + Src, Recv * MaxRecv + Src,
|
||||
MultimemDsts, Send + Dst, Send * MaxSend + Dst, PreOpSrcs>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
@@ -453,19 +453,19 @@ private:
|
||||
srcs[nsrcs] = dsts[0];
|
||||
nsrcs++;
|
||||
if (MULTISRCS){
|
||||
reduceCopy<Unroll, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem);
|
||||
} else {
|
||||
reduceCopy<Unroll, RedOp, T, 0, 2, 2, 0, 1, 1, 0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 2, 2, 0, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem);
|
||||
}
|
||||
}
|
||||
if (COPY){
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)srcs, 1, (void **)dsts, nelem);
|
||||
if (MULTISRCS) {
|
||||
for (int i = 1; i < nsrcs; i++){
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)&srcs[i], 1, (void **)&dsts[i], nelem);
|
||||
}
|
||||
}
|
||||
@@ -617,7 +617,7 @@ private:
|
||||
void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset;
|
||||
ssize_t realPeerSize = min(realSize, totalElem-pOffset);
|
||||
if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) {
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize);
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize);
|
||||
// Mark for threadfence at the end
|
||||
fenceNeeded |= true;
|
||||
}
|
||||
@@ -637,7 +637,7 @@ private:
|
||||
void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset;
|
||||
ssize_t realPeerSize = min(realSize, totalElem-pOffset);
|
||||
if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0;
|
||||
if (realPeerSize > 0) reduceCopy<Unroll, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
if (realPeerSize > 0) reduceCopy<Unroll, useAcc, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1190,7 +1190,7 @@ public:
|
||||
|
||||
int workSize = ncclShmem.aborted ? 0 : nelem;
|
||||
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 2, 0, 1, 1, /*PreOpSrcs*/0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 2, 0, 1, 1, /*PreOpSrcs*/0>
|
||||
(tid, nthreads, ncclShmem.redOpArgs[0], nullptr, /*postOp=*/false,
|
||||
nSrcs, srcs, 1, ncclShmem.groups[group].dsts, workSize);
|
||||
|
||||
@@ -1284,7 +1284,7 @@ public:
|
||||
|
||||
int workSize = ncclShmem.aborted ? 0 : nelem;
|
||||
|
||||
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 2, /*PreOpSrcs*/0>
|
||||
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 2, /*PreOpSrcs*/0>
|
||||
(tid, nthreads, ncclShmem.redOpArgs[0], nullptr, /*postOp=*/false,
|
||||
1, ncclShmem.groups[group].srcs, nDsts, dsts, workSize);
|
||||
|
||||
|
||||
@@ -258,7 +258,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
|
||||
if (!work->regUsed) {
|
||||
if (tid < tidEndScatter) {
|
||||
// Scatter
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -270,7 +270,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
|
||||
// coverity[overrun-call] => Coverity think prims.index can be greater than 1
|
||||
} else if (tid < tidEndReduce) {
|
||||
// Reduce through NVLS
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, work->recvbuff,
|
||||
work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0);
|
||||
@@ -283,7 +283,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
|
||||
} else {
|
||||
if (tid < tidEndScatter) {
|
||||
// Scatter
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
|
||||
prims(tid, nThreadsScatter, nvls->up, nvls->up, NULL, NULL,
|
||||
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
|
||||
@@ -295,7 +295,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
|
||||
prims.gather(0, 0, 0, 0, -1, 0);
|
||||
} else if (tid < tidEndReduce) {
|
||||
// Reduce through NVLS
|
||||
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
|
||||
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/1, Proto, 0>
|
||||
prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, &nvls->down, NULL, work->recvbuff,
|
||||
work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, work);
|
||||
|
||||
@@ -247,10 +247,10 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
|
||||
if (isCopy) {
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
reduceCopy<COLL_UNROLL*2, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
|
||||
reduceCopy<COLL_UNROLL*2, 0, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
|
||||
(subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes);
|
||||
#else
|
||||
reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
|
||||
reduceCopy<COLL_UNROLL, 0, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
|
||||
(subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes);
|
||||
#endif
|
||||
} else if (isSend) {
|
||||
@@ -258,9 +258,9 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
runSend<ProtoLL>(subtid, subtn, group, work);
|
||||
} else {
|
||||
#if defined(__gfx90a__)
|
||||
runSend<ProtoSimple<1,1,8>>(subtid, subtn, group, work);
|
||||
runSend<ProtoSimple<1,1,0,8>>(subtid, subtn, group, work);
|
||||
#elif defined(__gfx908__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
runSend<ProtoSimple<1,1,4>>(subtid, subtn, group, work);
|
||||
runSend<ProtoSimple<1,1,0,4>>(subtid, subtn, group, work);
|
||||
#else
|
||||
runSend<ProtoSimple<1,1>>(subtid, subtn, group, work);
|
||||
#endif
|
||||
@@ -270,13 +270,13 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
runRecv<ProtoLL>(subtid, subtn, group, work);
|
||||
} else {
|
||||
#if defined(__gfx90a__)
|
||||
runRecv<ProtoSimple<1,1,8>>(subtid, subtn, group, work);
|
||||
runRecv<ProtoSimple<1,1,0,8>>(subtid, subtn, group, work);
|
||||
#elif defined(__gfx908__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
runRecv<ProtoSimple<1,1,4>>(subtid, subtn, group, work);
|
||||
runRecv<ProtoSimple<1,1,0,4>>(subtid, subtn, group, work);
|
||||
#else
|
||||
runRecv<ProtoSimple<1,1>>(subtid, subtn, group, work);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -755,6 +755,6 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto)
|
||||
return ncclDevFuncRowToId[row];
|
||||
}
|
||||
|
||||
inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - 1]; }
|
||||
inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - AR_WITH_BIAS_FUNC_COUNTS - NCCL_NUM_ONERANK - 1]; }
|
||||
|
||||
#endif
|
||||
|
||||
@@ -40,7 +40,8 @@ typedef enum {
|
||||
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
|
||||
|
||||
#define NCCL_NUM_ONERANK 12
|
||||
#define FUNC_INDEX_TOTAL 821 + NCCL_NUM_ONERANK
|
||||
#define AR_WITH_BIAS_FUNC_COUNTS 324
|
||||
#define FUNC_INDEX_TOTAL 821 + AR_WITH_BIAS_FUNC_COUNTS + NCCL_NUM_ONERANK
|
||||
|
||||
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
|
||||
typedef enum {
|
||||
@@ -49,11 +50,12 @@ typedef enum {
|
||||
ncclFuncAllGather = 2,
|
||||
ncclFuncReduceScatter = 3,
|
||||
ncclFuncAllReduce = 4,
|
||||
ncclFuncSendRecv = 5,
|
||||
ncclFuncSend = 6,
|
||||
ncclFuncRecv = 7,
|
||||
ncclFuncAllToAllPivot = 8,
|
||||
ncclNumFuncs = 9
|
||||
ncclFuncAllReduceWithBias = 5,
|
||||
ncclFuncSendRecv = 6,
|
||||
ncclFuncSend = 7,
|
||||
ncclFuncRecv = 8,
|
||||
ncclFuncAllToAllPivot = 9,
|
||||
ncclNumFuncs = 10
|
||||
} ncclFunc_t;
|
||||
|
||||
#define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*
|
||||
|
||||
Ссылка в новой задаче
Block a user