diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 29a951ab11..d4c89343f2 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -11,11 +11,7 @@ #include "collectives.h" #include "devcomm.h" -#if defined(__gfx940__) #define COLL_UNROLL 4 -#else -#define COLL_UNROLL 2 -#endif #define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index f204aee88e..b6a1a3bc7c 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -214,10 +214,17 @@ __device__ __forceinline__ void reduceCopy( if (lane < nDsts) aligned &= 0 == cvta_to_global(dstPtrs[lane]) % (BigPackSize + !BigPackSize); aligned = !(__any(!aligned)); if (aligned) { +#if defined(__gfx90a__) + reduceCopyPacks 1) ? 2 : Unroll), BigPackSize, + MultimemSrcs, MinSrcs, MaxSrcs, MultimemDsts, MinDsts, MaxDsts, PreOpSrcs> + (nThreads, thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead); +#else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); +#endif if (nBytesAhead == 0) return; reduceCopyPacks 1) { + reduceCopyPacks + (nThreads, thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, nBytesBehind, nBytesAhead); + } else { + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + } +#else reduceCopyPacks (nThreads, /*&*/thread, redArg, preOpArgs, postOp, nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); +#endif if (nBytesAhead == 0) return; reduceCopyPacks