From 58e53dfd37abdb36523d7d5d92ba1b54571ccc05 Mon Sep 17 00:00:00 2001 From: Nusrat Islam Date: Fri, 7 Jul 2023 14:29:27 -0500 Subject: [PATCH] device: fine tune MI200/MI250 simple protocol performance With Simple protocol, unroll factor of 4 offers better performance for most of the collectives (on MI200. MI250, and MI300) except large message allreduce with Ring algorithm on MI250 and MI200). This PR changes the default unroll factor to 4 while adding fine tuning for reduction operations. --- src/collectives/device/common.h | 4 ---- src/collectives/device/common_kernel.h | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 29a951ab11..d4c89343f2 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -11,11 +11,7 @@ #include "collectives.h" #include "devcomm.h" -#if defined(__gfx940__) #define COLL_UNROLL 4 -#else -#define COLL_UNROLL 2 -#endif #define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index f204aee88e..b6a1a3bc7c 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -214,10 +214,17 @@ __device__ __forceinline__ void reduceCopy( if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane]) % (BigPackSize + !BigPackSize); aligned = !(__any(!aligned)); if (aligned) { +#if defined(__gfx90a__) + reduceCopyPacks 1) ? 2 : Unroll), BigPackSize, + MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs> + (nThreads, thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead); +#else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); +#endif if (nBytesAhead == 0) return; reduceCopyPacks 1) { + reduceCopyPacks + (nThreads, thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead); + } else { + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + } +#else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); +#endif if (nBytesAhead == 0) return; reduceCopyPacks