428837ffe4
Co-authored-by: mberenjk <mberenjk@amd.com>
435 строки
19 KiB
C++
435 строки
19 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
|
* Modifications Copyright (c) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
|
|
* Modifications Copyright (c) Microsoft Corporation. Licensed under the MIT License.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
#ifndef MSSCLKERNELIMPL_H
|
|
#define MSSCLKERNELIMPL_H
|
|
|
|
#include "device.h"
|
|
#include "primitives.h"
|
|
#include "collectives.h"
|
|
|
|
#include "msccl/msccl_struct.h"
|
|
#include "msccl/msccl_kernel.h"
|
|
|
|
extern __shared__ struct mscclShmemData mscclShmem;
|
|
|
|
#define MSCCL_MAX_ITER 65536
|
|
|
|
// flags are a 3-tuple of (workindex, gridoffset_iter, step) and it follows a lexicographical order. a threadblock is ahead of another iff its flag is ahead
|
|
#define COMPUTE_FLAG(__WORKINDEX__,__GRIDOFFSET_ITER__,__STEP__) \
|
|
MSCCL_MAX_ITER*MSCCL_MAX_NUM_STEPS*(uint64_t)__WORKINDEX__ + ((uint64_t)__GRIDOFFSET_ITER__ * MSCCL_MAX_NUM_STEPS + (uint64_t)__STEP__)
|
|
|
|
#define GET_WORKINDEX_FROM_FLAG(__FLAG__) \
|
|
(__FLAG__) / (MSCCL_MAX_ITER*MSCCL_MAX_NUM_STEPS)
|
|
|
|
#ifdef ENABLE_COLLTRACE
|
|
#define INC_COLL_TRACE \
|
|
uint32_t pos = atomicAdd(&ncclShmem.collTraceTail->tail, 1)%COLLTRACE_NUM_ITEMS; \
|
|
struct ncclCollTrace* collTrace = ncclShmem.collTrace+pos; \
|
|
collTrace->timeStamp = wall_clock64(); \
|
|
collTrace->bid = blockIdx.x;
|
|
// TODO: switch to atomicInc after llvm crash is fixed
|
|
// uint32_t pos = atomicInc(&ncclShmem.collTraceTail->tail, COLLTRACE_NUM_ITEMS)
|
|
|
|
#define traceData(data2, data4, data8_0, data8_1) { \
|
|
INC_COLL_TRACE \
|
|
collTrace->funcIndex = data2; \
|
|
collTrace->data_0 = data4; \
|
|
collTrace->opCount = data8_0; \
|
|
collTrace->data_1 = data8_1; \
|
|
collTrace->type = ncclCollTraceDataType; \
|
|
}
|
|
#else
|
|
#define traceData(data2, data4, data8_0, data8_1)
|
|
#endif
|
|
|
|
inline __device__ static void barrier(int nthreads) {
|
|
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
|
assert(nthreads == NCCL_MAX_NTHREADS);
|
|
__asm__ __volatile__("s_waitcnt vmcnt(0) lgkmcnt(0)\ns_barrier");
|
|
#else
|
|
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(15));
|
|
#endif
|
|
}
|
|
|
|
// Copy 8-byte aligned data. You must call with at least `(bytes+7)/8` threads.
|
|
inline __device__ static void copyToShmem8(int tid, void* dst, void const* src, int bytes) {
|
|
int offset = sizeof(uint32_t) * tid;
|
|
if (offset < bytes) {
|
|
uint32_t *src2 = (uint32_t*)((char const*)src + offset);
|
|
uint32_t *dst2 = (uint32_t*)((char*)dst + offset);
|
|
*dst2 = *src2;
|
|
offset += WARP_SIZE*sizeof(uint32_t);
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ static void threadBlockCopy(
|
|
uint32_t *dst, uint32_t const *src, uint64_t size, int tid, int nthreads) {
|
|
for (int i = tid; i < size; i += nthreads) {
|
|
dst[i] = src[i];
|
|
}
|
|
}
|
|
|
|
#define MSCCL_REDUCE_UNROLL_LOOP_A(numloops, BytePerPack) \
|
|
for (int r = 0; r < numloops; r++) { \
|
|
srcOffset = srcBaseOffset + (ssize_t)mscclShmem.mscclTB.reductionSrcOffsets[t->reductionPointer+r] * sizePerMscclChunk; \
|
|
reduceInput = ld_volatile_global<BytePerPack>((uintptr_t)(srcPointer + srcOffset)); \
|
|
o = applyReduce(redFn, reduceInput, o); \
|
|
}
|
|
|
|
template<typename T, typename RedOp, int BytePerPack>
|
|
__device__ __forceinline__ static void mscclReduce(int c, int numReductions, int currIdx, ssize_t sizePerMscclChunk, RedOp redFn,
|
|
struct mscclTransmission* t, ssize_t gridOffset, ssize_t &srcOffset, ssize_t dstOffset, T *srcPointer, T *dstPointer) {
|
|
const int elemsPerPack = BytePerPack/sizeof(T);
|
|
T* dstIndex = dstPointer + dstOffset + currIdx*elemsPerPack;
|
|
BytePack<BytePerPack> reduceInput;
|
|
BytePack<BytePerPack> o = ld_volatile_global<BytePerPack>((uintptr_t)dstIndex);
|
|
ssize_t srcBaseOffset = gridOffset + (ssize_t)c * sizePerMscclChunk + currIdx*elemsPerPack;
|
|
switch (numReductions) {
|
|
case 7:
|
|
#pragma unroll
|
|
MSCCL_REDUCE_UNROLL_LOOP_A(7, BytePerPack);
|
|
break;
|
|
#if defined(__gfx90a__)
|
|
case 15:
|
|
#pragma unroll
|
|
MSCCL_REDUCE_UNROLL_LOOP_A(15, BytePerPack);
|
|
break;
|
|
#endif
|
|
default:
|
|
MSCCL_REDUCE_UNROLL_LOOP_A(numReductions, BytePerPack);
|
|
break;
|
|
}
|
|
st_global<BytePerPack>((uintptr_t)dstIndex, o);
|
|
}
|
|
|
|
|
|
template<typename T, typename RedOp, typename Proto, bool fullOps>
|
|
__device__ __forceinline__ void mscclRunInterpreter(
|
|
struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) {
|
|
const int tid = threadIdx.x;
|
|
const int bid = blockIdx.x;
|
|
const int nthreads = NCCL_MAX_NTHREADS;
|
|
|
|
#if defined(ENABLE_NPKIT)
|
|
uint64_t timestamp_entry = 0;
|
|
if (tid == 0) {
|
|
timestamp_entry = NPKIT_GET_GPU_TIMESTAMP();
|
|
}
|
|
#endif
|
|
// initialize mscclShmem.mscclTB
|
|
threadBlockCopy(
|
|
(uint32_t *)&mscclShmem.mscclTB, (uint32_t *)(algo->mscclTBs + bid),
|
|
sizeof(struct mscclThreadBlock) / sizeof(uint32_t), tid, nthreads);
|
|
__synclds(); // publish mscclShmem.mscclTB.channelId
|
|
|
|
// initialize ncclShmem and mscclShmem.work
|
|
int channelId = mscclShmem.mscclTB.channelId;
|
|
{
|
|
void *dst, *src;
|
|
int bytes = 0;
|
|
// Use first 3 warps to load comm, channel, and work into shmem
|
|
switch (tid/WARP_SIZE) {
|
|
case 0:
|
|
dst = &ncclShmem.comm;
|
|
src = comm;
|
|
bytes = sizeof(ncclDevComm);
|
|
break;
|
|
case 1:
|
|
// Get address of channel without incurring indirect load from ncclDevComm::channels
|
|
dst = &ncclShmem.channel;
|
|
src = &((ncclDevCommAndChannels*)comm)->channels[channelId];
|
|
bytes = sizeof(ncclDevChannel);
|
|
break;
|
|
case 2:
|
|
dst = &mscclShmem.work;
|
|
src = work + blockIdx.x;
|
|
bytes = sizeof(mscclWork);
|
|
break;
|
|
case 3:
|
|
/* set abort flag to 0 */
|
|
if (tid%WARP_SIZE == 0) ncclShmem.aborted = 0;
|
|
#ifdef ENABLE_COLLTRACE
|
|
else if (tid%WARP_SIZE == 1) ncclShmem.collTrace = comm->collTrace + COLLTRACE_NUM_ITEMS*channelId;
|
|
else if (tid%WARP_SIZE == 2) ncclShmem.collTraceTail = comm->collTraceTail + channelId;
|
|
#endif
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
copyToShmem8(tid%WARP_SIZE, dst, src, bytes);
|
|
}
|
|
|
|
#if defined(ENABLE_NPKIT)
|
|
int npKitCtxIdx = bid;
|
|
int xcc_id = 0;
|
|
if (tid == 0) {
|
|
ncclShmem.event_buffer_head = 0;
|
|
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
|
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_XCC_ID)" : "=s" (xcc_id));
|
|
#endif
|
|
}
|
|
#endif
|
|
__synclds(); // publish shmem
|
|
|
|
if (fullOps && tid == 0) {
|
|
traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0);
|
|
}
|
|
|
|
if (tid == 0)
|
|
*mscclShmem.work.workFifoDone = mscclShmem.work.workFifoDoneAck;
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU)
|
|
if (tid == 0) {
|
|
uint64_t* cpuTimestamp = ncclShmem.comm.cpuTimestamp;
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_CPU, 0, xcc_id, *cpuTimestamp);
|
|
}
|
|
#endif
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_TIME_SYNC_GPU, 0, xcc_id, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
|
|
// User pointers for primitives
|
|
T* thisInput = (T*)mscclShmem.work.sendBuff;
|
|
T* thisOutput = (T*)mscclShmem.work.recvBuff;
|
|
T* thisScratch = (T*)mscclShmem.work.scratchBuffer;
|
|
int recvPeer = mscclShmem.mscclTB.recvPeer;
|
|
int sendPeer = mscclShmem.mscclTB.sendPeer;
|
|
|
|
const ssize_t chunkSize = int(Proto::calcBytePerStep()/sizeof(T) * (Proto::Id == NCCL_PROTO_SIMPLE ? MSCCL_CHUNKSTEPS : 1));
|
|
int minChunkSize;
|
|
if (Proto::Id == NCCL_PROTO_LL)
|
|
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T));
|
|
if (Proto::Id == NCCL_PROTO_LL128) {
|
|
// We should not need the final /2 but it makes performance much, much smoother. Might be a bug somewhere.
|
|
minChunkSize = nthreads*(Proto::calcBytePerGrain()/sizeof(T))/2;
|
|
}
|
|
|
|
RedOp redFn(mscclShmem.work.redOpArg);
|
|
Primitives<T, RedOp, FanAsymmetric<1,1>, 1, Proto, 0> prims
|
|
(tid, nthreads, &recvPeer, &sendPeer, thisInput, thisOutput, mscclShmem.work.redOpArg);
|
|
|
|
#if defined(ENABLE_NPKIT)
|
|
if (tid == 0) {
|
|
prims.npKitCtxIdx = npKitCtxIdx;
|
|
}
|
|
#endif
|
|
|
|
const ssize_t sizePerMscclChunk = mscclShmem.work.sizePerMscclChunk;
|
|
uint32_t maxAllowedCount = mscclShmem.work.maxAllowedCount;
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RUN_ENTRY)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_ENTRY, mscclShmem.work.sizePerMscclChunk*mscclShmem.work.nChunksPerLoop, xcc_id, timestamp_entry);
|
|
}
|
|
#endif
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_INIT_ENTRY)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_INIT_ENTRY, 0, xcc_id, timestamp_entry);
|
|
}
|
|
#endif
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_INIT_EXIT)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_INIT_EXIT, 0, xcc_id, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
|
|
// msccl flags all start out with 0. this is used as a part of the flag to make sure different work items deal with different synchronization flags
|
|
// this still needs more work. when we make a way around the queue, the flag might have been set to undesired values. will be fixed in subsequent versions.
|
|
const int64_t workIndex = mscclShmem.work.workIndex;
|
|
volatile struct mscclFlag* mscclFlags = mscclShmem.work.syncFlags;
|
|
for (ssize_t gridOffset = 0, iter = 0; gridOffset < sizePerMscclChunk; gridOffset += chunkSize, iter++) {
|
|
ssize_t realChunkSize;
|
|
if (Proto::Id == NCCL_PROTO_SIMPLE) {
|
|
realChunkSize = min(chunkSize, sizePerMscclChunk-gridOffset);
|
|
realChunkSize = roundUp(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
|
|
}
|
|
else
|
|
realChunkSize = min(chunkSize, divUp(sizePerMscclChunk-gridOffset, minChunkSize)*minChunkSize);
|
|
realChunkSize = int(realChunkSize);
|
|
int nelem = min(realChunkSize, sizePerMscclChunk-gridOffset);
|
|
|
|
ssize_t srcOffset, dstOffset;
|
|
T *srcPointer, *dstPointer;
|
|
int step = 0;
|
|
for (int i = 0; i < mscclShmem.mscclTB.nSteps; i++){
|
|
struct mscclTransmission* t = &mscclShmem.mscclTB.transmissions[i];
|
|
// first wait if there is a dependence
|
|
int16_t numDependencies = t->numDependencies;
|
|
if (numDependencies > 0){
|
|
if (tid < numDependencies) {
|
|
int16_t dependentPointer = t->dependencePointer;
|
|
int8_t dependentBid = mscclShmem.mscclTB.dependentBid[dependentPointer+tid];
|
|
int16_t dependentStep = mscclShmem.mscclTB.dependentStep[dependentPointer+tid];
|
|
uint64_t goalFlag = COMPUTE_FLAG(workIndex, iter, dependentStep);
|
|
while (true){
|
|
uint64_t curFlag = __atomic_load_n(&(mscclFlags + dependentBid)->flag, __ATOMIC_RELAXED);
|
|
if (curFlag >= goalFlag && GET_WORKINDEX_FROM_FLAG(curFlag) == workIndex) break;
|
|
}
|
|
}
|
|
step += numDependencies-1;
|
|
barrier(nthreads);
|
|
}
|
|
|
|
srcPointer = (t->srcBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->srcBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
|
|
dstPointer = (t->dstBuffer == MSCCL_INPUT_BUFFER) ? thisInput : ((t->dstBuffer == MSCCL_OUTPUT_BUFFER) ? thisOutput : thisScratch);
|
|
prims.setDataPtrs(srcPointer, dstPointer);
|
|
|
|
int count = t->count;
|
|
for (int c = 0; c < count; c += maxAllowedCount) {
|
|
srcOffset = gridOffset + (ssize_t) (t->srcOffset+c) * sizePerMscclChunk;
|
|
dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk;
|
|
int thisCount = min(maxAllowedCount, count - c);
|
|
int thisNelem = nelem * thisCount;
|
|
if (t->type == MSCCL_SEND) {
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_ENTRY)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
|
|
#endif
|
|
prims.send(srcOffset, thisNelem); // LL.send is the only situation where there is no barrier at the end.
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_SEND_EXIT)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_SEND_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
}
|
|
else if (t->type == MSCCL_RECV) {
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_ENTRY)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
prims.recv(dstOffset, thisNelem);
|
|
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_EXIT)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
}
|
|
else if (t->type == MSCCL_REDUCE) {
|
|
int numReductions = t->numReductions;
|
|
int currIdx = tid;
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_ENTRY)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
dstOffset = gridOffset + (ssize_t) (t->dstOffset+c) * sizePerMscclChunk;
|
|
// process 16-byte packed elements
|
|
const int elemsPerPack = 16/sizeof(T);
|
|
while (currIdx < thisNelem/elemsPerPack) {
|
|
mscclReduce<T, RedOp, 16>(c, numReductions, currIdx, sizePerMscclChunk, redFn, t, gridOffset, srcOffset, dstOffset, srcPointer, dstPointer);
|
|
currIdx += nthreads;
|
|
}
|
|
// process remaining elements
|
|
currIdx = tid + (thisNelem/elemsPerPack)*elemsPerPack;
|
|
if (currIdx < thisNelem) {
|
|
mscclReduce<T, RedOp, sizeof(T)>(c, numReductions, currIdx, sizePerMscclChunk, redFn, t, gridOffset, srcOffset, dstOffset, srcPointer, dstPointer);
|
|
}
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_REDUCE_EXIT)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_REDUCE_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
barrier(nthreads);
|
|
if (c == 0) step += (numReductions-1); // only advance step once!
|
|
}
|
|
else if (fullOps && t->type == MSCCL_RECV_COPY_SEND)
|
|
prims.recvCopySend(dstOffset, thisNelem);
|
|
else if (fullOps && t->type == MSCCL_RECV_REDUCE_SEND)
|
|
prims.recvReduceSend(srcOffset, thisNelem);
|
|
else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY_SEND)
|
|
prims.recvReduceCopySend(srcOffset, dstOffset, thisNelem);
|
|
else if (fullOps && t->type == MSCCL_RECV_REDUCE_COPY) {
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_ENTRY, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
prims.recvReduceCopy(srcOffset, dstOffset, thisNelem);
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_EXIT)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RECV_REDUCE_COPY_EXIT, thisNelem*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
}
|
|
else if (t->type == MSCCL_LOCAL_COPY)
|
|
prims.localCopy(srcPointer+srcOffset, dstPointer+dstOffset, thisNelem);
|
|
else
|
|
return;
|
|
}
|
|
if (t->hasDependence && tid == nthreads-1)
|
|
__atomic_store_n(&mscclFlags[bid].flag, (uint64_t) COMPUTE_FLAG(workIndex, iter, step), ((t->type == MSCCL_REDUCE || t->type == MSCCL_RECV) && (t->dstBuffer != MSCCL_SCRATCH_BUFFER)) ? __ATOMIC_RELEASE : __ATOMIC_RELAXED);
|
|
step++;
|
|
}
|
|
}
|
|
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_MSCCL_RUN_EXIT)
|
|
if (tid == 0) {
|
|
NpKit::CollectGpuEventLDS(NPKIT_EVENT_MSCCL_RUN_EXIT, mscclShmem.work.sizePerMscclChunk*mscclShmem.work.nChunksPerLoop, xcc_id, NPKIT_GET_GPU_TIMESTAMP());
|
|
}
|
|
#endif
|
|
#if defined(ENABLE_NPKIT)
|
|
__synclds();
|
|
NpKitEventCollectContext* ctx = ncclShmem.comm.npKitEventCollectContexts + npKitCtxIdx;
|
|
copyToShmem16(tid, ctx->event_buffer+ctx->event_buffer_head, ncclShmem.event_buffer, sizeof(NpKitEvent)*ncclShmem.event_buffer_head);
|
|
if (tid == 0) ctx->event_buffer_head += ncclShmem.event_buffer_head;
|
|
#endif
|
|
|
|
if (fullOps && tid == 0) {
|
|
traceData(__LINE__, mscclShmem.work.fnIndex, (uint64_t)mscclShmem.work.sendBuff, 0);
|
|
}
|
|
}
|
|
|
|
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, type, fullOps) \
|
|
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
|
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL, fullOps>(comm, algo, work); \
|
|
} \
|
|
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, LL128, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
|
mscclRunInterpreter<type, Func##devredop<type>, ProtoLL128, fullOps>(comm, algo, work); \
|
|
} \
|
|
__global__ void MSCCL_KERNEL_ENTRY_NAME(devredop, type, Simple, fullOps)(struct ncclDevComm* comm, struct mscclAlgo* algo, struct mscclWork* work) { \
|
|
mscclRunInterpreter<type, Func##devredop<type>, ProtoSimple<MSCCL_CHUNKSTEPS/MSCCL_SLICESTEPS, MSCCL_SLICESTEPS>, fullOps>(comm, algo, work); \
|
|
}
|
|
|
|
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(devredop, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, half, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, float, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, double, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, hip_bfloat16, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_float8, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, rccl_bfloat8, fullOps)
|
|
|
|
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_NOFLOAT(devredop, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int8_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint8_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int32_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint32_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, int64_t, fullOps) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP_TYPE(devredop, uint64_t, fullOps)
|
|
|
|
#define MSCCL_IMPL_KERNEL_ENTRY_FUNC() \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Sum, false) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(Prod, false) \
|
|
MSCCL_IMPL_KERNEL_ENTRY_FUNC_DEVREDOP(MinMax, false)
|
|
|
|
#endif
|