From b797b62f6b9ef6f05b2cc690a287dd9841fa3757 Mon Sep 17 00:00:00 2001 From: Nilesh M Negi Date: Mon, 9 Jun 2025 01:26:07 -0500 Subject: [PATCH] [DEVICE] Use threadfence on gfx950 for LL protocol (#1686) Signed-off-by: nileshnegi [ROCm/rccl commit: b926203c0553e812f55513e14b751527c8d40ff4] --- projects/rccl/src/device/primitives.h | 11 ++++------- projects/rccl/src/device/prims_ll.h | 6 +++++- projects/rccl/src/device/prims_ll128.h | 6 +++++- projects/rccl/src/device/prims_simple.h | 6 +++++- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/projects/rccl/src/device/primitives.h b/projects/rccl/src/device/primitives.h index 3ef9fd6126..c0536f1cf4 100644 --- a/projects/rccl/src/device/primitives.h +++ b/projects/rccl/src/device/primitives.h @@ -15,13 +15,7 @@ #define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 -#if defined(__gfx942__) || defined(__gfx950__) -#define __THREAD_FENCE __threadfence_block() -#else -#define __THREAD_FENCE __threadfence() -#endif - -#define barrier_by_group() do { \ +#define barrier_by_group_common(__THREAD_FENCE) do { \ if (nthreads == NCCL_MAX_NTHREADS) { \ __THREAD_FENCE; __builtin_amdgcn_s_barrier(); \ } else { \ @@ -53,6 +47,9 @@ } \ } while (0) +#define barrier_by_group() barrier_by_group_common(__threadfence()) +#define barrier_by_group_block() barrier_by_group_common(__threadfence_block()) + /* Protocol classes: ProtoSimple, ProtoLL, ProtoLL128 * We use these as template args to the Primtiives class instead of integral * enums (e.g. NCCL_PROTO_LL) because for SIMPLE we need to carry a few extra diff --git a/projects/rccl/src/device/prims_ll.h b/projects/rccl/src/device/prims_ll.h index 387348942c..4a2c2b9c8b 100644 --- a/projects/rccl/src/device/prims_ll.h +++ b/projects/rccl/src/device/prims_ll.h @@ -71,7 +71,11 @@ private: inline __device__ void barrier() { #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) if (nthreads != WARP_SIZE) - barrier_by_group(); + #if defined(__gfx942__) + barrier_by_group_block(); + #else + barrier_by_group(); + #endif #else if (nthreads == WARP_SIZE) { __syncwarp(); diff --git a/projects/rccl/src/device/prims_ll128.h b/projects/rccl/src/device/prims_ll128.h index bcc72d075c..6758b78808 100644 --- a/projects/rccl/src/device/prims_ll128.h +++ b/projects/rccl/src/device/prims_ll128.h @@ -76,7 +76,11 @@ private: inline __device__ void barrier() { #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) if (nthreads != WARP_SIZE) - barrier_by_group(); + #if defined(__gfx942__) || defined(__gfx950__) + barrier_by_group_block(); + #else + barrier_by_group(); + #endif #else barrier_sync(15-group, nthreads); #endif diff --git a/projects/rccl/src/device/prims_simple.h b/projects/rccl/src/device/prims_simple.h index e0d89a9620..270d01d827 100644 --- a/projects/rccl/src/device/prims_simple.h +++ b/projects/rccl/src/device/prims_simple.h @@ -79,7 +79,11 @@ private: if (nthreads == WARP_SIZE) __syncwarp(); else - barrier_by_group(); + #if defined(__gfx942__) || defined(__gfx950__) + barrier_by_group_block(); + #else + barrier_by_group(); + #endif } inline __device__ void subBarrier() { barrier();