Remove nontemporality from stores, put in casts to global address space (#1982)
* Implements casting key loads and stores to address_space(1) so that vector global load and store instructions are emitted by the compiler instead of more costly flat loads and stores * Removes nontemporality from some key stores for gfx950.
This commit is contained in:
committed by
GitHub
orang tua
f290e302d3
melakukan
e69b11eba5
@@ -464,6 +464,7 @@ set(SRC_FILES
|
||||
src/device/reduce_kernel.h
|
||||
src/device/reduce_scatter.h
|
||||
src/device/rccl_metadata.h
|
||||
src/device/rccl_ptr.h
|
||||
src/device/sendrecv.h
|
||||
src/device/common.cu
|
||||
src/device/onerank.cu
|
||||
|
||||
+21
-8
@@ -9,14 +9,16 @@
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "device/rccl_ptr.h"
|
||||
|
||||
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
|
||||
v0 = __builtin_nontemporal_load(ptr);
|
||||
v1 = __builtin_nontemporal_load(ptr+1);
|
||||
v0 = __builtin_nontemporal_load((u64_gptr) ptr);
|
||||
v1 = __builtin_nontemporal_load((u64_gptr) ptr+1);
|
||||
}
|
||||
|
||||
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) {
|
||||
__builtin_nontemporal_store(v0, ptr);
|
||||
__builtin_nontemporal_store(v1, ptr+1);
|
||||
*((u64_gptr) ptr) = v0;
|
||||
*((u64_gptr) ptr + 1) = v1;
|
||||
}
|
||||
|
||||
inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) {
|
||||
@@ -297,6 +299,18 @@ DEFINE_ld_st__size(8, uint64_t, b64, l)
|
||||
#undef DEFINE_ld_st__size_space
|
||||
#undef DEFINE_ld_st__size
|
||||
|
||||
#ifdef __gfx950__
|
||||
__device__ __forceinline__ void store16global(uintptr_t addr, BytePack<16> value){
|
||||
*(u64_gptr) addr = *(u64_gptr) value.u64;
|
||||
*((u64_gptr) addr+1) = *((u64_gptr) value.u64+1);
|
||||
}
|
||||
#else
|
||||
__device__ __forceinline__ void store16global(uintptr_t addr, BytePack<16> value){
|
||||
__builtin_nontemporal_store(value.u64[0], (u64_gptr) addr);
|
||||
__builtin_nontemporal_store(value.u64[1], (u64_gptr) addr + 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
#define DEFINE_ld_st_16__space(space, addr_cxx_ty, addr_reg_ty) \
|
||||
template<> \
|
||||
__device__ __forceinline__ BytePack<16> ld_##space<16>(addr_cxx_ty addr) { \
|
||||
@@ -308,14 +322,13 @@ DEFINE_ld_st__size(8, uint64_t, b64, l)
|
||||
template<> \
|
||||
__device__ __forceinline__ BytePack<16> ld_volatile_##space<16>(addr_cxx_ty addr) { \
|
||||
BytePack<16> ans; \
|
||||
ans.u64[0] = __builtin_nontemporal_load((uint64_t*)addr); \
|
||||
ans.u64[1] = __builtin_nontemporal_load((uint64_t*)addr+1); \
|
||||
*(u64_gptr) ans.u64 = __builtin_nontemporal_load((u64_gptr)addr); \
|
||||
*((u64_gptr) ans.u64+1) = __builtin_nontemporal_load((u64_gptr)addr+1); \
|
||||
return ans; \
|
||||
} \
|
||||
template<> \
|
||||
__device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \
|
||||
__builtin_nontemporal_store(value.u64[0], (uint64_t*)addr); \
|
||||
__builtin_nontemporal_store(value.u64[1], (uint64_t*)addr+1); \
|
||||
store16##space(addr, value); \
|
||||
}
|
||||
|
||||
DEFINE_ld_st_16__space(global, uintptr_t, l)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <type_traits>
|
||||
#include "reduce_kernel.h" // for reduction funcs
|
||||
#include "rccl_metadata.h"
|
||||
#include "rccl_ptr.h"
|
||||
#include "common_kernel.h"
|
||||
#include "common.h"
|
||||
|
||||
|
||||
+16
-10
@@ -10,6 +10,7 @@
|
||||
#include "npkit/npkit.h"
|
||||
#endif
|
||||
|
||||
#include "device/rccl_ptr.h"
|
||||
template<typename T, typename RedOp, typename Fan, int Direct, int P2p, bool isNetOffload, int Metadata, int Pipeline, int useAcc>
|
||||
class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pipeline, useAcc>:
|
||||
public PrimitivesWithoutDirect<Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pipeline, useAcc>> {
|
||||
@@ -166,8 +167,8 @@ private:
|
||||
asm volatile ("global_load_b128 %0, %1, off glc slc dlc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(i4.i4) : "v"(&src->i4));
|
||||
#else
|
||||
i4.v[0] = __builtin_nontemporal_load(src->v);
|
||||
i4.v[1] = __builtin_nontemporal_load(src->v+1);
|
||||
*((u64_gptr)i4.v) = __builtin_nontemporal_load((u64_gptr)src->v);
|
||||
*((u64_gptr)i4.v + 1) = __builtin_nontemporal_load((u64_gptr)src->v+1);
|
||||
#endif
|
||||
#if defined(ENABLE_NPKIT) && (defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_ENTRY) && defined(ENABLE_NPKIT_EVENT_PRIM_LL_DATA_PROCESS_EXIT) || defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME))
|
||||
npkitWaitRecvSpins++;
|
||||
@@ -236,8 +237,8 @@ private:
|
||||
asm volatile ("global_load_b128 %0, %1, off glc slc dlc\n"
|
||||
"s_waitcnt vmcnt(0)\n" : "=v"(line[i].i4) : "v"(&src->i4));
|
||||
#else
|
||||
line[i].v[0] = __builtin_nontemporal_load(src->v);
|
||||
line[i].v[1] = __builtin_nontemporal_load(src->v+1);
|
||||
line[i].v[0] = __builtin_nontemporal_load((u64_gptr)src->v);
|
||||
line[i].v[1] = __builtin_nontemporal_load((u64_gptr)src->v+1);
|
||||
#endif
|
||||
#else
|
||||
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(line[i].data1), "=r"(line[i].flag1), "=r"(line[i].data2), "=r"(line[i].flag2) : "l"(&src->i4) : "memory");
|
||||
@@ -266,8 +267,8 @@ private:
|
||||
i4.flag1 = flag;
|
||||
i4.data2 = (val >> 32);
|
||||
i4.flag2 = flag;
|
||||
__builtin_nontemporal_store(i4.v[0], dst->v);
|
||||
__builtin_nontemporal_store(i4.v[1], dst->v+1);
|
||||
*((u64_gptr) dst->v) = *((u64_gptr) i4.v);
|
||||
*((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1);
|
||||
#if defined(__gfx950__) && ROCM_VERSION < 70200
|
||||
__builtin_amdgcn_fence(__ATOMIC_RELEASE, ""); // flush cache
|
||||
#endif
|
||||
@@ -292,25 +293,25 @@ private:
|
||||
#ifdef __GFX11__
|
||||
u1 = __atomic_load_n((uint8_t*)src, __ATOMIC_RELAXED);
|
||||
#else
|
||||
u1 = __builtin_nontemporal_load((uint8_t*)src);
|
||||
u1 = __builtin_nontemporal_load((u8_gptr)src);
|
||||
#endif
|
||||
else if(sizeof(U) == 2)
|
||||
#ifdef __GFX11__
|
||||
u2 = __atomic_load_n((uint16_t*)src, __ATOMIC_RELAXED);
|
||||
#else
|
||||
u2 = __builtin_nontemporal_load((uint16_t*)src);
|
||||
u2 = __builtin_nontemporal_load((u16_gptr)src);
|
||||
#endif
|
||||
else if(sizeof(U) == 4)
|
||||
#ifdef __GFX11__
|
||||
u4 = __atomic_load_n((uint32_t*)src, __ATOMIC_RELAXED);
|
||||
#else
|
||||
u4 = __builtin_nontemporal_load((uint32_t*)src);
|
||||
u4 = __builtin_nontemporal_load((u32_gptr)src);
|
||||
#endif
|
||||
else
|
||||
#ifdef __GFX11__
|
||||
u8 = __atomic_load_n((uint64_t*)src, __ATOMIC_RELAXED);
|
||||
#else
|
||||
u8 = __builtin_nontemporal_load((uint64_t*)src);
|
||||
u8 = __builtin_nontemporal_load((u64_gptr)src);
|
||||
#endif
|
||||
#else
|
||||
if(sizeof(U) == 1)
|
||||
@@ -397,7 +398,12 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
__device__ void storeData(T *dst, uint64_t val, int eltN) {
|
||||
if (__all((reinterpret_cast<uintptr_t>(dst) & (sizeof(T) - 1)) == 0 && sizeof(T) * eltN == sizeof(val))){
|
||||
*((u64_gptr) dst) = val;
|
||||
return;
|
||||
}
|
||||
union {
|
||||
uint64_t u8;
|
||||
T elt[EltPerLine];
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
/*
|
||||
Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
// Defines a series of global address space pointers. Casting to these
|
||||
// pointers in hot code paths should improve performance since global
|
||||
// aperture vector instrutions like global_store_dwordx4 can be used.
|
||||
// These are cheaper than flat loads and stores.
|
||||
// Verify the intended effect by inspecting assembly. If you see
|
||||
// flat in the name of the emitted instruction, something is wrong.
|
||||
using u64_gptr = __attribute__((address_space(1))) uint64_t*;
|
||||
using u32_gptr = __attribute__((address_space(1))) uint32_t*;
|
||||
using u16_gptr = __attribute__((address_space(1))) uint16_t*;
|
||||
using u8_gptr = __attribute__((address_space(1))) uint8_t*;
|
||||
|
||||
Reference in New Issue
Block a user