Remove nontemporality from stores, put in casts to global address space (#1982)

* Implements casting key loads and stores to address_space(1) so that vector global load and store instructions are emitted by the compiler instead of more costly flat loads and stores
* Removes nontemporality from some key stores for gfx950.
This commit is contained in:
alex-breslow-amd
2025-10-28 10:34:48 -07:00
committed by GitHub
orang tua f290e302d3
melakukan e69b11eba5
5 mengubah file dengan 74 tambahan dan 18 penghapusan
+1
Melihat File
@@ -464,6 +464,7 @@ set(SRC_FILES
src/device/reduce_kernel.h
src/device/reduce_scatter.h
src/device/rccl_metadata.h
src/device/rccl_ptr.h
src/device/sendrecv.h
src/device/common.cu
src/device/onerank.cu
+21 -8
Melihat File
@@ -9,14 +9,16 @@
#include <type_traits>
#include "device/rccl_ptr.h"
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
v0 = __builtin_nontemporal_load(ptr);
v1 = __builtin_nontemporal_load(ptr+1);
v0 = __builtin_nontemporal_load((u64_gptr) ptr);
v1 = __builtin_nontemporal_load((u64_gptr) ptr+1);
}
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) {
__builtin_nontemporal_store(v0, ptr);
__builtin_nontemporal_store(v1, ptr+1);
*((u64_gptr) ptr) = v0;
*((u64_gptr) ptr + 1) = v1;
}
inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) {
@@ -297,6 +299,18 @@ DEFINE_ld_st__size(8, uint64_t, b64, l)
#undef DEFINE_ld_st__size_space
#undef DEFINE_ld_st__size
#ifdef __gfx950__
__device__ __forceinline__ void store16global(uintptr_t addr, BytePack<16> value){
*(u64_gptr) addr = *(u64_gptr) value.u64;
*((u64_gptr) addr+1) = *((u64_gptr) value.u64+1);
}
#else
__device__ __forceinline__ void store16global(uintptr_t addr, BytePack<16> value){
__builtin_nontemporal_store(value.u64[0], (u64_gptr) addr);
__builtin_nontemporal_store(value.u64[1], (u64_gptr) addr + 1);
}
#endif
#define DEFINE_ld_st_16__space(space, addr_cxx_ty, addr_reg_ty) \
template<> \
__device__ __forceinline__ BytePack<16> ld_##space<16>(addr_cxx_ty addr) { \
@@ -308,14 +322,13 @@ DEFINE_ld_st__size(8, uint64_t, b64, l)
template<> \
__device__ __forceinline__ BytePack<16> ld_volatile_##space<16>(addr_cxx_ty addr) { \
BytePack<16> ans; \
ans.u64[0] = __builtin_nontemporal_load((uint64_t*)addr); \
ans.u64[1] = __builtin_nontemporal_load((uint64_t*)addr+1); \
*(u64_gptr) ans.u64 = __builtin_nontemporal_load((u64_gptr)addr); \
*((u64_gptr) ans.u64+1) = __builtin_nontemporal_load((u64_gptr)addr+1); \
return ans; \
} \
template<> \
__device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \
__builtin_nontemporal_store(value.u64[0], (uint64_t*)addr); \
__builtin_nontemporal_store(value.u64[1], (uint64_t*)addr+1); \
store16##space(addr, value); \
}
DEFINE_ld_st_16__space(global, uintptr_t, l)
+1
Melihat File
@@ -11,6 +11,7 @@
#include <type_traits>
#include "reduce_kernel.h" // for reduction funcs
#include "rccl_metadata.h"
#include "rccl_ptr.h"
#include "common_kernel.h"
#include "common.h"
+16 -10
Melihat File
@@ -10,6 +10,7 @@
#include "npkit/npkit.h"
#endif
#include "device/rccl_ptr.h"
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int Metadata, int Pipeline, int useAcc>
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pipeline, useAcc>:
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pipeline, useAcc>> {
@@ -166,8 +167,8 @@ private:
asm volatile ("global_load_b128 %0, %1, off glc slc dlc\n"
"s_waitcnt vmcnt(0)\n" : "=v"(i4.i4) : "v"(&src->i4));
#else
i4.v[0] = __builtin_nontemporal_load(src->v);
i4.v[1] = __builtin_nontemporal_load(src->v+1);
*((u64_gptr)i4.v) = __builtin_nontemporal_load((u64_gptr)src->v);
*((u64_gptr)i4.v + 1) = __builtin_nontemporal_load((u64_gptr)src->v+1);
#endif
#if defined(ENABLE_NPKIT) && (defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_ENTRY) && defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_EXIT) || defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME))
npkitWaitRecvSpins++;
@@ -236,8 +237,8 @@ private:
asm volatile ("global_load_b128 %0, %1, off glc slc dlc\n"
"s_waitcnt vmcnt(0)\n" : "=v"(line[i].i4) : "v"(&src->i4));
#else
line[i].v[0] = __builtin_nontemporal_load(src->v);
line[i].v[1] = __builtin_nontemporal_load(src->v+1);
line[i].v[0] = __builtin_nontemporal_load((u64_gptr)src->v);
line[i].v[1] = __builtin_nontemporal_load((u64_gptr)src->v+1);
#endif
#else
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4) : "memory");
@@ -266,8 +267,8 @@ private:
i4.flag1 = flag;
i4.data2 = (val >> 32);
i4.flag2 = flag;
__builtin_nontemporal_store(i4.v[0], dst->v);
__builtin_nontemporal_store(i4.v[1], dst->v+1);
*((u64_gptr) dst->v) = *((u64_gptr) i4.v);
*((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1);
#if defined(__gfx950__) && ROCM_VERSION < 70200
__builtin_amdgcn_fence(__ATOMIC_RELEASE, ""); // flush cache
#endif
@@ -292,25 +293,25 @@ private:
#ifdef __GFX11__
u1 = __atomic_load_n((uint8_t*)src, __ATOMIC_RELAXED);
#else
u1 = __builtin_nontemporal_load((uint8_t*)src);
u1 = __builtin_nontemporal_load((u8_gptr)src);
#endif
else if(sizeof(U) == 2)
#ifdef __GFX11__
u2 = __atomic_load_n((uint16_t*)src, __ATOMIC_RELAXED);
#else
u2 = __builtin_nontemporal_load((uint16_t*)src);
u2 = __builtin_nontemporal_load((u16_gptr)src);
#endif
else if(sizeof(U) == 4)
#ifdef __GFX11__
u4 = __atomic_load_n((uint32_t*)src, __ATOMIC_RELAXED);
#else
u4 = __builtin_nontemporal_load((uint32_t*)src);
u4 = __builtin_nontemporal_load((u32_gptr)src);
#endif
else
#ifdef __GFX11__
u8 = __atomic_load_n((uint64_t*)src, __ATOMIC_RELAXED);
#else
u8 = __builtin_nontemporal_load((uint64_t*)src);
u8 = __builtin_nontemporal_load((u64_gptr)src);
#endif
#else
if(sizeof(U) == 1)
@@ -397,7 +398,12 @@ private:
}
};
__device__ void storeData(T *dst, uint64_t val, int eltN) {
if (__all((reinterpret_cast<uintptr_t>(dst) & (sizeof(T) - 1)) == 0 && sizeof(T) * eltN == sizeof(val))){
*((u64_gptr) dst) = val;
return;
}
union {
uint64_t u8;
T elt[EltPerLine];
+35
Melihat File
@@ -0,0 +1,35 @@
#pragma once
/*
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
// Defines a series of global address space pointers. Casting to these
// pointers in hot code paths should improve performance since global
// aperture vector instrutions like global_store_dwordx4 can be used.
// These are cheaper than flat loads and stores.
// Verify the intended effect by inspecting assembly. If you see
// flat in the name of the emitted instruction, something is wrong.
using u64_gptr = __attribute__((address_space(1))) uint64_t*;
using u32_gptr = __attribute__((address_space(1))) uint32_t*;
using u16_gptr = __attribute__((address_space(1))) uint16_t*;
using u8_gptr = __attribute__((address_space(1))) uint8_t*;