device: fine tune MI200/MI250 simple protocol performance

With Simple protocol, unroll factor of 4 offers better
performance for most of the collectives (on MI200. MI250, and
MI300) except large message allreduce with Ring algorithm
on MI250 and MI200). This PR changes the default unroll factor
to 4 while adding fine tuning for reduction operations.
Этот коммит содержится в:
Nusrat Islam
2023-07-07 14:29:27 -05:00
родитель 6ef70811d2
Коммит 58e53dfd37
2 изменённых файлов: 21 добавлений и 4 удалений
-4
Просмотреть файл
@@ -11,11 +11,7 @@
#include "collectives.h"
#include "devcomm.h"
#if defined(__gfx940__)
#define COLL_UNROLL 4
#else
#define COLL_UNROLL 2
#endif
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
+21
Просмотреть файл
@@ -214,10 +214,17 @@ __device__ __forceinline__ void reduceCopy(
if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane]) % (BigPackSize + !BigPackSize);
aligned = !(__any(!aligned));
if (aligned) {
#if defined(__gfx90a__)
reduceCopyPacks<RedFn, T, ((MinSrcs > 1) ? 2 : Unroll), BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead);
#else
reduceCopyPacks<RedFn, T, Unroll*((MinSrcs == 1 && MinDsts == 1) ? 2 : 1), BigPackSize,
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
#endif
if (nBytesAhead == 0) return;
reduceCopyPacks<RedFn, T, /*Unroll=*/1, BigPackSize,
@@ -228,10 +235,24 @@ __device__ __forceinline__ void reduceCopy(
}
}
#if defined(__gfx90a__)
if (MinSrcs > 1) {
reduceCopyPacks<RedFn, T, Unroll/2*(16/sizeof(T))/2, sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead);
} else {
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
}
#else
reduceCopyPacks<RedFn, T, Unroll*(16/sizeof(T))/2, /*BytePerPack=*/sizeof(T),
MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs>
(nThreads, /*&*/thread, redArg, preOpArgs, postOp,
nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead);
#endif
if (nBytesAhead == 0) return;
reduceCopyPacks<RedFn, T, /*Unroll=*/1, /*BytePerPack=*/sizeof(T),