Files
rocm-systems/src/device/reduce.h
T
2025-03-27 12:53:04 -05:00

81 γραμμές
3.3 KiB
C++

/*************************************************************************
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "device.h"
#include "collectives.h"
#include "primitives.h"
namespace {
template<typename T, typename RedOp, typename Proto>
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__)
__device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
#else
__device__ __attribute__((noinline)) void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
#endif
ncclRing *ring = &ncclShmem.channel.ring;
const int nranks = ncclShmem.comm.nRanks;
const int rank = ncclShmem.comm.rank;
const int prevRank = ring->userRanks[nranks-1];
const int root = work->root;
size_t chunkCount;
size_t channelCount;
size_t gridOffset;
ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), (size_t*)nullptr, &gridOffset, &channelCount, &chunkCount);
size_t offset;
int nelem;
// Coverity reports that the callee treats &ring->next as an array. However, due to the use of
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
// coverity[callee_ptr_arith:FALSE]
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0>
prims(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex);
if (prevRank == root) {
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
offset = gridOffset + elemOffset;
nelem = min(chunkCount, channelCount - elemOffset);
prims.send(offset, nelem);
}
}
else if (rank == root) {
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
offset = gridOffset + elemOffset;
nelem = min(chunkCount, channelCount - elemOffset);
prims.recvReduceCopy(offset, offset, nelem, /*postOp=*/true);
}
}
else {
for (size_t elemOffset = 0; elemOffset < channelCount; elemOffset += chunkCount) {
offset = gridOffset + elemOffset;
nelem = min(chunkCount, channelCount - elemOffset);
prims.recvReduceSend(offset, nelem);
}
}
}
}
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
using Proto = ProtoSimple<REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS>;
runRing<T, RedOp, Proto>(tid, nthreads, work);
}
};
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
runRing<T, RedOp, ProtoLL>(tid, nthreads, work);
}
};
template<typename T, typename RedOp>
struct RunWorkColl<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
runRing<T, RedOp, ProtoLL128>(tid, nthreads, work);
}
};