82afb2bcfe
* Internal RCCL/NCCL functionality exposed when RCCL_EXPOSE_STATIC is enabled * Algo/protocol/max channels can be obtained with the new RCCL API * Introduce rccl_static and rccl_static_inline macros to work around invisible functions in core source files like enqueue.cc * Add usage example in topo-explorer tool
644 řádky
24 KiB
C++
644 řádky
24 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
|
* Modifications Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
|
|
#ifndef NCCL_DEVICE_COMMON_H_
|
|
#define NCCL_DEVICE_COMMON_H_
|
|
|
|
#include "collectives.h"
|
|
#include "device.h"
|
|
#include "op128.h"
|
|
#include "reduce_kernel.h"
|
|
#include "device_table.h"
|
|
#include "network/unpack/unpack_defs.h"
|
|
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
|
|
|
#define __syncwarp()
|
|
|
|
#ifdef __GFX9__
|
|
#define STORE(DST, SRC) \
|
|
{ __atomic_store_n((DST), (SRC), __ATOMIC_RELAXED); }
|
|
#else
|
|
#define STORE(DST, SRC) \
|
|
{ __atomic_store_n((DST), (SRC), __ATOMIC_SEQ_CST); }
|
|
#endif
|
|
|
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1200__) || defined(__gfx1201__)
|
|
#define __trace_hwreg()
|
|
#else
|
|
#define __trace_hwreg() \
|
|
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (collTrace->data_0));
|
|
#endif
|
|
#ifdef ENABLE_COLLTRACE
|
|
#define INC_COLL_TRACE \
|
|
uint32_t pos = __hip_atomic_fetch_add(&ncclShmem.collTraceTail->tail, 1, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_WORKGROUP)%COLLTRACE_NUM_ITEMS; \
|
|
struct ncclCollTrace* collTrace = ncclShmem.collTrace+pos; \
|
|
collTrace->timeStamp = wall_clock64(); \
|
|
collTrace->bid = blockIdx.x; \
|
|
collTrace->tid = threadIdx.x; \
|
|
collTrace->channelId = ncclShmem.channelId;
|
|
// TODO: switch to atomicInc after llvm crash is fixed
|
|
// uint32_t pos = atomicInc(&ncclShmem.collTraceTail->tail, COLLTRACE_NUM_ITEMS)
|
|
|
|
#define traceKernelLaunch(launch_type, ix) { \
|
|
INC_COLL_TRACE \
|
|
collTrace->funcIndex = ncclShmem.funcId; \
|
|
__trace_hwreg()\
|
|
collTrace->batchIx = ix; \
|
|
if (ncclShmem.workType == ncclDevWorkTypeP2p) { \
|
|
struct ncclDevWorkP2p *p2pWork = (struct ncclDevWorkP2p*)ncclShmem.workStorage; \
|
|
collTrace->p2p.sendRank = p2pWork->sendRank; \
|
|
collTrace->p2p.recvRank = p2pWork->recvRank; \
|
|
collTrace->p2p.nSendChannels = p2pWork->nSendChannels; \
|
|
collTrace->p2p.nRecvChannels = p2pWork->nRecvChannels; \
|
|
collTrace->p2p.channelBase = p2pWork->channelBase; \
|
|
collTrace->p2p.sendConnIndex = p2pWork->sendConnIndex; \
|
|
collTrace->p2p.recvConnIndex = p2pWork->recvConnIndex; \
|
|
collTrace->p2p.sendProtoLL = p2pWork->sendProtoLL; \
|
|
collTrace->p2p.recvProtoLL = p2pWork->recvProtoLL; \
|
|
collTrace->p2p.sendRegistered = p2pWork->sendRegistered; \
|
|
collTrace->p2p.recvRegistered = p2pWork->recvRegistered; \
|
|
collTrace->p2pOpCount[0] = p2pWork->sendOpCount; \
|
|
collTrace->p2pOpCount[1] = p2pWork->recvOpCount; \
|
|
collTrace->type = (launch_type) | ncclCollTraceP2pElemType; \
|
|
} else if (ncclShmem.workType == ncclDevWorkTypeColl) { \
|
|
struct ncclDevWorkColl *collWork = (struct ncclDevWorkColl*)ncclShmem.workStorage; \
|
|
collTrace->coll.nWarps = collWork->nWarps; \
|
|
collTrace->coll.nChannels = collWork->channelHi-collWork->channelLo+1; \
|
|
collTrace->coll.bid = ncclShmem.channelId - collWork->channelLo; \
|
|
collTrace->coll.root = collWork->root; \
|
|
collTrace->opCount = collWork->opCount; \
|
|
collTrace->type = (launch_type) | ncclCollTraceCollElemType; \
|
|
} \
|
|
}
|
|
#define traceKernelEnd(end_type) { \
|
|
INC_COLL_TRACE \
|
|
if (ncclShmem.workType == ncclDevWorkTypeP2p) { \
|
|
struct ncclDevWorkP2p *p2pWork = (struct ncclDevWorkP2p*)ncclShmem.workStorage; \
|
|
collTrace->p2pOpCount[0] = p2pWork->sendOpCount; \
|
|
collTrace->p2pOpCount[1] = p2pWork->recvOpCount; \
|
|
collTrace->type = (end_type) | ncclCollTraceP2pElemType; \
|
|
} else if (ncclShmem.workType == ncclDevWorkTypeColl) { \
|
|
struct ncclDevWorkColl *collWork = (struct ncclDevWorkColl*)ncclShmem.workStorage; \
|
|
collTrace->opCount = collWork->opCount; \
|
|
collTrace->type = (end_type) | ncclCollTraceCollElemType; \
|
|
} \
|
|
}
|
|
#define traceData(data2, data4, data8_0, data8_1) { \
|
|
INC_COLL_TRACE \
|
|
collTrace->funcIndex = data2; \
|
|
collTrace->data_0 = data4; \
|
|
collTrace->opCount = data8_0; \
|
|
collTrace->data_1 = data8_1; \
|
|
collTrace->type = ncclCollTraceDataType; \
|
|
}
|
|
#else
|
|
#define traceKernelLaunch(launch_type, batchIx)
|
|
#define traceKernelEnd(end_type)
|
|
#define traceData(data2, data4, data8_0, data8_1)
|
|
#endif
|
|
|
|
#if __CUDA_ARCH__ >= 700
|
|
// __grid_constant__ appears to break cuda-gdb
|
|
//#define NCCL_GRID_CONSTANT __grid_constant__
|
|
#define NCCL_GRID_CONSTANT
|
|
#else
|
|
#define NCCL_GRID_CONSTANT
|
|
#endif
|
|
|
|
struct ncclShmemGroup {
|
|
ncclConnInfo *recvConns[NCCL_MAX_ARITY];
|
|
ncclConnInfo *sendConns[NCCL_MAX_ARITY];
|
|
void* userInput;
|
|
void* userOutput;
|
|
void* srcs[NCCL_MAX_ARITY+1];
|
|
void* dsts[NCCL_MAX_ARITY+1];
|
|
uint64_t barrier;
|
|
union {
|
|
unpackGroupShmem unpack;
|
|
} devicePlugin;
|
|
int32_t dstSizes[NCCL_MAX_ARITY+1];
|
|
};
|
|
|
|
struct ncclShmemData {
|
|
struct ncclDevKernelArgs args;
|
|
int channelId;
|
|
int aborted;
|
|
alignas(16) struct ncclDevComm comm;
|
|
alignas(16) struct ncclDevChannel channel;
|
|
|
|
int batchIx, nextBatchIx;
|
|
enum ncclDevWorkType workType;
|
|
uint8_t directMode;
|
|
uint16_t funcId;
|
|
int nWorks;
|
|
int workSize;
|
|
uint32_t workConsumed;
|
|
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
|
|
uint64_t redOpArgs[NCCL_MAX_NVLS_ARITY+1];
|
|
|
|
alignas(16) char workStorage[1024];
|
|
|
|
alignas(16) union {
|
|
unpackShmem unpack;
|
|
} devicePlugin;
|
|
#ifdef ENABLE_COLLTRACE
|
|
struct ncclCollTrace* collTrace;
|
|
union ncclCollTraceTail* collTraceTail;
|
|
#endif
|
|
#ifdef ENABLE_PROFILING
|
|
struct ncclProf prof;
|
|
#endif
|
|
#ifdef ENABLE_FAULT_INJECTION
|
|
uint64_t faults;
|
|
#endif
|
|
};
|
|
|
|
extern __shared__ ncclShmemData ncclShmem;
|
|
#if __CUDA_ARCH__ >= 700
|
|
extern __shared__ ulong2 ncclShmemPerWarp[/*ncclShmemDynamicSize()/sizeof(ulong2)*/];
|
|
#else
|
|
extern __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)];
|
|
#endif
|
|
|
|
#ifdef ENABLE_FAULT_INJECTION
|
|
__device__ inline void insert_random_delay_per_warp() {
|
|
if ((ncclShmem.faults & RANDOM_DELAY_ON_WARP_START) && (threadIdx.x%WARP_SIZE == 0)) {
|
|
switch ((wall_clock64()>>(threadIdx.x/WARP_SIZE*2))&0x3) {
|
|
case 0:
|
|
__builtin_amdgcn_s_sleep(0);
|
|
break;
|
|
case 1:
|
|
__builtin_amdgcn_s_sleep(8);
|
|
break;
|
|
case 2:
|
|
__builtin_amdgcn_s_sleep(16);
|
|
break;
|
|
case 3:
|
|
default:
|
|
__builtin_amdgcn_s_sleep(32);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
__device__ inline void* ncclScratchForWarp(int warp) {
|
|
return (char*)ncclShmemPerWarp + warp*ncclShmemScratchWarpSize();
|
|
}
|
|
|
|
__device__ inline void barrier_sync(int name) {
|
|
#if 0
|
|
asm volatile("barrier.sync %0;" :: "r"(name) : "memory");
|
|
#else
|
|
asm volatile("barrier.sync.aligned %0;" :: "r"(name) : "memory");
|
|
#endif
|
|
}
|
|
__device__ inline void barrier_sync(int name, int nThreads) {
|
|
#if 0
|
|
asm volatile("barrier.sync %0, %1;" :: "r"(name), "r"(nThreads) : "memory");
|
|
#else
|
|
asm volatile("barrier.sync.aligned %0, %1;" :: "r"(name), "r"(nThreads) : "memory");
|
|
#endif
|
|
}
|
|
__device__ inline void barrier_sync_aligned(int name) {
|
|
asm volatile("barrier.sync.aligned %0;" :: "r"(name) : "memory");
|
|
}
|
|
__device__ inline void barrier_sync_aligned(int name, int nThreads) {
|
|
asm volatile("barrier.sync.aligned %0, %1;" :: "r"(name), "r"(nThreads) : "memory");
|
|
}
|
|
|
|
__device__ inline bool barrier_red_or(bool vote, int name) {
|
|
int ans;
|
|
asm volatile("{ .reg .pred p;"
|
|
" setp.ne.s32 p, %1, 0;"
|
|
" barrier.red.or.pred p, %2, p; "
|
|
" selp.s32 %0, 1, 0, p; }"
|
|
: "=r"(ans) : "r"((int)vote), "r"(name) : "memory");
|
|
return bool(ans);
|
|
}
|
|
__device__ inline bool barrier_red_or(bool vote, int name, int nThreads) {
|
|
int ans;
|
|
asm volatile("{ .reg .pred p;"
|
|
" setp.ne.s32 p, %1, 0;"
|
|
" barrier.red.or.pred p, %2, %3, p; "
|
|
" selp.s32 %0, 1, 0, p; }"
|
|
: "=r"(ans) : "r"((int)vote), "r"(name), "r"(nThreads) : "memory");
|
|
return bool(ans);
|
|
}
|
|
__device__ inline bool barrier_red_or_aligned(bool vote, int name) {
|
|
int ans;
|
|
asm volatile("{ .reg .pred p;"
|
|
" setp.ne.s32 p, %1, 0;"
|
|
" barrier.red.or.pred.aligned p, %2, p; "
|
|
" selp.s32 %0, 1, 0, p; }"
|
|
: "=r"(ans) : "r"((int)vote), "r"(name) : "memory");
|
|
return bool(ans);
|
|
}
|
|
__device__ inline bool barrier_red_or_aligned(bool vote, int name, int nThreads) {
|
|
int ans;
|
|
asm("{ .reg .pred p;"
|
|
" setp.ne.s32 p, %1, 0;"
|
|
" barrier.red.or.pred.aligned p, %2, %3, p; "
|
|
" selp.s32 %0, 1, 0, p; }"
|
|
: "=r"(ans) : "r"((int)vote), "r"(name), "r"(nThreads) : "memory");
|
|
return bool(ans);
|
|
}
|
|
|
|
#ifdef ENABLE_PROFILING
|
|
#define __insert_timestamp(line_num) do { \
|
|
if (ncclShmem.prof.count < PROFILE_NUM_ITEMS) { \
|
|
ncclShmem.prof.elem[ncclShmem.prof.count].line = line_num; \
|
|
ncclShmem.prof.elem[ncclShmem.prof.count].timeStamp = wall_clock64(); \
|
|
ncclShmem.prof.count++; \
|
|
} \
|
|
} while(0);
|
|
#else
|
|
#define __insert_timestamp(line_num)
|
|
#endif
|
|
|
|
// Copy 16-byte aligned data. You must call with at least `(bytes+15)/16` threads.
|
|
inline __device__ void copyToShmem16(int tid, void* dst, void const* src, int bytes) {
|
|
int offset = 16*tid;
|
|
if (offset < bytes) {
|
|
ulong2 *src2, *dst2;
|
|
src2 = (ulong2*)((char const*)src + offset);
|
|
dst2 = (ulong2*)((char*)dst + offset);
|
|
dst2->x = src2->x;
|
|
dst2->y = src2->y;
|
|
}
|
|
}
|
|
|
|
// Must run with at least 64 threads
|
|
__device__ __forceinline__ void loadWorkBatchToShmem(
|
|
int tid, int tn, struct ncclDevKernelArgs const* args, int batchIx
|
|
) {
|
|
int lane = tid%WARP_SIZE;
|
|
int workCursor = 0; // num works written in previous loop iterations.
|
|
while (true) {
|
|
struct ncclDevWorkBatch batch = ((struct ncclDevWorkBatch*)(args+1))[batchIx];
|
|
|
|
// fnsOfBitset[n] = index of n'th set bit in batch.offsetBitset.
|
|
// PTX has instruction "fns" (find n-th set) but it expands to a lot of SASS,
|
|
// since we know all lanes will be querying the same bitmask we can compute
|
|
// much faster using shared memory.
|
|
uint8_t* fnsOfBitset = (uint8_t*)ncclScratchForWarp(threadIdx.x/WARP_SIZE);
|
|
int nWorks = 0;
|
|
__syncwarp();
|
|
|
|
if (WARP_SIZE == 64) {
|
|
if (uint64_t(batch.offsetBitset) & (1ull<<lane)) {
|
|
int nWorksBelow = __popc(uint64_t(batch.offsetBitset) & ((1ull<<lane)-1));
|
|
fnsOfBitset[nWorksBelow] = lane;
|
|
}
|
|
nWorks = __popc(uint64_t(batch.offsetBitset));
|
|
} else {
|
|
// WARP_SIZE == 32
|
|
if (uint32_t(batch.offsetBitset) & (1u<<lane)) {
|
|
int nWorksBelow = __popc(uint32_t(batch.offsetBitset) & ((1u<<lane)-1));
|
|
fnsOfBitset[nWorksBelow] = lane;
|
|
}
|
|
int nWorksLow32 = __popc(uint32_t(batch.offsetBitset)); // just of low 32 bits
|
|
if (uint32_t(batch.offsetBitset>>32) & (1u<<lane)) {
|
|
int nWorksBelow = nWorksLow32;
|
|
nWorksBelow += __popc(uint32_t(batch.offsetBitset>>32) & ((1u<<lane)-1));
|
|
fnsOfBitset[nWorksBelow] = 32 + lane;
|
|
}
|
|
nWorks = nWorksLow32 + __popc(uint32_t(batch.offsetBitset>>32)); // add high 32 bits
|
|
}
|
|
|
|
int workSize;
|
|
int nPacks; // total number of packs loaded, each pack is 16 bytes
|
|
int packInWork; // my pack index within work struct
|
|
int dstWork; // my work index in contiguous destination shmem
|
|
switch (batch.workType) {
|
|
case (int)ncclDevWorkTypeP2p:
|
|
workSize = sizeof(struct ncclDevWorkP2p);
|
|
nPacks = nWorks*(workSize/16);
|
|
packInWork = tid%(workSize/16);
|
|
dstWork = tid/(workSize/16);
|
|
break;
|
|
case (int)ncclDevWorkTypeColl:
|
|
workSize = sizeof(struct ncclDevWorkColl);
|
|
nPacks = nWorks*(workSize/16);
|
|
packInWork = tid%(workSize/16);
|
|
dstWork = tid/(workSize/16);
|
|
break;
|
|
case (int)ncclDevWorkTypeCollReg:
|
|
default:
|
|
workSize = sizeof(struct ncclDevWorkCollReg);
|
|
nPacks = nWorks*(workSize/16);
|
|
packInWork = tid%(workSize/16);
|
|
dstWork = tid/(workSize/16);
|
|
break;
|
|
}
|
|
if (tid == 0) {
|
|
ncclShmem.workSize = workSize;
|
|
ncclShmem.workConsumed = batch.offsetBase + (64-__clzll(batch.offsetBitset))*workSize;
|
|
}
|
|
// We deliberately replicate these div and mod calculations into the case
|
|
// blocks above so that they get constant divisor optimizations by the compiler.
|
|
// packInWork = tid%(workSize/16);
|
|
// dstWork = tid/(workSize/16);
|
|
|
|
// We can only assume we have 64 threads, which means we can read at most 1024 bytes
|
|
// here which is the per batch maximum.
|
|
if (tid < nPacks) {
|
|
int srcWork = fnsOfBitset[dstWork]; // find n'th set bit in batch.offsetBitset
|
|
ulong2 tmp;
|
|
// The loads done in these two cases must be kept separate since we are
|
|
// relying on the compiler to use "ld.param" in the first one. The parameter
|
|
// space is not generically addressable, so any attempt to load through
|
|
// a pointer that *might* be parameter space backed will cause the
|
|
// compiler to spill the parameter struct (4K!) to each thread's local space
|
|
// before creating a pointer (to the spill) and decimate perf.
|
|
//
|
|
// An example of what not to do would be the following:
|
|
//
|
|
// if (condition) {
|
|
// // The compiler could spill parameter_variable to local space and take
|
|
// // the address of that, since when src is loaded below it could also
|
|
// // be global space.
|
|
// src = ¶meter_variable;
|
|
// } else {
|
|
// src = &global_variable;
|
|
// }
|
|
// memcpy(dst, src, n);
|
|
if (ncclShmem.args.workStorageType == ncclDevWorkStorageTypeArgs) {
|
|
char* src = (char*)args + (batch.offsetBase + srcWork*workSize + packInWork*16);
|
|
tmp = *(ulong2*)src; // becomes ld.param.v2.u64
|
|
}
|
|
if (ncclShmem.args.workStorageType != ncclDevWorkStorageTypeArgs) {
|
|
char* src = (char*)ncclShmem.args.workBuf + ((batch.offsetBase + srcWork*workSize + packInWork*16) & ncclShmem.args.workMask);
|
|
tmp = *(ulong2*)src; // becomes ld.v2.u64
|
|
}
|
|
char* dst = ncclShmem.workStorage;
|
|
dst += (workCursor + dstWork)*workSize + packInWork*16;
|
|
*(ulong2*)dst = tmp;
|
|
}
|
|
workCursor += nWorks;
|
|
|
|
if (batch.nextExtends) {
|
|
batchIx += batch.nextJump;
|
|
tid -= 64; // Rotate threads so we use the next two warps for next batch struct.
|
|
if (tid < 0) tid += tn;
|
|
} else {
|
|
if (tid == 0) {
|
|
ncclShmem.batchIx = batchIx;
|
|
ncclShmem.nextBatchIx = (batch.nextJump == 0) ? -1 : batchIx + batch.nextJump;
|
|
ncclShmem.workType = (enum ncclDevWorkType)batch.workType;
|
|
ncclShmem.nWorks = workCursor;
|
|
ncclShmem.funcId = batch.funcId;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
|
|
struct RunWorkColl {
|
|
__device__ void run(int tid, int tn, struct ncclDevWorkColl* work) {
|
|
// Put NOT IMPLEMENTED behavior here.
|
|
}
|
|
};
|
|
|
|
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
|
|
struct RunWorkBatch;
|
|
|
|
// Specialized for P2p in sendrecv.h
|
|
template<typename T, typename RedOp>
|
|
struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE>;
|
|
|
|
// Specialized here for non-P2p (Coll and CollReg)
|
|
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int COLL_UNROLL>
|
|
struct RunWorkBatch {
|
|
// This __forceinline__ is necessary. The compiler was inserting a function call
|
|
// here from the LL ncclKernel.
|
|
__device__ __forceinline__ void run() {
|
|
int tid = threadIdx.x;
|
|
int tn = blockDim.x;
|
|
|
|
if (RedOpArg<RedOp>::ArgUsed) {
|
|
int nWorks = ncclShmem.nWorks;
|
|
for (int w=tid; w < nWorks; w += tn) {
|
|
struct ncclDevWorkColl* work = (ncclDevWorkColl*)(ncclShmem.workStorage + w*ncclShmem.workSize);
|
|
if (work->redOpArgIsPtr) {
|
|
work->redOpArg = RedOpArg<RedOp>::loadArg(reinterpret_cast<void*>(work->redOpArg));
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
#pragma unroll 1
|
|
for (int w=0; w < ncclShmem.nWorks; w++) {
|
|
struct ncclDevWorkColl* work = (struct ncclDevWorkColl*)(ncclShmem.workStorage + w*ncclShmem.workSize);
|
|
if (w != 0) {
|
|
struct ncclDevWorkColl* workPrev = (struct ncclDevWorkColl*)(ncclShmem.workStorage + (w-1)*ncclShmem.workSize);
|
|
if (work->nWarps != workPrev->nWarps) __syncthreads();
|
|
}
|
|
int subtn = work->nWarps*WARP_SIZE;
|
|
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
|
|
// However, the code ensures that the participation is on a per-warp basis.
|
|
// coverity[device_thread_diverged:FALSE]
|
|
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, COLL_UNROLL>().run(tid, subtn, work);
|
|
}
|
|
}
|
|
};
|
|
|
|
template<int SpecializedFnId, typename SpecializedRunWorkBatch, bool COLLTRACE, int COLL_UNROLL>
|
|
__device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* args) {
|
|
const int tid = threadIdx.x;
|
|
int tn = blockDim.x;
|
|
int x = tid;
|
|
int total = 0, y;
|
|
int num = MAXCHANNELS/64 > 0 ? MAXCHANNELS/64 : 1;
|
|
|
|
// Copy kernel args to shmem and then only read those. Otherwise the compiler
|
|
// will end up putting the args into thread local stack which is very wasteful.
|
|
if (tid < sizeof(ncclDevKernelArgs)/sizeof(uint32_t)) {
|
|
((uint32_t*)&ncclShmem.args)[tid] = ((uint32_t*)args)[tid];
|
|
}
|
|
|
|
// To map blockId to channelId, we need the n'th set bit of channelMask which
|
|
// is the inverse of counting the number of set bits among the the first n.
|
|
// PTX has the fns instruction which does this but is extremely slow. We can
|
|
// do better when we know all threads are querying the same bitmask.
|
|
switch (tid/WARP_SIZE) {
|
|
case 0:
|
|
//ncclShmem.channelId = blockIdx.x;
|
|
for (int i = 0; i < num; i++) {
|
|
if (args->channelMask.masks[i] & (1ull<<x)) {
|
|
y = __popcll(args->channelMask.masks[i] & ((1ull<<x)-1));
|
|
y = total + y;
|
|
if (blockIdx.x == y) {
|
|
ncclShmem.channelId = x + total;
|
|
break;
|
|
}
|
|
}
|
|
if (WARP_SIZE < 64) {
|
|
x = WARP_SIZE + tid;
|
|
if (args->channelMask.masks[i] & (1ull<<x)) {
|
|
y = __popcll(args->channelMask.masks[i] & ((1ull<<x)-1));
|
|
y = y + total;
|
|
if (blockIdx.x == y) {
|
|
ncclShmem.channelId = x + total;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
total = total + __popcll(args->channelMask.masks[i]);
|
|
}
|
|
break;
|
|
case 1:
|
|
if (tid < WARP_SIZE + NCCL_MAX_GROUPS)
|
|
ncclShmem.groups[tid-WARP_SIZE].barrier = 0;
|
|
break;
|
|
case 2:
|
|
#ifdef ENABLE_FAULT_INJECTION
|
|
/* load faults injection before first sync threads */
|
|
if (tid == 2*WARP_SIZE) ncclShmem.faults = args->comm->faults;
|
|
#endif
|
|
break;
|
|
case 3:
|
|
/* set abort flag to 0 */
|
|
if (tid == 3*WARP_SIZE) ncclShmem.aborted = 0;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
__syncthreads(); // publish ncclShmem.{args, channelId}
|
|
|
|
// Use first 2 warps to load comm and channel, and reamaining load work batch.
|
|
switch (tid/WARP_SIZE) {
|
|
case 0:
|
|
{ void* dst = &ncclShmem.comm;
|
|
void* src = ncclShmem.args.comm;
|
|
int bytes = sizeof(ncclDevComm);
|
|
static_assert(sizeof(ncclDevComm) <= 16*WARP_SIZE, "ncclDevComm cannot be loaded by a single warp in one insn.");
|
|
copyToShmem16(tid, dst, src, bytes);
|
|
} break;
|
|
case 1:
|
|
{ // Get address of channel without incurring indirect load from ncclDevComm::channels
|
|
void* dst = &ncclShmem.channel;
|
|
void* src = &((ncclDevCommAndChannels*)ncclShmem.args.comm)->channels[ncclShmem.channelId];
|
|
int bytes = sizeof(ncclDevChannel);
|
|
static_assert(sizeof(ncclDevChannel) <= 16*WARP_SIZE, "ncclDevChannel cannot be loaded by a single warp in one insn.");
|
|
copyToShmem16(tid-WARP_SIZE, dst, src, bytes);
|
|
} break;
|
|
default:
|
|
{ int subtid = tid - 2*WARP_SIZE;
|
|
int subtn = tn - 2*WARP_SIZE;
|
|
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
|
|
// However, the code ensures that the participation is on a per-warp basis.
|
|
// coverity[device_thread_diverged:FALSE]
|
|
loadWorkBatchToShmem(subtid, subtn, args, /*batchIx=*/blockIdx.x);
|
|
} break;
|
|
}
|
|
#ifdef ENABLE_COLLTRACE
|
|
if (tid == 0) {
|
|
ncclShmem.collTrace = args->comm->collTrace + COLLTRACE_NUM_ITEMS*ncclShmem.channelId;
|
|
ncclShmem.collTraceTail = args->comm->collTraceTail + ncclShmem.channelId;
|
|
}
|
|
#endif
|
|
__syncthreads(); // publish shmem
|
|
|
|
#ifdef ENABLE_PROFILING
|
|
if (tid == 0) {
|
|
ncclShmem.prof.count = 0;
|
|
ncclShmem.prof.seq = ncclShmem.comm.devProf[blockIdx.x].seq;
|
|
}
|
|
#endif
|
|
if (tid == 0) __insert_timestamp(__LINE__);
|
|
if (COLLTRACE && tid%WARP_SIZE == 0) traceKernelLaunch(ncclCollTraceKernelLaunchType, 0);
|
|
|
|
if (tid == 0 && ncclShmem.args.workStorageType == ncclDevWorkStorageTypeFifo) {
|
|
// ncclShmem.workConsumed written by loadWorkBatchToShmem before __syncthreads()
|
|
ncclShmem.comm.workConsumed[ncclShmem.channelId] = ncclShmem.workConsumed;
|
|
}
|
|
|
|
while (true) {
|
|
if (tid == 0) __insert_timestamp(__LINE__);
|
|
|
|
if (0 <= SpecializedFnId && ncclShmem.funcId == (unsigned)SpecializedFnId) {
|
|
SpecializedRunWorkBatch().run();
|
|
} else {
|
|
#ifdef USE_INDIRECT_FUNCTION_CALL
|
|
if (COLL_UNROLL == 1)
|
|
ncclDevFuncTable_1[ncclShmem.funcId]();
|
|
else if (COLL_UNROLL == 2)
|
|
ncclDevFuncTable_2[ncclShmem.funcId]();
|
|
else
|
|
ncclDevFuncTable_4[ncclShmem.funcId]();
|
|
#else
|
|
if (COLL_UNROLL == 1)
|
|
NCCL_CALL_FUNCTIONS_1(ncclShmem.funcId);
|
|
else if (COLL_UNROLL == 2)
|
|
NCCL_CALL_FUNCTIONS_2(ncclShmem.funcId);
|
|
else
|
|
NCCL_CALL_FUNCTIONS_4(ncclShmem.funcId);
|
|
#endif
|
|
}
|
|
|
|
if (ncclShmem.nextBatchIx == -1) break;
|
|
int batchIx = ncclShmem.nextBatchIx;
|
|
__syncthreads();
|
|
switch (tid/WARP_SIZE) {
|
|
case 1:
|
|
if (tid < WARP_SIZE + NCCL_MAX_GROUPS)
|
|
ncclShmem.groups[tid-WARP_SIZE].barrier = 0;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
loadWorkBatchToShmem(tid%WARP_SIZE, tn, args, batchIx);
|
|
__syncthreads();
|
|
|
|
// Check whether the last operation was aborted and make sure all threads exit
|
|
bool aborted = false;
|
|
if (tid == 0) aborted = *ncclShmem.comm.abortFlag;
|
|
aborted = __any(aborted); // publish ncclShmem.work
|
|
if (tid == 0 && ncclShmem.args.workStorageType == ncclDevWorkStorageTypeFifo) {
|
|
// ncclShmem.workConsumed written by loadWorkBatchToShmem before barrier_red_or()
|
|
ncclShmem.comm.workConsumed[ncclShmem.channelId] = ncclShmem.workConsumed;
|
|
}
|
|
if (aborted) break;
|
|
|
|
if (COLLTRACE && tid%WARP_SIZE == 0) traceKernelLaunch(ncclCollTraceCollLaunchType, batchIx);
|
|
}
|
|
if (COLLTRACE && tid%WARP_SIZE == 0) traceKernelEnd(ncclCollTraceKernelEndType);
|
|
|
|
#ifdef ENABLE_PROFILING
|
|
if (ncclShmem.comm.devProf->seq < PROFILE_NUM_LAUNCHES) {
|
|
__syncthreads();
|
|
copyToShmem16(tid, ncclShmem.comm.devProf+MAXCHANNELS*ncclShmem.prof.seq+blockIdx.x, &ncclShmem.prof, sizeof(struct ncclProf));
|
|
if (tid == 0) ncclShmem.comm.devProf[blockIdx.x].seq++;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
|
__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
|
__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
|
#ifdef ENABLE_COLLTRACE
|
|
__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
|
__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
|
__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
|
|
#endif
|
|
|
|
#ifdef USE_INDIRECT_FUNCTION_CALL
|
|
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \
|
|
__device__ void ncclDevFunc_##suffix() { \
|
|
RunWorkBatch<coll, ty, redop<ty>, algo, proto, unroll>().run(); \
|
|
}
|
|
#else
|
|
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, unroll) \
|
|
__device__ __attribute__((noinline)) void ncclDevFunc_##suffix() { \
|
|
RunWorkBatch<coll, ty, redop<ty>, algo, proto, unroll>().run(); \
|
|
}
|
|
#endif
|
|
|
|
#endif
|