/************************************************************************* * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ #ifndef NCCL_COMMON_KERNEL_H_ #define NCCL_COMMON_KERNEL_H_ #include "devcomm.h" #include #include #include // Define min for ssize_t static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } typedef uint64_t PackType; #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) template struct MULTI { __device__ PackType operator()(const PackType x, const PackType y) const { return FUNC()(x, y); } }; #else // unpack x and y to elements of type T and apply FUNC to each element template struct MULTI { __device__ PackType operator()(const PackType x, const PackType y) const; }; template struct MULTI { static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), "PackType must be twice the size of uint32_t."); union converter { PackType storage; struct { uint32_t a, b; }; }; __device__ PackType operator()(const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; // for char, we do these as vector ops cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); return cr.storage; } }; template struct MULTI { static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), "PackType must be twice the size of uint32_t."); union converter { PackType storage; struct { uint32_t a, b; }; }; __device__ PackType operator()(const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; // for char, we do these as vector ops cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); return cr.storage; } }; template struct MULTI { static_assert(sizeof(PackType) == 2 * sizeof(int32_t), "PackType must be twice the size of int."); union converter { PackType storage; struct { int32_t a, b; }; }; __device__ PackType operator()(const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); return cr.storage; } }; template struct MULTI { static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), "PackType must be twice the size of int."); union converter { PackType storage; struct { uint32_t a, b; }; }; __device__ PackType operator()(const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); return cr.storage; } }; template struct MULTI { static_assert(sizeof(PackType) == 4 * sizeof(half), "PackType must be four times the size of half."); struct PackHalf2 { half2 a, b; }; __device__ PackType operator()(const PackType x, const PackType y) const { struct PackHalf2 cx, cy, cr; cx = *(reinterpret_cast(&x)); cy = *(reinterpret_cast(&y)); cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); return *(reinterpret_cast(&cr)); } }; template struct MULTI { static_assert(sizeof(PackType) == 2 * sizeof(float), "PackType must be twice the size of float."); union converter { PackType storage; struct { float a, b; }; }; __device__ PackType operator()(const PackType x, const PackType y) const { converter cx, cy, cr; cx.storage = x; cy.storage = y; cr.a = FUNC()(cx.a, cy.a); cr.b = FUNC()(cx.b, cy.b); return cr.storage; } }; template struct MULTI { static_assert(sizeof(PackType) == sizeof(double), "PackType must be the same size as double."); __device__ PackType operator()(const PackType x, const PackType y) const { double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y)); return __double_as_longlong(rv); } }; template struct MULTI { static_assert(sizeof(PackType) == sizeof(uint64_t), "PackType must be the same size as uint64_t."); __device__ PackType operator()(const PackType x, const PackType y) const { uint64_t rv = FUNC()(x, y); return rv; } }; template struct MULTI { static_assert(sizeof(PackType) == sizeof(int64_t), "PackType must be the same size as int64_t."); __device__ PackType operator()(const PackType x, const PackType y) const { int64_t rv = FUNC()((int64_t)x, (int64_t)y); return rv; } }; #endif //defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) template inline __device__ T vFetch(const volatile T* ptr) { return *ptr; } template inline __device__ void vStore(volatile T* ptr, const T val) { *ptr = val; } #if CUDART_VERSION < 9000 && !(defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)) template<> inline __device__ half vFetch(const volatile half* ptr) { half r; r.x = ptr->x; return r; } template<> inline __device__ void vStore(volatile half* ptr, const half val) { ptr->x = val.x; } #else template<> inline __device__ half vFetch(const volatile half* ptr) { half r; r = ((half*)ptr)[0]; return r; } template<> inline __device__ void vStore(volatile half* ptr, const half val) { ((half*)ptr)[0] = val; } template<> inline __device__ rccl_bfloat16 vFetch(const volatile rccl_bfloat16* ptr) { rccl_bfloat16 r; r.data = ptr->data; return r; } template<> inline __device__ void vStore(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) { ptr->data = val.data; } #endif typedef ulong2 Pack128; template struct MULTI128 { __device__ void operator()(Pack128& x, Pack128& y) { x.x = MULTI()(x.x, y.x); x.y = MULTI()(x.y, y.y); } }; inline __device__ void Fetch128(Pack128& v, const Pack128* p) { #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) v.x = p->x; v.y = p->y; #else asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); #endif } inline __device__ void Store128(Pack128* p, Pack128& v) { #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) p->x = v.x; p->y = v.y; #else asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory"); #endif } template __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) { const int inc = nw * UNROLL * WARP_SIZE; int offset = w * UNROLL * WARP_SIZE + t; const T* srcs[MAXSRCS]; for (int i=0; i __device__ void ReduceCopy128bMulti(const int w, const int nw, const int t, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) { const int inc = nw * UNROLL * WARP_SIZE; int offset = w * UNROLL * WARP_SIZE + t; const Pack128* srcs[MAXSRCS]; for (int i=0; i()(vals[u], vals2[u]); } #pragma unroll 1 for (int i=MINSRCS; i()(vals[u], vals2[u]); } // Store for (int i = 0; i < MINDSTS; i++) { for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); } #pragma unroll 1 for (int i=MINDSTS; i __device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(int32_t); } #define PACKELEMS (sizeof(Pack128) / sizeof(T)) #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) // Multiply UNROLL by 2 if single source/single destination #define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1)) #else // Try to limit consecutive load/stores to 8. // Use UNROLL 8 when we have a single source and a single destination, 4 otherwise #define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS))) #endif template __device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads, int nsrcs, const T** srcs, int ndsts, T** dsts, int N) { int Nrem = N; if (Nrem <= 0) return; int w = tid / WARP_SIZE; // Warp number int nw = nthreads / WARP_SIZE; // Number of warps int t = tid % WARP_SIZE; // Thread (inside the warp) // Check that all is 16B aligned. If not don't use 16B load/stores. int align = 0; #pragma unroll for (int i=0; i(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); Nrem -= Nelem; if (Nrem == 0) return; offset += Nelem; // slightly less optimized for section when we don't have full unrolling Npack = Nrem / PACKELEMS; Nelem = Npack * PACKELEMS; ReduceCopy128bMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); Nrem -= Nelem; if (Nrem == 0) return; offset += Nelem; } // unrolled, by-type (mostly for unaligned buffers) int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down ReduceCopyMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem); Nrem -= Nelem; if (Nrem == 0) return; offset += Nelem; // no unroll, by type. Should finish what's remaining. ReduceCopyMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem); } #endif // COMMON_KERNEL_H_