2
0
Ficheiros
rocm-systems/src/device/prims_simple.h
T
2024-01-24 15:25:33 -08:00

864 linhas
38 KiB
C++

/*************************************************************************
* Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
*
* See LICENSE.txt for license information
************************************************************************/
#if defined(ENABLE_NPKIT)
#include "npkit/npkit.h"
#endif
#include "msccl/msccl_struct.h"
#include "network/unpack/unpack.h"
template<typename T, typename RedOp, typename Fan, int Direct,
int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, int MultimemSrcs, int MultimemDsts>
class Primitives<
T, RedOp, Fan, Direct, ProtoSimple<SlicePerChunk, StepPerSlice, Unroll, MultimemSrcs, MultimemDsts>, P2p
> {
static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend;
static constexpr int Input=0, Output=1;
static constexpr int RoleInput = 0x01,
RoleOutput = 0x02,
RoleWaitRecv = 0x04,
RoleWaitSend = 0x08,
RolePostSend = 0x10,
RolePostRecv = 0x20,
Aborted = 0x40,
OffsFifoEnabled = 0x80,
SizesFifoEnabled = 0x100,
DirectWrite = 0x200,
DirectRead = 0x400,
ThreadsSynced = 0x800,
NvlsMinPolling = 0x1000,
NetDeviceUnpack = 0x2000,
AnyNetDeviceUnpack = 0x4000,
NvlsDirectRead = 0x8000,
NvlsDirectWrite = 0x10000;
const int tid, tidInBlock;
const int nthreads;
int nworkers;
const int stepSize;
Fan fan;
int index; // Peer index I'm responsible for
int flags;
int group;
uint64_t step;
int *connOffsFifoPtr; // (flags & OffsFifoEnabled)
union {
T *userBuff; // (flags & (RoleInput|RoleOutput))
T *connEltsFifo; // !(flags & (RoleInput|RoleOutput))
};
union {
int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled)
T *directBuff; // !(flags & SizesFifoEnabled)
};
uint64_t *connStepPtr;
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
uint64_t* barriers;
uint64_t* barrier_next;
uint32_t* next_hdp_reg;
void* mhandle;
void* netDeviceHandle;
#if defined(ENABLE_NPKIT)
public:
int npKitCtxIdx = 0;
uint64_t npKitDataProcessEntryTime = 0;
uint64_t npKitDataProcessExitTime = 0;
uint64_t npKitDataProcessTotalTime = 0;
private:
#endif
// Don't use barrier 0 as it's used by the final sync
inline __device__ void barrier() {
flags |= ThreadsSynced;
if (nthreads == WARP_SIZE)
__syncwarp();
else
barrier_by_group();
}
inline __device__ void subBarrier() {
barrier();
}
inline __device__ bool checkAbort(int &spins) {
spins++;
if (!(flags & Aborted) && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
if (__atomic_load_n(ncclShmem.comm.abortFlag, __ATOMIC_SEQ_CST)) {
flags |= Aborted;
ncclShmem.aborted = 1;
}
spins = 0;
}
return flags & Aborted;
}
inline __device__ uint64_t loadStepValue(uint64_t* ptr) {
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
if (flags & NvlsMinPolling) {
uint64_t ans;
asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr)));
return ans;
}
#endif
// volatile is faster than acquire but not as correct. Make sure reduceCopy
// loads data using volatile so it doesn't see stale data in L1.
return __atomic_load_n(ptr, __ATOMIC_RELAXED);
}
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
__device__ __forceinline__ void waitPeer(intptr_t srcIx, intptr_t dstIx, int offset, int nelts) {
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
const bool noRecvWait = DirectRecv && Src && (flags & DirectRead); // no wait when directly reading from remote input
const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write
if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) ||
((flags & (Send*RoleWaitSend)) && !noSendWait)) {
int spins = 0;
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
__builtin_amdgcn_s_sleep(1);
connStepCache = loadStepValue(connStepPtr);
if (checkAbort(spins)) break;
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
if (spins == 0) traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
}
__asm__ __volatile__("s_wakeup");
}
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
if (isSendNotRecv && (flags & SizesFifoEnabled))
__atomic_store_n(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T), __ATOMIC_RELAXED);
void **ptrs = isSendNotRecv ? (ncclShmem.groups[group].dsts + Dst)
: (ncclShmem.groups[group].srcs + Src);
if (flags & OffsFifoEnabled)
ptrs[index] = connEltsFifo + loadInt(connOffsFifoPtr + (step%NCCL_STEPS))/sizeof(T);
else if (isSendNotRecv && DirectSend) {
if (flags & (DirectWrite | NvlsDirectWrite)) {
ptrs[index] = directBuff + dstIx + offset;
} else if (flags & DirectRead) { // empty send
ptrs[index] = nullptr;
} else {
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
}
} else if (!isSendNotRecv && DirectRecv) {
if (flags & (DirectRead | NvlsDirectRead)) {
ptrs[index] = directBuff + srcIx + offset;
} else if (flags & DirectWrite) {
ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer
} else {
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
}
}
else {
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
}
if ((flags & (AnyNetDeviceUnpack)) && (flags & (Recv*RoleWaitRecv))) {
ncclNetDeviceIncrementHead(group);
}
step += StepPerSlice;
}
}
template<int Recv, int Send>
inline __device__ void postPeer(bool dataStored) {
if (Send && (flags & RolePostSend) && dataStored)
#ifdef __GFX9__
__threadfence();
#else
__threadfence_system();
#endif
if ((flags & Send*RolePostSend) && next_hdp_reg)
STORE((unsigned int *)next_hdp_reg, 0x1);
if (flags & (Recv*RolePostRecv | Send*RolePostSend)) {
step += StepPerSlice;
STORE(connStepPtr, step);
}
}
template <int DirectRecv1, int DirectSend1, int Recv, int Send, int SrcBuf, int DstBuf>
__device__ __forceinline__ void genericOp(
intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp
) {
constexpr int DirectRecv = /*1 &&*/ Direct && DirectRecv1;
constexpr int DirectSend = /*1 &&*/ Direct && DirectSend1;
constexpr int Src = SrcBuf != -1;
constexpr int Dst = DstBuf != -1;
nelem = nelem < 0 ? 0 : nelem;
int sliceSize = stepSize*StepPerSlice;
sliceSize = max(divUp(nelem, 16*SlicePerChunk)*16, sliceSize/32);
int slice = 0;
int offset = 0;
if (tid < nworkers && offset < nelem) {
// Worker-only loop for non-empty slices. Non-workers and empty slices are
// processed in the loop following this if block. The benefit of splitting
// the loop like this is we pull two branches out of the critical path.
// Using "number of branch insns (taken or not) encountered dynamically"
// as the performance metric, then:
// perf_orig = 2*numslices
// perf_new = 2+numslices
// So the new code and old code behave the same for numslices=2, and for
// numslices>2 the new code is superior. And note that in the case
// numslices=1, the loop is trivially unrollable (single iteration) so we
// don't incur that that tail branch and we still have perf_new=2.
//
// ORIGINAL CODE:
// unrolled for(slices) {
// if(worker) { // This branch removed
// wait();
// subBarrier();
// if(slice not empty) // This branch removed
// ReduceCopyMulti();
// }
// barrier();
// post();
// } // Since we no longer unroll, new branch added here
#pragma unroll 1
do {
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput)))
ncclShmem.groups[group].srcs[0] = userBuff + srcIx + offset;
if (Dst && (flags & (DstBuf==Input ? RoleInput : RoleOutput)))
ncclShmem.groups[group].dsts[0] = userBuff + dstIx + offset;
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(srcIx, dstIx, offset, sliceSize);
subBarrier();
/* if user abort the kernel, we don't need to actually perform copy/reduce; just set size
* to 0 to avoid unnecessary workload. */
int workSize = ncclShmem.aborted ? 0 : sliceSize;
if (flags & AnyNetDeviceUnpack) {
ncclNetDeviceUnpack<Recv>(tid, tidInBlock, nworkers, group, ncclShmem.groups[group].devicePlugin.unpack.unpackNetDeviceIndexMask, Src, workSize);
// Sync here to make sure all workers are reading from the updated srcs)
subBarrier();
}
if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]
/* NVLS can have srcs[0] == dsts[0], but we cannot enter this "if branch",
* so we need to check whether MultimemSrcs and MultimemDsts are 0. */
&& MultimemSrcs == 0 && MultimemDsts == 0) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (Send) {
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP();
}
#endif
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, MaxSend, /*PreOpSrcs*/0>
(tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false,
1, ncclShmem.groups[group].srcs,
fan.nsend(), ncclShmem.groups[group].dsts+1,
workSize);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP();
npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime;
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
}
} else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) {
// For broadcast in CollNet to do empty send
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP();
}
#endif
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, /*PreOpSrcs*/0>
(tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp,
Recv, ncclShmem.groups[group].srcs,
Dst, ncclShmem.groups[group].dsts,
workSize);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP();
npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime;
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
} else {
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
npKitDataProcessEntryTime = NPKIT_GET_GPU_TIMESTAMP();
}
#endif
constexpr int PreOpSrcs = SrcBuf != Input ? 0 :
DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1;
reduceCopy<Unroll, RedOp, T,
MultimemSrcs, Recv+Src, Recv*MaxRecv+Src,
MultimemDsts, Send+Dst, Send*MaxSend+Dst, PreOpSrcs>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp,
Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs,
Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts,
workSize);
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME)
if (tid == 0) {
npKitDataProcessExitTime = NPKIT_GET_GPU_TIMESTAMP();
npKitDataProcessTotalTime += npKitDataProcessExitTime - npKitDataProcessEntryTime;
}
#endif
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
}
barrier(); // This barrier has a counterpart in following loop
postPeer<Recv, Send>(0 < sliceSize);
offset += sliceSize;
slice += 1;
} while (slice < SlicePerChunk && offset < nelem);
}
// Non-workers come straight here. Workers too but only once the remaining
// slices are all empty. Since empty slices are the uncommon case, and
// worker perf is the limiter, perf-wise this loop is effectively unentered,
// hence just a single branch insn.
#pragma unroll 1
while (slice < SlicePerChunk) {
sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset;
{ // Only workers could have Wait roles so we know the slice must be empty
// since we've exited the loop above.
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(0, 0, 0, 0);
}
barrier(); // Has couterpart in preceding worker-only loop.
postPeer<Recv, Send>(0 < sliceSize);
offset += sliceSize;
slice += 1;
}
}
template <int REDUCE, int COPY, int MULTISRCS, int MULTIDSTS>
__device__ __forceinline__ void mscclGenericOp(T** srcs, int nsrcs, T** dsts, int ndsts, int nelem) {
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_ENTRY, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
nelem = nelem < 0 ? 0 : nelem;
if (tid < nworkers) {
if (REDUCE){
srcs[nsrcs] = dsts[0];
nsrcs++;
if (MULTISRCS){
reduceCopy<Unroll, RedOp, T, 0, 3, MSCCL_MAX_REDUCE_FUSION, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem);
} else {
reduceCopy<Unroll, RedOp, T, 0, 2, 2, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem);
}
}
if (COPY){
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)srcs, 1, (void **)dsts, nelem);
if (MULTISRCS) {
for (int i = 1; i < nsrcs; i++){
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, 0>
(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)&srcs[i], 1, (void **)&dsts[i], nelem);
}
}
}
}
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_MSCCL_GENERIC_OP_EXIT, nelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(),
ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx);
}
#endif
barrier();
}
// Scatter/Gather generic op
// skip: my own rank order in the buffer chunks
// shift: peer offset to avoid all ranks sending to or receiving from same peer
template <int DirectRecv1, int DirectSend1, int Recv, int Send>
__device__ __forceinline__ void
ScatterGatherOp(intptr_t inpIx, intptr_t outIx, ssize_t totalElem, int peerElem, ssize_t peerOffset, int skip, int shift, bool postOp) {
constexpr int DirectRecv = /*1 &&*/ Direct && DirectRecv1;
constexpr int DirectSend = /*1 &&*/ Direct && DirectSend1;
int offset = 0; // slice offset
int sliceSize = stepSize*StepPerSlice;
int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size
#pragma unroll 1
for (int slice=0; slice<SlicePerChunk; ++slice) {
ssize_t realSize = max(0, min(dataSize, peerElem-offset));
bool fenceNeeded = false;
if (tid < nworkers) {
if (Send) {
// Scatter pre-scales data of input buffer only in non-Direct case
constexpr int PreOpSrcs = DirectSend ? 0 : 1;
if (flags & RoleInput) ncclShmem.groups[group].srcs[0] = userBuff + inpIx + offset;
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
waitPeer<0, DirectSend, 0, 1, 1, 0>(0, inpIx, offset, realSize);
subBarrier();
#pragma unroll 1
// Loop over peers
for (int j=0; j<fan.nsend(); j++) {
int i = (j+shift)%fan.nsend();
ssize_t pOffset = i*peerOffset;
// Skip the data I am responsible of reducing myself
if (skip >= 0 && i >= skip) pOffset += peerElem;
void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset;
ssize_t realPeerSize = min(realSize, totalElem-pOffset);
if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) {
reduceCopy<Unroll, RedOp, T, 0, 1, 1, 0, 1, 1, PreOpSrcs>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize);
// Mark for threadfence at the end
fenceNeeded |= true;
}
}
} else if (Recv) {
if (flags & RoleOutput) ncclShmem.groups[group].dsts[0] = userBuff + outIx + offset;
ssize_t pOffset = index*peerOffset;
if (skip >= 0 && index >= skip) pOffset += peerElem;
// Adjust remote index with peer offset in case we are directly pulling from peer's output buffer
waitPeer<DirectRecv, 0, 1, 0, 0, 1>(outIx+pOffset, outIx+pOffset, offset, realSize);
subBarrier();
#pragma unroll 1
for (int j=0; j<fan.nrecv(); j++) {
int i = (j+shift)%fan.nrecv();
pOffset = i*peerOffset;
if (skip >= 0 && i >= skip) pOffset += peerElem;
void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset;
ssize_t realPeerSize = min(realSize, totalElem-pOffset);
if (DirectRecv && ncclShmem.groups[group].srcs[i] == dst0) realPeerSize = 0;
if (realPeerSize > 0) reduceCopy<Unroll, RedOp, T, 0,1,1, 0,1,1, /*PreOpSrcs=*/0>(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize);
}
}
}
fenceNeeded = __any(fenceNeeded);
postPeer<Recv, Send>(fenceNeeded);
offset += realSize;
}
}
__device__ __forceinline__ void loadRecvConn(ncclDevChannelPeer *peer, int connIndex, struct ncclWorkElem* e) {
if (flags & (RoleWaitRecv|RolePostRecv)) {
auto *conn = &peer->recv[connIndex];
if (conn->netDeviceHandle.netDeviceType == NCCL_NET_DEVICE_UNPACK) {
// handle must be a device ptr
netDeviceHandle = conn->netDeviceHandle.handle;
// Cache the handle
ncclNetDeviceUnpackSetup(netDeviceHandle, group, index);
flags |= NetDeviceUnpack;
}
step = conn->step;
step = roundUp(step, SlicePerChunk*StepPerSlice);
if (flags & RolePostRecv) {
connStepPtr = conn->head;
STORE(connStepPtr, step); // Return credits in case we rounded up.
}
if (flags & RoleWaitRecv) {
ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs()
flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0;
connStepPtr = conn->tail;
connStepCache = loadStepValue(connStepPtr);
flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0;
if (Direct) {
// User buffers have been registered
if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) {
if (connIndex == 1 && P2p == 0) {
flags |= DirectRead; // scatter-reduce use direct pull
} else {
flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite :
(e->direct & NCCL_DIRECT_READ) ? DirectRead : 0;
}
} else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) {
if (connIndex == 1 && P2p == 0) {
flags |= DirectRead; // scatter-reduce use direct pull
} else {
// direct read not allowed in non-register case
// otherwise, in one-to-multi send, we could mix empty send and intermediate send
flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0;
}
} else if ((conn->flags & NCCL_NVLS_MIN_POLL) && e != nullptr && e->regUsed) {
/* NVLS direct */
flags |= NvlsDirectRead;
}
}
if (flags & OffsFifoEnabled)
connOffsFifoPtr = conn->offsFifo;
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
}
}
}
__device__ __forceinline__ void loadSendConn(ncclDevChannelPeer *peer, int connIndex, struct ncclWorkElem* e) {
if (flags & (RoleWaitSend|RolePostSend)) {
auto *conn = &peer->send[connIndex];
step = conn->step;
step = roundUp(step, SlicePerChunk*StepPerSlice);
if (flags & RolePostSend) {
connStepPtr = conn->tail;
next_hdp_reg = conn->next_hdp_reg;
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
}
if (flags & RoleWaitSend) {
ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs()
flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0;
connStepPtr = conn->head;
connStepCache = loadStepValue(connStepPtr);
flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0;
if (flags & OffsFifoEnabled)
connOffsFifoPtr = conn->offsFifo;
connEltsFifo = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
if (conn->sizesFifo != nullptr) {
flags |= SizesFifoEnabled;
connSizesFifoPtr = conn->sizesFifo;
} else if (Direct) {
// User buffers have been registered
if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) {
if (connIndex == 1 && P2p == 0) {
flags |= DirectRead; // scatter-reduce use direct pull
} else {
flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite :
(e->direct & NCCL_DIRECT_READ) ? DirectRead : 0;
}
} else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) {
if (connIndex == 1 && P2p == 0) {
flags |= DirectRead; // scatter-reduce use direct pull
} else {
// direct read not allowed in non-register case
// otherwise, in one-to-multi send, we could mix empty send and intermediate send
flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0;
}
} else if ((conn->flags & NCCL_NVLS_MIN_POLL) && e != nullptr && e->regUsed) {
/* NVLS direct */
flags |= NvlsDirectWrite;
}
}
}
}
}
public:
__forceinline__ __device__ Primitives(
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group=0,
uint8_t connIndexRecv = 0, uint8_t connIndexSend = 0, struct ncclWorkElem* e = nullptr, int stepSize_=0
):
tid(tid), nthreads(nthreads), tidInBlock(threadIdx.x), group(group),
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_) {
// For send operations, we need an extra warp to overlap the threadfence and the copy
barriers = &ncclShmem.groups[group].barrier;
barrier_next = ncclShmem.groups[group].barrier_next;
this->nworkers = nthreads;
int nrecv=0, nsend=0;
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
while (nsend < MaxSend && sendPeers[nsend] != -1) nsend++;
this->fan = Fan(nrecv, nsend);
constexpr int ThreadPerSync = 8;
static_assert(MaxSend <= ThreadPerSync && MaxRecv <= ThreadPerSync, "Not enough threads to cover all peers");
int g = tid / ThreadPerSync;
int ng = nthreads / ThreadPerSync;
index = tid % ThreadPerSync;
flags = 0;
if (g == 0) {
if (index < nrecv) flags |= RoleWaitRecv;
if (index == nrecv) flags |= RoleInput;
} else if (g == 1) {
if (index < nsend) flags |= RoleWaitSend;
if (index == nsend) flags |= RoleOutput;
} else if (g == ng - 2) {
if (index < nrecv) flags |= RolePostRecv;
} else if (g == ng - 1) {
if (index < nsend) flags |= RolePostSend;
}
int peer = 0;
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];
loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, e);
loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, e);
// if (barrierAny(flags & NetDeviceUnpack)) {
// flags |= AnyNetDeviceUnpack;
// // g == 0 is the first ThreadPerSync # of threads of this warp
// // g == 0 is also the RoleWaitRecv threads of this group, thus the thread ID will correlate to the peer index
// if (g == 0) {
// uint32_t mask = __ballot_sync((1U << ThreadPerSync) - 1, (flags & NetDeviceUnpack) ? 1 : 0);
// // We only want to update the shared memory variable with a single thread
// if (tid == 0) {
// ncclShmem.groups[this->group].devicePlugin.unpack.unpackNetDeviceIndexMask = mask;
// }
// }
// }
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclWorkElemReg*)e);
}
__forceinline__ __device__ ~Primitives() {
// Ensure ncclShmem.groups[].send/recvConns are available
if (!(flags & ThreadsSynced))
barrier();
// Save steps for the next operation
if (flags & (RolePostSend|RolePostRecv)) {
auto *conns = (flags & RolePostSend) ? ncclShmem.groups[group].sendConns : ncclShmem.groups[group].recvConns;
conns[index]->step = step;
}
if ((flags & (AnyNetDeviceUnpack)) && (flags & (RoleWaitRecv))) {
ncclNetDeviceSaveHead(netDeviceHandle, group);
}
barrier();
}
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclWorkElemReg* e) {
if (flags & RoleInput) {
userBuff = (T*)inputBuf;
ncclShmem.redOpArgs[0] = redOpArg; // scaler for local input
}
if (flags & RoleOutput) userBuff = (T*)outputBuf;
bool recvProvider = flags == (flags|RoleWaitRecv|DirectWrite);
bool sendAcceptor = (flags == (flags|RoleWaitSend|DirectWrite)) || (flags == (flags|RoleWaitSend|NvlsDirectWrite));
bool sendProvider = flags == (flags|RoleWaitSend|DirectRead); // sender provides direct buffer (to be fetched)
bool recvAcceptor = flags == (flags|RoleWaitRecv|DirectRead) || (flags == (flags|RoleWaitRecv|NvlsDirectRead)); // receiver accepts direct buffer
int regUsed = e != nullptr ? e->elem.regUsed : 0;
if (Direct && recvProvider) {
int spins = 0;
void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange;
// Wait for consumer to consume previous value before trampling it.
if (slot) {
while ((void *)atomicAdd((unsigned long long *) slot,0) != nullptr && !checkAbort(spins));
directBuff = (T*)outputBuf;
// Encode pointer by XOR'ing against some address they definitely wouldn't send
// since we want to allow them sending us nullptr while not colliding with
// the empty slot value.
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(directBuff) ^ reinterpret_cast<uintptr_t>(slot));
}
}
if (Direct && sendAcceptor) {
int spins = 0;
void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange;
void *ptr;
while (slot) {
ptr = (void *)atomicAdd((unsigned long long *) slot,0);
if (ptr != nullptr || checkAbort(spins)) break;
}
if (slot) {
directBuff = regUsed ? (T*)(e->dnOutputs[index]) :
reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
*slot = nullptr;
} else {
/* slot is NULL, it must be regUsed == 1 */
directBuff = (T*)e->dnOutputs[index];
}
}
if (Direct && sendProvider) {
int spins = 0;
void *volatile *slot = ncclShmem.groups[group].sendConns[index]->ptrExchange;
volatile uint64_t* argSlot0 = ncclShmem.groups[group].sendConns[index]->redOpArgExchange;
volatile uint64_t* argSlot1 = ncclShmem.groups[group].sendConns[index]->redOpArgExchange+1;
// Wait for consumer to consume previous value before trampling it.
if (slot && argSlot0 && argSlot1) {
while (((void *)atomicAdd((unsigned long long *) slot,0) != nullptr || *argSlot0 != 0 || *argSlot1 !=0) && !checkAbort(spins));
// If there is no recv, then we are directly pulling from input buffer (e.g. directScatter)
// Otherwise, we are pulling from output buffer (e.g. recvCopyDirectSend)
directBuff = MaxRecv == 0 ? (T*)inputBuf : (T*)outputBuf;
// Exchange pre-scalers for use in direct pull
*argSlot0 = (uint64_t(1)<<32) | (uint32_t)redOpArg;
*argSlot1 = (uint64_t(1)<<32) | (uint32_t)(redOpArg>>32);
// Encode pointer by XOR'ing against some address they definitely wouldn't send
// since we want to allow them sending us nullptr while not colliding with
// the empty slot value.
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(directBuff) ^ reinterpret_cast<uintptr_t>(slot));
}
}
if (Direct && recvAcceptor) {
int spins = 0;
void *volatile *slot = ncclShmem.groups[group].recvConns[index]->ptrExchange;
volatile uint64_t* argSlot0 = ncclShmem.groups[group].recvConns[index]->redOpArgExchange;
volatile uint64_t* argSlot1 = ncclShmem.groups[group].recvConns[index]->redOpArgExchange+1;
void *ptr;
while (slot) {
ptr = (void *)atomicAdd((unsigned long long *) slot,0);
if (ptr != nullptr || checkAbort(spins)) break;
}
if (slot && argSlot0 && argSlot1) {
directBuff = regUsed ? (T*)(MaxSend == 0 ? e->upOutputs[index] : e->dnInputs[index]) :
reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
if (MaxSend != 0) { // reduce group rather than gather group
// Store scalers for remote inputs
uint64_t arg0, arg1;
while (true) {
arg0 = *argSlot0;
arg1 = *argSlot1;
if ((arg0 != 0 && arg1 != 0) || checkAbort(spins)) break;
}
ncclShmem.redOpArgs[1 + index] = ((arg1 & 0xffffffff) << 32) | (arg0 & 0xffffffff);
}
*argSlot0 = 0; *argSlot1 = 0;
*slot = nullptr;
} else {
directBuff = (T*)e->dnInputs[index];
}
}
}
__device__ void moveDataPtrs(intptr_t delta) {
if (flags & (RoleInput|RoleOutput))
userBuff += delta;
}
// Set MSCCL data pointers
__device__ __forceinline__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
if (flags & RoleInput) userBuff = (T*)inputBuf;
if (flags & RoleOutput) userBuff = (T*)outputBuf;
}
__device__ __forceinline__ void send(intptr_t inpIx, int eltN) {
genericOp<0, 0, 0, 1, Input, -1>(inpIx, -1, eltN, false);
}
__device__ __forceinline__ void sendFromOutput(intptr_t outIx, int eltN) {
genericOp<0, 0, 0, 1, Output, -1>(outIx, -1, eltN, false);
}
__device__ __forceinline__ void directSend(intptr_t inpIx, intptr_t outIx, int eltN) {
genericOp<0, 1, 0, 1, Input, -1>(inpIx, outIx, eltN, false);
}
__device__ __forceinline__ void directSendFromOutput(intptr_t outIx, int eltN) {
genericOp<0, 1, 0, 1, Output, -1>(outIx, outIx, eltN, false);
}
__device__ __forceinline__ void recv(intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 0, -1, Output>(-1, outIx, eltN, postOp);
}
__device__ __forceinline__ void directRecv(intptr_t outIx, int eltN) {
genericOp<1, 0, 1, 0, -1, Output>(-1, outIx, eltN, /*postOp=*/false);
}
__device__ __forceinline__ void directRecvCopy(intptr_t inpIx, intptr_t outIx, int eltN) {
genericOp<1, 0, 1, 0, -1, Output>(inpIx, outIx, eltN, /*postOp=*/false);
}
__device__ __forceinline__ void copySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ __forceinline__ void directCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ __forceinline__ void recvSend(int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, -1, -1>(-1, -1, eltN, postOp);
}
__device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, eltN, postOp);
}
__device__ __forceinline__ void directRecvCopySend(intptr_t outIx, int eltN) {
genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, eltN, false);
}
__device__ __forceinline__ void directRecvDirectSend(intptr_t inpIx, intptr_t outIx, int eltN) {
genericOp<1, 1, 1, 1, -1, -1>(inpIx, outIx, eltN, false);
}
__device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, eltN, postOp);
}
__device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp);
}
__device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, eltN, postOp);
}
__device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ __forceinline__ void directRecvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
// Direct is only for the send part
genericOp<0, 1, 1, 1, Input, Output>(inpIx, outIx, eltN, postOp);
}
__device__ __forceinline__ void
scatter(intptr_t inpIx, ssize_t totalElem, int peerElem, ssize_t peerOffset, int skip, int shift) {
ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false);
}
__device__ __forceinline__ void
directScatter(intptr_t inpIx, ssize_t totalElem, int peerElem, ssize_t peerOffset, int skip, int shift) {
ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false);
}
__device__ __forceinline__ void
gather(intptr_t outIx, ssize_t totalElem, int peerElem, ssize_t peerOffset, int skip, int shift, bool postOp=false) {
ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, postOp);
}
__device__ __forceinline__ void
directGather(intptr_t outIx, ssize_t totalElem, int peerElem, ssize_t peerOffset, int skip, int shift) {
ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false);
}
// MSCCL primitives
__device__ __forceinline__ void sendWithBarrier(intptr_t inpIx, int eltN) {
send(inpIx, eltN);
}
__device__ __forceinline__ void localCopy(T* srcs, T* dsts, int eltN) {
return mscclGenericOp<0,1,0,0>(&srcs, 1, &dsts, 1, eltN);
}
};