diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 0ad5ddeef7..436063c1f9 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -128,12 +128,18 @@ static std::nullptr_t ptradd(std::nullptr_t ptr, int i) { return nullptr; } +// use different unroll numbers for all primitives for best throughput +#define COPY_UNROLL 4 +#define REDUCE_UNROLL 2 +#define DOUBLECOPY_UNROLL 2 +#define REDUCECOPY_UNROLL 2 // Implementation of primitive types -template > +template > class Primitives { private: - template // either WaitFunc or PostFunc static __device__ __attribute__((noinline)) void @@ -204,28 +210,28 @@ class Primitives { static __device__ void Copy(const int tid, const int nthreads, const T* src, T* dst, int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { - GenericOp(tid, nthreads, src, nullptr, dst, nullptr, len, maxOffset, step, flags...); + GenericOp(tid, nthreads, src, nullptr, dst, nullptr, len, maxOffset, step, flags...); } template static __device__ void DoubleCopy(const int tid, const int nthreads, const T* src, T* dst1, T* dst2, int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { - GenericOp(tid, nthreads, src, nullptr, dst1, dst2, len, maxOffset, step, flags...); + GenericOp(tid, nthreads, src, nullptr, dst1, dst2, len, maxOffset, step, flags...); } template static __device__ void Reduce(const int tid, const int nthreads, const T* src1, const T* src2, T* dst, int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { - GenericOp(tid, nthreads, src1, src2, dst, nullptr, len, maxOffset, step, flags...); + GenericOp(tid, nthreads, src1, src2, dst, nullptr, len, maxOffset, step, flags...); } template static __device__ void ReduceCopy(const int tid, const int nthreads, const T* src1, const T* src2, T* dst1, T* dst2, int len, int maxOffset, uint64_t step, SYNC_Ts... flags) { - GenericOp(tid, nthreads, src1, src2, dst1, dst2, len, maxOffset, step, flags...); + GenericOp(tid, nthreads, src1, src2, dst1, dst2, len, maxOffset, step, flags...); } };