Files
rocm-systems/src/device/sendrecv.h
T
2025-02-25 16:13:48 -05:00

277 wiersze
11 KiB
C++

/*************************************************************************
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "device.h"
#include "collectives.h"
#include "primitives.h"
#if defined(ENABLE_NPKIT)
#include "npkit/npkit.h"
#endif
template<typename T, typename RedOp>
struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
static_assert(sizeof(T)==1, "SendRecv only works on single byte types T.");
template<typename Proto>
__device__ void runSend(int tid, int tn, int group, struct ncclDevWorkP2p* work) {
size_t bytes = work->sendBytes;
int chunkSize = u32fp8Decode(work->sendChunkSize_u32fp8);
#if defined(ENABLE_NPKIT)
bool isNpKitThread = (tid == 0);
int npKitCtxIdx = blockIdx.x + group;
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, NPKIT_GET_CPU_TIMESTAMP_FROM_BLOCK,
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
Primitives<T, RedOp, FanAsymmetric<0, 1>, 0, Proto, 1>
prims(tid, tn, nullptr, &work->sendRank, work->sendAddr, nullptr,
/*redOpArg(ignored)=*/0, group, work->sendConnIndex, work->sendConnIndex, nullptr,
/*userBufferMode=*/work->sendRegistered, ncclShmem.comm.p2pChunkSize);
#if defined(ENABLE_NPKIT)
if (isNpKitThread) {
prims.npKitCtxIdx = npKitCtxIdx;
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_SEND_RECV_SEND_ENTRY)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_SEND_RECV_SEND_ENTRY, bytes*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
prims.npKitDataProcessTotalTime = 0;
}
#endif
size_t cursor = 0;
do {
int n = min(size_t(chunkSize), bytes-cursor);
prims.directSend(cursor, cursor, n);
cursor += n;
} while (cursor < bytes && work->sendRegistered == 0);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_SEND_RECV_SEND_EXIT)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_SEND_RECV_SEND_EXIT, bytes*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
}
template<typename Proto>
__device__ void runRecv(int tid, int tn, int group, struct ncclDevWorkP2p* work) {
size_t bytes = work->recvBytes;
int chunkSize = u32fp8Decode(work->recvChunkSize_u32fp8);
#if defined(ENABLE_NPKIT)
bool isNpKitThread = (tid == 0);
int npKitCtxIdx = blockIdx.x + group;
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, NPKIT_GET_CPU_TIMESTAMP_FROM_BLOCK,
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
Primitives<T, RedOp, FanAsymmetric<1, 0>, 0, Proto, 1>
prims(tid, tn, &work->recvRank, nullptr, nullptr, work->recvAddr,
/*redOpArg(ignored)=*/0, group, work->recvConnIndex, work->recvConnIndex, nullptr,
/*userBufferMode=*/work->recvRegistered, ncclShmem.comm.p2pChunkSize);
#if defined(ENABLE_NPKIT)
if (isNpKitThread) {
prims.npKitCtxIdx = npKitCtxIdx;
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_SEND_RECV_RECV_ENTRY)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_SEND_RECV_RECV_ENTRY, bytes*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
prims.npKitDataProcessTotalTime = 0;
}
#endif
size_t cursor = 0;
do {
int n = min(size_t(chunkSize), bytes-cursor);
prims.directRecv(cursor, n);
cursor += n;
} while (cursor < bytes && work->recvRegistered == 0);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_SEND_RECV_RECV_EXIT)
if (isNpKitThread) {
NpKit::CollectGpuEvent(NPKIT_EVENT_SEND_RECV_RECV_EXIT, bytes*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
}
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) && !defined(__gfx950__)
__device__ void run() {
#else
__device__ __attribute__((noinline)) void run() {
#endif
const int tid = threadIdx.x;
const int tn = blockDim.x;
const int wid = tid/WARP_SIZE;
const int nWarps = tn/WARP_SIZE;
const int lane = tid%WARP_SIZE;
struct Shared {
uint32_t workSendMask; // bitmasks of which work indices have send/recv
uint32_t workRecvMask;
};
Shared* shared = (Shared*)ncclScratchForWarp(0);
struct ncclDevWorkP2p* works = (ncclDevWorkP2p*)ncclShmem.workStorage;
int nWorks = ncclShmem.nWorks;
if (wid == 0) {
// Modify the memory range of each work[] to reflect this channel's
// partition of the work. Since integer divides are very heavy it's
// best to do them all in one warp.
int workIx = lane%16;
int isSend = lane < 16 ? 0 : 1;
bool hasWork = false;
if (workIx < nWorks) {
struct ncclDevWorkP2p* work = &works[workIx];
size_t bytes = isSend ? work->sendBytes : work->recvBytes;
int nParts = isSend ? work->nSendChannels : work->nRecvChannels;
int part = ncclP2pChannelToPart(work->nP2pChannels, work->channelBase, ncclShmem.channelId, ncclShmem.comm.p2pnChannelsPerPeer);
hasWork = (part < nParts);
if (nParts != 0) {
size_t partBeg, partEnd;
ncclP2pPartBounds(nParts, part, bytes, &partBeg, &partEnd);
(isSend ? work->sendAddr : work->recvAddr) = (char*)(isSend ? work->sendAddr : work->recvAddr) + partBeg;
(isSend ? work->sendBytes : work->recvBytes) = partEnd - partBeg;
}
}
uint32_t mask = __ballot(hasWork);
if (lane == 0) {
shared->workSendMask = mask>>16;
shared->workRecvMask = mask & 0xffff;
}
}
// The fastest way to compute a warp uniform division x/y in [0,32) is to
// use each lane to guess a solution and count the ones that don't exceed
// the numerator:
// __popc(__ballot_sync(~0u, y*(lane+1) <= x))
// That takes 1/3 the time of standard division and about 3/4 the time of
// approximate floating point division:
// __float2int_rd(__fdividef(float(x),float(y))).
// nWarpPerWork = nWarps/nWorks
int nWarpPerWork = __popcll(__ballot(nWorks*(lane+1) <= nWarps));
int nRecvWarpPerWork = nWarpPerWork/2;
int nSendWarpPerWork = nWarpPerWork - nRecvWarpPerWork;
// This might reduce nWarpPerWork which is probably desirable. It is better
// to have a balanced number of reading and writing threads even if that
// leaves warps unused.
nWarpPerWork = nSendWarpPerWork + nRecvWarpPerWork;
// The work index this warp belongs to: workIx = wid/nWarpPerWork
int workIx = __popcll(__ballot((lane+1)*nWarpPerWork <= wid));
__syncthreads(); // Wait for works[] and shared->* to be updated by warp=0
uint32_t workSendMask = shared->workSendMask;
uint32_t workRecvMask = shared->workRecvMask;
__syncthreads(); // release scratch space used by shared->*
if (nWorks <= workIx) return;
// Thread range for whole work (send & recv combined)
int subtid = tid - workIx*nWarpPerWork*WARP_SIZE;
int subtn = nWarpPerWork*WARP_SIZE;
// A send primtive of sufficient size requires 2 cuda barrier ids.
constexpr int nSendWarpsForExtraGroup = NCCL_SIMPLE_EXTRA_GROUP_IF_NTHREADS_GE/WARP_SIZE;
// Count up all group ids used below this workIx:
int group, extra;
// Each recv gets one group id:
group = __popcll(workRecvMask & ((1<<workIx)-1));
// Sends accompanying recvs get one and maybe an extra:
extra = (nSendWarpPerWork >= nSendWarpsForExtraGroup) ? 1 : 0;
group += __popcll((workSendMask & workRecvMask) & ((1<<workIx)-1))*(1+extra);
// Sends without recvs use more warps so compute extra accordingly:
extra = (nWarpPerWork >= nSendWarpsForExtraGroup) ? 1 : 0;
group += __popcll((workSendMask & ~workRecvMask) & ((1<<workIx)-1))*(1+extra);
struct ncclDevWorkP2p* work = &works[workIx];
bool hasSend = 1 & (workSendMask>>workIx);
bool hasRecv = 1 & (workRecvMask>>workIx);
bool isCopy = work->sendRank == ncclShmem.comm.rank;
bool isSend = !hasRecv || (hasSend && subtid < nSendWarpPerWork*WARP_SIZE);
if (!isCopy && hasSend && hasRecv) {
// Translate thread ids to reflect just this send or recv as opposed to whole work.
if (isSend) {
subtn = nSendWarpPerWork*WARP_SIZE;
} else {
subtid -= nSendWarpPerWork*WARP_SIZE;
subtn = nRecvWarpPerWork*WARP_SIZE;
group += 1 + (nSendWarpPerWork >= nSendWarpsForExtraGroup ? 1 : 0);
}
}
if (isCopy) {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
reduceCopy<COLL_UNROLL*2, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
(subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes);
#else
reduceCopy<COLL_UNROLL, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>
(subtid, subtn, 0, nullptr, false, 1, &work->sendAddr, 1, &work->recvAddr, (ssize_t)work->sendBytes);
#endif
} else if (isSend) {
if (work->sendProtoLL) {
runSend<ProtoLL>(subtid, subtn, group, work);
} else {
#if defined(__gfx90a__)
runSend<ProtoSimple<1,1,8>>(subtid, subtn, group, work);
#elif defined(__gfx908__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
runSend<ProtoSimple<1,1,4>>(subtid, subtn, group, work);
#else
runSend<ProtoSimple<1,1>>(subtid, subtn, group, work);
#endif
}
} else {
if (work->recvProtoLL) {
runRecv<ProtoLL>(subtid, subtn, group, work);
} else {
#if defined(__gfx90a__)
runRecv<ProtoSimple<1,1,8>>(subtid, subtn, group, work);
#elif defined(__gfx908__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
runRecv<ProtoSimple<1,1,4>>(subtid, subtn, group, work);
#else
runRecv<ProtoSimple<1,1>>(subtid, subtn, group, work);
#endif
}
}
}
};