From 2223cccf158519159f08f4175982dae20fa95dbc Mon Sep 17 00:00:00 2001 From: Wenkai Du Date: Thu, 15 Aug 2019 09:16:11 -0700 Subject: [PATCH] Tune LL threshold for VEGA Also move abort check after SPINS_BEFORE_CHECK_ABORT as NCCL --- src/collectives/device/primitives.h | 4 ++-- src/include/enqueue.h | 1 + src/init.cc | 4 ++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index c38341f92c..13429f2850 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -84,8 +84,8 @@ class ncclPrimitives { __device__ int checkAbort(volatile uint64_t* remoteOpCount) { spins++; - abort = LOAD(comm->abortFlag); if (spins == SPINS_BEFORE_CHECK_ABORT) { + abort = LOAD(comm->abortFlag); checkMismatch(remoteOpCount); spins = 0; } @@ -404,8 +404,8 @@ class ncclLLPrimitives { __device__ int checkAbort(volatile uint64_t* remoteOpCount) { spins++; - abort = LOAD(comm->abortFlag); if (spins == SPINS_BEFORE_CHECK_ABORT) { + abort = LOAD(comm->abortFlag); checkMismatch(remoteOpCount); spins = 0; } diff --git a/src/include/enqueue.h b/src/include/enqueue.h index 35d006e512..c40957df91 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -15,6 +15,7 @@ #define NCCL_LL_CHANNEL_THRESHOLD 8 // Per thread size before we start increasing nrings #define NCCL_THREAD_THRESHOLD 256 // Per thread size before we switch to non-LL #define NCCL_THREAD_THRESHOLD_PREVOLTA 32 // Per thread size before we switch to non-LL for pre-Volta archs +#define NCCL_THREAD_THRESHOLD_VEGA 8 // Per thread size before we switch to non-LL for VEGA #define NCCL_LL_MIN_NTHREADS 256 ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); diff --git a/src/init.cc b/src/init.cc index 3025d3b085..23c1f53c2e 100644 --- a/src/init.cc +++ b/src/init.cc @@ -150,7 +150,11 @@ NCCL_PARAM(TreeThreshold, "TREE_THRESHOLD", 0); int ncclThreadThreshold(int minCompCap, int multiNode) { int threshold = ncclParamThreadThreshold(); if (threshold == -2) { // user has not set this env variable +#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) + threshold = NCCL_THREAD_THRESHOLD_VEGA; +#else threshold = (minCompCap <= 6) ? NCCL_THREAD_THRESHOLD_PREVOLTA : NCCL_THREAD_THRESHOLD; +#endif // multiply by 2 if running on multiple nodes if (multiNode) { threshold *= 2;