From e69b11eba5538ddf27cf41a91894e1e64efae77e Mon Sep 17 00:00:00 2001 From: alex-breslow-amd Date: Tue, 28 Oct 2025 10:34:48 -0700 Subject: [PATCH] Remove nontemporality from stores, put in casts to global address space (#1982) * Implements casting key loads and stores to address_space(1) so that vector global load and store instructions are emitted by the compiler instead of more costly flat loads and stores * Removes nontemporality from some key stores for gfx950. --- CMakeLists.txt | 1 + src/device/op128.h | 29 +++++++++++++++++++++-------- src/device/primitives.h | 1 + src/device/prims_ll.h | 26 ++++++++++++++++---------- src/device/rccl_ptr.h | 35 +++++++++++++++++++++++++++++++++++ 5 files changed, 74 insertions(+), 18 deletions(-) create mode 100644 src/device/rccl_ptr.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f2c90e46a..a226bb3c21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -464,6 +464,7 @@ set(SRC_FILES src/device/reduce_kernel.h src/device/reduce_scatter.h src/device/rccl_metadata.h + src/device/rccl_ptr.h src/device/sendrecv.h src/device/common.cu src/device/onerank.cu diff --git a/src/device/op128.h b/src/device/op128.h index 21335ac3c4..18c003d80b 100644 --- a/src/device/op128.h +++ b/src/device/op128.h @@ -9,14 +9,16 @@ #include +#include "device/rccl_ptr.h" + inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) { - v0 = __builtin_nontemporal_load(ptr); - v1 = __builtin_nontemporal_load(ptr+1); + v0 = __builtin_nontemporal_load((u64_gptr) ptr); + v1 = __builtin_nontemporal_load((u64_gptr) ptr+1); } inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) { - __builtin_nontemporal_store(v0, ptr); - __builtin_nontemporal_store(v1, ptr+1); + *((u64_gptr) ptr) = v0; + *((u64_gptr) ptr + 1) = v1; } inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) { @@ -297,6 +299,18 @@ DEFINE_ld_st__size(8, uint64_t, b64, l) #undef DEFINE_ld_st__size_space #undef DEFINE_ld_st__size +#ifdef __gfx950__ +__device__ __forceinline__ void store16global(uintptr_t addr, BytePack<16> value){ + *(u64_gptr) addr = *(u64_gptr) value.u64; + *((u64_gptr) addr+1) = *((u64_gptr) value.u64+1); +} +#else +__device__ __forceinline__ void store16global(uintptr_t addr, BytePack<16> value){ + __builtin_nontemporal_store(value.u64[0], (u64_gptr) addr); + __builtin_nontemporal_store(value.u64[1], (u64_gptr) addr + 1); +} +#endif + #define DEFINE_ld_st_16__space(space, addr_cxx_ty, addr_reg_ty) \ template<> \ __device__ __forceinline__ BytePack<16> ld_##space<16>(addr_cxx_ty addr) { \ @@ -308,14 +322,13 @@ DEFINE_ld_st__size(8, uint64_t, b64, l) template<> \ __device__ __forceinline__ BytePack<16> ld_volatile_##space<16>(addr_cxx_ty addr) { \ BytePack<16> ans; \ - ans.u64[0] = __builtin_nontemporal_load((uint64_t*)addr); \ - ans.u64[1] = __builtin_nontemporal_load((uint64_t*)addr+1); \ + *(u64_gptr) ans.u64 = __builtin_nontemporal_load((u64_gptr)addr); \ + *((u64_gptr) ans.u64+1) = __builtin_nontemporal_load((u64_gptr)addr+1); \ return ans; \ } \ template<> \ __device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \ - __builtin_nontemporal_store(value.u64[0], (uint64_t*)addr); \ - __builtin_nontemporal_store(value.u64[1], (uint64_t*)addr+1); \ + store16##space(addr, value); \ } DEFINE_ld_st_16__space(global, uintptr_t, l) diff --git a/src/device/primitives.h b/src/device/primitives.h index 19bccc8e2b..28ecc0d36c 100644 --- a/src/device/primitives.h +++ b/src/device/primitives.h @@ -11,6 +11,7 @@ #include #include "reduce_kernel.h" // for reduction funcs #include "rccl_metadata.h" +#include "rccl_ptr.h" #include "common_kernel.h" #include "common.h" diff --git a/src/device/prims_ll.h b/src/device/prims_ll.h index 703860a642..1d6ba34822 100644 --- a/src/device/prims_ll.h +++ b/src/device/prims_ll.h @@ -10,6 +10,7 @@ #include "npkit/npkit.h" #endif +#include "device/rccl_ptr.h" template class Primitives: public PrimitivesWithoutDirect> { @@ -166,8 +167,8 @@ private: asm volatile ("global_load_b128 %0, %1, off glc slc dlc\n" "s_waitcnt vmcnt(0)\n" : "=v"(i4.i4) : "v"(&src->i4)); #else - i4.v[0] = __builtin_nontemporal_load(src->v); - i4.v[1] = __builtin_nontemporal_load(src->v+1); + *((u64_gptr)i4.v) = __builtin_nontemporal_load((u64_gptr)src->v); + *((u64_gptr)i4.v + 1) = __builtin_nontemporal_load((u64_gptr)src->v+1); #endif #if defined(ENABLE_NPKIT) && (defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_ENTRY) && defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_EXIT) || defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)) npkitWaitRecvSpins++; @@ -236,8 +237,8 @@ private: asm volatile ("global_load_b128 %0, %1, off glc slc dlc\n" "s_waitcnt vmcnt(0)\n" : "=v"(line[i].i4) : "v"(&src->i4)); #else - line[i].v[0] = __builtin_nontemporal_load(src->v); - line[i].v[1] = __builtin_nontemporal_load(src->v+1); + line[i].v[0] = __builtin_nontemporal_load((u64_gptr)src->v); + line[i].v[1] = __builtin_nontemporal_load((u64_gptr)src->v+1); #endif #else asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4) : "memory"); @@ -266,8 +267,8 @@ private: i4.flag1 = flag; i4.data2 = (val >> 32); i4.flag2 = flag; - __builtin_nontemporal_store(i4.v[0], dst->v); - __builtin_nontemporal_store(i4.v[1], dst->v+1); + *((u64_gptr) dst->v) = *((u64_gptr) i4.v); + *((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1); #if defined(__gfx950__) && ROCM_VERSION < 70200 __builtin_amdgcn_fence(__ATOMIC_RELEASE, ""); // flush cache #endif @@ -292,25 +293,25 @@ private: #ifdef __GFX11__ u1 = __atomic_load_n((uint8_t*)src, __ATOMIC_RELAXED); #else - u1 = __builtin_nontemporal_load((uint8_t*)src); + u1 = __builtin_nontemporal_load((u8_gptr)src); #endif else if(sizeof(U) == 2) #ifdef __GFX11__ u2 = __atomic_load_n((uint16_t*)src, __ATOMIC_RELAXED); #else - u2 = __builtin_nontemporal_load((uint16_t*)src); + u2 = __builtin_nontemporal_load((u16_gptr)src); #endif else if(sizeof(U) == 4) #ifdef __GFX11__ u4 = __atomic_load_n((uint32_t*)src, __ATOMIC_RELAXED); #else - u4 = __builtin_nontemporal_load((uint32_t*)src); + u4 = __builtin_nontemporal_load((u32_gptr)src); #endif else #ifdef __GFX11__ u8 = __atomic_load_n((uint64_t*)src, __ATOMIC_RELAXED); #else - u8 = __builtin_nontemporal_load((uint64_t*)src); + u8 = __builtin_nontemporal_load((u64_gptr)src); #endif #else if(sizeof(U) == 1) @@ -397,7 +398,12 @@ private: } }; + __device__ void storeData(T *dst, uint64_t val, int eltN) { + if (__all((reinterpret_cast(dst) & (sizeof(T) - 1)) == 0 && sizeof(T) * eltN == sizeof(val))){ + *((u64_gptr) dst) = val; + return; + } union { uint64_t u8; T elt[EltPerLine]; diff --git a/src/device/rccl_ptr.h b/src/device/rccl_ptr.h new file mode 100644 index 0000000000..51c9b32e1d --- /dev/null +++ b/src/device/rccl_ptr.h @@ -0,0 +1,35 @@ +#pragma once + +/* +Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +// Defines a series of global address space pointers. Casting to these +// pointers in hot code paths should improve performance since global +// aperture vector instrutions like global_store_dwordx4 can be used. +// These are cheaper than flat loads and stores. +// Verify the intended effect by inspecting assembly. If you see +// flat in the name of the emitted instruction, something is wrong. +using u64_gptr = __attribute__((address_space(1))) uint64_t*; +using u32_gptr = __attribute__((address_space(1))) uint32_t*; +using u16_gptr = __attribute__((address_space(1))) uint16_t*; +using u8_gptr = __attribute__((address_space(1))) uint8_t*; +