Support fused all reduce and elementwise operations (#1729)
* Support fused all reduce and elementwise operations Add additional "acc" parameter to RCCL Replayer logs Add flag which indicates availability of new API * Fix Recorder json parsing * Remove unreachable code * Remove extra acc pointer check * . * Revert "[DEVICE] Adding ability to choose unroll factor at runtime (#1734)" This reverts commit4cadf3597c. * Use noinline to reduce kernels linking time * Don't use noinline for gfx942 and gfx950 to avoid perf regression --------- Co-authored-by: AtlantaPepsi <timhu102@amd.com> Co-authored-by: BertanDogancay <bertan.dogancay@gmail.com> [ROCm/rccl commit:9a4213356d]
Этот коммит содержится в:
@@ -14,10 +14,9 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
|
||||
### Added
|
||||
|
||||
* Added new GPU target `gfx950`.
|
||||
* Added support for `unroll=1` in device-code generation to improve performance.
|
||||
* Set a default of 112 channels for a single node with `8 * gfx950`.
|
||||
* Added support for `unroll=1` in device-code generation to improve performance,
|
||||
* Set a default of 112 channels for a single node with `8 * gfx950`,
|
||||
* Enabled LL128 protocol on `gfx950`.
|
||||
* Adding ability to choose unroll factor at runtime via `RCCL_UNROLL_FACTOR`. This can be set at runtime to 1, 2, or 4. This change currently increases compilation and linking time because it triples the number of kernels generated.
|
||||
* Added MSCCL support for AllGather multinode gfx942/gfx950 (i.e., 16 and 32 GPUs). To enable, set the environment variable `RCCL_MSCCL_FORCE_ENABLE=1`. Max message size for MSCCL AllGather usage is `12292 * sizeof(datatype) * nGPUs`.
|
||||
* Thread thresholds for LL/LL128 are selected in Tuning Models for the MI300X. This impacts the number of channels used for AG and RS. Channel tuning model is bypassed if `NCCL_THREAD_THRESHOLDS`, `NCCL_MIN_NCHANNELS', or 'NCCL_MAX_NCHANNELS` are set.
|
||||
* Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges.
|
||||
|
||||
@@ -778,13 +778,9 @@ endif()
|
||||
|
||||
set(GEN_DIR "${HIPIFY_DIR}/gensrc")
|
||||
|
||||
if(ONLY_FUNCS)
|
||||
message(WARNING "Using ONLY_FUNCS = ${ONLY_FUNCS}. Not meant for release builds.")
|
||||
endif()
|
||||
|
||||
# Execute the python script to generate required files
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${ONLY_FUNCS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_SOURCE_DIR}/src/device/generate.py ${GEN_DIR} ${IFC_ENABLED} ${COLLTRACE} ${ENABLE_MSCCL_KERNEL} ${BUILD_LOCAL_GPU_TARGET_ONLY} ${ONLY_FUNCS}
|
||||
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||
RESULT_VARIABLE gen_py_result
|
||||
ERROR_VARIABLE gen_py_error
|
||||
|
||||
@@ -89,7 +89,7 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen
|
||||
|
||||
struct ncclInfo info = { ncclFuncAllGather, "AllGather",
|
||||
sendbuff, recvbuff, sendcount, datatype, ncclSum, 0, comm, stream, /* Args */
|
||||
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS };
|
||||
ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -117,7 +117,7 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
|
||||
// RCCL update slice steps for AllReduce if single node
|
||||
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
||||
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
||||
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS };
|
||||
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -135,6 +135,26 @@ ncclResult_t ncclAllReduce_impl(const void* sendbuff, void* recvbuff, size_t cou
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
|
||||
ncclResult_t ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc) {
|
||||
NVTX3_FUNC_WITH_PARAMS(AllReduce, NcclNvtxParamsAllReduce,
|
||||
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), op, datatype));
|
||||
|
||||
if (acc == nullptr) {
|
||||
WARN("ncclAllReduceWithBias : acc cannot be nullptr");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
// RCCL update slice steps for AllReduce if single node
|
||||
struct ncclInfo info = { ncclFuncAllReduce, "AllReduce",
|
||||
sendbuff, recvbuff, count, datatype, op, 0, comm, stream, /* Args */
|
||||
ALLREDUCE_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLREDUCE_SLICESTEPS_SINGLE_NODE : ALLREDUCE_SLICESTEPS, acc };
|
||||
|
||||
NCCLCHECK(Recorder::instance().record(rrAllReduceWithBias, info));
|
||||
|
||||
return ncclEnqueueCheck(&info);
|
||||
}
|
||||
|
||||
RCCL_PARAM(AllToAllPivotEnable, "ALL_TO_ALL_PIVOT_ENABLE", 0);
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
@@ -164,7 +184,7 @@ ncclResult_t ncclAllToAll_impl(const void* sendbuff, void* recvbuff, size_t coun
|
||||
rankOffset >= 744 * 1024 && rankAlign != 4 && rcclParamAllToAllPivotEnable()) {
|
||||
struct ncclInfo info = { ncclFuncAllToAllPivot, "AllToAllPivot",
|
||||
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream, /* Args */
|
||||
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS };
|
||||
ALLTOALL_PIVOT_CHUNKSTEPS, ALLTOALL_PIVOT_SLICESTEPS, nullptr };
|
||||
return ncclEnqueueCheck(&info);
|
||||
} else {
|
||||
int nRanks;
|
||||
@@ -240,7 +260,7 @@ ncclResult_t ncclBroadcast_impl(const void* sendbuff, void* recvbuff, size_t cou
|
||||
|
||||
struct ncclInfo info = { ncclFuncBroadcast, "Broadcast",
|
||||
sendbuff, recvbuff, count, datatype, ncclSum, root, comm, stream, /* Args */
|
||||
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS };
|
||||
BROADCAST_CHUNKSTEPS, BROADCAST_SLICESTEPS, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -311,7 +331,7 @@ ncclResult_t ncclReduce_impl(const void* sendbuff, void* recvbuff, size_t count,
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduce, "Reduce",
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream, /* Args */
|
||||
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS };
|
||||
REDUCE_CHUNKSTEPS, REDUCE_SLICESTEPS, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -338,7 +358,7 @@ ncclResult_t ncclReduceScatter_impl(const void* sendbuff, void* recvbuff, size_t
|
||||
|
||||
struct ncclInfo info = { ncclFuncReduceScatter, "ReduceScatter",
|
||||
sendbuff, recvbuff, recvcount, datatype, op, 0, comm, stream, /* Args */
|
||||
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS };
|
||||
REDUCESCATTER_CHUNKSTEPS, comm -> rcclUseOneSlice ? REDUCESCATTER_SLICESTEPS_SINGLE_NODE : REDUCESCATTER_SLICESTEPS, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -403,7 +423,7 @@ ncclResult_t ncclSend_impl(const void* sendbuff, size_t count, ncclDataType_t da
|
||||
|
||||
struct ncclInfo info = { ncclFuncSend, "Send",
|
||||
NULL, (void*)sendbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
||||
1, 1 };
|
||||
1, 1, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
@@ -429,7 +449,7 @@ ncclResult_t ncclRecv_impl(void* recvbuff, size_t count, ncclDataType_t datatype
|
||||
|
||||
struct ncclInfo info = { ncclFuncRecv, "Recv",
|
||||
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
|
||||
1, 1 };
|
||||
1, 1, nullptr };
|
||||
|
||||
if (!mscclIsCaller()) // when msccl falls back to
|
||||
{
|
||||
|
||||
@@ -589,7 +589,12 @@ struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
|
||||
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(tid, nthreads, work);
|
||||
using Proto = ProtoSimple<1, 1>;
|
||||
if (work->acc != nullptr) {
|
||||
runTreeSplit<T, RedOp, Proto>(tid, nthreads, work);
|
||||
} else {
|
||||
runTreeUpDown<T, RedOp, Proto>(tid, nthreads, work);
|
||||
}
|
||||
// Check-here
|
||||
// #if CUDART_VERSION >= 11020 && CUDART_VERSION < 11040 && __CUDA_ARCH__ >= 800
|
||||
// runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(tid, nthreads, work);
|
||||
|
||||
@@ -13,6 +13,31 @@ __shared__ ncclShmemData ncclShmem;
|
||||
__shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)];
|
||||
#endif
|
||||
|
||||
struct RunWorkNop {
|
||||
__device__ void run() {}
|
||||
};
|
||||
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args);
|
||||
}
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args);
|
||||
}
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
|
||||
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
__device__ void ncclDevFunc_Nop();
|
||||
#else
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
#include "collectives.h"
|
||||
#include "device.h"
|
||||
#include "op128.h"
|
||||
#include "device_table.h"
|
||||
#include "reduce_kernel.h"
|
||||
#include "device_table.h"
|
||||
#include "network/unpack/unpack_defs.h"
|
||||
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
||||
|
||||
@@ -121,8 +121,10 @@ struct ncclShmemGroup {
|
||||
ncclConnInfo *sendConns[NCCL_MAX_ARITY];
|
||||
void* userInput;
|
||||
void* userOutput;
|
||||
void* userAcc;
|
||||
void* srcs[NCCL_MAX_ARITY+1];
|
||||
void* dsts[NCCL_MAX_ARITY+1];
|
||||
void* acc;
|
||||
uint64_t barrier;
|
||||
union {
|
||||
unpackGroupShmem unpack;
|
||||
@@ -606,7 +608,21 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
|
||||
if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) {
|
||||
SpecializedRunWorkBatch().run();
|
||||
} else {
|
||||
NCCL_CALL_FUNCTIONS<COLL_UNROLL>(ncclShmem.funcId);
|
||||
#ifdef USE_INDIRECT_FUNCTION_CALL
|
||||
if (COLL_UNROLL == 1)
|
||||
ncclDevFuncTable_1[ncclShmem.funcId]();
|
||||
else if (COLL_UNROLL == 2)
|
||||
ncclDevFuncTable_2[ncclShmem.funcId]();
|
||||
else
|
||||
ncclDevFuncTable_4[ncclShmem.funcId]();
|
||||
#else
|
||||
if (COLL_UNROLL == 1)
|
||||
NCCL_CALL_FUNCTIONS_1(ncclShmem.funcId);
|
||||
else if (COLL_UNROLL == 2)
|
||||
NCCL_CALL_FUNCTIONS_2(ncclShmem.funcId);
|
||||
else
|
||||
NCCL_CALL_FUNCTIONS_4(ncclShmem.funcId);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (ncclShmem.nextBatchIx == -1) break;
|
||||
@@ -643,6 +659,15 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
||||
#endif
|
||||
|
||||
#define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \
|
||||
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
|
||||
|
||||
|
||||
@@ -31,7 +31,11 @@ template<typename RedFn, typename T, int Unroll, int BytePerPack,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn>
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
__device__ __forceinline__ void reduceCopyPacks(
|
||||
#else
|
||||
__device__ __attribute__((noinline)) void reduceCopyPacks(
|
||||
#endif
|
||||
int nThreads, int &thread,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
|
||||
@@ -207,15 +211,209 @@ __device__ __forceinline__ void reduceCopyPacks(
|
||||
thread = warp*WARP_SIZE + lane;
|
||||
}
|
||||
|
||||
template<typename RedFn, typename T, int Unroll, int BytePerPack,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
__device__ __forceinline__ void reduceCopyPacksWithBias(
|
||||
#else
|
||||
__device__ __attribute__((noinline)) void reduceCopyPacksWithBias(
|
||||
#endif
|
||||
int nThreads, int &thread,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
|
||||
IntBytes &nBytesBehind, IntBytes &nBytesAhead, AccPtrFn const &accPtrFn
|
||||
) {
|
||||
static_assert(std::is_signed<IntBytes>::value, "IntBytes must be a signed integral type.");
|
||||
//if (BytePerPack == 0) __trap();
|
||||
|
||||
// A hunk is the amount of contiguous data a warp consumes per loop iteration
|
||||
// assuming all threads partake.
|
||||
constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack;
|
||||
int nWarps = nThreads/WARP_SIZE;
|
||||
int warp = thread/WARP_SIZE;
|
||||
int lane = thread%WARP_SIZE;
|
||||
|
||||
// This thread's initial position.
|
||||
IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack);
|
||||
IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack);
|
||||
// Number of hunks to be consumed over all warps.
|
||||
IntBytes nHunksAhead = nBytesAhead/(BytePerHunk + !BytePerHunk);
|
||||
// Advance collective position.
|
||||
nBytesBehind += nHunksAhead*BytePerHunk;
|
||||
nBytesAhead -= nHunksAhead*BytePerHunk;
|
||||
if (Unroll==1 && BytePerPack <= nBytesAhead) {
|
||||
// Only Unroll=1 can do partial hunks (where not all threads partake).
|
||||
nHunksAhead += 1;
|
||||
nBytesBehind += nBytesAhead - (nBytesAhead%(BytePerPack + !BytePerPack));
|
||||
nBytesAhead = nBytesAhead%(BytePerPack + !BytePerPack);
|
||||
}
|
||||
nHunksAhead -= warp;
|
||||
|
||||
RedFn redFn(redArg);
|
||||
uintptr_t minSrcs[MinSrcs + !MinSrcs];
|
||||
uintptr_t minDsts[MinDsts + !MinDsts];
|
||||
uintptr_t accPtr = cvta_to_global(accPtrFn()) + threadBytesBehind;
|
||||
BytePack<BytePerPack> bias[Unroll];
|
||||
|
||||
#pragma unroll
|
||||
for (int s=0; s < MinSrcs; s++) {
|
||||
minSrcs[s] = cvta_to_global(srcPtrFn(s)) + threadBytesBehind;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int d=0; d < MinDsts; d++) {
|
||||
// Yes, for some template arguments this code will be unreachable. That's fine.
|
||||
// coverity[dead_error_line]
|
||||
minDsts[d] = cvta_to_global(dstPtrFn(d)) + threadBytesBehind;
|
||||
}
|
||||
|
||||
// We dictate loop termination condition according to whether partial hunks
|
||||
// can be handled or not.
|
||||
while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) {
|
||||
BytePack<BytePerPack> acc[Unroll];
|
||||
|
||||
// minSrcs[0] cannot be nullptr so we always process it
|
||||
{ RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0);
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
if (0 < MultimemSrcs) {
|
||||
// applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
|
||||
acc[u] = applyLoadMultimem<RedFn, BytePerPack>(redFn, minSrcs[0]);
|
||||
} else {
|
||||
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
|
||||
acc[u] = ld_volatile_global<BytePerPack>(minSrcs[0]);
|
||||
// coverity[dead_error_condition]
|
||||
bias[u] = ld_volatile_global<BytePerPack>(accPtr);
|
||||
accPtr += WARP_SIZE*BytePerPack;
|
||||
if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]);
|
||||
}
|
||||
minSrcs[0] += WARP_SIZE*BytePerPack;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll Unroll
|
||||
for (int s=1; s < MinSrcs; s++) {
|
||||
// Yes, for some template arguments this code will be unreachable. That's fine.
|
||||
// coverity[dead_error_begin]
|
||||
BytePack<BytePerPack> tmp[Unroll];
|
||||
// coverity[dead_error_line]
|
||||
RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0);
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
if (s < MultimemSrcs) {
|
||||
// applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
|
||||
// coverity[dead_error_line]
|
||||
tmp[u] = applyLoadMultimem<RedFn, BytePerPack>(redFn, minSrcs[s]);
|
||||
} else {
|
||||
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
|
||||
tmp[u] = ld_volatile_global<BytePerPack>(minSrcs[s]);
|
||||
}
|
||||
minSrcs[s] += WARP_SIZE*BytePerPack;
|
||||
}
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
// coverity[dead_error_line]
|
||||
if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]);
|
||||
acc[u] = applyReduce(redFn, acc[u], tmp[u]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) {
|
||||
uintptr_t src = cvta_to_global(srcPtrFn(s)) + threadBytesBehind;
|
||||
BytePack<BytePerPack> tmp[Unroll];
|
||||
// Yes, for some template arguments this code will be unreachable. That's fine.
|
||||
// coverity[dead_error_line]
|
||||
RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0);
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
// Use volatile loads in case credits are polled for with volatile (instead of acquire).
|
||||
tmp[u] = ld_volatile_global<BytePerPack>(src);
|
||||
src += WARP_SIZE*BytePerPack;
|
||||
}
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
// Yes, for some template arguments this code will be unreachable. That's fine.
|
||||
// coverity[dead_error_line]
|
||||
if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]);
|
||||
acc[u] = applyReduce(redFn, acc[u], tmp[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++)
|
||||
acc[u] = applyPostOp(redFn, acc[u]);
|
||||
}
|
||||
|
||||
#pragma unroll Unroll
|
||||
for (int d=0; d < MinDsts; d++) {
|
||||
#pragma unroll Unroll
|
||||
// Yes, for some template arguments this code will be unreachable. That's fine.
|
||||
// coverity[dead_error_begin]
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
// coverity[dead_error_condition]
|
||||
if (d < MultimemDsts) {
|
||||
multimem_st_global(minDsts[d], acc[u]);
|
||||
} else {
|
||||
if (d == 0)
|
||||
st_global<BytePerPack>(minDsts[d], applyReduce(redFn, acc[u], bias[u]));
|
||||
else
|
||||
st_global<BytePerPack>(minDsts[d], acc[u]);
|
||||
}
|
||||
minDsts[d] += WARP_SIZE*BytePerPack;
|
||||
}
|
||||
}
|
||||
for (int d=MinDsts; (MinDsts < MaxDsts) && (d < MaxDsts) && (d < nDsts); d++) {
|
||||
uintptr_t dstPtr = cvta_to_global(dstPtrFn(d));
|
||||
uintptr_t dst = dstPtr + threadBytesBehind;
|
||||
#pragma unroll Unroll
|
||||
for (int u=0; u < Unroll; u++) {
|
||||
st_global<BytePerPack>(dst, acc[u]);
|
||||
dst += WARP_SIZE*BytePerPack;
|
||||
}
|
||||
}
|
||||
|
||||
nWarps = nThreads/WARP_SIZE;
|
||||
#pragma unroll
|
||||
for (int s=0; s < MinSrcs; s++) {
|
||||
minSrcs[s] += (nWarps-1)*BytePerHunk;
|
||||
}
|
||||
#pragma unroll
|
||||
// Yes, for some template arguments this code will be unreachable. That's fine.
|
||||
// coverity[dead_error_line]
|
||||
for (int d=0; d < MinDsts; d++) {
|
||||
minDsts[d] += (nWarps-1)*BytePerHunk;
|
||||
}
|
||||
accPtr += (nWarps-1)*BytePerHunk;
|
||||
threadBytesBehind += nWarps*BytePerHunk;
|
||||
threadBytesAhead -= nWarps*BytePerHunk;
|
||||
nHunksAhead -= nWarps;
|
||||
}
|
||||
|
||||
nWarps = nThreads/WARP_SIZE;
|
||||
warp = thread/WARP_SIZE;
|
||||
lane = thread%WARP_SIZE;
|
||||
// The last loop iteration could have been partial, i.e. not taken by all
|
||||
// threads. The threads that weren't included need an extra subtraction to
|
||||
// make the value warp uniform.
|
||||
if (Unroll==1 && nHunksAhead > 0) nHunksAhead -= nWarps;
|
||||
// Rotate warps so the warp which got the least work here will be warp 0.
|
||||
// This effectively assigns: warp = (warp-nHunks+nWarps)%nWarps;
|
||||
warp = -nHunksAhead;
|
||||
thread = warp*WARP_SIZE + lane;
|
||||
}
|
||||
|
||||
template<int Unroll, typename RedFn, typename T,
|
||||
int MultimemSrcs, int MinSrcs, int MaxSrcs,
|
||||
int MultimemDsts, int MinDsts, int MaxDsts, int PreOpSrcs,
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn>
|
||||
typename IntBytes, typename SrcPtrFn, typename DstPtrFn, typename AccPtrFn>
|
||||
__device__ __forceinline__ void reduceCopy(
|
||||
int thread, int nThreads,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
int nSrcs, SrcPtrFn const &srcPtrFn, int nDsts, DstPtrFn const &dstPtrFn,
|
||||
IntBytes nElts
|
||||
IntBytes nElts, AccPtrFn const &accPtrFn
|
||||
) {
|
||||
static_assert(MultimemSrcs <= MinSrcs && MultimemDsts <= MinDsts, "Multimem pointers cannot exceed respective Min values.");
|
||||
//int nWarps = nThreads/WARP_SIZE;
|
||||
@@ -230,6 +428,7 @@ __device__ __forceinline__ void reduceCopy(
|
||||
|
||||
IntBytes nBytesBehind = 0;
|
||||
IntBytes nBytesAhead = nElts*sizeof(T);
|
||||
bool useAcc = accPtrFn() != nullptr;
|
||||
|
||||
#if __cpp_if_constexpr
|
||||
if constexpr (BigPackSize > sizeof(T)) {
|
||||
@@ -243,11 +442,23 @@ __device__ __forceinline__ void reduceCopy(
|
||||
aligned = !(__any(!aligned));
|
||||
if (aligned) {
|
||||
#if defined(__gfx90a__)
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
#else
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
@@ -255,6 +466,12 @@ __device__ __forceinline__ void reduceCopy(
|
||||
#endif
|
||||
if (nBytesAhead == 0) return;
|
||||
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, /*Unroll=*/1, BigPackSize,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
@@ -285,17 +502,35 @@ __device__ __forceinline__ void reduceCopy(
|
||||
*/
|
||||
#if defined(__gfx90a__)
|
||||
if (MinSrcs > 1) {
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, (Unroll*4 + sizeof(T) - 1)/sizeof(T), sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, nBytesBehind, nBytesAhead);
|
||||
} else {
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead);
|
||||
}
|
||||
#else
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
@@ -303,6 +538,12 @@ __device__ __forceinline__ void reduceCopy(
|
||||
#endif
|
||||
if (nBytesAhead == 0) return;
|
||||
|
||||
if (useAcc)
|
||||
reduceCopyPacksWithBias<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
nSrcs, srcPtrFn, nDsts, dstPtrFn, /*&*/nBytesBehind, /*&*/nBytesAhead, accPtrFn);
|
||||
else
|
||||
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),
|
||||
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
|
||||
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
|
||||
@@ -317,14 +558,14 @@ __device__ __forceinline__ void reduceCopy(
|
||||
int thread, int nThreads,
|
||||
uint64_t redArg, uint64_t *preOpArgs, bool postOp,
|
||||
int nSrcs, void** srcPtrs, int nDsts, void** dstPtrs,
|
||||
IntBytes nElts
|
||||
IntBytes nElts, void *accPtr = nullptr
|
||||
) {
|
||||
reduceCopy<Unroll, RedFn, T,
|
||||
MultimemSrcs, MinSrcs, MaxSrcs,
|
||||
MultimemDsts, MinDsts, MaxDsts, PreOpSrcs, IntBytes>
|
||||
(thread, nThreads, redArg, preOpArgs, postOp,
|
||||
nSrcs, [=]__device__(int i) { return srcPtrs[i]; },
|
||||
nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts);
|
||||
nDsts, [=]__device__(int i) { return dstPtrs[i]; }, nElts, [=]__device__() { return accPtr; });
|
||||
}
|
||||
|
||||
#endif // COMMON_KERNEL_H_
|
||||
|
||||
@@ -32,7 +32,7 @@ else:
|
||||
# developing device code. The regex supports non-space containing globs '*',
|
||||
# and union 'a|b'. The string representing the function has the form:
|
||||
#
|
||||
# <coll> <algo> <proto> <redop> <type> <unroll>
|
||||
# <coll> <algo> <proto> <redop> <type>
|
||||
#
|
||||
# The possible values for redop, type, algo, proto can be found in the all_<foo>
|
||||
# lists at the top of this file.
|
||||
@@ -45,30 +45,23 @@ else:
|
||||
# # Only AllReduce and Reduce
|
||||
# make ONLY_FUNCS="AllReduce|Reduce"
|
||||
#
|
||||
# # Only AllGather with unroll=4
|
||||
# make ONLY_FUNCS="AllGather * * * * 4"
|
||||
#
|
||||
# # Only non-reductions:
|
||||
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
#
|
||||
# # Only AllReduce Sum int32_t (but all algos, protos)
|
||||
# make ONLY_FUNCS="AllReduce * * Sum i32"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos and unrolls)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max f32"
|
||||
# # Only AllReduce RING Max float (but all protos)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16 1"
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32 *"
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types for AllReduce and all redops for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32 *"
|
||||
#
|
||||
#
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|AllGather RING LL/LL128/SIMPLE Sum i8 1/2/4|AllToAllPivot RING SIMPLE Sum i8 1/2/4|Broadcast RING LL/LL128/SIMPLE Sum i8 1/2/4|Reduce RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|ReduceScatter RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|SendRecv RING SIMPLE Sum i8 1/2/4"
|
||||
#
|
||||
# # ONLY_FUNCS can be used together for debugging
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|AllGather RING LL/SIMPLE Sum int8_t|AllToAllPivot RING SIMPLE Sum int8_t|Broadcast RING LL/SIMPLE Sum int8_t|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8|SendRecv RING SIMPLE Sum int8_t"
|
||||
|
||||
# Paste all non-None arguments together with `sep`.
|
||||
def paste(sep, *args):
|
||||
@@ -77,8 +70,9 @@ def paste(sep, *args):
|
||||
is_ifc = 1 if sys.argv[2] == "ON" else 0
|
||||
is_colltrace = 1 if sys.argv[3] == "ON" else 0
|
||||
is_msccl_kernels = 1 if sys.argv[4] == "ON" else 0
|
||||
is_local_arch_only = 1 if sys.argv[5] == "ON" else 0
|
||||
|
||||
func_pattern = sys.argv[5:]
|
||||
func_pattern = sys.argv[6:7]
|
||||
if func_pattern and func_pattern[0]:
|
||||
func_pattern = func_pattern[0]
|
||||
else:
|
||||
@@ -139,13 +133,49 @@ coll_lower_to_camel = {coll_camel_to_lower[x]: x for x in coll_camel_to_lower}
|
||||
|
||||
################################################################################
|
||||
|
||||
seen_unroll = []
|
||||
def calc_unroll_for_local_arch():
|
||||
if not is_local_arch_only:
|
||||
return
|
||||
|
||||
rocminfo_path = os.environ.get('ROCM_PATH') + "/bin/rocminfo"
|
||||
|
||||
res = subprocess.run([rocminfo_path], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
rocminfo_output = res.stdout
|
||||
|
||||
# Parse rocminfo binary output
|
||||
gfx_targets = {}
|
||||
curr_name = None
|
||||
for line in rocminfo_output.splitlines():
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith("Name:"):
|
||||
name = line.split(':')[-1].strip()
|
||||
if "gfx" in name:
|
||||
curr_name = name
|
||||
if line.startswith("Compute Unit:") and curr_name:
|
||||
cu_count = int(line.split(':')[-1].strip())
|
||||
gfx_targets[(curr_name, cu_count)] = None
|
||||
curr_name = None
|
||||
|
||||
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
|
||||
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
|
||||
gfx_targets = list(gfx_targets.keys())
|
||||
|
||||
# Homogeneous system is required to build for only 1 varient of unroll factor
|
||||
if len(gfx_targets) == 1:
|
||||
gfx_name, cu_count = gfx_targets[0]
|
||||
if "gfx950" == gfx_name:
|
||||
return 1
|
||||
elif "gfx908" == gfx_name or ("gfx942" == gfx_name and cu_count > 80):
|
||||
return 2
|
||||
else:
|
||||
return 4
|
||||
|
||||
# Helper function to check if the conditions for the collective is being met
|
||||
def func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
def func_validate(coll, algo, proto, redop, ty):
|
||||
if redop == "SumPostDiv" and ty[0] not in ("i","u"):
|
||||
return False
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll] or unroll not in all_unroll:
|
||||
if algo not in algos_of_coll[coll] or proto not in protos_of_coll[coll] or redop not in redops_of_coll[coll] or ty not in tys_of_coll[coll]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -161,7 +191,10 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
|
||||
# If the paramter is equal to '*', include all possible cases for it
|
||||
if current_element == "*":
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type> <unroll>
|
||||
if current_idx == 0:
|
||||
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
|
||||
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type>
|
||||
# Get the current list from all_params
|
||||
current_list = all_params[current_idx]
|
||||
|
||||
@@ -177,12 +210,12 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
# Check if the current element is recognized
|
||||
elements = current_element.split("/")
|
||||
current_param = all_params[current_idx]
|
||||
|
||||
|
||||
# Iterate over the elements in the elements list
|
||||
for item in elements:
|
||||
if item not in current_param:
|
||||
raise ValueError(f"Error: {item} is unrecognized or does not belong to this category {current_param}.")
|
||||
|
||||
|
||||
for item in elements:
|
||||
item_list.append(item)
|
||||
yield from func_filter(function_params, current_idx+1, item_list)
|
||||
@@ -192,9 +225,7 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
else:
|
||||
coll, algo, proto, redop, ty, unroll = item_list
|
||||
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
if not unroll in seen_unroll:
|
||||
seen_unroll.append(unroll)
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
yield(coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
# Parse ONLY_FUNCS input and feed it to func_filter
|
||||
@@ -216,6 +247,10 @@ def parse_input(func_pattern):
|
||||
# Maps functions to the chosen representative for the equivalence class it
|
||||
# belongs to. For instance (sum, signed int) maps to (sum, unsigned int).
|
||||
def equivalent_primary(coll, algo, proto, redop, ty, unroll):
|
||||
# if local arch only, we only need to build for 1 varient of coll_unroll.
|
||||
# map the other varient of coll_unroll to this one.
|
||||
if coll_unroll:
|
||||
unroll = str(coll_unroll)
|
||||
if coll in ("AllReduce", "Reduce", "ReduceScatter"):
|
||||
# map signed integer sum/prod to unsigned
|
||||
if redop in ("Sum","Prod","PreMulSum","SumPostDiv") and ty[0]=="i":
|
||||
@@ -233,13 +268,13 @@ def enumerate_func_rows():
|
||||
for proto in all_protos:
|
||||
for redop in all_redops:
|
||||
for ty in all_tys:
|
||||
if func_validate(coll, algo, proto, redop, ty, unroll):
|
||||
if func_validate(coll, algo, proto, redop, ty):
|
||||
yield (coll, algo, proto, redop, ty, unroll)
|
||||
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty> <unroll>
|
||||
# Sort the hashmap based on custom key <coll> <algo> <proto> <redop> <ty>
|
||||
def custom_sort_key(fn):
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
|
||||
|
||||
return (
|
||||
all_unroll.index(unroll),
|
||||
all_colls.index(coll),
|
||||
@@ -251,6 +286,8 @@ def custom_sort_key(fn):
|
||||
|
||||
################################################################################
|
||||
|
||||
coll_unroll = calc_unroll_for_local_arch()
|
||||
|
||||
# Corresponds to ncclDevFuncRowToId[]
|
||||
func_rows = [fn for fn in enumerate_func_rows()]
|
||||
|
||||
@@ -267,8 +304,6 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
print("-- Generating %s" % os.path.join(gensrc, "device_table.h"))
|
||||
out = f.write
|
||||
|
||||
out("#include \"common.h\"\n\n")
|
||||
|
||||
if is_ifc: func_declaration = "__device__ void"
|
||||
else: func_declaration = "__device__ __attribute__((noinline)) void"
|
||||
|
||||
@@ -285,86 +320,113 @@ with open(os.path.join(gensrc, "device_table.h"), "w") as f:
|
||||
out("\n")
|
||||
|
||||
out("typedef void(*ncclDevFuncPtr_t)();\n\n")
|
||||
|
||||
# Generate function tables per unroll factor
|
||||
tableIdx = 0
|
||||
for curr_unroll in seen_unroll:
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_%s[] = {\n" % curr_unroll)
|
||||
tableIdx = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if curr_unroll != unroll: continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (tableIdx, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (tableIdx, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (tableIdx, sym))
|
||||
tableIdx += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
# Construct indirection function workaround
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_1[] = {\n")
|
||||
index1 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "1": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index1, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index1, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index1, sym))
|
||||
index1 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_2[] = {\n")
|
||||
index2 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "2": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index2, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index2, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index2, sym))
|
||||
index2 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
out("__device__ ncclDevFuncPtr_t const ncclDevFuncTable_4[] = {\n")
|
||||
index4 = 0
|
||||
for fn in primary_funcs:
|
||||
coll, algo, proto, redop, ty, unroll = fn
|
||||
if unroll != "4": continue
|
||||
sym = paste("_", "ncclDevFunc", *fn)
|
||||
if fn[2] == "LL128":
|
||||
out("#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) && defined(ENABLE_LL128)\n")
|
||||
out("/*%4d*/ %s,\n#else\n" % (index4, sym))
|
||||
fn_ll = fn[:2] + ("LL",) + fn[3:]
|
||||
sym_ll = paste("_", "ncclDevFunc", *fn_ll)
|
||||
out("/*%4d*/ %s,\n#endif\n" % (index4, sym_ll))
|
||||
else:
|
||||
out("/*%4d*/ %s,\n" % (index4, sym))
|
||||
index4 += 1
|
||||
out("nullptr};\n")
|
||||
out("\n")
|
||||
|
||||
if not is_ifc:
|
||||
out("template<int unroll, unsigned short f, unsigned short l>\n"
|
||||
"struct Caller {\n"
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller1 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call(unsigned short funcIndex) noexcept\n"
|
||||
" void call1(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller<unroll, f, m>::call(funcIndex) : Caller<unroll, m, l>::call(funcIndex);\n"
|
||||
" return (funcIndex < m) ? Caller1<f, m>::call1(funcIndex) : Caller1<m, l>::call1(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n")
|
||||
|
||||
for curr_unroll in seen_unroll:
|
||||
out("template<unsigned short f>\n")
|
||||
out("struct Caller<%s, f, f + 1>{\n" % curr_unroll)
|
||||
out(" static __forceinline__ __device__ __host__\n");
|
||||
out(" void call(unsigned short funcIndex) noexcept { ncclDevFuncTable_%s[f](); }\n" % curr_unroll)
|
||||
out("};\n")
|
||||
|
||||
out("\n")
|
||||
# Create NCCL_CALL_FUNCTION helper function that will call the appropriate device function
|
||||
out("template <int unroll>\n"
|
||||
"__forceinline__ __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {\n")
|
||||
if is_ifc:
|
||||
for curr_unroll in seen_unroll:
|
||||
out(" if (unroll == %s) { ncclDevFuncTable_%s[funcIndex](); }\n" % (curr_unroll, curr_unroll))
|
||||
else:
|
||||
out(f" Caller<unroll, 0, {tableIdx}>::call(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
|
||||
# Create RCCL
|
||||
out("template<int SpecializedFnId, typename SpecializedRunWorkBatch, bool COLLTRACE, int COLL_UNROLL>\n");
|
||||
out("__device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* args);\n\n");
|
||||
|
||||
out("struct RunWorkNop {\n");
|
||||
out(" __device__ void run() {}\n");
|
||||
out("};\n\n");
|
||||
|
||||
out("template <int UNROLL, bool COLLTRACE>\n"
|
||||
"__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void rcclGenericKernel(ncclDevKernelArgs4K const args4K) {\n"
|
||||
" ncclKernelMain<-1, RunWorkNop, COLLTRACE, UNROLL>(&args4K.args);\n"
|
||||
"}\n\n")
|
||||
|
||||
out("struct rcclKernelItem {\n");
|
||||
out(" void* funcPtr;\n");
|
||||
out(" int unroll;\n");
|
||||
out("};\n\n");
|
||||
|
||||
out("/* This table contains all the __global__ functions that were compiled */\n");
|
||||
out("static struct rcclKernelItem rcclKernelTable[] = {\n")
|
||||
for unroll in seen_unroll:
|
||||
out(" {(void*)&(rcclGenericKernel<%s, false>), %s},\n" % (unroll, unroll))
|
||||
out("#ifdef ENABLE_COLLTRACE\n")
|
||||
for unroll in seen_unroll:
|
||||
out(" {(void*)&(rcclGenericKernel<%s, true>), %s},\n" % (unroll, unroll))
|
||||
out("#endif\n");
|
||||
out("};\n\n");
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller1<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call1(unsigned short funcIndex) noexcept { ncclDevFuncTable_1[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_1(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller1<0, {index1}>::call1(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller2 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call2(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller2<f, m>::call2(funcIndex) : Caller2<m, l>::call2(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller2<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call2(unsigned short funcIndex) noexcept { ncclDevFuncTable_2[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_2(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller2<0, {index2}>::call2(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
out("template<unsigned short f, unsigned short l>\n"
|
||||
"struct Caller4 {\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept\n"
|
||||
" {\n"
|
||||
" constexpr unsigned short m = f + (l - f) / 2;\n"
|
||||
" return (funcIndex < m) ? Caller4<f, m>::call4(funcIndex) : Caller4<m, l>::call4(funcIndex);\n"
|
||||
" }\n"
|
||||
"};\n"
|
||||
"\n"
|
||||
"template<unsigned short f>\n"
|
||||
"struct Caller4<f, f + 1>{\n"
|
||||
" static __forceinline__ __device__ __host__\n"
|
||||
" void call4(unsigned short funcIndex) noexcept { ncclDevFuncTable_4[f](); }\n"
|
||||
"};\n")
|
||||
out("__forceinline__ __device__ void NCCL_CALL_FUNCTIONS_4(unsigned short funcIndex) noexcept {\n")
|
||||
out(f" Caller4<0, {index4}>::call4(funcIndex);\n")
|
||||
out("}\n\n")
|
||||
|
||||
# Generate <gensrc>/device_table.cpp
|
||||
if is_colltrace:
|
||||
@@ -374,7 +436,7 @@ if is_colltrace:
|
||||
out = f.write
|
||||
out('#include "nccl_common.h"\n#include "device.h"\n')
|
||||
out("\n")
|
||||
|
||||
|
||||
seen_fns = set()
|
||||
out("const char* funcNames[FUNC_INDEX_TOTAL] = {\n")
|
||||
for fn in primary_funcs:
|
||||
@@ -397,13 +459,10 @@ with open(os.path.join(gensrc, "host_table.cpp"), "w") as f:
|
||||
# The mapping from function rows to valid primary function ids.
|
||||
out("extern int const ncclDevFuncRowToId[] = {\n")
|
||||
index = 0
|
||||
offset = len(func_rows)//len(all_unroll)
|
||||
start = all_unroll.index(seen_unroll[0]) * offset
|
||||
end = start + offset
|
||||
for fn in func_rows[start:end]:
|
||||
for fn in func_rows[:len(func_rows)//3]:
|
||||
fn_id, comment = -1, ""
|
||||
if fn is not None:
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)] % offset if primary_to_index[equivalent_primary(*fn)] != -1 else -1
|
||||
fn_id = primary_to_index[equivalent_primary(*fn)]
|
||||
comment = " // " + paste(" ", *fn[:-1])
|
||||
out("/*%4d*/ %d,%s\n" % (index, fn_id, comment))
|
||||
index += 1
|
||||
|
||||
@@ -18,7 +18,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>:
|
||||
// This is because of a recv buffer which is allocated to MaxRecv length in send-only cases
|
||||
static constexpr int MaxRecv = Fan::MaxRecv > 1 ? Fan::MaxRecv : 1;
|
||||
static constexpr int MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
static constexpr int Input=0, Output=1, Acc=2;
|
||||
RedOp redOp;
|
||||
const int tid;
|
||||
const int nthreads;
|
||||
@@ -26,7 +26,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload>:
|
||||
const int group;
|
||||
const int stepLines;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
T *userBufs[3];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
volatile uint64_t* recvConnHeadPtr = NULL;
|
||||
uint64_t recvConnHead;
|
||||
@@ -436,6 +436,7 @@ private:
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstElts = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
T *accElts = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
|
||||
|
||||
// Always waitSend in case of cleanup
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
@@ -460,14 +461,15 @@ private:
|
||||
nelem -= tid*EltPerLine;
|
||||
srcElts += tid*EltPerLine;
|
||||
dstElts += tid*EltPerLine;
|
||||
if (accElts != nullptr) accElts += tid*EltPerLine;
|
||||
int offset = tid;
|
||||
int eltPerTrip = nthreads*EltPerLine;
|
||||
while (nelem > 0) {
|
||||
int eltInLine = EltPerLine < nelem ? EltPerLine : nelem;
|
||||
|
||||
DataLoader dl;
|
||||
DataLoader dl, accdl;
|
||||
ncclLLFifoLine line[MaxRecv];
|
||||
uint64_t data, peerData;
|
||||
uint64_t data, peerData, accData;
|
||||
if (SRC) {
|
||||
dl.loadBegin(srcElts, eltInLine);
|
||||
srcElts += eltPerTrip;
|
||||
@@ -502,7 +504,14 @@ private:
|
||||
storeLL(sendPtr(0)+offset, data, sendFlag(0));
|
||||
}
|
||||
if (DST) {
|
||||
storeData(dstElts, data, eltInLine);
|
||||
if (accElts != nullptr) {
|
||||
accdl.loadBegin(accElts, eltInLine);
|
||||
accElts += eltPerTrip;
|
||||
accData = accdl.loadFinish();
|
||||
storeData(dstElts, applyReduce(redOp, accData, data), eltInLine);
|
||||
} else {
|
||||
storeData(dstElts, data, eltInLine);
|
||||
}
|
||||
dstElts += eltPerTrip;
|
||||
}
|
||||
nelem -= eltPerTrip;
|
||||
@@ -672,7 +681,7 @@ public:
|
||||
loadRecvSync();
|
||||
// coverity[var_deref_model:FALSE]
|
||||
loadSendSync();
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr);
|
||||
}
|
||||
|
||||
__device__ ~Primitives() {
|
||||
@@ -685,9 +694,10 @@ public:
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, void const *acc = nullptr) {
|
||||
userBufs[Input] = (T*)inputBuf;
|
||||
userBufs[Output] = (T*)outputBuf;
|
||||
userBufs[Acc] = (T*)acc;
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
|
||||
@@ -25,7 +25,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>> {
|
||||
|
||||
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
|
||||
static constexpr int Input=0, Output=1;
|
||||
static constexpr int Input=0, Output=1, Acc=2;;
|
||||
RedOp redOp;
|
||||
const int tid;
|
||||
const int nthreads;
|
||||
@@ -36,7 +36,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload>:
|
||||
const bool flagThread;
|
||||
const int group;
|
||||
Fan fan;
|
||||
T *userBufs[2];
|
||||
T *userBufs[3];
|
||||
struct ncclConnInfo* recvConn = NULL;
|
||||
volatile uint64_t* recvConnHeadPtr = NULL;
|
||||
uint64_t recvConnHead;
|
||||
@@ -347,6 +347,7 @@ private:
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx;
|
||||
T *accPtr = (DstBuf == -1 || userBufs[Acc] == nullptr) ? nullptr : userBufs[Acc] + dstIx;
|
||||
int wireOffset = WireWordPerSlice*warp + 2*wid;
|
||||
const int nwarps = nthreads/WARP_SIZE;
|
||||
nelem = nelem < 0 ? 0 : nelem;
|
||||
@@ -356,12 +357,25 @@ private:
|
||||
nelem -= DataEltPerSlice*warp;
|
||||
srcPtr += DataEltPerSlice*warp;
|
||||
dstPtr += DataEltPerSlice*warp;
|
||||
if (accPtr != nullptr) accPtr += DataEltPerSlice*warp;
|
||||
while (nelem > 0) {
|
||||
const int eltInSlice = min(nelem, DataEltPerSlice);
|
||||
uint64_t regs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
|
||||
if (SRC) loadRegsBegin(regs, srcPtr, eltInSlice);
|
||||
recvReduceSendCopy<NCCL_LL128_SHMEM_ELEMS_PER_THREAD, RECV, SEND, SrcBuf, DstBuf>(regs, wireOffset, postOp);
|
||||
if (DST) storeRegs(dstPtr, regs, eltInSlice);
|
||||
if (DST) {
|
||||
if (accPtr != nullptr) {
|
||||
uint64_t accRegs[NCCL_LL128_SHMEM_ELEMS_PER_THREAD];
|
||||
loadRegsBegin(accRegs, accPtr, eltInSlice);
|
||||
loadRegsFinish(accRegs);
|
||||
accPtr += DataEltPerSlice*nwarps;
|
||||
#pragma unroll
|
||||
for (int u=0; u<NCCL_LL128_SHMEM_ELEMS_PER_THREAD; u++) {
|
||||
regs[u] = applyReduce(redOp, accRegs[u], regs[u]);
|
||||
}
|
||||
}
|
||||
storeRegs(dstPtr, regs, eltInSlice);
|
||||
}
|
||||
|
||||
wireOffset += WireWordPerSlice*nwarps;
|
||||
srcPtr += DataEltPerSlice*nwarps;
|
||||
@@ -529,7 +543,7 @@ public:
|
||||
loadRecvSync();
|
||||
// coverity[var_deref_model:FALSE]
|
||||
loadSendSync();
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr);
|
||||
}
|
||||
|
||||
__device__ ~Primitives() {
|
||||
@@ -542,9 +556,10 @@ public:
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, void const *acc = nullptr) {
|
||||
userBufs[Input] = (T*)inputBuf;
|
||||
userBufs[Output] = (T*)outputBuf;
|
||||
userBufs[Acc] = (T*)acc;
|
||||
}
|
||||
|
||||
__device__ void moveDataPtrs(intptr_t delta) {
|
||||
|
||||
@@ -265,6 +265,8 @@ private:
|
||||
T* userOutput = (T*)ncclShmem.groups[group].userOutput;
|
||||
if (Src) ncclShmem.groups[group].srcs[0] = (SrcBuf==Input ? userInput : userOutput) + srcIx + offset;
|
||||
if (Dst) ncclShmem.groups[group].dsts[0] = (DstBuf==Input ? userInput : userOutput) + dstIx + offset;
|
||||
T* userAcc = (T*)ncclShmem.groups[group].userAcc;
|
||||
ncclShmem.groups[group].acc = (Dst && userAcc != nullptr) ? userAcc + dstIx + offset : nullptr;
|
||||
}
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(srcIx, dstIx, offset, sliceSize);
|
||||
subBarrier();
|
||||
@@ -385,7 +387,7 @@ private:
|
||||
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
|
||||
Recv * fan.nrecv() + Src, ncclShmem.groups[group].srcs,
|
||||
Send * fan.nsend() + Dst, ncclShmem.groups[group].dsts,
|
||||
workSize);
|
||||
workSize, ncclShmem.groups[group].acc);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
|
||||
@@ -827,7 +829,7 @@ public:
|
||||
|
||||
// coverity[negative_returns:FALSE] => coverity thinks that index could be -1 but that's not actually the case
|
||||
// coverity[var_deref_model] => coverity thinks work can dereferenced if NULL but this is not the case
|
||||
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer);
|
||||
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclDevWorkCollReg*)collWork, sendIpcReg || recvIpcReg, peer, collWork != nullptr ? collWork->acc : nullptr);
|
||||
// coverity[uninit_member] => coverity thinks fan.n is not initialized
|
||||
} else if (mode == primsModePatRs || mode == primsModePatAg) { // Connect to all ranks +/- 2^n
|
||||
flags |= PatMode;
|
||||
@@ -902,10 +904,11 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer) {
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclDevWorkCollReg* work, uint8_t ipcReg, int peer, void const *acc) {
|
||||
if (tid==0) {
|
||||
ncclShmem.groups[group].userInput = (void*)inputBuf;
|
||||
ncclShmem.groups[group].userOutput = (void*)outputBuf;
|
||||
ncclShmem.groups[group].userAcc = (void*)acc;
|
||||
ncclShmem.redOpArgs[0] = redOpArg; // scaler for local input
|
||||
}
|
||||
|
||||
@@ -1023,7 +1026,7 @@ public:
|
||||
}
|
||||
|
||||
// Set MSCCL data pointers
|
||||
__device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
__device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf = nullptr) {
|
||||
if (tid==0) {
|
||||
ncclShmem.groups[group].userInput = (T*)inputBuf;
|
||||
ncclShmem.groups[group].userOutput = (T*)outputBuf;
|
||||
|
||||
@@ -28,31 +28,30 @@
|
||||
|
||||
using namespace rccl;
|
||||
|
||||
/* [RCCL] Determine which GPU kernel to execute */
|
||||
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task)
|
||||
{
|
||||
// At this time, unroll factor is controlled only by passed in unroll argument
|
||||
// After more investigation, this may be further tuned by the actual task being processed
|
||||
struct ncclKernelMatch {
|
||||
void* kernelFn;
|
||||
bool specialized;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]) / 2;
|
||||
int firstKernel = useCollTrace ? numKernels : 0;
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll + ((p_comm)->collTraceEnabled ? 3 : 0))
|
||||
static ncclKernelMatch const ncclKerns[6] = {
|
||||
{(void *)ncclDevKernel_Generic_1, true},
|
||||
{(void *)ncclDevKernel_Generic_2, true},
|
||||
{(void *)ncclDevKernel_Generic_4, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_1, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_2, true},
|
||||
{(void *)ncclDevKernelDebug_Generic_4, true}
|
||||
};
|
||||
#else
|
||||
int numKernels = sizeof(rcclKernelTable) / sizeof(rcclKernelTable[0]);
|
||||
int firstKernel = 0;
|
||||
#define ncclGetKernelIndex(p_comm) ((p_comm)->unroll)
|
||||
static ncclKernelMatch const ncclKerns[3] = {
|
||||
{(void*)ncclDevKernel_Generic_1, true},
|
||||
{(void*)ncclDevKernel_Generic_2, true},
|
||||
{(void*)ncclDevKernel_Generic_4, true}
|
||||
};
|
||||
#endif
|
||||
|
||||
// Check if the requested unroll exists
|
||||
for (int kernelIdx = 0; kernelIdx < numKernels; kernelIdx++) {
|
||||
if (rcclKernelTable[firstKernel + kernelIdx].unroll == unroll) {
|
||||
return rcclKernelTable[firstKernel + kernelIdx].funcPtr;
|
||||
}
|
||||
}
|
||||
|
||||
// If does not match, return null
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int rcclProtoGrainSize(int proto, ncclComm *comm){
|
||||
switch (proto) {
|
||||
case NCCL_PROTO_LL: return 16;
|
||||
@@ -82,7 +81,7 @@ NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0);
|
||||
|
||||
// Returns maximum kernel stack size of all CUDA kernels
|
||||
ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize) {
|
||||
constexpr int KernelCount = sizeof(rcclKernelTable)/sizeof(rcclKernelTable[0]);
|
||||
constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]);
|
||||
ncclResult_t result = ncclSuccess;
|
||||
|
||||
if (maxStackSize) *maxStackSize = 0;
|
||||
@@ -95,7 +94,7 @@ ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* ma
|
||||
int ncclMaxSharedMem = rcclShmemDynamicSize(cudaArch, WarpSize);
|
||||
|
||||
for (int k=0; k < KernelCount; k++) {
|
||||
void* fn = rcclKernelTable[k].funcPtr;
|
||||
void* fn = ncclKerns[k].kernelFn;
|
||||
cudaFuncAttributes attr = {0};
|
||||
if (fn == nullptr) continue;
|
||||
|
||||
@@ -356,6 +355,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) {
|
||||
|
||||
devWork.sendbuff = (void*)task->sendbuff;
|
||||
devWork.recvbuff = (void*)task->recvbuff;
|
||||
devWork.acc = (void*)task->acc;
|
||||
devWork.sendbuffOffset = task->sendbuffOffset;
|
||||
devWork.recvbuffOffset = task->recvbuffOffset;
|
||||
devWork.sendbuffRmtAddrs = task->sendbuffRmtAddrs;
|
||||
@@ -806,12 +806,8 @@ static ncclResult_t scheduleCollTasksToPlan(
|
||||
//plan->channelMask.masks[channelId/64] |= (2ull<<devWork->channelHi) - (1ull<<devWork->channelLo);
|
||||
plan->threadPerBlock = std::max(plan->threadPerBlock, 192 /* 3*WARP_SIZE */);
|
||||
if (!plan->kernelSpecialized) {
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled);
|
||||
#else
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, false);
|
||||
#endif
|
||||
plan->kernelSpecialized = true;
|
||||
plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn;
|
||||
plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized;
|
||||
}
|
||||
|
||||
if (comm->rank == 0) {
|
||||
@@ -1127,12 +1123,8 @@ static ncclResult_t scheduleP2pTasksToPlan(
|
||||
|
||||
plan->threadPerBlock = std::max(plan->threadPerBlock, NCCL_MAX_NTHREADS);
|
||||
if (!plan->kernelSpecialized) {
|
||||
#ifdef ENABLE_COLLTRACE
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, comm->collTraceEnabled);
|
||||
#else
|
||||
plan->kernelFn = rcclGetKernelIndex(comm->unroll, false);
|
||||
#endif
|
||||
plan->kernelSpecialized = true;
|
||||
plan->kernelFn = ncclKerns[ncclGetKernelIndex(comm)].kernelFn;
|
||||
plan->kernelSpecialized = ncclKerns[ncclGetKernelIndex(comm)].specialized;
|
||||
}
|
||||
|
||||
// Compute how much to split operations
|
||||
@@ -2505,6 +2497,7 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) {
|
||||
t->sliceSteps = info->sliceSteps;
|
||||
t->eActivationMask = __atomic_load_n(&ncclProfilerEventMask, __ATOMIC_RELAXED);
|
||||
t->opCount = comm->opCount;
|
||||
t->acc = info->acc;
|
||||
|
||||
planner->nTasksColl += 1;
|
||||
ncclTaskCollSorterInsert(&planner->collSorter, t, t->trafficBytes);
|
||||
@@ -2554,8 +2547,8 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
|
||||
}
|
||||
NCCLCHECKGOTO(ArgsCheck(info), ret, fail);
|
||||
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d",
|
||||
info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->count,
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p acc %p count %zu datatype %d op %d root %d comm %p [nranks=%d] stream %p task %d globalrank %d",
|
||||
info->opName, info->comm->opCount, info->sendbuff, info->recvbuff, info->acc, info->count,
|
||||
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream,
|
||||
info->comm->planner.nTasksP2p + info->comm->planner.nTasksColl,
|
||||
info->comm->localRankToRank[info->comm->localRank]);
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
#define RCCL_API_TRACE_VERSION_MAJOR 0
|
||||
|
||||
// should be increased every time new members are added to existing dispatch tables
|
||||
#define RCCL_API_TRACE_VERSION_PATCH 0
|
||||
#define RCCL_API_TRACE_VERSION_PATCH 1
|
||||
|
||||
#if !defined(RCCL_EXTERN_C_INIT)
|
||||
# ifdef __cplusplus
|
||||
@@ -61,6 +61,10 @@ typedef ncclResult_t (*ncclAllReduce_fn_t)(const void* sendbuff, void* recvbuff,
|
||||
size_t count, ncclDataType_t datatype,
|
||||
ncclRedOp_t op, struct ncclComm* comm,
|
||||
hipStream_t stream);
|
||||
typedef ncclResult_t (*ncclAllReduceWithBias_fn_t)(const void* sendbuff, void* recvbuff,
|
||||
size_t count, ncclDataType_t datatype,
|
||||
ncclRedOp_t op, struct ncclComm* comm,
|
||||
hipStream_t stream, const void* acc);
|
||||
typedef ncclResult_t (*ncclAllToAll_fn_t)(const void* sendbuff, void* recvbuff,
|
||||
size_t count, ncclDataType_t datatype,
|
||||
ncclComm_t comm, hipStream_t stream);
|
||||
@@ -194,6 +198,7 @@ typedef struct rcclApiFuncTable
|
||||
mscclUnloadAlgo_fn_t mscclUnloadAlgo_fn;
|
||||
ncclCommRegister_fn_t ncclCommRegister_fn;
|
||||
ncclCommDeregister_fn_t ncclCommDeregister_fn;
|
||||
ncclAllReduceWithBias_fn_t ncclAllReduceWithBias_fn;
|
||||
|
||||
} rcclApiFuncTable;
|
||||
|
||||
|
||||
@@ -195,6 +195,7 @@ struct ncclTaskColl {
|
||||
ncclFunc_t func;
|
||||
void const* sendbuff;
|
||||
void* recvbuff;
|
||||
void const* acc;
|
||||
size_t count;
|
||||
int root;
|
||||
ncclDataType_t datatype;
|
||||
|
||||
@@ -310,6 +310,7 @@ struct alignas(16) ncclDevWorkColl {
|
||||
uint16_t pivotA2ANumBiRings:15, profilerEnabled:1;
|
||||
void* recvbuff;
|
||||
void* sendbuff;
|
||||
void *acc;
|
||||
uintptr_t sendbuffOffset;
|
||||
uintptr_t recvbuffOffset;
|
||||
uintptr_t* sendbuffRmtAddrs;
|
||||
|
||||
@@ -29,6 +29,7 @@ struct ncclInfo {
|
||||
// Algorithm details
|
||||
int chunkSteps;
|
||||
int sliceSteps;
|
||||
const void* acc;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -74,5 +74,10 @@ typedef enum {
|
||||
|
||||
#define NCCL_ALGO_PROTO_IGNORE -1.0
|
||||
|
||||
#define NCCL_NUM_UNROLLS 3 // 1/2/4
|
||||
#define NCCL_UNROLL_1 0
|
||||
#define NCCL_UNROLL_2 1
|
||||
#define NCCL_UNROLL_4 2
|
||||
|
||||
#define NCCL_NUM_FLOATS 6 // half/float/double/rccl_bfloat16/rccl_float8/rccl_bfloat8
|
||||
#endif
|
||||
|
||||
@@ -16,6 +16,7 @@ typedef enum {
|
||||
rrAllGather,
|
||||
rrReduceScatter,
|
||||
rrAllReduce,
|
||||
rrAllReduceWithBias,
|
||||
rrSend,
|
||||
rrRecv,
|
||||
rrAllToAll,
|
||||
@@ -51,6 +52,7 @@ constexpr const char* rcclCallStr[]
|
||||
"AllGather",
|
||||
"ReduceScatter",
|
||||
"AllReduce",
|
||||
"AllReduceWithBias",
|
||||
"Send",
|
||||
"Recv",
|
||||
"AllToAll",
|
||||
@@ -94,6 +96,7 @@ struct rcclApiCall {
|
||||
uint64_t opCount = 0;
|
||||
const void* sendbuff = NULL;
|
||||
void* recvbuff = NULL;
|
||||
const void* acc = NULL;
|
||||
void* sendPtrBase = NULL;
|
||||
void* recvPtrBase = NULL;
|
||||
size_t sendPtrExtent = 0;
|
||||
|
||||
@@ -610,6 +610,8 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in
|
||||
|
||||
// RCCL: create persistent stream for calloc
|
||||
CUDACHECK(hipStreamCreateWithFlags(&comm->sideStream, hipStreamNonBlocking));
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
comm->checkPointers = ncclParamCheckPointers() == 1 ? true : false;
|
||||
comm->dmaBufSupport = (dmaBufSupported(comm) == ncclSuccess) ? true : false;
|
||||
|
||||
@@ -1958,9 +1960,6 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
|
||||
|
||||
NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail);
|
||||
|
||||
// RCCL: determine and set unroll factor for comm
|
||||
NCCLCHECK(commSetUnrollFactor(comm));
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
if (job->parent) {
|
||||
if (job->parent->mscclppCompatible) {
|
||||
|
||||
@@ -153,6 +153,11 @@ ncclCommRegister_impl(const ncclComm_t comm, void* buff, size_t size, void** han
|
||||
ncclResult_t
|
||||
ncclCommDeregister_impl(const ncclComm_t comm, void* handle);
|
||||
|
||||
ncclResult_t
|
||||
ncclAllReduceWithBias_impl(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm,
|
||||
cudaStream_t stream, const void* acc);
|
||||
|
||||
namespace rccl
|
||||
{
|
||||
namespace
|
||||
@@ -211,10 +216,11 @@ RCCL_ASSERT_OFFSET(rcclApiFuncTable, mscclRunAlgo_fn, 33);
|
||||
RCCL_ASSERT_OFFSET(rcclApiFuncTable, mscclUnloadAlgo_fn, 34);
|
||||
RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclCommRegister_fn, 35);
|
||||
RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclCommDeregister_fn, 36);
|
||||
RCCL_ASSERT_OFFSET(rcclApiFuncTable, ncclAllReduceWithBias_fn, 37);
|
||||
|
||||
#undef RCCL_ASSERT_OFFSET
|
||||
|
||||
static_assert(sizeof(rcclApiFuncTable) == compute_table_size(37),
|
||||
static_assert(sizeof(rcclApiFuncTable) == compute_table_size(38),
|
||||
"Update table major/step version and add a new offset assertion if this "
|
||||
"fails to compile");
|
||||
|
||||
@@ -261,7 +267,8 @@ RcclGetFunctionTable_impl()
|
||||
&mscclRunAlgo_impl,
|
||||
&mscclUnloadAlgo_impl,
|
||||
&ncclCommRegister_impl,
|
||||
&ncclCommDeregister_impl };
|
||||
&ncclCommDeregister_impl,
|
||||
&ncclAllReduceWithBias_impl };
|
||||
|
||||
#if defined(RCCL_ROCPROFILER_REGISTER) && RCCL_ROCPROFILER_REGISTER > 0
|
||||
std::array<void*, 1> table_array{ tbl };
|
||||
@@ -301,6 +308,9 @@ NCCL_API(ncclResult_t, ncclAllGather, const void* sendbuff, void* recvbuff,
|
||||
NCCL_API(ncclResult_t, ncclAllReduce, const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, hipStream_t stream);
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllReduceWithBias, const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm* comm, hipStream_t stream, const void* acc);
|
||||
|
||||
NCCL_API(ncclResult_t, ncclAllToAll, const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclComm_t comm, hipStream_t stream);
|
||||
|
||||
@@ -411,6 +421,14 @@ ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t
|
||||
datatype, op, comm, stream);
|
||||
}
|
||||
|
||||
ncclResult_t
|
||||
ncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
ncclRedOp_t op, ncclComm* comm, cudaStream_t stream, const void* acc)
|
||||
{
|
||||
return ::rccl::RcclGetFunctionTable()->ncclAllReduceWithBias_fn(sendbuff, recvbuff, count,
|
||||
datatype, op, comm, stream, acc);
|
||||
}
|
||||
|
||||
ncclResult_t
|
||||
ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
|
||||
ncclComm_t comm, hipStream_t stream)
|
||||
|
||||
@@ -35,6 +35,7 @@ rcclApiCall::rcclApiCall(rcclCall_t type, const ncclInfo& info)://name(rcclCallS
|
||||
opCount(info.comm->opCount),
|
||||
sendbuff(info.sendbuff),
|
||||
recvbuff(info.recvbuff),
|
||||
acc(info.acc),
|
||||
count(info.count),
|
||||
datatype(info.datatype),
|
||||
op(info.op),
|
||||
@@ -69,7 +70,7 @@ std::string alloc_fmt = "%s : [returned ptr : %p, size : %zu, context : [";
|
||||
std::string free_fmt = "%s : [ptr : %p, context : [";
|
||||
std::string redop_fmt = "%s : [scalar : %p, datatype : %d, op : %d, residence : %d, comm : %p, context : [";
|
||||
std::string redopdestroy_fmt = "%s : [op : %d, comm : %p, context : [";
|
||||
std::string coll_fmt = "%s : [opCount : %lx, sendbuff : [addr : %p, base : %p, size : %zu], recvbuff : [addr : %p, base : %p, size : %zu], count : %zu, datatype : %d, op : %d, root : %d, comm : %p, nranks : %d, stream : %p, task : %d, globalrank : %d, context : [";
|
||||
std::string coll_fmt = "%s : [opCount : %lx, sendbuff : [addr : %p, base : %p, size : %zu], recvbuff : [addr : %p, base : %p, size : %zu], acc : %p, count : %zu, datatype : %d, op : %d, root : %d, comm : %p, nranks : %d, stream : %p, task : %d, globalrank : %d, context : [";
|
||||
|
||||
Recorder::Recorder()
|
||||
{
|
||||
@@ -256,7 +257,7 @@ void Recorder::write(const rcclApiCall &call)
|
||||
default: // collectives
|
||||
len = snprintf(buffer, 4096, coll_fmt.c_str(),
|
||||
rcclCallStr[call.type], call.opCount, call.sendbuff, call.sendPtrBase, call.sendPtrExtent,
|
||||
call.recvbuff, call.recvPtrBase, call.recvPtrExtent, call.count, call.datatype,
|
||||
call.recvbuff, call.recvPtrBase, call.recvPtrExtent, call.acc, call.count, call.datatype,
|
||||
call.op, call.root, call.comm, call.nRanks, call.stream, call.nTasks, call.globalRank);
|
||||
|
||||
}
|
||||
@@ -686,9 +687,9 @@ void parseJsonEntry(const char* entry, std::vector<rcclApiCall>& calls)
|
||||
default:
|
||||
assert(sscanf(str.c_str() + end + 3, (coll_fmt.substr(5) + ctxt_fmt).c_str(),
|
||||
&call.opCount, &call.sendbuff, &call.sendPtrBase, &call.sendPtrExtent, &call.recvbuff, &call.recvPtrBase, &call.recvPtrExtent,
|
||||
&call.count, &call.datatype, &call.op, &call.root,
|
||||
&call.acc, &call.count, &call.datatype, &call.op, &call.root,
|
||||
&call.comm, &call.nRanks, &call.stream, &call.nTasks, &call.globalRank, &call.timestamp, &call.tid,
|
||||
&call.hipDev, &call.graphCaptured, &call.graphID) == 21);
|
||||
&call.hipDev, &call.graphCaptured, &call.graphID) == 22);
|
||||
}
|
||||
calls.push_back(call);
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#define RCCL_FLOAT8 1
|
||||
#define RCCL_GATHER_SCATTER 1
|
||||
#define RCCL_ALLTOALLV 1
|
||||
#define RCCL_ALLREDUCE_WITH_BIAS 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@@ -563,6 +564,27 @@ ncclResult_t pncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream);
|
||||
/*! @endcond */
|
||||
|
||||
/*! @brief All-Reduce-with-Bias
|
||||
@details Reduces data arrays of length *count* in *sendbuff* using *op* operation, and
|
||||
leaves identical copies of result on each *recvbuff*.
|
||||
In-place operation will happen if sendbuff == recvbuff.
|
||||
@return Result code. See @ref rccl_result_code for more details.
|
||||
|
||||
@param[in] sendbuff Input data array to reduce
|
||||
@param[out] recvbuff Data array to store reduced result array
|
||||
@param[in] count Number of elements in data buffer
|
||||
@param[in] datatype Data buffer element datatype
|
||||
@param[in] op Reduction operator
|
||||
@param[in] comm Communicator group object to execute on
|
||||
@param[in] stream HIP stream to execute collective on
|
||||
@param[in] acc Bias data array to reduce */
|
||||
ncclResult_t ncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream, const void* acc);
|
||||
/*! @cond include_hidden */
|
||||
ncclResult_t pncclAllReduceWithBias(const void* sendbuff, void* recvbuff, size_t count,
|
||||
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, hipStream_t stream, const void* acc);
|
||||
/*! @endcond */
|
||||
|
||||
/*! @brief Reduce-Scatter
|
||||
@details Reduces data in *sendbuff* using *op* operation and leaves reduced result
|
||||
scattered over the devices so that *recvbuff* on rank i will contain the i-th
|
||||
|
||||
@@ -127,43 +127,14 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count,
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
//RCCL runtime param to set Unroll Factor
|
||||
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
|
||||
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
|
||||
//If RCCL runtime param is set, it will override defaults, if supported
|
||||
if (rcclParamUnrollFactor() != 0) {
|
||||
#if ENABLE_COLLTRACE
|
||||
if(rcclGetKernelIndex(rcclParamUnrollFactor(), comm->collTraceEnabled)) {
|
||||
#else
|
||||
if(rcclGetKernelIndex(rcclParamUnrollFactor(), false)) {
|
||||
#endif
|
||||
comm->unroll = rcclParamUnrollFactor();
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
|
||||
return ncclSuccess;
|
||||
}
|
||||
else {
|
||||
// Fall back to default unroll
|
||||
WARN("Requested RCCL_UNROLL_FACTOR: %ld is invalid and does not exist in `rcclKernelTable`. Falling back to pre-set unroll.", rcclParamUnrollFactor());
|
||||
}
|
||||
}
|
||||
|
||||
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
|
||||
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
|
||||
if (comm->nNodes == 1)
|
||||
comm->unroll = 1;
|
||||
else
|
||||
comm->unroll = 2;
|
||||
}
|
||||
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
|
||||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
|
||||
//on MI300X and gfx908, use unroll=2
|
||||
comm->unroll = 2;
|
||||
if(IsArchMatch(devProp.gcnArchName, "gfx950"))
|
||||
comm->unroll = NCCL_UNROLL_1;
|
||||
else if(IsArchMatch(devProp.gcnArchName, "gfx908") || ((IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80)))
|
||||
comm->unroll = NCCL_UNROLL_2;
|
||||
else
|
||||
comm->unroll = 4;
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
|
||||
comm->unroll = NCCL_UNROLL_4;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -487,6 +487,13 @@ void Replayer::replay()
|
||||
NCCL_CALL(ncclAllReduce(sbuffer, rbuffer, call.count, call.datatype, call.op, commMap[call.comm], streams[call.stream].first));
|
||||
break;
|
||||
}
|
||||
case rrAllReduceWithBias:
|
||||
{
|
||||
std::vector<char> acc(call.count * ncclTypeSize(call.datatype));
|
||||
NCCL_CALL(ncclAllReduceWithBias(sbuffer, rbuffer, call.count, call.datatype, call.op, commMap[call.comm], streams[call.stream].first, acc.data()));
|
||||
HIP_CALL(hipStreamSynchronize(streams[call.stream].first)); // TODO: remove, and further verify behavior of fused AR
|
||||
break;
|
||||
}
|
||||
// a2av
|
||||
case rrAllToAllv:
|
||||
{
|
||||
@@ -670,4 +677,4 @@ int main(int argc, char **argv)
|
||||
replayer.replay();
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
Ссылка в новой задаче
Block a user