Added useAcc as a template parameter to address the performance regression (#1856)

* Added useAcc as a template parameter to address the 2% performance regression in allreduceWithBias
---------

Co-authored-by: Marzieh Berenjkoub <mberenjk@amd.com>


[ROCm/rccl commit: c61152baa4]
Этот коммит содержится в:
mberenjk
2025-08-14 15:58:54 -05:00
коммит произвёл GitHub
родитель d3e9db9432
Коммит c76a4492f1
18 изменённых файлов: 153 добавлений и 137 удалений
+8 -8
Просмотреть файл
@@ -21,13 +21,13 @@
HIP_FILE=$1
if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int COLL_UNROLL\2>/g' "$HIP_FILE"
perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, COLL_UNROLL>/g' "$HIP_FILE"
perl -pi -e 's/(runRing<T.*?)((, (true|false))?>\()/\1, COLL_UNROLL\2/g' "$HIP_FILE"
perl -pi -e 's/(runTreeUpDown<T.*?)>\(/\1, COLL_UNROLL>(/' "$HIP_FILE"
perl -pi -e 's/(runTreeSplit<T.*?)>\(/\1, COLL_UNROLL>(/' "$HIP_FILE"
sed -i "s/\\(struct RunWorkColl<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, COLL_UNROLL>/" "$HIP_FILE"
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL\2>/g' "$HIP_FILE"
perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, USE_ACC, COLL_UNROLL>/g' "$HIP_FILE"
perl -pi -e 's/(runRing<T.*?)((, (true|false))?>\()/\1, USE_ACC, COLL_UNROLL\2/g' "$HIP_FILE"
perl -pi -e 's/(runTreeUpDown<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
perl -pi -e 's/(runTreeSplit<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
sed -i "s/\\(struct RunWorkColl<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
echo "Added COLL_UNROLL and USE_ACC template arguments to $HIP_FILE"
echo "Added COLL_UNROLL template argument to $HIP_FILE"
fi
+4 -4
Просмотреть файл
@@ -162,7 +162,7 @@ namespace {
} else if (inputBuf != outputBuf + ringRanks[0] * count) {
inputBuf = inputBuf + partOffset;
outputBuf = outputBuf + partOffset + ringRanks[0] * count;
reduceCopy<COLL_UNROLL, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
(tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, partCount);
}
#if !defined(__HIP_PLATFORM_AMD__) && !defined(__HIPCC__)
@@ -303,7 +303,7 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
if (!work->regUsed) {
if (tid < tidEndGather) {
// Gather
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -329,7 +329,7 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
} else {
/* direct allgather */
if (tid < tidEndGather) {
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsGather, nvls->up, nvls->up, NULL, NULL,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -409,7 +409,7 @@ struct RunWorkColl<ncclFuncAllGather, T, RedOp, NCCL_ALGO_COLLNET_DIRECT, NCCL_P
ssize_t userOneBeg = rank*countPerRank + railOneOffset;
int outIsDst = (inPlace && rank == ncclShmem.comm.rank) ? 0 : 1;
if (nSrcs != 0 && outIsDst+nDsts != 0) {
reduceCopy<ncclCollUnroll(), RedOp, T,
reduceCopy<ncclCollUnroll(), USE_ACC, RedOp, T,
/*MultimemSrcs,MinSrcs,MaxSrcs=*/0,1,1,
/*MultimemDsts=*/0, 0+MinDsts, 1+MaxDsts,
/*PreOpSrcs=*/0>
+14 -14
Просмотреть файл
@@ -14,7 +14,7 @@
#endif
namespace {
template<typename T, typename RedOp, typename Proto, int RCCLMetadata, int COLL_UNROLL>
template<typename T, typename RedOp, typename Proto, int RCCLMetadata, int USE_ACC, int COLL_UNROLL>
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__)
__device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
#else
@@ -420,7 +420,7 @@ namespace {
if (tree->up == -1) {
// Reduce and broadcast. Max number of recv is 2, max number of send is 2
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0>
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto,USE_ACC >
prims(tid, nthreads, tree->down, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
#if defined(ENABLE_NPKIT)
@@ -566,7 +566,7 @@ namespace {
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS_SINGLE_NODE, ALLREDUCE_SLICESTEPS_SINGLE_NODE>; \
if(work->regUsed || work->netRegUsed || work->gfx942CheapFenceOff){ \
runRing<T, RedOp, Proto, RCCL_METADATA_EMPTY>(tid, nthreads, work); \
} \
} \
else { \
runRing<T, RedOp, Proto, RCCL_ONE_NODE_RING_SIMPLE>(tid, nthreads, work); \
} \
@@ -772,7 +772,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -784,7 +784,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
}
} else if (tid < tidEndGather) {
// Gather
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1);
@@ -819,7 +819,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -831,7 +831,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
// coverity[overrun-call] => Coverity think prims.index can be greater than 1
} else if (tid < tidEndGather) {
// Gather
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1);
@@ -842,7 +842,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
}
} else if (tid < tidEndReduce && nvls->headRank != -1) {
// Reduce, send to network
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
// Coverity complains about a possible overrun inside the class below, but that's actually
// a false positive.
// coverity[identity_transfer:FALSE]
@@ -857,7 +857,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_SIMPL
}
} else if (tid < tidEndBcast && nvls->headRank != -1) {
// Recv from network, broadcast
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 0, 1>;
// Coverity complains about a possible overrun inside the class below, but that's actually
// a false positive.
// coverity[identity_transfer:FALSE]
@@ -907,7 +907,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -919,7 +919,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
}
} else if (tid < tidEndGather) {
// Gather
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_NVLS_ARITY, 0>, /*Direct=*/0, Proto, 0>
prims(tid - tidEndScatter, nThreadsGather, nvls->up, NULL, NULL, work->recvbuff,
work->redOpArg, 1 * Proto::MaxGroupWidth, 1, 1);
@@ -932,7 +932,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
} else if (tid < tidEndReduce && nvls->headRank != -1) {
if (!hasUp) {
// Reduce and Broadcast
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 1>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 1>;
Primitives<T, RedOp, FanSymmetric<3>, /*Direct=*/1, Proto, 0>
prims(tid - tidEndGather, nThreadsReduce, treeDown, treeDown, NULL, NULL,
work->redOpArg, 2 * Proto::MaxGroupWidth, 0, 0, work);
@@ -946,7 +946,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
}
} else {
// Reduce, send to network
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
// Coverity reports that the callee treats &treeUp as an array. However, due to the use of
// FanAsymmetric<3, 1>, only the first element is ever accessed, so it's fine.
// coverity[callee_ptr_arith:FALSE]
@@ -964,7 +964,7 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_NVLS_TREE, NCCL_PROTO_
}
} else if (tid < tidEndBcast && nvls->headRank != -1) {
// Recv from network, broadcast
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 0, 1>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 0, 1>;
// Coverity reports that the callee treats &treeUp as an array. However, due to the use of
// FanAsymmetric<1, 3>, only the first element is ever accessed, so it's fine.
// coverity[callee_ptr_arith:FALSE]
+1 -1
Просмотреть файл
@@ -52,7 +52,7 @@ namespace {
if (num_hops == 0 && work->sendbuff != work->recvbuff) {
const T* sendbuff = (const T*)work->sendbuff + send_offset;
T* recvbuff = (T *)work->recvbuff + recv_offset;
reduceCopy<COLL_UNROLL, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
reduceCopy<COLL_UNROLL, USE_ACC, RedOp, T, 0,1, 1, 0, 1, 1, 0>(
tid, nthreads, 0, nullptr, false, 1, (void **)&sendbuff, 1, (void **)&recvbuff, send_recv_size);
} else {
for (ssize_t prims_offset = 0; prims_offset < send_recv_size; prims_offset += prims_size) {
+2 -2
Просмотреть файл
@@ -91,7 +91,7 @@ namespace {
} else if (inputBuf != outputBuf && rank == root) {
inputBuf = inputBuf + gridOffset;
outputBuf = outputBuf + gridOffset;
reduceCopy<COLL_UNROLL, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
reduceCopy<COLL_UNROLL, 0, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs=*/0>
(tid - workNthreads, nthreads - workNthreads, work->redOpArg, &work->redOpArg, false, 1, (void**)&inputBuf, 1, (void**)&outputBuf, channelCount);
}
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_BROADCAST_RING_EXIT)
@@ -126,4 +126,4 @@ struct RunWorkColl<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
runRing<T, RedOp, ProtoLL128>(tid, nthreads, work);
}
};
};
+8 -8
Просмотреть файл
@@ -392,14 +392,14 @@ __device__ __forceinline__ void loadWorkBatchToShmem(
}
}
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
struct RunWorkColl {
__device__ void run(int tid, int tn, struct ncclDevWorkColl* work) {
// Put NOT IMPLEMENTED behavior here.
}
};
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
struct RunWorkBatch;
// Specialized for P2p in sendrecv.h
@@ -407,7 +407,7 @@ template<typename T, typename RedOp>
struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE>;
// Specialized here for non-P2p (Coll and CollReg)
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
struct RunWorkBatch {
// This __forceinline__ is necessary. The compiler was inserting a function call
// here from the LL ncclKernel.
@@ -437,7 +437,7 @@ struct RunWorkBatch {
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
// However, the code ensures that the participation is on a per-warp basis.
// coverity[device_thread_diverged:FALSE]
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, COLL_UNROLL>().run(tid, subtn, work);
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, USE_ACC, COLL_UNROLL>().run(tid, subtn, work);
}
}
};
@@ -672,14 +672,14 @@ __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONST
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
#ifdef USE_INDIRECT_FUNCTION_CALL
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
__device__ void ncclDevFunc_##suffix() { \
RunWorkBatch<coll, ty, redop<ty>, algo, proto, unroll>().run(); \
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
}
#else
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
__device__ __attribute__((noinline)) void ncclDevFunc_##suffix() { \
RunWorkBatch<coll, ty, redop<ty>, algo, proto, unroll>().run(); \
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
}
#endif
+4 -4
Просмотреть файл
@@ -618,7 +618,7 @@ __device__ __attribute__((noinline)) void reduceCopyPacksWithBias(
thread = warp*WARP_SIZE + lane;
}
template<int Unroll, typename RedFn, typename T,
template<int Unroll, int useAcc, typename RedFn, typename T,
int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
@@ -641,7 +641,7 @@ __device__ __forceinline__ void reduceCopy(
IntBytes nBytesBehind = 0;
IntBytes nBytesAhead = nElts*sizeof(T);
bool useAcc = accPtrFn() != nullptr;
//bool useAcc = accPtrFn() != nullptr;
#if __cpp_if_constexpr
if constexpr (BigPackSize > sizeof(T)) {
@@ -763,7 +763,7 @@ __device__ __forceinline__ void reduceCopy(
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
}
template<int Unroll, typename RedFn, typename T,
template<int Unroll, int useAcc, typename RedFn, typename T,
int MultimemSrcs, int MinSrcs, int MaxSrcs,
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
typename IntBytes>
@@ -773,7 +773,7 @@ __device__ __forceinline__ void reduceCopy(
int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs,
IntBytes nElts, void *accPtr = nullptr
) {
reduceCopy<Unroll, RedFn, T,
reduceCopy<Unroll, useAcc, RedFn, T,
MultimemSrcs, MinSrcs, MaxSrcs,
MultimemDsts, MinDsts, MaxDsts, PreOpSrcs, IntBytes>
(thread, nThreads, redArg, preOpArgs, postOp,
+67 -53
Просмотреть файл
@@ -4,14 +4,15 @@ import sys
import subprocess
# Order of redops, tys, protos, algos must match src/include/device.h
all_colls = ["AllGather","AllReduce","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"]
all_colls = ["AllGather","AllReduce","AllReduceWithBias","AllToAllPivot","Broadcast","Reduce","ReduceScatter","SendRecv"]
all_redops = ["Sum","Prod","MinMax","PreMulSum","SumPostDiv"]
all_tys = ["i8","u8","i32","u32","i64","u64","f16","f32","f64","bf16","f8e4m3","f8e5m2"]
all_protos = ["LL","LL128","SIMPLE"]
all_algos = ["TREE","RING", "PAT"]
all_unroll = ["1", "2", "4"]
use_acc = ["0", "1"]
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, all_unroll]
all_params = [all_colls, all_algos, all_protos, all_redops, all_tys, use_acc, all_unroll]
################################################################################
# The first command line argument is the path to the directory to generate and
@@ -76,43 +77,47 @@ func_pattern = sys.argv[6:7]
if func_pattern and func_pattern[0]:
func_pattern = func_pattern[0]
else:
func_pattern = "AllGather|AllReduce|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
func_pattern = "AllGather|AllReduce|AllReduceWithBias|AllToAllPivot|Broadcast|Reduce|ReduceScatter|SendRecv"
################################################################################
algos_of_coll = {
"AllGather": ["RING", "PAT"],
"AllReduce": ["RING", "TREE"],
"AllToAllPivot": ["RING"],
"Broadcast": ["RING"],
"Reduce": ["RING"],
"ReduceScatter": ["RING", "PAT"],
"SendRecv": ["RING"]
"AllGather": ["RING", "PAT"],
"AllReduce": ["RING", "TREE"],
"AllReduceWithBias": ["RING", "TREE"],
"AllToAllPivot": ["RING"],
"Broadcast": ["RING"],
"Reduce": ["RING"],
"ReduceScatter": ["RING", "PAT"],
"SendRecv": ["RING"]
}
protos_of_coll = {
"AllGather": all_protos,
"AllReduce": all_protos,
"AllToAllPivot": ["SIMPLE"],
"Broadcast": all_protos,
"Reduce": all_protos,
"ReduceScatter": all_protos,
"SendRecv": ["SIMPLE"]
"AllGather": all_protos,
"AllReduce": all_protos,
"AllReduceWithBias": all_protos,
"AllToAllPivot": ["SIMPLE"],
"Broadcast": all_protos,
"Reduce": all_protos,
"ReduceScatter": all_protos,
"SendRecv": ["SIMPLE"]
}
redops_of_coll = {
"AllGather": ["Sum"],
"AllReduce": all_redops,
"AllToAllPivot": ["Sum"],
"Broadcast": ["Sum"],
"Reduce": all_redops,
"ReduceScatter": all_redops,
"SendRecv": ["Sum"]
"AllGather": ["Sum"],
"AllReduce": all_redops,
"AllReduceWithBias": all_redops,
"AllToAllPivot": ["Sum"],
"Broadcast": ["Sum"],
"Reduce": all_redops,
"ReduceScatter": all_redops,
"SendRecv": ["Sum"]
}
tys_of_coll = {
"AllGather": ["i8"],
"AllReduce": all_tys,
"AllReduceWithBias": all_tys,
"AllToAllPivot": ["i8"],
"Broadcast": ["i8"],
"Reduce": all_tys,
@@ -121,11 +126,12 @@ tys_of_coll = {
}
coll_camel_to_lower = {
"AllGather": "all_gather",
"AllReduce": "all_reduce",
"AllToAllPivot": "alltoall_pivot",
"Broadcast": "broadcast",
"Reduce": "reduce",
"AllGather": "all_gather",
"AllReduce": "all_reduce",
"AllReduceWithBias": "allreduce_with_bias",
"AllToAllPivot": "alltoall_pivot",
"Broadcast": "broadcast",
"Reduce": "reduce",
"ReduceScatter": "reduce_scatter",
"SendRecv": "sendrecv"
}
@@ -174,7 +180,11 @@ def calc_unroll_for_local_arch():
return all_unroll
# Helper function to check if the conditions for the collective is being met
def func_validate(coll, algo, proto, redop, ty, unroll):
def func_validate(coll, algo, proto, redop, ty, acc, unroll):
if acc == "1" and coll != "AllReduceWithBias":
return False
if acc == "0" and coll == "AllReduceWithBias":
return False
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
return False
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
@@ -222,10 +232,10 @@ def func_filter(function_params, current_idx, item_list=None):
# For each loop layer remove the last element in item_list
item_list.pop()
else:
coll, algo, proto, redop, ty, unroll = item_list
coll, algo, proto, redop, ty, acc, unroll = item_list
if func_validate(coll, algo, proto, redop, ty, unroll):
yield(coll, algo, proto, redop, ty, unroll)
if func_validate(coll, algo, proto, redop, ty, acc, unroll):
yield(coll, algo, proto, redop, ty, acc, unroll)
# Parse ONLY_FUNCS input and feed it to func_filter
def parse_input(func_pattern):
@@ -245,33 +255,35 @@ def parse_input(func_pattern):
# Maps functions to the chosen representative for the equivalence class it
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
def equivalent_primary(coll, algo, proto, redop, ty, acc, unroll):
if coll in ("AllReduce", "AllReduceWithBias", "Reduce", "ReduceScatter"):
# map signed integer sum/prod to unsigned
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
ty = "u"+ty[1:]
# map signed integer min/max to unsigned for non-NVLS
elif redop=="MinMax" and ty[0]=="i" and ("NVLS" not in algo):
ty = "u"+ty[1:]
return (coll, algo, proto, redop, ty, unroll)
return (coll, algo, proto, redop, ty, acc, unroll)
# Order rows are enumerated must match formula of `ncclDevFuncId()`:
def enumerate_func_rows():
for unroll in all_unroll:
for coll in all_colls:
for algo in all_algos:
for proto in all_protos:
for redop in all_redops:
for ty in all_tys:
if func_validate(coll, algo, proto, redop, ty, unroll):
yield (coll, algo, proto, redop, ty, unroll)
for acc in use_acc:
for unroll in all_unroll:
for coll in all_colls:
for algo in all_algos:
for proto in all_protos:
for redop in all_redops:
for ty in all_tys:
if func_validate(coll, algo, proto, redop, ty, acc, unroll):
yield (coll, algo, proto, redop, ty, acc, unroll)
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
def custom_sort_key(fn):
coll, algo, proto, redop, ty, unroll = fn
coll, algo, proto, redop, ty, acc, unroll = fn
return (
all_unroll.index(unroll),
use_acc.index(acc),
all_colls.index(coll),
all_algos.index(algo),
all_protos.index(proto),
@@ -320,7 +332,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
index1 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
coll, algo, proto, redop, ty, acc, unroll = fn
if unroll != "1": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
@@ -337,7 +349,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
index2 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
coll, algo, proto, redop, ty, acc, unroll = fn
if unroll != "2": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
@@ -354,7 +366,7 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
index4 = 0
for fn in primary_funcs:
coll, algo, proto, redop, ty, unroll = fn
coll, algo, proto, redop, ty, acc, unroll = fn
if unroll != "4": continue
sym = paste("_", "ncclDevFunc", *fn)
if fn[2] == "LL128":
@@ -469,7 +481,7 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
# Maps to .cu filename which implements this func. The only constraint is that
# "coll" is reflected in the name: formally that no two funcs having different
# coll's map to the same filename.
def impl_filename(coll, algo, proto, redop, ty, unroll):
def impl_filename(coll, algo, proto, redop, ty, acc, unroll):
return "%s.cpp" % paste("_", coll_camel_to_lower[coll], redop and redop.lower(), ty)
# Partition the functions and kernels to the .cu filenames. The partition is
@@ -518,6 +530,8 @@ for name in name_to_funcs.keys():
print("-- Generating %s" % os.path.join(gensrc, name))
out = f.write
if coll == "AllReduceWithBias":
coll = "AllReduce"
out(
'#include "common.h"\n'
'#include "{lower_coll}.h"\n'
@@ -525,14 +539,14 @@ for name in name_to_funcs.keys():
)
for fn in fns:
(coll, algo, proto, redop, ty, unroll) = fn
sym = paste("_", coll, algo, proto, redop, ty, unroll)
(coll, algo, proto, redop, ty, acc, unroll) = fn
sym = paste("_", coll, algo, proto, redop, ty, acc, unroll)
if proto == "LL128":
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
out(
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {unroll})\n"
"DEFINE_ncclDevFunc({sym}, ncclFunc{coll}, {redop_cxx}, {ty_cxx}, NCCL_ALGO_{algo}, NCCL_PROTO_{proto}, {acc}, {unroll})\n"
.format(sym=sym, coll=coll, redop_cxx=redop_to_cxx[redop], ty_cxx=ty_to_cxx[ty],
algo=(algo or "RING"), proto=(proto or "SIMPLE"), unroll=unroll)
algo=(algo or "RING"), proto=(proto or "SIMPLE"), acc=acc, unroll=unroll)
)
if proto == "LL128":
out("#endif\n")
+1 -1
Просмотреть файл
@@ -376,7 +376,7 @@ __global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct n
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128, fullOps>(comm, algo, work); \
} \
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS, 2>, fullOps>(comm, algo, work); \
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS, 0, 2>, fullOps>(comm, algo, work); \
}
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
+1 -1
Просмотреть файл
@@ -49,7 +49,7 @@ namespace {
redOpArg = *reinterpret_cast<uint64_t*>(redOpArg);
}
}
reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/1>
reduceCopy<COLL_UNROLL,0, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/1>
(tid, tn, redOpArg, &redOpArg, true, 1, &src, 1, &dst, i1-i0);
}
}
+1 -1
Просмотреть файл
@@ -55,7 +55,7 @@
* to how that protocol operates with a consistent interface so that our
* algorithm code can operate protocol parametrically.
*/
template<int SlicePerChunk_1, int StepPerSlice_1, int Unroll_1, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
template<int SlicePerChunk_1, int StepPerSlice_1,int useAcc, int Unroll_1, int MultimemSrcs_1 = 0, int MultimemDsts_1 = 0>
struct ProtoSimple {
static constexpr int Id = NCCL_PROTO_SIMPLE;
static constexpr int SlicePerChunk = SlicePerChunk_1;
+4 -4
Просмотреть файл
@@ -10,9 +10,9 @@
#include "npkit/npkit.h"
#endif
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload>
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>> {
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int useAcc>
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, useAcc>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, useAcc>> {
// In the case of Fan::MaxRecv == 0, we need to force MaxRecv to 1 for this to compile
// This is because of a recv buffer which is allocated to MaxRecv length in send-only cases
@@ -437,7 +437,7 @@ private:
constexpr int DST = DstBuf != -1 ? 1 : 0;
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
T *accElts = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
T *accElts = (DstBuf == -1 || !useAcc) ? nullptr : userBufs[Acc] + dstIx;
// Always waitSend in case of cleanup
nelem = nelem < 0 ? 0 : nelem;
+4 -4
Просмотреть файл
@@ -20,9 +20,9 @@
#endif
#endif
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload>
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>> {
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int useAcc>
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, useAcc>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, useAcc>> {
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1, Acc=2;;
@@ -366,7 +366,7 @@ private:
constexpr int DST = DstBuf != -1 ? 1 : 0;
T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
T *accPtr = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
T *accPtr = (DstBuf == -1 || !useAcc) ? nullptr : userBufs[Acc] + dstIx;
int wireOffset = WireWordPerSlice*warp + 2*wid;
const int nwarps = nthreads/WARP_SIZE;
nelem = nelem < 0 ? 0 : nelem;
+14 -14
Просмотреть файл
@@ -23,9 +23,9 @@ enum primsMode {
};
template<typename T, typename RedOp, typename Fan, int Direct,
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata>
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata, int useAcc>
class Primitives<
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll, MultimemSrcs, MultimemDsts>, P2p, isNetOffload, Metadata
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, useAcc, Unroll, MultimemSrcs, MultimemDsts>, P2p, isNetOffload, Metadata
> {
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1;
@@ -299,7 +299,7 @@ private:
}
#endif
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
(tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false,
1, ncclShmem.groups[group].srcs,
fan.nsend(), ncclShmem.groups[group].dsts+1,
@@ -335,7 +335,7 @@ private:
}
#endif
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
(tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp,
Recv, ncclShmem.groups[group].srcs,
Dst, ncclShmem.groups[group].dsts,
@@ -373,7 +373,7 @@ private:
DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1;
if (Send && Dst && ncclShmem.groups[group].dsts[1] == nullptr) {
// this case should only be directCopySend() with registered buffers and send to net peer
reduceCopy<Unroll, RedOp, T,
reduceCopy<Unroll, useAcc, RedOp, T,
0, Recv + Src, Recv * MaxRecv + Src,
0, 1, 1, PreOpSrcs>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
@@ -381,7 +381,7 @@ private:
1, ncclShmem.groups[group].dsts,
workSize);
} else {
reduceCopy<Unroll, RedOp, T,
reduceCopy<Unroll, useAcc, RedOp, T,
MultimemSrcs, Recv + Src, Recv * MaxRecv + Src,
MultimemDsts, Send + Dst, Send * MaxSend + Dst, PreOpSrcs>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
@@ -453,19 +453,19 @@ private:
srcs[nsrcs] = dsts[0];
nsrcs++;
if (MULTISRCS){
reduceCopy<Unroll, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem);
} else {
reduceCopy<Unroll, RedOp, T, 0, 2, 2, 0, 1, 1, 0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 2, 2, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem);
}
}
if (COPY){
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)srcs, 1, (void **)dsts, nelem);
if (MULTISRCS) {
for (int i = 1; i < nsrcs; i++){
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)&srcs[i], 1, (void **)&dsts[i], nelem);
}
}
@@ -617,7 +617,7 @@ private:
void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset;
ssize_t realPeerSize = min(realSize, totalElem-pOffset);
if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) {
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize);
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize);
// Mark for threadfence at the end
fenceNeeded |= true;
}
@@ -637,7 +637,7 @@ private:
void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset;
ssize_t realPeerSize = min(realSize, totalElem-pOffset);
if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0;
if (realPeerSize > 0) reduceCopy<Unroll, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
if (realPeerSize > 0) reduceCopy<Unroll, useAcc, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
}
}
}
@@ -1190,7 +1190,7 @@ public:
int workSize = ncclShmem.aborted ? 0 : nelem;
reduceCopy<Unroll, RedOp, T, 0, 1, 2, 0, 1, 1, /*PreOpSrcs*/0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 2, 0, 1, 1, /*PreOpSrcs*/0>
(tid, nthreads, ncclShmem.redOpArgs[0], nullptr, /*postOp=*/false,
nSrcs, srcs, 1, ncclShmem.groups[group].dsts, workSize);
@@ -1284,7 +1284,7 @@ public:
int workSize = ncclShmem.aborted ? 0 : nelem;
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 2, /*PreOpSrcs*/0>
reduceCopy<Unroll, useAcc, RedOp, T, 0, 1, 1, 0, 1, 2, /*PreOpSrcs*/0>
(tid, nthreads, ncclShmem.redOpArgs[0], nullptr, /*postOp=*/false,
1, ncclShmem.groups[group].srcs, nDsts, dsts, workSize);
+4 -4
Просмотреть файл
@@ -258,7 +258,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
if (!work->regUsed) {
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, NULL, nvls->up, work->sendbuff, NULL,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -270,7 +270,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
// coverity[overrun-call] => Coverity think prims.index can be greater than 1
} else if (tid < tidEndReduce) {
// Reduce through NVLS
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto, 0>
prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, work->recvbuff,
work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0);
@@ -283,7 +283,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
} else {
if (tid < tidEndScatter) {
// Scatter
using Proto = ProtoSimple<1, 1, COLL_UNROLL>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL>;
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_NVLS_ARITY>, /*Direct=*/0, Proto, 0>
prims(tid, nThreadsScatter, nvls->up, nvls->up, NULL, NULL,
work->redOpArg, 0 * Proto::MaxGroupWidth, 1, 1);
@@ -295,7 +295,7 @@ struct RunWorkColl<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_NVLS, NCCL_PROTO_S
prims.gather(0, 0, 0, 0, -1, 0);
} else if (tid < tidEndReduce) {
// Reduce through NVLS
using Proto = ProtoSimple<1, 1, COLL_UNROLL, 1, 0>;
using Proto = ProtoSimple<1, 1, USE_ACC, COLL_UNROLL, 1, 0>;
Primitives<T, RedOp, FanSymmetric<1>, /*Direct=*/1, Proto, 0>
prims(tid - tidEndScatter, nThreadsReduce, &nvls->down, &nvls->down, NULL, work->recvbuff,
work->redOpArg, 3 * Proto::MaxGroupWidth, 0, 0, work);
+7 -7
Просмотреть файл
@@ -247,10 +247,10 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
if (isCopy) {
#if defined(__gfx942__) || defined(__gfx950__)
reduceCopy<COLL_UNROLL*2, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
reduceCopy<COLL_UNROLL*2, 0, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
(subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes);
#else
reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
reduceCopy<COLL_UNROLL, 0, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
(subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes);
#endif
} else if (isSend) {
@@ -258,9 +258,9 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
runSend<ProtoLL>(subtid, subtn, group, work);
} else {
#if defined(__gfx90a__)
runSend<ProtoSimple<1,1,8>>(subtid, subtn, group, work);
runSend<ProtoSimple<1,1,0,8>>(subtid, subtn, group, work);
#elif defined(__gfx908__) || defined(__gfx942__) || defined(__gfx950__)
runSend<ProtoSimple<1,1,4>>(subtid, subtn, group, work);
runSend<ProtoSimple<1,1,0,4>>(subtid, subtn, group, work);
#else
runSend<ProtoSimple<1,1>>(subtid, subtn, group, work);
#endif
@@ -270,13 +270,13 @@ struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
runRecv<ProtoLL>(subtid, subtn, group, work);
} else {
#if defined(__gfx90a__)
runRecv<ProtoSimple<1,1,8>>(subtid, subtn, group, work);
runRecv<ProtoSimple<1,1,0,8>>(subtid, subtn, group, work);
#elif defined(__gfx908__) || defined(__gfx942__) || defined(__gfx950__)
runRecv<ProtoSimple<1,1,4>>(subtid, subtn, group, work);
runRecv<ProtoSimple<1,1,0,4>>(subtid, subtn, group, work);
#else
runRecv<ProtoSimple<1,1>>(subtid, subtn, group, work);
#endif
}
}
}
};
};
+1 -1
Просмотреть файл
@@ -755,6 +755,6 @@ inline int ncclDevFuncId(int coll, int devRedOp, int type, int algo, int proto)
return ncclDevFuncRowToId[row];
}
inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - NCCL_NUM_ONERANK - 1]; }
inline int ncclDevFuncId_P2p() { return ncclDevFuncRowToId[FUNC_INDEX_TOTAL - AR_WITH_BIAS_FUNC_COUNTS - NCCL_NUM_ONERANK - 1]; }
#endif
+8 -6
Просмотреть файл
@@ -40,7 +40,8 @@ typedef enum {
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
#define NCCL_NUM_ONERANK 12
#define FUNC_INDEX_TOTAL 821 + NCCL_NUM_ONERANK
#define AR_WITH_BIAS_FUNC_COUNTS 324
#define FUNC_INDEX_TOTAL 821 + AR_WITH_BIAS_FUNC_COUNTS + NCCL_NUM_ONERANK
#define NCCL_NUM_FUNCTIONS 5 // Send/Recv not included for now
typedef enum {
@@ -49,11 +50,12 @@ typedef enum {
ncclFuncAllGather = 2,
ncclFuncReduceScatter = 3,
ncclFuncAllReduce = 4,
ncclFuncSendRecv = 5,
ncclFuncSend = 6,
ncclFuncRecv = 7,
ncclFuncAllToAllPivot = 8,
ncclNumFuncs = 9
ncclFuncAllReduceWithBias = 5,
ncclFuncSendRecv = 6,
ncclFuncSend = 7,
ncclFuncRecv = 8,
ncclFuncAllToAllPivot = 9,
ncclNumFuncs = 10
} ncclFunc_t;
#define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*