From 11fabf1de163ebac72f1860395b7a72a919cd39d Mon Sep 17 00:00:00 2001 From: alex-breslow-amd Date: Tue, 22 Jul 2025 07:15:15 -0700 Subject: [PATCH] Cheaper threadfence for gfx942 in postPeer [1/N]: enable for single node allreduce (#1766) Boosts single node bfloat16 allreduce performance by up to 20% for some data sizes and provides gating with the RCCL_GFX942_CHEAP_FENCE_OFF environment variable --- CMakeLists.txt | 9 +++++++ src/device/all_reduce.h | 22 +++++++++++------- src/device/gfx9_threadfence.h | 44 +++++++++++++++++++++++++++++++++++ src/device/primitives.h | 3 ++- src/device/prims_simple.h | 11 +++++---- src/device/rccl_metadata.h | 34 +++++++++++++++++++++++++++ src/enqueue.cc | 1 + src/include/comm.h | 1 + src/include/device.h | 2 +- src/init.cc | 3 +++ 10 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 src/device/gfx9_threadfence.h create mode 100644 src/device/rccl_metadata.h diff --git a/CMakeLists.txt b/CMakeLists.txt index c058cdfbd7..1942a1d7a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ option(PROFILE "Enable profiling" option(TIMETRACE "Enable time-trace during compilation" OFF) option(TRACE "Enable additional tracing" OFF) option(FAULT_INJECTION "Enable fault injection" ON) +option(DISABLE_CHEAP_THREADFENCE "Compile-time killswitch for simpler fence" OFF) # Default GPU architectures to build #================================================================================================== @@ -437,6 +438,7 @@ set(SRC_FILES src/device/broadcast.h src/device/common.h src/device/common_kernel.h + src/device/gfx9_threadfence.h src/device/op128.h src/device/primitives.h src/device/prims_ll128.h @@ -445,6 +447,7 @@ set(SRC_FILES src/device/reduce.h src/device/reduce_kernel.h src/device/reduce_scatter.h + src/device/rccl_metadata.h src/device/sendrecv.h src/device/common.cu src/device/onerank.cu @@ -1112,6 +1115,12 @@ if (FAULT_INJECTION) target_compile_definitions(rccl PRIVATE ENABLE_FAULT_INJECTION) message(STATUS "Fault injection enabled") endif() +if (DISABLE_CHEAP_THREADFENCE) + target_compile_definitions(rccl PRIVATE DISABLE_CHEAP_THREADFENCE) + message(STATUS "Cheap thread fence disabled") +else() + message(STATUS "Cheap thread fence enabled for some collectives/parameters") +endif() ## Set RCCL linked library directories target_link_directories(rccl PRIVATE ${ROCM_SMI_LIB_DIR}) diff --git a/src/device/all_reduce.h b/src/device/all_reduce.h index c1e9823763..6183577708 100644 --- a/src/device/all_reduce.h +++ b/src/device/all_reduce.h @@ -14,7 +14,7 @@ #endif namespace { - template + template #if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__) __device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { #else @@ -61,7 +61,7 @@ namespace { // Coverity reports that the callee treats &ring->next as an array. However, due to the use of // FanSymmetric<1>, only the first element is ever accessed, so it's fine. // coverity[callee_ptr_arith:FALSE] - Primitives, 0, Proto, 0> prims + Primitives, 0, Proto, 0, false, RCCLMetadata> prims (tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex, work); #if defined(ENABLE_NPKIT) @@ -562,15 +562,21 @@ namespace { #define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \ if(work->rcclUseOneSlice){ \ using Proto = ProtoSimple; \ - runRing(tid, nthreads, work); \ - } else{ \ + if(work->regUsed || work->netRegUsed || work->gfx942CheapFenceOff){ \ + runRing(tid, nthreads, work); \ + } \ + else { \ + runRing(tid, nthreads, work); \ + } \ + } \ + else{ \ using Proto = ProtoSimple; \ - runRing(tid, nthreads, work); \ + runRing(tid, nthreads, work); \ } #else #define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \ using Proto = ProtoSimple; \ - runRing(tid, nthreads, work); + runRing(tid, nthreads, work); #endif template @@ -1099,7 +1105,7 @@ struct RunWorkColl struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { - runRing(tid, nthreads, work); + runRing(tid, nthreads, work); } }; @@ -1113,7 +1119,7 @@ struct RunWorkColl { template struct RunWorkColl { __device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) { - runRing(tid, nthreads, work); + runRing(tid, nthreads, work); } }; diff --git a/src/device/gfx9_threadfence.h b/src/device/gfx9_threadfence.h new file mode 100644 index 0000000000..352cf6d291 --- /dev/null +++ b/src/device/gfx9_threadfence.h @@ -0,0 +1,44 @@ +/* +Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +#pragma once + +// This is only okay when the protocol buffer is allocated in uncached memory. +#if defined(__gfx942__) && defined(HIP_UNCACHED_MEMORY) && !defined(DISABLE_CHEAP_THREADFENCE) +#define RCCL_CHEAP_THREADFENCE_OK_SOMETIMES 1 +#else +#define RCCL_CHEAP_THREADFENCE_OK_SOMETIMES 0 +#endif + +template +inline __device__ void gfx9ThreadFence(); + +template<> +inline __device__ void gfx9ThreadFence() { + asm volatile("s_waitcnt lgkmcnt(0) vmcnt(0)"); + asm volatile("buffer_inv sc0 sc1"); +} + +template<> +inline __device__ void gfx9ThreadFence() { + __threadfence(); +} diff --git a/src/device/primitives.h b/src/device/primitives.h index 61edf9f921..79a93239c0 100644 --- a/src/device/primitives.h +++ b/src/device/primitives.h @@ -10,6 +10,7 @@ #include #include "reduce_kernel.h" // for reduction funcs +#include "rccl_metadata.h" #include "common_kernel.h" #include "common.h" @@ -136,7 +137,7 @@ struct FanSymmetric { }; // The primitives class. Specialized per protocol in the other headers. -template +template class Primitives; // Used by LL & LL128 to implement direct members in the naive way. diff --git a/src/device/prims_simple.h b/src/device/prims_simple.h index e453a31b95..58526495bd 100644 --- a/src/device/prims_simple.h +++ b/src/device/prims_simple.h @@ -10,6 +10,8 @@ #include "npkit/npkit.h" #endif +#include "device/gfx9_threadfence.h" +#include "device/rccl_metadata.h" #include "msccl/msccl_struct.h" #include "network/unpack/unpack.h" #include @@ -21,9 +23,9 @@ enum primsMode { }; template + int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts, bool isNetOffload, int Metadata> class Primitives< - T, RedOp, Fan, Direct, ProtoSimple, P2p, isNetOffload + T, RedOp, Fan, Direct, ProtoSimple, P2p, isNetOffload, Metadata > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; @@ -199,12 +201,13 @@ private: template inline __device__ void postPeer(bool dataStored) { - if (Send && (flags & RolePostSend) && dataStored) + if (Send && (flags & RolePostSend) && dataStored){ #ifdef __GFX9__ - __threadfence(); + gfx9ThreadFence(); #else __threadfence_system(); #endif + } if ((flags & Send*RolePostSend) && next_hdp_reg) STORE((unsigned int *)next_hdp_reg, 0x1); diff --git a/src/device/rccl_metadata.h b/src/device/rccl_metadata.h new file mode 100644 index 0000000000..dbd7330091 --- /dev/null +++ b/src/device/rccl_metadata.h @@ -0,0 +1,34 @@ +#pragma once +/* +Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +/* This file implements methods to extract metadata from an integer Metadata field passed in as a template parameter. Feel free to add additional fields below.*/ + +#define RCCL_METADATA_EMPTY 0 +#define RCCL_ONE_NODE_RING_SIMPLE (1 << 0) + +constexpr bool isOneNodeRingSimple(int metadata) { + return (metadata & RCCL_ONE_NODE_RING_SIMPLE) != 0; +} + +static_assert(isOneNodeRingSimple(RCCL_ONE_NODE_RING_SIMPLE), "RCCL_ONE_NODE_RING_SIMPLE should be set to (1 << 0)"); +static_assert(isOneNodeRingSimple(0) == 0, "RCCL_ONE_NODE_RING_SIMPLE should not be set when metadata is 0"); \ No newline at end of file diff --git a/src/enqueue.cc b/src/enqueue.cc index fe3a8d6334..3241652781 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -366,6 +366,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { devWork.redOpArgIsPtr = task->opDev.scalarArgIsPtr; devWork.oneNode = (comm->nNodes == 1); devWork.rcclUseOneSlice = comm->rcclUseOneSlice; + devWork.gfx942CheapFenceOff = comm->gfx942CheapFenceOff; devWork.isOneRPN = comm->isOneRPN; devWork.netRegUsed = devWork.regUsed = 0; devWork.profilerEnabled = ncclProfilerPluginLoaded() && (task->eActivationMask & ncclProfileKernelCh); diff --git a/src/include/comm.h b/src/include/comm.h index 8ff52af799..bd6ef4721d 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -488,6 +488,7 @@ struct ncclComm { int node; int nNodes; int rcclUseOneSlice; // RCCL: true if this comm is using one slice per primitive + int gfx942CheapFenceOff; // RCCL: true if gfx942 cheap fence is disabled int localRank; int localRanks; int maxLocalRanks; diff --git a/src/include/device.h b/src/include/device.h index 40c17b1d03..c7536f819e 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -305,7 +305,7 @@ struct alignas(16) ncclDevWorkColl { // nChannels == (channelHi - channelLo) + 1 uint32_t channelLo:8, channelHi:8; uint32_t nWarps:8; - uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1, rcclUseOneSlice:1; + uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1, rcclUseOneSlice:1, gfx942CheapFenceOff:1; uint32_t root:30, connIndex:2; uint16_t pivotA2ANumBiRings:15, profilerEnabled:1; void* recvbuff; diff --git a/src/init.cc b/src/init.cc index aa1db419ef..56c99feaef 100644 --- a/src/init.cc +++ b/src/init.cc @@ -103,6 +103,8 @@ RCCL_PARAM(MscclppThreshold, "MSCCLPP_THRESHOLD", (size_t)(16*1024*1024)); static constexpr int64_t defaultEnableMscclpp = 0; RCCL_PARAM(MscclppEnabled, "MSCCLPP_ENABLE", defaultEnableMscclpp); RCCL_PARAM(MscclppForceEnabled, "MSCCLPP_FORCE_ENABLE", 0); +// Turn off cheap fence for gfx942 +RCCL_PARAM(Gfx942CheapFenceOff, "GFX942_CHEAP_FENCE_OFF", 0); // GDRCOPY support: Off by default NCCL_PARAM(GdrCopyEnable, "GDRCOPY_ENABLE", 0); @@ -1365,6 +1367,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0)); // RCCL: Only use one slice per primitive on some single node gfx9xx systems comm->rcclUseOneSlice = !managed && nNodes == 1; + comm->gfx942CheapFenceOff = rcclParamGfx942CheapFenceOff(); if (managed && nNodes > 1) { // This forces the minimum channels to 24 allGather3Data[rank].nc = 6;