Merge remote-tracking branch 'nccl/master' into develop
[ROCm/rccl commit: 3a919c1f49]
Этот коммит содержится в:
@@ -133,6 +133,7 @@ else()
|
||||
src/collectives/device/broadcast.cu
|
||||
src/collectives/device/reduce_scatter.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/onerank_reduce.cu
|
||||
src/collectives/device/functions.cu)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
##### version
|
||||
NCCL_MAJOR := 2
|
||||
NCCL_MINOR := 10
|
||||
NCCL_PATCH := 3
|
||||
NCCL_MINOR := 11
|
||||
NCCL_PATCH := 4
|
||||
NCCL_SUFFIX :=
|
||||
PKG_REVISION := 1
|
||||
|
||||
@@ -232,7 +232,7 @@ struct unexConn {
|
||||
struct remAllocState {
|
||||
int cudaDev;
|
||||
int listenFd;
|
||||
int stop;
|
||||
volatile int stop;
|
||||
};
|
||||
|
||||
struct extState {
|
||||
@@ -287,7 +287,7 @@ void* ncclRemoteMemAllocationService(void* args) {
|
||||
for (int s=0; s<MAX_SEGMENTS; s++) segments[s] = NULL;
|
||||
for (int s=0; s<MAX_SEGMENTS; s++) {
|
||||
pollfds[s].fd = -1;
|
||||
pollfds[s].events = POLLHUP;
|
||||
pollfds[s].events = POLLIN;
|
||||
}
|
||||
pollfds[MAX_SEGMENTS].fd = state->listenFd;
|
||||
pollfds[MAX_SEGMENTS].events = POLLIN;
|
||||
@@ -315,7 +315,7 @@ void* ncclRemoteMemAllocationService(void* args) {
|
||||
}
|
||||
}
|
||||
for (int s=0; s<MAX_SEGMENTS; s++) {
|
||||
if (pollfds[s].revents & POLLHUP) {
|
||||
if (pollfds[s].revents & (POLLIN|POLLHUP)) {
|
||||
if (hipFree(segments[s]) != hipSuccess) {
|
||||
WARN("[Rem Allocator] hipFree %p failed", segments[s]);
|
||||
}
|
||||
@@ -462,7 +462,7 @@ ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int s
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks) {
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag) {
|
||||
if (nranks == 1) return ncclSuccess;
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d tag %x - ENTER", rank, nranks, tag);
|
||||
|
||||
@@ -483,6 +483,22 @@ ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, in
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, int nranks, void* allData, int size) {
|
||||
if (nranks == 1) return ncclSuccess;
|
||||
char* data = (char*)allData;
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d size %d - ENTER", rank, nranks, size);
|
||||
|
||||
for (int i=1; i<nranks; i++) {
|
||||
int src = (rank - i + nranks) % nranks;
|
||||
int dst = (rank + i) % nranks;
|
||||
NCCLCHECK(bootstrapSend(commState, ranks[dst], /*tag=*/i, data+rank*size, size));
|
||||
NCCLCHECK(bootstrapRecv(commState, ranks[src], /*tag=*/i, data+src*size, size));
|
||||
}
|
||||
|
||||
TRACE(NCCL_INIT, "rank %d nranks %d size %d - DONE", rank, nranks, size);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t unexpectedEnqueue(struct extState* state, int peer, int tag, int fd, union socketAddress *addr) {
|
||||
// New unex
|
||||
struct unexConn* unex;
|
||||
|
||||
@@ -46,8 +46,6 @@ __device__ void AllReduceCliqueSplitKernel(struct ncclWorkElem* args)
|
||||
size_t const currBlockStop = min(currBlockStart + perBlockN, N);
|
||||
size_t const blockN = currBlockStop - currBlockStart;
|
||||
|
||||
FUNC redOp(FuncTraits<FUNC>().make(args->comm->nRanks));
|
||||
|
||||
if (blockN > 0)
|
||||
{
|
||||
// Prepare input / output subarrays
|
||||
@@ -65,8 +63,8 @@ __device__ void AllReduceCliqueSplitKernel(struct ncclWorkElem* args)
|
||||
|
||||
// Perform the reduction
|
||||
#define ALL_REDUCE_CLIQUE_UNROLL 1
|
||||
ReduceOrCopyMulti<ALL_REDUCE_CLIQUE_UNROLL, FUNC, T, NUM_RANKS, NUM_RANKS, NUM_RANKS, NUM_RANKS>(
|
||||
threadIdx.x, blockDim.x, redOp, NUM_RANKS, true, NUM_RANKS, srcs, NUM_RANKS, dsts, blockN);
|
||||
ReduceOrCopyMulti<ALL_REDUCE_CLIQUE_UNROLL, FUNC, T, NUM_RANKS, NUM_RANKS, NUM_RANKS, NUM_RANKS, 0>(
|
||||
threadIdx.x, blockDim.x, nullptr, false, NUM_RANKS, srcs, NUM_RANKS, dsts, blockN);
|
||||
}
|
||||
|
||||
// Even if there was nothing for this GPU to do, it must participate in a barrier
|
||||
|
||||
@@ -274,9 +274,9 @@ bool CliqueManager::IsSupported(ncclFunc_t const coll,
|
||||
{
|
||||
if (m_cliqueMode == CLIQUE_DISABLED) return false;
|
||||
|
||||
// Filter based on total input size for each collective type
|
||||
// Filter based on total input size for each collective type and ops sum/prod/min/max
|
||||
size_t totalBytes = count * ncclTypeSize(datatype);
|
||||
if (coll == ncclFuncAllReduce && (totalBytes <= m_allReduceByteLimit)) return true;
|
||||
if (coll == ncclFuncAllReduce && (totalBytes <= m_allReduceByteLimit) && op < ncclAvg) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ include ../../../makefiles/version.mk
|
||||
BUILDDIR ?= $(abspath ../../../build)
|
||||
OBJDIR := $(BUILDDIR)/obj/collectives/device
|
||||
|
||||
LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu sendrecv.cu
|
||||
LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu sendrecv.cu onerank_reduce.cu
|
||||
|
||||
LIBSRCFILES += functions.cu
|
||||
|
||||
@@ -36,7 +36,7 @@ $(RULESFILE) :
|
||||
|
||||
-include $(RULESFILE)
|
||||
|
||||
LIBOBJ := $(GENOBJS) $(OBJDIR)/functions.o
|
||||
LIBOBJ := $(GENOBJS) $(OBJDIR)/functions.o $(OBJDIR)/onerank_reduce.o
|
||||
|
||||
-include $(DEPFILES)
|
||||
|
||||
@@ -63,6 +63,11 @@ $(OBJDIR)/functions.o : functions.cu $(OBJDIR)/functions.dep
|
||||
mkdir -p `dirname $@`
|
||||
$(NVCC) $(NVCUFLAGS) -dc $< -o $@
|
||||
|
||||
$(OBJDIR)/onerank_reduce.o : onerank_reduce.cu $(OBJDIR)/onerank_reduce.dep
|
||||
@printf "Compiling %-35s > %s\n" $< $@
|
||||
mkdir -p `dirname $@`
|
||||
$(NVCC) $(NVCUFLAGS) -dc $< -o $@
|
||||
|
||||
# ... and create the device-side linked object with all those.
|
||||
$(DEVOBJ) : $(LIBOBJ)
|
||||
$(NVCC) $(NVCUFLAGS) -dlink $^ -o $@
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -28,7 +28,7 @@ namespace {
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, 0, args->coll.connIndex);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->coll.redOpArg, args->coll.connIndex << 16);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@@ -79,7 +79,7 @@ namespace {
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLGATHER_CHUNKSTEPS/ALLGATHER_SLICESTEPS, ALLGATHER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
@@ -87,14 +87,14 @@ struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SI
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllGather, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -38,7 +38,7 @@ namespace {
|
||||
}
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto> prims
|
||||
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex);
|
||||
(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->coll.redOpArg, args->coll.connIndex << 16);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@@ -110,12 +110,12 @@ namespace {
|
||||
ACCUMULATE_COUNTER(directRecv);
|
||||
}
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0 && args->op.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk);
|
||||
if (tid == 0 && args->coll.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runTreeUpDown(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runTreeUpDown(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -135,7 +135,7 @@ namespace {
|
||||
|
||||
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto> prims
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff);
|
||||
(tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->coll.redOpArg);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@@ -161,7 +161,7 @@ namespace {
|
||||
|
||||
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto> prims
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff);
|
||||
(tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg);
|
||||
if (tree->up == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@@ -187,7 +187,7 @@ namespace {
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runTreeSplit(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runTreeSplit(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -219,7 +219,7 @@ namespace {
|
||||
if (tree->up == -1) {
|
||||
// Reduce and broadcast. Max number of recv is 3, max number of send is 3
|
||||
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff);
|
||||
prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@@ -236,7 +236,7 @@ namespace {
|
||||
* but the ctor above for tree roots would be DirectRecv=0 DirectSend=1.
|
||||
*/
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth);
|
||||
prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->coll.redOpArg, 0*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@@ -255,7 +255,7 @@ namespace {
|
||||
else {
|
||||
// Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg, 1*Proto::MaxGroupWidth);
|
||||
if (tree->down[0] == -1) {
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*int(chunkSize);
|
||||
@@ -276,7 +276,7 @@ namespace {
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
@@ -284,14 +284,14 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SI
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
static constexpr int COLLNET_COPY_THREADS = 64;
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -315,27 +315,37 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO
|
||||
|
||||
if (tid >= tidStartScatter && tid < tidStartReduce && hasUp) {
|
||||
// Scatter
|
||||
int group = (2*Proto::MaxGroupWidth) | (1<<16);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, 2*Proto::MaxGroupWidth);
|
||||
prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize;
|
||||
int nelem = min(tree->nHeads*chunkSize, size-offset);
|
||||
prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
if (args->regUsed) {
|
||||
prims.directScatter(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
} else {
|
||||
prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
}
|
||||
}
|
||||
} else if (tid >= tidStartReduce && tree->out != -1) {
|
||||
int group = (3*Proto::MaxGroupWidth) | (1<<16);
|
||||
if (hasDn) {
|
||||
// Reduce, send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth);
|
||||
prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
if (args->regUsed) {
|
||||
prims.directRecvReduceSend(offset, offset, nelem);
|
||||
} else {
|
||||
prims.recvReduceSend(offset, nelem);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Directly send to network
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth);
|
||||
prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, args->coll.redOpArg, group);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@@ -344,27 +354,29 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO
|
||||
}
|
||||
} else if (tid < tidStartBcast && hasUp) {
|
||||
// Gather
|
||||
int group = (0*Proto::MaxGroupWidth) | (0<<16);
|
||||
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DIRECT_ARITY, 0>, /*Direct=*/0, Proto>
|
||||
prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth);
|
||||
prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize;
|
||||
int nelem = min(tree->nHeads*chunkSize, size-offset);
|
||||
prims.gather(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
prims.directGather(offset, nelem, chunkSize, tree->headRank, tree->shift);
|
||||
}
|
||||
} else if (tid >= tidStartBcast && tid < tidStartScatter && tree->out != -1) {
|
||||
int group = (1*Proto::MaxGroupWidth) | (0<<16);
|
||||
if (hasDn) {
|
||||
// Recv from network, broadcast
|
||||
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DIRECT_ARITY>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
prims.recvCopySend(offset, nelem, /*postOp=*/true);
|
||||
prims.recvCopyDirectSend(offset, offset, nelem, /*postOp=*/true);
|
||||
}
|
||||
} else {
|
||||
// Recv from network (no post thread needed)
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, /*Direct=*/0, Proto>
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth);
|
||||
prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, args->coll.redOpArg, group);
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize;
|
||||
int nelem = min(chunkSize, size-offset);
|
||||
@@ -377,28 +389,28 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runTreeUpDown<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, RedOp, T, args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL128> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, RedOp, T, args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -32,7 +32,7 @@ namespace {
|
||||
T *inputBuf = (T*)args->sendbuff;
|
||||
T *outputBuf = (T*)args->recvbuff;
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, 0, args->coll.connIndex);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->coll.redOpArg, args->coll.connIndex << 16);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@@ -70,14 +70,14 @@ namespace {
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0 && args->op.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk);
|
||||
if (tid == 0 && args->coll.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<BROADCAST_CHUNKSTEPS/BROADCAST_SLICESTEPS, BROADCAST_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
@@ -85,14 +85,14 @@ struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SI
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncBroadcast, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -12,97 +12,92 @@
|
||||
#include "devcomm.h"
|
||||
|
||||
#define COLL_UNROLL 2
|
||||
#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY
|
||||
#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree
|
||||
|
||||
#define __syncwarp()
|
||||
|
||||
#define NCCL_FUNC5(func, algo, redop, type) \
|
||||
NCCL_FUNC_NAME(func, algo, LL, redop, type), \
|
||||
NCCL_FUNC_NAME(func, algo, LL, redop, type), \
|
||||
NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type)
|
||||
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
|
||||
|
||||
#define NCCL_FUNC4(func, redop, type) \
|
||||
NCCL_FUNC5(func, TREE, redop, type), \
|
||||
NCCL_FUNC5(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
#define NCCL_FUNC4(func, devredop, type, nullify) \
|
||||
NCCL_FUNC5(func, TREE, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, RING, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, COLLNET, devredop, type, nullify)
|
||||
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, uint8_t), \
|
||||
NCCL_FUNC4(func, redop, int32_t), \
|
||||
NCCL_FUNC4(func, redop, uint32_t), \
|
||||
NCCL_FUNC4(func, redop, int64_t), \
|
||||
NCCL_FUNC4(func, redop, uint64_t), \
|
||||
NCCL_FUNC4(func, redop, half), \
|
||||
NCCL_FUNC4(func, redop, float), \
|
||||
NCCL_FUNC4(func, redop, double), \
|
||||
NCCL_FUNC4(func, redop, rccl_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t)
|
||||
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, double, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0)
|
||||
|
||||
// Must be consistent with ncclRedOp_t
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum ), \
|
||||
NCCL_FUNCS3A(func, Prod), \
|
||||
NCCL_FUNCS3A(func, Max ), \
|
||||
NCCL_FUNCS3A(func, Min ), \
|
||||
NCCL_FUNCS3A(func, Avg)
|
||||
NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1)
|
||||
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// [RCCL] Adding clique-based kernels for AllReduce, in-place of unused RingLL28 kernels
|
||||
#define NCCL_FUNC5B(func, algo, redop, type) \
|
||||
NCCL_FUNC_NAME(func, algo, LL, redop, type), \
|
||||
NCCL_FUNC_NAME(func, algo, LL128, redop, type), \
|
||||
NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type)
|
||||
#define NCCL_FUNC5B(func, algo, devredop, type, nullify) \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
|
||||
|
||||
#define NCCL_FUNC4B(func, redop, type) \
|
||||
NCCL_FUNC5(func, TREE, redop, type), \
|
||||
NCCL_FUNC5B(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
#define NCCL_FUNC4B(func, devredop, type, nullify) \
|
||||
NCCL_FUNC5B(func, TREE, devredop, type, nullify), \
|
||||
NCCL_FUNC5B(func, RING, devredop, type, nullify), \
|
||||
NCCL_FUNC5B(func, COLLNET, devredop, type, nullify)
|
||||
|
||||
#define NCCL_FUNCS3C(func, redop) \
|
||||
NCCL_FUNC4B(func, redop, int8_t), \
|
||||
NCCL_FUNC4B(func, redop, uint8_t), \
|
||||
NCCL_FUNC4B(func, redop, int32_t), \
|
||||
NCCL_FUNC4B(func, redop, uint32_t), \
|
||||
NCCL_FUNC4B(func, redop, int64_t), \
|
||||
NCCL_FUNC4B(func, redop, uint64_t), \
|
||||
NCCL_FUNC4B(func, redop, half), \
|
||||
NCCL_FUNC4B(func, redop, float), \
|
||||
NCCL_FUNC4B(func, redop, double), \
|
||||
NCCL_FUNC4B(func, redop, rccl_bfloat16)
|
||||
#define NCCL_FUNCS3C(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4B(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4B(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4B(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4B(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4B(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4B(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4B(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4B(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4B(func, devredop, double, nullForFloat), \
|
||||
NCCL_FUNC4B(func, devredop, rccl_bfloat16, nullForFloat)
|
||||
|
||||
#define NCCL_FUNCS2C(func) \
|
||||
NCCL_FUNCS3C(func, Sum ), \
|
||||
NCCL_FUNCS3C(func, Prod), \
|
||||
NCCL_FUNCS3C(func, Max ), \
|
||||
NCCL_FUNCS3C(func, Min ), \
|
||||
NCCL_FUNCS3C(func, Avg)
|
||||
NCCL_FUNCS3C(func, Sum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3C(func, Prod, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3C(func, Max, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3C(func, Min, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3C(func, PreMulSum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3C(func, SumPostDiv, /*nullForFloat=*/1)
|
||||
|
||||
// Must be consistent with ncclFunc_t
|
||||
#define NCCL_FUNCS() { \
|
||||
NCCL_FUNCS2B(Broadcast), \
|
||||
NCCL_FUNCS2A(Reduce), \
|
||||
NCCL_FUNCS2B(AllGather), \
|
||||
NCCL_FUNCS2A(ReduceScatter), \
|
||||
NCCL_FUNCS2C(AllReduce), \
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t) }
|
||||
// [/RCCL]
|
||||
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
using ncclKernelFunc_t = void (*)(struct ncclWorkElem* args);
|
||||
@@ -113,13 +108,25 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
|
||||
// confuses clang. This will be fixed in the next clang release.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(BUILD_ALLREDUCE_ONLY)
|
||||
NCCL_FUNC4B(AllReduce, Sum, float),
|
||||
NCCL_FUNC4B(AllReduce, Sum, float, 0),
|
||||
#else
|
||||
NCCL_FUNCS2B(Broadcast),
|
||||
NCCL_FUNCS2A(Reduce),
|
||||
NCCL_FUNCS2B(AllGather),
|
||||
NCCL_FUNCS2A(ReduceScatter),
|
||||
NCCL_FUNCS2C(AllReduce),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16),
|
||||
#endif
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
|
||||
#endif
|
||||
#endif
|
||||
@@ -142,7 +149,7 @@ struct Caller<f, f + 1>{
|
||||
void call(struct ncclWorkElem* const c) noexcept { ncclFuncs[f](c); }
|
||||
};
|
||||
|
||||
static_assert(FUNC_INDEX_P2P == 2250, "Wrong P2P function index");
|
||||
static_assert(FUNC_INDEX_P2P == 2710, "Wrong P2P function index");
|
||||
|
||||
inline
|
||||
__device__
|
||||
@@ -165,7 +172,7 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
else
|
||||
assert("Unsupported function index");
|
||||
#else
|
||||
if (c->funcIndex < 450) {
|
||||
if (c->funcIndex < 540) {
|
||||
if (c->funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(c);
|
||||
@@ -176,8 +183,8 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
else if (c->funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t(c);
|
||||
else ncclFunction_Broadcast_COLLNET_SIMPLE_Sum_int8_t(c);
|
||||
}
|
||||
else if (c->funcIndex < 900) Caller<450, 900>::call(c);
|
||||
else if (c->funcIndex < 1350) {
|
||||
else if (c->funcIndex < 1080) Caller<540, 1080>::call(c);
|
||||
else if (c->funcIndex < 1620) {
|
||||
if (c->funcIndex % 9 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(c);
|
||||
else if (c->funcIndex % 9 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(c);
|
||||
@@ -188,8 +195,46 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
else if (c->funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t(c);
|
||||
else ncclFunction_AllGather_COLLNET_SIMPLE_Sum_int8_t(c);
|
||||
}
|
||||
else if (c->funcIndex < 2250) Caller<1350, 2250>::call(c);
|
||||
else ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c);
|
||||
else if (c->funcIndex < 2700) Caller<1620, 2700>::call(c);
|
||||
else {
|
||||
switch (c->funcIndex - 2700) {
|
||||
case 0:
|
||||
ncclFunction_OneRankReduce_PreMulSum_int8_t(c);
|
||||
break;
|
||||
case 1:
|
||||
ncclFunction_OneRankReduce_PreMulSum_uint8_t(c);
|
||||
break;
|
||||
case 2:
|
||||
ncclFunction_OneRankReduce_PreMulSum_int32_t(c);
|
||||
break;
|
||||
case 3:
|
||||
ncclFunction_OneRankReduce_PreMulSum_uint32_t(c);
|
||||
break;
|
||||
case 4:
|
||||
ncclFunction_OneRankReduce_PreMulSum_int64_t(c);
|
||||
break;
|
||||
case 5:
|
||||
ncclFunction_OneRankReduce_PreMulSum_uint64_t(c);
|
||||
break;
|
||||
case 6:
|
||||
ncclFunction_OneRankReduce_PreMulSum_half(c);
|
||||
break;
|
||||
case 7:
|
||||
ncclFunction_OneRankReduce_PreMulSum_float(c);
|
||||
break;
|
||||
case 8:
|
||||
ncclFunction_OneRankReduce_PreMulSum_double(c);
|
||||
break;
|
||||
case 9:
|
||||
ncclFunction_OneRankReduce_PreMulSum_rccl_bfloat16(c);
|
||||
break;
|
||||
case 10:
|
||||
ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -203,29 +248,42 @@ class ncclFunction {
|
||||
#define traceColl(fIdx) \
|
||||
uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
|
||||
shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
|
||||
shmem.comm.collTrace[pos].opCount = elems[0].op.opCount; \
|
||||
shmem.comm.collTrace[pos].bid = bid; \
|
||||
shmem.comm.collTrace[pos].funcIndex = fIdx; \
|
||||
if (fIdx == FUNC_INDEX_P2P) { \
|
||||
shmem.comm.collTrace[pos].opCount = elems[0].p2p.opCount; \
|
||||
shmem.comm.collTrace[pos].p2p.nThreads = elems[0].p2p.nThreads; \
|
||||
shmem.comm.collTrace[pos].p2p.delta = (uint16_t)(elems[0].p2p.delta); \
|
||||
} else { \
|
||||
shmem.comm.collTrace[pos].opCount = elems[0].coll.opCount; \
|
||||
shmem.comm.collTrace[pos].coll.nThreads = elems[0].nThreads; \
|
||||
shmem.comm.collTrace[pos].coll.bid = elems[0].coll.bid; \
|
||||
shmem.comm.collTrace[pos].coll.nChannels = elems[0].coll.nChannels; \
|
||||
}
|
||||
#define traceKernelLaunch(fIdx) { \
|
||||
traceColl(fIdx); \
|
||||
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceKernelLaunchType; \
|
||||
if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \
|
||||
traceColl(fIdx); \
|
||||
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceKernelLaunchType; \
|
||||
} \
|
||||
}
|
||||
#define traceCollEnd(fIdx) { \
|
||||
traceColl(fIdx); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceCollEndType; \
|
||||
if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \
|
||||
traceColl(fIdx); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceCollEndType; \
|
||||
} \
|
||||
}
|
||||
#define traceKernelEnd(fIdx) { \
|
||||
if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \
|
||||
traceColl(fIdx); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceKernelEndType; \
|
||||
} \
|
||||
}
|
||||
#define traceAbort(fIdx) { \
|
||||
traceColl(fIdx); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \
|
||||
if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \
|
||||
traceColl(fIdx); \
|
||||
shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \
|
||||
} \
|
||||
}
|
||||
// traceData(int16_t data2, uint32_t data4, uint64_t data8_0, uint64_t data8_1)
|
||||
#define traceData(data2, data4, data8_0, data8_1) { \
|
||||
@@ -285,14 +343,24 @@ __device__ int copyToShmem(T *dst, T const *src, int turn=0) {
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWorkElement {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem*) {
|
||||
__device__ void run(ncclWorkElem*) {
|
||||
// Put NOT IMPLEMENTED behavior here.
|
||||
}
|
||||
};
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
__device__ constexpr int ncclWorkElemFactors[NCCL_NUM_ALGORITHMS] =
|
||||
#else
|
||||
static __device__ __constant__ int ncclWorkElemFactors[NCCL_NUM_ALGORITHMS] =
|
||||
#endif
|
||||
{/*Tree*/1, /*Ring and P2P*/1, /*CollNet*/NCCL_REG_ELEM_FACTOR};
|
||||
|
||||
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto>
|
||||
struct RunWork {
|
||||
__device__ __attribute__((noinline)) void run(ncclWork *w) {
|
||||
// This __forceinline__ is necessary. The compiler was inserting a function call
|
||||
// here from the LL ncclKernel.
|
||||
__device__ __forceinline__ void run(ncclWork *w) {
|
||||
int tid = threadIdx.x;
|
||||
/* Some invariants that must hold:
|
||||
* 1. All elems[] have same funcIndex.
|
||||
* 2. All elems[] have same nThreads.
|
||||
@@ -300,20 +368,23 @@ struct RunWork {
|
||||
* for all elems[].
|
||||
*
|
||||
* If (1) isn't true then we might be in the wrong function since dispatch
|
||||
* on ncclFuncs[w->elems[0].funcIndex] is how we got here.
|
||||
* on ncclFuncs[w->funcIndex] is how we got here.
|
||||
*
|
||||
* If (2) or (3) aren't true, then threads from different work elements
|
||||
* could race for barrier resources (barrier numbers 0...15) which is fatal.
|
||||
*
|
||||
* Important, to ensure (3), implementations of
|
||||
* `RunWorkElement<Fn,T,RedOp,Algo,Proto>::run()` may only use values which
|
||||
* are the same for all elems[] when deciding how to map threads to groups,
|
||||
* such as the following:
|
||||
* IMPORTANT!!! To ensure (3), implementations of
|
||||
* `RunWorkElement<Fn,T,RedOp,Algo,Proto>::run()` may only use the following
|
||||
* when deciding how to map threads to groups:
|
||||
* Fn, T, RedOp, Algo, Proto, nThreads
|
||||
*
|
||||
* This last one is difficult to enforce and diagnosing it is a headeache.
|
||||
* Device-side developers, consider yourselves warned.
|
||||
* This last one is difficult to enforce so I hope everyone reads this.
|
||||
*/
|
||||
if (tid < w->elems[0].nThreads) {
|
||||
#pragma unroll 1
|
||||
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e+=ncclWorkElemFactors[Algo])
|
||||
RunWorkElement<Fn, T, RedOp, Algo, Proto>().run(&w->elems[e]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -323,6 +394,7 @@ struct ncclShmemGroup {
|
||||
ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY];
|
||||
void* srcs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
void* dsts[NCCL_MAX_DIRECT_ARITY+1];
|
||||
int totalSendSize[NCCL_MAX_SLICE_PER_CHUNK];
|
||||
uint64_t barrier;
|
||||
uint64_t barrier_next[MAXWARPS];
|
||||
};
|
||||
@@ -333,6 +405,7 @@ struct ncclShmemData {
|
||||
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
|
||||
};
|
||||
uint32_t sync[MAXWARPS];
|
||||
uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
ncclDevComm comm;
|
||||
ncclChannel channel;
|
||||
ncclWork work;
|
||||
@@ -362,9 +435,13 @@ __device__ void ncclKernel(ncclWorkElem first) {
|
||||
turn = copyToShmem(&shmem.channel, channel, turn);
|
||||
|
||||
// To optimize for latency, (only) the first operation is passed as argument.
|
||||
if (bid == 0 && first.active != 0)
|
||||
if (bid == 0 && first.active != 0) {
|
||||
turn = copyToShmem(&shmem.work.elems[0], &first, turn);
|
||||
|
||||
if (1 <= tid && tid < NCCL_MAX_WORK_ELEMENTS && tid % ncclWorkElemFactors[Algo] == 0) {
|
||||
shmem.work.elems[tid].active = 0;
|
||||
shmem.work.elems[tid].redOpArgIsPtr = 0;
|
||||
}
|
||||
}
|
||||
struct ncclWorkElem* elems = shmem.work.elems;
|
||||
__syncthreads(); // publish shmem
|
||||
|
||||
@@ -401,13 +478,36 @@ __device__ void ncclKernel(ncclWorkElem first) {
|
||||
if (tid == 0)
|
||||
channel->index = workFifoIx; // write back to real channel, not shmem shadow
|
||||
|
||||
if (tid < NCCL_MAX_WORK_ELEMENTS && tid % ncclWorkElemFactors[Algo] == 0) {
|
||||
ncclWorkElem *we = &shmem.work.elems[tid];
|
||||
if (we->redOpArgIsPtr && we->active != 0) {
|
||||
/* redOpArg is a pointer to the scalar value, so we'll dereference it
|
||||
* here so that redOpArg holds the bits of the scalar going forward.
|
||||
* The tricky thing is we don't know its type T since that's encoded in
|
||||
* the funcIndex. Because it would be difficult to get sizeof(T) from
|
||||
* funcIndex, we'll cheat and just dereference the largest possible size
|
||||
* given the alignment of the pointer. We might be reading in more bytes
|
||||
* than we need but that's harmless.
|
||||
*/
|
||||
if (we->coll.redOpArg%2 != 0)
|
||||
we->coll.redOpArg = *reinterpret_cast<uint8_t*>(we->coll.redOpArg);
|
||||
else if (we->coll.redOpArg%4 != 0)
|
||||
we->coll.redOpArg = *reinterpret_cast<uint16_t*>(we->coll.redOpArg);
|
||||
else if (we->coll.redOpArg%8 != 0)
|
||||
we->coll.redOpArg = *reinterpret_cast<uint32_t*>(we->coll.redOpArg);
|
||||
else
|
||||
we->coll.redOpArg = *reinterpret_cast<uint64_t*>(we->coll.redOpArg);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (shmem.work.elems[0].funcIndex == FnIndex)
|
||||
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
|
||||
else
|
||||
NCCL_CALL_FUNCTIONS(&elems[0]);
|
||||
|
||||
if (shmem.work.elems[0].active == 2) {
|
||||
if (COLLTRACE && tid == 0) traceCollEnd(0xffff)
|
||||
if (COLLTRACE && tid == 0) traceKernelEnd(elems->funcIndex)
|
||||
break;
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -415,43 +515,51 @@ __device__ void ncclKernel(ncclWorkElem first) {
|
||||
}
|
||||
}
|
||||
|
||||
#define IMPL_COLL_KERN(func, algo, proto, redop, type, fIndex) \
|
||||
#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \
|
||||
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \
|
||||
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(ncclWorkElem first) { \
|
||||
if (first.comm->collTraceThread) \
|
||||
ncclKernel<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(first); \
|
||||
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(first); \
|
||||
else \
|
||||
ncclKernel<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(first); \
|
||||
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(first); \
|
||||
}
|
||||
|
||||
// Examples : AllReduce, RING, LL, Sum, uint8
|
||||
/* Functions for aggregation case */
|
||||
#define IMPL_COLL_FUNC(func, algo, proto, redop, type) \
|
||||
__device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \
|
||||
RunWorkElement<ncclFunc##func, type, Func##redop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(args); \
|
||||
#define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \
|
||||
__device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(struct ncclWorkElem* args) { \
|
||||
RunWork<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem->work); \
|
||||
}
|
||||
|
||||
// Only generate inline kernels for LL
|
||||
#define IMPL_COLL4(func, algo, redop, type, ncclType) \
|
||||
IMPL_COLL_FUNC(func, algo, LL, redop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, SIMPLE, redop, type)
|
||||
#define IMPL_COLL4(func, algo, devredop, type, ncclType) \
|
||||
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) \
|
||||
|
||||
#define IMPL_COLL3(func, redop, type, ncclType) \
|
||||
IMPL_COLL4(func, TREE, redop, type, ncclType) \
|
||||
IMPL_COLL4(func, RING, redop, type, ncclType) \
|
||||
IMPL_COLL4(func, COLLNET, redop, type, ncclType)
|
||||
#define IMPL_COLL3(func, devredop, type, ncclType) \
|
||||
IMPL_COLL4(func, TREE, devredop, type, ncclType) \
|
||||
IMPL_COLL4(func, RING, devredop, type, ncclType) \
|
||||
IMPL_COLL4(func, COLLNET, devredop, type, ncclType)
|
||||
|
||||
#define IMPL_COLL2(func, redop) \
|
||||
IMPL_COLL3(func, redop, int8_t, ncclInt8) \
|
||||
IMPL_COLL3(func, redop, uint8_t, ncclUint8) \
|
||||
IMPL_COLL3(func, redop, int32_t, ncclInt32) \
|
||||
IMPL_COLL3(func, redop, uint32_t, ncclUint32) \
|
||||
IMPL_COLL3(func, redop, int64_t, ncclInt64) \
|
||||
IMPL_COLL3(func, redop, uint64_t, ncclUint64) \
|
||||
IMPL_COLL3(func, redop, half, ncclFloat16) \
|
||||
IMPL_COLL3(func, redop, float, ncclFloat32) \
|
||||
IMPL_COLL3(func, redop, double, ncclFloat64) \
|
||||
IMPL_COLL3(func, redop, rccl_bfloat16, ncclBfloat16)
|
||||
#define IMPL_COLL2(func, devredop) \
|
||||
IMPL_COLL3(func, devredop, int8_t, ncclInt8) \
|
||||
IMPL_COLL3(func, devredop, uint8_t, ncclUint8) \
|
||||
IMPL_COLL3(func, devredop, int32_t, ncclInt32) \
|
||||
IMPL_COLL3(func, devredop, uint32_t, ncclUint32) \
|
||||
IMPL_COLL3(func, devredop, int64_t, ncclInt64) \
|
||||
IMPL_COLL3(func, devredop, uint64_t, ncclUint64) \
|
||||
IMPL_COLL3(func, devredop, half, ncclFloat16) \
|
||||
IMPL_COLL3(func, devredop, float, ncclFloat32) \
|
||||
IMPL_COLL3(func, devredop, double, ncclFloat64) \
|
||||
IMPL_COLL3(func, devredop, rccl_bfloat16, ncclBfloat16)
|
||||
|
||||
#define IMPL_COLL2A(func, devredop) \
|
||||
IMPL_COLL3(func, devredop, int8_t, ncclInt8) \
|
||||
IMPL_COLL3(func, devredop, uint8_t, ncclUint8) \
|
||||
IMPL_COLL3(func, devredop, int32_t, ncclInt32) \
|
||||
IMPL_COLL3(func, devredop, uint32_t, ncclUint32) \
|
||||
IMPL_COLL3(func, devredop, int64_t, ncclInt64) \
|
||||
IMPL_COLL3(func, devredop, uint64_t, ncclUint64)
|
||||
|
||||
// Reduction define all functions
|
||||
#define IMPL_COLL_R(func) \
|
||||
@@ -459,40 +567,49 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, red
|
||||
IMPL_COLL2(func, Prod) \
|
||||
IMPL_COLL2(func, Min) \
|
||||
IMPL_COLL2(func, Max) \
|
||||
IMPL_COLL2(func, Avg)
|
||||
IMPL_COLL2(func, PreMulSum) \
|
||||
IMPL_COLL2A(func, SumPostDiv)
|
||||
|
||||
// [RCCL] Define clique-based implementations (repurposed LL128)
|
||||
#define IMPL_COLL4_CLIQUE(func, algo, redop, type, ncclType) \
|
||||
IMPL_COLL_FUNC(func, algo, LL, redop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, LL128, redop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, SIMPLE, redop, type)
|
||||
#define IMPL_COLL4_CLIQUE(func, algo, devredop, type, ncclType) \
|
||||
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \
|
||||
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) \
|
||||
|
||||
#define IMPL_COLL3_CLIQUE(func, redop, type, ncclType) \
|
||||
IMPL_COLL4(func, TREE, redop, type, ncclType) \
|
||||
IMPL_COLL4_CLIQUE(func, RING, redop, type, ncclType) \
|
||||
IMPL_COLL4(func, COLLNET, redop, type, ncclType)
|
||||
#define IMPL_COLL3_CLIQUE(func, devredop, type, ncclType) \
|
||||
IMPL_COLL4_CLIQUE(func, TREE, devredop, type, ncclType) \
|
||||
IMPL_COLL4_CLIQUE(func, RING, devredop, type, ncclType) \
|
||||
IMPL_COLL4_CLIQUE(func, COLLNET, devredop, type, ncclType)
|
||||
|
||||
#define IMPL_COLL2_CLIQUE(func, redop) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, int8_t, ncclInt8) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, uint8_t, ncclUint8) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, int32_t, ncclInt32) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, uint32_t, ncclUint32) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, int64_t, ncclInt64) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, uint64_t, ncclUint64) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, half, ncclFloat16) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, float, ncclFloat32) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, double, ncclFloat64) \
|
||||
IMPL_COLL3_CLIQUE(func, redop, rccl_bfloat16, ncclBfloat16)
|
||||
#define IMPL_COLL2_CLIQUE(func, devredop) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, int8_t, ncclInt8) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, uint8_t, ncclUint8) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, int32_t, ncclInt32) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, uint32_t, ncclUint32) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, int64_t, ncclInt64) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, uint64_t, ncclUint64) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, half, ncclFloat16) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, float, ncclFloat32) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, double, ncclFloat64) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, rccl_bfloat16, ncclBfloat16)
|
||||
|
||||
#define IMPL_COLL2A_CLIQUE(func, devredop) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, int8_t, ncclInt8) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, uint8_t, ncclUint8) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, int32_t, ncclInt32) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, uint32_t, ncclUint32) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, int64_t, ncclInt64) \
|
||||
IMPL_COLL3_CLIQUE(func, devredop, uint64_t, ncclUint64)
|
||||
|
||||
#define IMPL_COLL_CLIQUE(func) \
|
||||
IMPL_COLL2_CLIQUE(func, Sum) \
|
||||
IMPL_COLL2_CLIQUE(func, Prod) \
|
||||
IMPL_COLL2_CLIQUE(func, Min) \
|
||||
IMPL_COLL2_CLIQUE(func, Max) \
|
||||
IMPL_COLL2_CLIQUE(func, Avg)
|
||||
IMPL_COLL2_CLIQUE(func, PreMulSum) \
|
||||
IMPL_COLL2A_CLIQUE(func, SumPostDiv)
|
||||
// [/RCCL]
|
||||
|
||||
|
||||
// Copy primitives only define one function for copy
|
||||
#define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t, ncclInt8);
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ typedef uint64_t PackType;
|
||||
|
||||
template<typename Fn>
|
||||
struct FuncTraits /*{
|
||||
__device__ static Fn make();
|
||||
__device__ static T preOp(Fn, T);
|
||||
__device__ static T postOp(Fn, T);
|
||||
}*/;
|
||||
@@ -501,12 +500,12 @@ inline __device__ void Store128(Pack128* p, Pack128& v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS, int PreOpN, typename Int>
|
||||
__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t,
|
||||
FUNC fn, int const numPreOpSrcs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem
|
||||
uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Nelem
|
||||
) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
const Int inc = nw * UNROLL * WARP_SIZE;
|
||||
Int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
const T* srcs[MAXSRCS];
|
||||
for (int i=0; i<MAXSRCS; i++) srcs[i] = s[i]+elemOffset+offset;
|
||||
@@ -517,15 +516,17 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
T vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE);
|
||||
if (numPreOpSrcs) {
|
||||
if (PreOpN) {
|
||||
FUNC fn(redOpArgs[0]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits<FUNC>().preOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
T vals2[UNROLL];
|
||||
FUNC fn(redOpArgs[i]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
if (i < numPreOpSrcs) {
|
||||
if (i<PreOpN) {
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = FuncTraits<FUNC>().preOp(fn, vals2[u]);
|
||||
}
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]);
|
||||
@@ -534,12 +535,17 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
T vals2[UNROLL];
|
||||
FUNC fn(redOpArgs[i]);
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
if (i<PreOpN) {
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = FuncTraits<FUNC>().preOp(fn, vals2[u]);
|
||||
}
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
FUNC fn(redOpArgs[0]);
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits<FUNC>().postOp(fn, vals[u]);
|
||||
}
|
||||
@@ -561,12 +567,12 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const
|
||||
}
|
||||
}
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS, int PreOpN, typename Int>
|
||||
__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
FUNC fn, int numPreOpSrcs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack
|
||||
uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Npack
|
||||
) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
const Int inc = nw * UNROLL * WARP_SIZE;
|
||||
Int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
const Pack128* srcs[MAXSRCS];
|
||||
for (int i=0; i<MAXSRCS; i++) srcs[i] = ((const Pack128*)(s[i]+elemOffset))+offset;
|
||||
@@ -577,15 +583,17 @@ __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, c
|
||||
Pack128 vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
|
||||
if (numPreOpSrcs) {
|
||||
if (PreOpN) {
|
||||
FUNC fn(redOpArgs[0]);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().preOp(fn, vals[u]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
Pack128 vals2[UNROLL];
|
||||
FUNC fn(redOpArgs[i]);
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
if (i < numPreOpSrcs) {
|
||||
if (i<PreOpN) {
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().preOp(fn, vals2[u]);
|
||||
}
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(fn, vals[u], vals2[u]);
|
||||
@@ -594,12 +602,17 @@ __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, c
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
Pack128 vals2[UNROLL];
|
||||
FUNC fn(redOpArgs[i]);
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
if (i<PreOpN) {
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().preOp(fn, vals2[u]);
|
||||
}
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(fn, vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
if (postOp) {
|
||||
FUNC fn(redOpArgs[0]);
|
||||
#pragma unroll
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>().postOp(fn, vals[u]);
|
||||
}
|
||||
@@ -627,11 +640,11 @@ __device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(uint32_t); }
|
||||
#define PACKELEMS (sizeof(Pack128) / sizeof(T))
|
||||
#define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1))
|
||||
|
||||
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS, int PreOpN, typename Int>
|
||||
__device__ __forceinline__ void ReduceOrCopyMulti(
|
||||
const int tid, const int nthreads, FUNC fn, int numPreOpSrcs, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, int N
|
||||
const int tid, const int nthreads, uint64_t* redOpArgs, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, Int N
|
||||
) {
|
||||
int Nrem = N;
|
||||
Int Nrem = N;
|
||||
if (Nrem <= 0) return;
|
||||
|
||||
int w = tid / WARP_SIZE; // Warp number
|
||||
@@ -647,17 +660,17 @@ __device__ __forceinline__ void ReduceOrCopyMulti(
|
||||
for (int i=0; i<MINDSTS; i++) align |= ptrAlign128(dsts[i]);
|
||||
for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) align |= ptrAlign128(dsts[i]);
|
||||
|
||||
int offset = 0;
|
||||
Int offset = 0;
|
||||
if (align == 0) {
|
||||
// fast path: use 128b loads/stores to do the bulk of the work,
|
||||
// assuming the pointers we have are all 128-bit aligned.
|
||||
|
||||
// main loop
|
||||
int Npack = (Nrem / (PACKELEMS*AUTOUNROLL*WARP_SIZE)) * (AUTOUNROLL*WARP_SIZE); // round down
|
||||
int Nelem = Npack * PACKELEMS;
|
||||
Int Npack = (Nrem / (PACKELEMS*AUTOUNROLL*WARP_SIZE)) * (AUTOUNROLL*WARP_SIZE); // round down
|
||||
Int Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS, PreOpN>
|
||||
(w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
@@ -667,8 +680,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti(
|
||||
Npack = Nrem / PACKELEMS;
|
||||
Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS, PreOpN>
|
||||
(w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
@@ -676,18 +689,18 @@ __device__ __forceinline__ void ReduceOrCopyMulti(
|
||||
}
|
||||
|
||||
// unrolled, by-type (mostly for unaligned buffers)
|
||||
int Nelem = (Nrem / (AUTOUNROLL*PACKELEMS/2*WARP_SIZE)) * (AUTOUNROLL*PACKELEMS/2*WARP_SIZE); // round down
|
||||
Int Nelem = (Nrem / (AUTOUNROLL*PACKELEMS/2*WARP_SIZE)) * (AUTOUNROLL*PACKELEMS/2*WARP_SIZE); // round down
|
||||
|
||||
ReduceCopyMulti<FUNC, T, AUTOUNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
ReduceCopyMulti<FUNC, T, AUTOUNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS, PreOpN>
|
||||
(w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
offset += Nelem;
|
||||
|
||||
// no unroll, by type. Should finish what's remaining.
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>
|
||||
(w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS, PreOpN>
|
||||
(w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
}
|
||||
|
||||
#endif // COMMON_KERNEL_H_
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
@@ -13,66 +13,100 @@ __device__ struct ncclShmemData* ncclShmem;
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
#else
|
||||
#define NCCL_FUNC5(func, algo, redop, type) \
|
||||
NCCL_FUNC_NAME(func, algo, LL, redop, type), \
|
||||
NCCL_FUNC_NAME(func, algo, LL128, redop, type), \
|
||||
NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type)
|
||||
#define NCCL_FUNC5(func, algo, devredop, type, nullify) \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
|
||||
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
|
||||
|
||||
#define NCCL_FUNC4(func, redop, type) \
|
||||
NCCL_FUNC5(func, TREE, redop, type), \
|
||||
NCCL_FUNC5(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
#define NCCL_FUNC4(func, devredop, type, nullify) \
|
||||
NCCL_FUNC5(func, TREE, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, RING, devredop, type, nullify), \
|
||||
NCCL_FUNC5(func, COLLNET, devredop, type, nullify)
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, uint8_t), \
|
||||
NCCL_FUNC4(func, redop, int32_t), \
|
||||
NCCL_FUNC4(func, redop, uint32_t), \
|
||||
NCCL_FUNC4(func, redop, int64_t), \
|
||||
NCCL_FUNC4(func, redop, uint64_t), \
|
||||
NCCL_FUNC4(func, redop, half), \
|
||||
NCCL_FUNC4(func, redop, float), \
|
||||
NCCL_FUNC4(func, redop, double)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t)
|
||||
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, double, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0)
|
||||
#else
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, devredop, nullForFloat) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint32_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, uint64_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, half, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, float, nullForFloat), \
|
||||
NCCL_FUNC4(func, devredop, double, nullForFloat)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0), \
|
||||
NCCL_FUNC4(func, devredop, int8_t, 0)
|
||||
#endif
|
||||
|
||||
// Must be consistent with ncclRedOp_t
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum ), \
|
||||
NCCL_FUNCS3A(func, Prod), \
|
||||
NCCL_FUNCS3A(func, Max ), \
|
||||
NCCL_FUNCS3A(func, Min )
|
||||
NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \
|
||||
NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1)
|
||||
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
|
||||
// Must be consistent with ncclFunc_t
|
||||
#define NCCL_FUNCS() { \
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),\
|
||||
NCCL_FUNCS2B(Broadcast), \
|
||||
NCCL_FUNCS2A(Reduce), \
|
||||
NCCL_FUNCS2B(AllGather), \
|
||||
NCCL_FUNCS2A(ReduceScatter), \
|
||||
NCCL_FUNCS2A(AllReduce) }
|
||||
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
__device__ ncclKern_t ncclFuncs[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
|
||||
__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
|
||||
// Don't try to initialize the host shadow copy of this device-side global
|
||||
// variable. There is no host pointer to a device-side function, which
|
||||
// confuses clang. This will be fixed in the next clang release.
|
||||
#if __CUDA_ARCH__
|
||||
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16),
|
||||
#endif
|
||||
NCCL_FUNCS2B(Broadcast),
|
||||
NCCL_FUNCS2A(Reduce),
|
||||
NCCL_FUNCS2B(AllGather),
|
||||
|
||||
@@ -17,7 +17,7 @@ targets="GENOBJS := \\\\\n"
|
||||
|
||||
for base in sendrecv all_reduce all_gather broadcast reduce reduce_scatter; do
|
||||
opn=0
|
||||
for op in sum prod min max avg; do
|
||||
for op in sum prod min max premulsum sumpostdiv; do
|
||||
dtn=0
|
||||
# Order must match that of the ncclDataType_t enum
|
||||
for dt in ${datatypes}; do
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#include "devcomm.h"
|
||||
#include "collectives.h"
|
||||
#include "reduce_kernel.h"
|
||||
#include "common.h"
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp>
|
||||
__device__ __forceinline__ void oneRankReduce() {
|
||||
ncclWork *w = &ncclShmem->work;
|
||||
int tid = threadIdx.x;
|
||||
int tn = blockDim.x;
|
||||
#pragma unroll 1
|
||||
for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) {
|
||||
ncclWorkElem *we = &w->elems[e];
|
||||
intptr_t eltN = we->coll.count;
|
||||
int bid = we->coll.bid;
|
||||
int bn = we->coll.nChannels;
|
||||
T const *src = (T const*)we->sendbuff;
|
||||
T *dst = (T*)we->recvbuff;
|
||||
|
||||
// each block/channel gets a roughly equal segment of 16 byte packs
|
||||
constexpr int EltPerPack = 16/sizeof(T);
|
||||
intptr_t packN = (eltN + EltPerPack-1) - (eltN + EltPerPack-1)%EltPerPack;
|
||||
intptr_t i0 = (bid+0)*(packN/bn) + (bid+0 < packN%bn ? bid+0 : packN%bn);
|
||||
intptr_t i1 = (bid+1)*(packN/bn) + (bid+1 < packN%bn ? bid+1 : packN%bn);
|
||||
i0 *= EltPerPack;
|
||||
i0 = i0 < eltN ? i0 : eltN;
|
||||
i1 *= EltPerPack;
|
||||
i1 = i1 < eltN ? i1 : eltN;
|
||||
src += i0;
|
||||
dst += i0;
|
||||
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1, 1>
|
||||
(tid, tn, &(we->coll.redOpArg), true, 1, &src, 1, &dst, i1-i0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE(devredop, type) \
|
||||
__device__ void NCCL_ONERANK_REDUCE_NAME(devredop, type)(struct ncclWorkElem* args) { \
|
||||
oneRankReduce<type, Func##devredop<type>>(); \
|
||||
}
|
||||
|
||||
INSTANTIATE(PreMulSum, int8_t)
|
||||
INSTANTIATE(PreMulSum, uint8_t)
|
||||
INSTANTIATE(PreMulSum, int32_t)
|
||||
INSTANTIATE(PreMulSum, uint32_t)
|
||||
INSTANTIATE(PreMulSum, int64_t)
|
||||
INSTANTIATE(PreMulSum, uint64_t)
|
||||
INSTANTIATE(PreMulSum, half)
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
INSTANTIATE(PreMulSum, rccl_bfloat16)
|
||||
#endif
|
||||
INSTANTIATE(PreMulSum, float)
|
||||
INSTANTIATE(PreMulSum, double)
|
||||
@@ -161,13 +161,13 @@ struct PrimitivesWithoutDirect {
|
||||
#define INIT_COUNTER \
|
||||
if (tid == 0) { t0 = __builtin_amdgcn_s_memrealtime(); }
|
||||
#define ACCUMULATE_COUNTER(prim) \
|
||||
if (tid == 0 && args->op.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0); \
|
||||
if (tid == 0 && args->coll.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0); \
|
||||
devProf->elems[blockIdx.x].prim##_byte += nelem * sizeof(T); }
|
||||
#else
|
||||
#define INIT_COUNTER \
|
||||
if (tid == 0) { t0 = __builtin_amdgcn_s_memrealtime(); ws = devProf->elems[blockIdx.x].wait_cycle; }
|
||||
#define ACCUMULATE_COUNTER(prim) \
|
||||
if (tid == 0 && args->op.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0 \
|
||||
if (tid == 0 && args->coll.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0 \
|
||||
+ ws - devProf->elems[blockIdx.x].wait_cycle); \
|
||||
devProf->elems[blockIdx.x].prim##_byte += nelem * sizeof(T); }
|
||||
#endif
|
||||
|
||||
@@ -284,7 +284,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL>:
|
||||
}
|
||||
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
__device__ __forceinline__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx;
|
||||
@@ -389,9 +389,9 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL>:
|
||||
public:
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0, int connIndex=0
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0, int connIndex=0
|
||||
):
|
||||
redOp(FuncTraits<RedOp>().make(ncclShmem->comm.nRanks)),
|
||||
redOp(redOpArg),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group),
|
||||
stepLines(ncclShmem->comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)),
|
||||
barriers(&ncclShmem->groups[group].barrier), barrier_next(ncclShmem->groups[group].barrier_next) {
|
||||
|
||||
@@ -280,7 +280,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128>:
|
||||
static constexpr int DataEltPerSlice = (WireWordPerSlice - WireWordPerSlice/NCCL_LL128_LINEELEMS)*(sizeof(uint64_t)/sizeof(T));
|
||||
|
||||
template <int RECV, int SEND, int SrcBuf, int DstBuf>
|
||||
__device__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
__device__ __forceinline__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) {
|
||||
constexpr int SRC = SrcBuf != -1 ? 1 : 0;
|
||||
constexpr int DST = DstBuf != -1 ? 1 : 0;
|
||||
static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh");
|
||||
@@ -357,9 +357,9 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128>:
|
||||
public:
|
||||
__device__ Primitives(
|
||||
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0, int connIndex=0
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0, int connIndex=0
|
||||
):
|
||||
redOp(FuncTraits<RedOp>().make(ncclShmem->comm.nRanks)),
|
||||
redOp(redOpArg),
|
||||
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
|
||||
flagThread((tid%8)==7), group(group),
|
||||
stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -20,14 +21,14 @@ class Primitives<
|
||||
Aborted = 0x40,
|
||||
PtrsFifoEnabled = 0x80,
|
||||
SizesFifoEnabled = 0x100,
|
||||
DirectEnabled = 0x200,
|
||||
ThreadsSynced = 0x400;
|
||||
DirectWrite = 0x200,
|
||||
DirectRead = 0x400,
|
||||
ThreadsSynced = 0x800;
|
||||
const int tid;
|
||||
int nthreads;
|
||||
int nworkers;
|
||||
const int stepSize;
|
||||
Fan fan;
|
||||
RedOp const redOp;
|
||||
int index; // Peer index I'm responsible for
|
||||
int flags;
|
||||
int group;
|
||||
@@ -45,7 +46,6 @@ class Primitives<
|
||||
uint64_t connStepCache; // Cache last seen value of (*connStepPtr)
|
||||
uint64_t* barriers;
|
||||
uint64_t* barrier_next;
|
||||
const int connIndex;
|
||||
const uint64_t opCount;
|
||||
|
||||
// Don't use barrier 0 as it's used by the final sync
|
||||
@@ -84,12 +84,15 @@ class Primitives<
|
||||
}
|
||||
|
||||
template <int DirectRecv, int DirectSend, int Recv, int Send, int Src, int Dst>
|
||||
inline __device__ void waitPeer(intptr_t dstIx, intptr_t remoteOutIx, int offset, int nelts) {
|
||||
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
|
||||
bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
|
||||
__device__ __forceinline__ void waitPeer(intptr_t dstIx, intptr_t remoteIx, int offset, int nelts) {
|
||||
const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send;
|
||||
const bool noRecvWait = DirectRecv && Src && (flags & DirectRead); // no wait when directly reading from remote input
|
||||
const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write
|
||||
#if defined(ENABLE_PROFILING) && !defined(ENABLE_TIMING_PROFILE)
|
||||
uint64_t t0 = __builtin_amdgcn_s_memrealtime();
|
||||
uint64_t t0 = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) ||
|
||||
((flags & (Send*RoleWaitSend)) && !noSendWait)) {
|
||||
int spins = 0;
|
||||
while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) {
|
||||
__builtin_amdgcn_s_sleep(8);
|
||||
@@ -98,7 +101,9 @@ class Primitives<
|
||||
//if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice));
|
||||
}
|
||||
__asm__ __volatile__("s_wakeup");
|
||||
}
|
||||
|
||||
if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) {
|
||||
if (isSendNotRecv && (flags & SizesFifoEnabled))
|
||||
STORE(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T));
|
||||
|
||||
@@ -106,10 +111,26 @@ class Primitives<
|
||||
: (ncclShmem->groups[group].srcs + Src);
|
||||
if (flags & PtrsFifoEnabled)
|
||||
loadPtr(connPtrsFifoPtr + step%NCCL_STEPS, ptrs[index]);
|
||||
else if ((isSendNotRecv ? DirectSend : DirectRecv) && (flags & DirectEnabled))
|
||||
ptrs[index] = directBuff + (isSendNotRecv ? remoteOutIx : dstIx) + offset;
|
||||
else
|
||||
else if (isSendNotRecv && DirectSend) {
|
||||
if (flags & DirectWrite) {
|
||||
ptrs[index] = directBuff + remoteIx + offset;
|
||||
} else if (flags & DirectRead) { // empty send
|
||||
ptrs[index] = nullptr;
|
||||
} else {
|
||||
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
|
||||
}
|
||||
} else if (!isSendNotRecv && DirectRecv) {
|
||||
if (flags & DirectRead) {
|
||||
ptrs[index] = directBuff + remoteIx + offset;
|
||||
} else if (flags & DirectWrite) {
|
||||
ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer
|
||||
} else {
|
||||
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
|
||||
}
|
||||
}
|
||||
else {
|
||||
ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize;
|
||||
}
|
||||
step += StepPerSlice;
|
||||
#if defined(ENABLE_PROFILING) && !defined(ENABLE_TIMING_PROFILE)
|
||||
if (opCount) {
|
||||
@@ -131,8 +152,8 @@ class Primitives<
|
||||
}
|
||||
|
||||
template <int DirectRecv1, int DirectSend1, int Recv, int Send, int SrcBuf, int DstBuf>
|
||||
inline __device__ void genericOp(
|
||||
intptr_t srcIx, intptr_t dstIx, intptr_t remoteOutIx, int nelem, bool postOp
|
||||
__device__ __forceinline__ void genericOp(
|
||||
intptr_t srcIx, intptr_t dstIx, intptr_t remoteIx, int nelem, bool postOp
|
||||
) {
|
||||
constexpr int DirectRecv = 1 && Direct && DirectRecv1;
|
||||
constexpr int DirectSend = 1 && Direct && DirectSend1;
|
||||
@@ -180,7 +201,7 @@ class Primitives<
|
||||
uint64_t t0;
|
||||
if (tid == 0) t0 = __builtin_amdgcn_s_memrealtime();
|
||||
#endif
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(dstIx, remoteOutIx, offset, sliceSize);
|
||||
waitPeer<DirectRecv, DirectSend, Recv, Send, Src, Dst>(dstIx, remoteIx, offset, sliceSize);
|
||||
subBarrier();
|
||||
#ifdef ENABLE_PROFILING
|
||||
if (tid == 0 && opCount) ncclShmem->comm.devProf->elems[blockIdx.x].wait_cycle += (__builtin_amdgcn_s_memrealtime() - t0);
|
||||
@@ -189,15 +210,24 @@ class Primitives<
|
||||
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
|
||||
if (Send) {
|
||||
// (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0).
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, (1-Send)+MaxSend>
|
||||
(tid, nworkers, redOp, 0, false,
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, (1-Send)+MaxSend, 0>
|
||||
(tid, nworkers, nullptr, false,
|
||||
1, (T const**)ncclShmem->groups[group].srcs,
|
||||
fan.nsend(), (T**)ncclShmem->groups[group].dsts+1,
|
||||
sliceSize);
|
||||
}
|
||||
} else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem->groups[group].dsts[Dst] == nullptr) {
|
||||
// For broadcast in CollNet to do empty send
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, 0>
|
||||
(tid, nworkers, ncclShmem->redOpArgs, postOp,
|
||||
Recv, (T const**)ncclShmem->groups[group].srcs,
|
||||
Dst, (T**)ncclShmem->groups[group].dsts,
|
||||
sliceSize);
|
||||
} else {
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, Recv+Src, Recv*MaxRecv+Src, Send+Dst, Send*MaxSend+Dst>
|
||||
(tid, nworkers, redOp, SrcBuf==Input ? 1 : 0, postOp,
|
||||
constexpr int PreOpN = SrcBuf != Input ? 0 :
|
||||
DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1;
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, Recv+Src, Recv*MaxRecv+Src, Send+Dst, Send*MaxSend+Dst, PreOpN>
|
||||
(tid, nworkers, ncclShmem->redOpArgs, postOp,
|
||||
Recv*fan.nrecv()+Src, (T const**)ncclShmem->groups[group].srcs,
|
||||
Send*fan.nsend()+Dst, (T**)ncclShmem->groups[group].dsts,
|
||||
sliceSize);
|
||||
@@ -231,10 +261,12 @@ class Primitives<
|
||||
}
|
||||
}
|
||||
|
||||
// Scatter and gather do not support Direct
|
||||
template <int Recv, int Send>
|
||||
inline __device__ void
|
||||
// Scatter/Gather generic op
|
||||
template <int DirectRecv1, int DirectSend1, int Recv, int Send>
|
||||
__device__ __forceinline__ void
|
||||
ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) {
|
||||
constexpr int DirectRecv = 1 && Direct && DirectRecv1;
|
||||
constexpr int DirectSend = 1 && Direct && DirectSend1;
|
||||
int offset = 0; // slice offset
|
||||
int sliceSize = stepSize*StepPerSlice;
|
||||
int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size
|
||||
@@ -243,12 +275,14 @@ class Primitives<
|
||||
for (int slice=0; slice<SlicePerChunk; ++slice) {
|
||||
int realSize = max(0, min(dataSize, peerElem-offset));
|
||||
if (tid < nworkers) {
|
||||
if (Send && (flags & RoleInput)) ncclShmem->groups[group].srcs[0] = userBuff + inpIx + offset;
|
||||
if (Recv && (flags & RoleOutput)) ncclShmem->groups[group].dsts[0] = userBuff + outIx + offset;
|
||||
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
|
||||
waitPeer<0, 0, Recv, Send, 0, 0>(0, 0, 0, realSize);
|
||||
subBarrier();
|
||||
if (Send) {
|
||||
// Scatter pre-scales data of input buffer only in non-Direct case
|
||||
constexpr int PreOpN = DirectSend ? 0 : 1;
|
||||
if (flags & RoleInput) ncclShmem->groups[group].srcs[0] = userBuff + inpIx + offset;
|
||||
if (tid == 0) ncclShmem->groups[group].totalSendSize[slice] = 0; // Skip the threadfence
|
||||
// realSize is not accurate here; but intra-node does not rely on sizes FIFO
|
||||
waitPeer<0, DirectSend, 0, 1, 1, 0>(0, inpIx, offset, realSize);
|
||||
subBarrier();
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<fan.nsend(); j++) {
|
||||
int i = (j+shift)%fan.nsend();
|
||||
@@ -256,33 +290,45 @@ class Primitives<
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
const T* src0 = (T*)ncclShmem->groups[group].srcs[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1>(tid, nworkers, redOp, 1, false, 1, &src0, 1, (T**)ncclShmem->groups[group].dsts+i, realPeerSize);
|
||||
if (realPeerSize > 0 && ncclShmem->groups[group].dsts[i] != nullptr) {
|
||||
ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, PreOpN>(tid, nworkers, ncclShmem->redOpArgs, false, 1, &src0, 1, (T**)ncclShmem->groups[group].dsts+i, realPeerSize);
|
||||
if (tid == 0) ncclShmem->groups[group].totalSendSize[slice] += realPeerSize;
|
||||
}
|
||||
}
|
||||
} else if (Recv) {
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<fan.nrecv(); j++) {
|
||||
int i = (j+shift)%fan.nrecv();
|
||||
int peerOffset = i*peerElem;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
T* dst0 = (T*)ncclShmem->groups[group].dsts[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1>(tid, nworkers, redOp, 0, postOp, 1, (T const**)ncclShmem->groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
if (flags & RoleOutput) ncclShmem->groups[group].dsts[0] = userBuff + outIx + offset;
|
||||
int peerOffset = index*peerElem;
|
||||
if (skip >= 0 && index >= skip) peerOffset += peerElem;
|
||||
// Adjust remote index with peer offset in case we are directly pulling from peer's output buffer
|
||||
waitPeer<DirectRecv, 0, 1, 0, 0, 1>(outIx, outIx+peerOffset, offset, realSize);
|
||||
subBarrier();
|
||||
if (DirectRecv && ncclShmem->groups[group].srcs[0] == ncclShmem->groups[group].dsts[0]) {
|
||||
// Since waitPeer sets srcs[0] to output buffer + offset, we are doing a direct-write based recv
|
||||
// Do nothing
|
||||
} else {
|
||||
#pragma unroll 1
|
||||
for (int j=0; j<fan.nrecv(); j++) {
|
||||
int i = (j+shift)%fan.nrecv();
|
||||
peerOffset = i*peerElem;
|
||||
if (skip >= 0 && i >= skip) peerOffset += peerElem;
|
||||
T* dst0 = (T*)ncclShmem->groups[group].dsts[0] + peerOffset;
|
||||
int realPeerSize = min(realSize, totalElem-peerOffset);
|
||||
if (realPeerSize > 0) ReduceOrCopyMulti<Unroll, RedOp, T, 1, 1, 1, 1, 0>(tid, nworkers, ncclShmem->redOpArgs, postOp, 1, (const T**)ncclShmem->groups[group].srcs+i, 1, &dst0, realPeerSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
//if (Send && (flags & RolePostSend) && realSize > 0 && index == 0) __threadfence_system();
|
||||
if (Send && (flags & RolePostSend) && ncclShmem->groups[group].totalSendSize[slice] > 0 && index == 0)
|
||||
__threadfence_system();
|
||||
__syncwarp();
|
||||
postPeer<Recv, Send>();
|
||||
offset += realSize;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadRecvConn(ncclPeer *peer) {
|
||||
__device__ __forceinline__ void loadRecvConn(ncclPeer *peer, int connIndex, struct ncclWorkElem* e) {
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) {
|
||||
// For other colls: group <= 2, hence always use conn 0
|
||||
// For P2P: Direct is set to 1, hence always use conn 0
|
||||
// Ideally we should be accepting connIndex from the constructor!
|
||||
auto *conn = &peer->recv[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
@@ -295,7 +341,25 @@ class Primitives<
|
||||
connStepPtr = conn->tail;
|
||||
connStepCache = LOAD(connStepPtr);
|
||||
flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0;
|
||||
flags |= (Direct && (conn->direct & NCCL_DIRECT_GPU)) ? DirectEnabled : 0;
|
||||
if (Direct) {
|
||||
// User buffers have been registered
|
||||
if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) {
|
||||
if (connIndex == 1) {
|
||||
flags |= DirectRead; // scatter-reduce use direct pull
|
||||
} else {
|
||||
flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite :
|
||||
(e->direct & NCCL_DIRECT_READ) ? DirectRead : 0;
|
||||
}
|
||||
} else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) {
|
||||
if (connIndex == 1) {
|
||||
flags |= DirectRead; // scatter-reduce use direct pull
|
||||
} else {
|
||||
// direct read not allowed in non-register case
|
||||
// otherwise, in one-to-multi send, we could mix empty send and intermediate send
|
||||
flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (flags & PtrsFifoEnabled)
|
||||
connPtrsFifoPtr = conn->ptrsFifo;
|
||||
else
|
||||
@@ -304,11 +368,8 @@ class Primitives<
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadSendConn(ncclPeer *peer) {
|
||||
__device__ __forceinline__ void loadSendConn(ncclPeer *peer, int connIndex, struct ncclWorkElem* e) {
|
||||
if (flags & (RoleWaitSend|RolePostSend)) {
|
||||
// For other colls: group <= 2, hence always use conn 0
|
||||
// For P2P: Direct is set to 1, hence always use conn 0
|
||||
// Ideally we should be accepting connIndex from the constructor!
|
||||
auto *conn = &peer->send[connIndex].conn;
|
||||
step = conn->step;
|
||||
step = roundUp(step, SlicePerChunk*StepPerSlice);
|
||||
@@ -328,9 +389,25 @@ class Primitives<
|
||||
if (conn->sizesFifo != nullptr) {
|
||||
flags |= SizesFifoEnabled;
|
||||
connSizesFifoPtr = conn->sizesFifo;
|
||||
} else if (Direct) {
|
||||
// User buffers have been registered
|
||||
if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) {
|
||||
if (connIndex == 1) {
|
||||
flags |= DirectRead; // scatter-reduce use direct pull
|
||||
} else {
|
||||
flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite :
|
||||
(e->direct & NCCL_DIRECT_READ) ? DirectRead : 0;
|
||||
}
|
||||
} else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) {
|
||||
if (connIndex == 1) {
|
||||
flags |= DirectRead; // scatter-reduce use direct pull
|
||||
} else {
|
||||
// direct read not allowed in non-register case
|
||||
// otherwise, in one-to-multi send, we could mix empty send and intermediate send
|
||||
flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (Direct && (conn->direct & NCCL_DIRECT_GPU))
|
||||
flags |= DirectEnabled;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -338,19 +415,19 @@ class Primitives<
|
||||
public:
|
||||
__device__ Primitives(
|
||||
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
|
||||
void const *inputBuf, void *outputBuf, int group=0, int connIndex=0
|
||||
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr
|
||||
):
|
||||
tid(tid),
|
||||
stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)),
|
||||
redOp(FuncTraits<RedOp>::make(ncclShmem->comm.nRanks)),
|
||||
connIndex((NCCL_MAX_DIRECT_ARITY==Fan::MaxSend || NCCL_MAX_DIRECT_ARITY==Fan::MaxRecv)?(group/2):connIndex),
|
||||
barriers(&ncclShmem->groups[group].barrier), barrier_next(ncclShmem->groups[group].barrier_next),
|
||||
opCount(ncclShmem->work.elems[0].op.opCount) {
|
||||
opCount(ncclShmem->work.elems[0].coll.opCount) {
|
||||
|
||||
// For send operations, we need an extra warp to overlap the threadfence and the copy
|
||||
this->nthreads = nthreads;
|
||||
this->nworkers = nthreads;
|
||||
this->group = group;
|
||||
this->group = group & (uint16_t)0xFFFF;
|
||||
int connIndex = group >> 16;
|
||||
barriers = &ncclShmem->groups[this->group].barrier;
|
||||
barrier_next = ncclShmem->groups[this->group].barrier_next;
|
||||
|
||||
int nrecv=0, nsend=0;
|
||||
while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++;
|
||||
@@ -380,10 +457,10 @@ class Primitives<
|
||||
if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index];
|
||||
if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index];
|
||||
|
||||
loadRecvConn(&ncclShmem->channel.devPeers[peer]);
|
||||
loadSendConn(&ncclShmem->channel.devPeers[peer]);
|
||||
loadRecvConn(&ncclShmem->channel.devPeers[peer], connIndex, e);
|
||||
loadSendConn(&ncclShmem->channel.devPeers[peer], connIndex, e);
|
||||
|
||||
setDataPtrs(inputBuf, outputBuf);
|
||||
setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclWorkRegElem*)e);
|
||||
}
|
||||
|
||||
__device__ ~Primitives() {
|
||||
@@ -400,10 +477,19 @@ class Primitives<
|
||||
barrier();
|
||||
}
|
||||
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf) {
|
||||
if (flags & RoleInput) userBuff = (T*)inputBuf;
|
||||
__device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclWorkRegElem* e) {
|
||||
if (flags & RoleInput) {
|
||||
userBuff = (T*)inputBuf;
|
||||
ncclShmem->redOpArgs[0] = redOpArg; // scaler for local input
|
||||
}
|
||||
if (flags & RoleOutput) userBuff = (T*)outputBuf;
|
||||
if (Direct && flags == (flags|RoleWaitRecv|DirectEnabled)) {
|
||||
bool recvProvider = flags == (flags|RoleWaitRecv|DirectWrite);
|
||||
bool sendAcceptor = flags == (flags|RoleWaitSend|DirectWrite);
|
||||
bool sendProvider = flags == (flags|RoleWaitSend|DirectRead); // sender provides direct buffer (to be fetched)
|
||||
bool recvAcceptor = flags == (flags|RoleWaitRecv|DirectRead); // receiver accepts direct buffer
|
||||
int regUsed = e != nullptr ? e->elem.regUsed : 0;
|
||||
|
||||
if (Direct && recvProvider) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem->groups[group].recvConns[index]->ptrExchange;
|
||||
// Wait for consumer to consume previous value before trampling it.
|
||||
@@ -412,9 +498,9 @@ class Primitives<
|
||||
// Encode pointer by XOR'ing against some address they definitely wouldn't send
|
||||
// since we want to allow them sending us nullptr while not colliding with
|
||||
// the empty slot value.
|
||||
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(outputBuf) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(directBuff) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
}
|
||||
if (Direct && flags == (flags|RoleWaitSend|DirectEnabled)) {
|
||||
if (Direct && sendAcceptor) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem->groups[group].sendConns[index]->ptrExchange;
|
||||
void *ptr;
|
||||
@@ -422,7 +508,51 @@ class Primitives<
|
||||
ptr = LOAD(slot);
|
||||
if (ptr != nullptr || checkAbort(spins)) break;
|
||||
}
|
||||
directBuff = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
directBuff = regUsed ? (T*)(e->dnOutputs[index]) :
|
||||
reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
*slot = nullptr;
|
||||
}
|
||||
if (Direct && sendProvider) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem->groups[group].sendConns[index]->ptrExchange;
|
||||
volatile uint64_t* argSlot0 = ncclShmem->groups[group].sendConns[index]->redOpArgExchange;
|
||||
volatile uint64_t* argSlot1 = ncclShmem->groups[group].sendConns[index]->redOpArgExchange+1;
|
||||
// Wait for consumer to consume previous value before trampling it.
|
||||
while ((*slot != nullptr || *argSlot0 != 0 || *argSlot1 !=0) && !checkAbort(spins));
|
||||
// If there is no recv, then we are directly pulling from input buffer (e.g. directScatter)
|
||||
// Otherwise, we are pulling from output buffer (e.g. recvCopyDirectSend)
|
||||
directBuff = MaxRecv == 0 ? (T*)inputBuf : (T*)outputBuf;
|
||||
// Exchange pre-scalers for use in direct pull
|
||||
*argSlot0 = (uint64_t(1)<<32) | (uint32_t)redOpArg;
|
||||
*argSlot1 = (uint64_t(1)<<32) | (uint32_t)(redOpArg>>32);
|
||||
// Encode pointer by XOR'ing against some address they definitely wouldn't send
|
||||
// since we want to allow them sending us nullptr while not colliding with
|
||||
// the empty slot value.
|
||||
*slot = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(directBuff) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
}
|
||||
if (Direct && recvAcceptor) {
|
||||
int spins = 0;
|
||||
void *volatile *slot = ncclShmem->groups[group].recvConns[index]->ptrExchange;
|
||||
volatile uint64_t* argSlot0 = ncclShmem->groups[group].recvConns[index]->redOpArgExchange;
|
||||
volatile uint64_t* argSlot1 = ncclShmem->groups[group].recvConns[index]->redOpArgExchange+1;
|
||||
void *ptr;
|
||||
while (true) {
|
||||
ptr = *slot;
|
||||
if (ptr != nullptr || checkAbort(spins)) break;
|
||||
}
|
||||
directBuff = regUsed ? (T*)(MaxSend == 0 ? e->upOutputs[index] : e->dnInputs[index]) :
|
||||
reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(ptr) ^ reinterpret_cast<uintptr_t>(slot));
|
||||
if (MaxSend != 0) { // reduce group rather than gather group
|
||||
// Store scalers for remote inputs
|
||||
uint64_t arg0, arg1;
|
||||
while (true) {
|
||||
arg0 = *argSlot0;
|
||||
arg1 = *argSlot1;
|
||||
if ((arg0 != 0 && arg1 != 0) || checkAbort(spins)) break;
|
||||
}
|
||||
ncclShmem->redOpArgs[1+index] = ((arg1 & 0xffffffff)<<32) | (arg0 & 0xffffffff);
|
||||
}
|
||||
*argSlot0 = 0; *argSlot1 = 0;
|
||||
*slot = nullptr;
|
||||
}
|
||||
}
|
||||
@@ -465,6 +595,9 @@ class Primitives<
|
||||
__device__ __forceinline__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) {
|
||||
genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, false);
|
||||
}
|
||||
__device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
@@ -473,6 +606,9 @@ class Primitives<
|
||||
__device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, -1, eltN, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, intptr_t remoteInpIx, int eltN, bool postOp=false) {
|
||||
genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, remoteInpIx, eltN, postOp);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) {
|
||||
genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp);
|
||||
@@ -484,11 +620,19 @@ class Primitives<
|
||||
|
||||
__device__ __forceinline__ void
|
||||
scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false);
|
||||
ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directScatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) {
|
||||
ScatterGatherOp<1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp);
|
||||
ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp);
|
||||
}
|
||||
__device__ __forceinline__ void
|
||||
directGather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift) {
|
||||
ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, /*postOp=*/false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -27,7 +27,7 @@ namespace {
|
||||
const int root = args->coll.root;
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->coll.redOpArg, args->coll.connIndex << 16);
|
||||
|
||||
auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int {
|
||||
int realChunkSize;
|
||||
@@ -71,7 +71,7 @@ namespace {
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCE_CHUNKSTEPS/REDUCE_SLICESTEPS, REDUCE_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
@@ -79,14 +79,14 @@ struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPL
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
template<typename T>
|
||||
struct FuncNull {
|
||||
__device__ FuncNull(uint64_t opArg=0) {}
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return 0;
|
||||
}
|
||||
@@ -22,6 +23,7 @@ struct FuncNull {
|
||||
|
||||
template<typename T>
|
||||
struct FuncSum {
|
||||
__device__ FuncSum(uint64_t opArg=0) {}
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return x + y;
|
||||
}
|
||||
@@ -29,6 +31,7 @@ struct FuncSum {
|
||||
|
||||
template<typename T>
|
||||
struct FuncProd {
|
||||
__device__ FuncProd(uint64_t opArg=0) {}
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return x * y;
|
||||
}
|
||||
@@ -36,6 +39,7 @@ struct FuncProd {
|
||||
|
||||
template<typename T>
|
||||
struct FuncMax {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return (x < y) ? y : x;
|
||||
}
|
||||
@@ -43,6 +47,7 @@ struct FuncMax {
|
||||
|
||||
template<typename T>
|
||||
struct FuncMin {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
__device__ T operator()(const T x, const T y) const {
|
||||
return (x < y) ? x : y;
|
||||
}
|
||||
@@ -53,7 +58,6 @@ struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
__device__ static Fn make(int rankN) { return Fn(); }
|
||||
template<typename T>
|
||||
__device__ static T preOp(Fn, T x) { return x; }
|
||||
template<typename T>
|
||||
@@ -75,6 +79,7 @@ static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) {
|
||||
|
||||
template<>
|
||||
struct FuncSum<int8_t> {
|
||||
__device__ FuncSum(uint64_t opArg=0) {}
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
|
||||
int32_t rv, z=0;
|
||||
@@ -90,6 +95,7 @@ struct FuncSum<int8_t> {
|
||||
};
|
||||
template<>
|
||||
struct FuncSum<uint8_t> {
|
||||
__device__ FuncSum(uint64_t opArg=0) {}
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
|
||||
int32_t rv, z=0;
|
||||
@@ -119,6 +125,7 @@ static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) {
|
||||
|
||||
template<>
|
||||
struct FuncProd<int8_t> {
|
||||
__device__ FuncProd(uint64_t opArg=0) {}
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
return mulChar4(x, y);
|
||||
}
|
||||
@@ -128,6 +135,7 @@ struct FuncProd<int8_t> {
|
||||
};
|
||||
template<>
|
||||
struct FuncProd<uint8_t> {
|
||||
__device__ FuncProd(uint64_t opArg=0) {}
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
return mulChar4(x, y);
|
||||
}
|
||||
@@ -138,6 +146,7 @@ struct FuncProd<uint8_t> {
|
||||
|
||||
template<>
|
||||
struct FuncMax<int8_t> {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
union converter { uint32_t storage; char4 a; };
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
|
||||
@@ -161,6 +170,7 @@ struct FuncMax<int8_t> {
|
||||
};
|
||||
template<>
|
||||
struct FuncMax<uint8_t> {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
union converter { uint32_t storage; uchar4 a; };
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
|
||||
@@ -185,6 +195,7 @@ struct FuncMax<uint8_t> {
|
||||
|
||||
template<>
|
||||
struct FuncMin<int8_t> {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
union converter { uint32_t storage; char4 a; };
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
|
||||
@@ -208,6 +219,7 @@ struct FuncMin<int8_t> {
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<uint8_t> {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
union converter { uint32_t storage; uchar4 a; };
|
||||
__device__ uint32_t operator()(const uint32_t x, const uint32_t y) const {
|
||||
#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500)
|
||||
@@ -232,6 +244,7 @@ struct FuncMin<uint8_t> {
|
||||
|
||||
template<>
|
||||
struct FuncSum<half> {
|
||||
__device__ FuncSum(uint64_t opArg=0) {}
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
|
||||
return __hadd2(x, y);
|
||||
@@ -256,18 +269,16 @@ struct FuncSum<half> {
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncSum<rccl_bfloat16> {
|
||||
__device__ FuncSum(uint64_t opArg=0) {}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hadd(x, y);
|
||||
#else
|
||||
return x + y;
|
||||
#endif
|
||||
return (rccl_bfloat16)((float)x + (float)y);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncProd<half> {
|
||||
__device__ FuncProd(uint64_t opArg=0) {}
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
|
||||
return __hmul2(x, y);
|
||||
@@ -292,18 +303,16 @@ struct FuncProd<half> {
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncProd<rccl_bfloat16> {
|
||||
__device__ FuncProd(uint64_t opArg=0) {}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmul(x, y);
|
||||
#else
|
||||
return x * y;
|
||||
#endif
|
||||
return (rccl_bfloat16)((float)x * (float)y);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMax<half> {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
float2 fx, fy, fr;
|
||||
fx = __half22float2(x);
|
||||
@@ -324,18 +333,16 @@ struct FuncMax<half> {
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncMax<rccl_bfloat16> {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmax(x, y);
|
||||
#else
|
||||
return x < y ? y : x;
|
||||
#endif
|
||||
return (float)x < (float)y ? y : x;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMin<half> {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
__device__ half2 operator()(const half2 x, const half2 y) const {
|
||||
float2 fx, fy, fr;
|
||||
fx = __half22float2(x);
|
||||
@@ -356,24 +363,23 @@ struct FuncMin<half> {
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncMin<rccl_bfloat16> {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
__device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __hmin(x, y);
|
||||
#else
|
||||
return x < y ? x : y;
|
||||
#endif
|
||||
return (float)x < (float)y ? x : y;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template<>
|
||||
struct FuncMax<float> {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
__device__ float operator()(float x, float y) const {
|
||||
return fmaxf(x, y);
|
||||
}
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<float> {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
__device__ float operator()(float x, float y) const {
|
||||
return fminf(x, y);
|
||||
}
|
||||
@@ -381,71 +387,98 @@ struct FuncMin<float> {
|
||||
|
||||
template<>
|
||||
struct FuncMax<double> {
|
||||
__device__ FuncMax(uint64_t opArg=0) {}
|
||||
__device__ double operator()(double x, double y) const {
|
||||
return fmax(x, y);
|
||||
}
|
||||
};
|
||||
template<>
|
||||
struct FuncMin<double> {
|
||||
__device__ FuncMin(uint64_t opArg=0) {}
|
||||
__device__ double operator()(double x, double y) const {
|
||||
return fmin(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct FuncAvg: FuncSum<T> {
|
||||
static_assert(!std::is_floating_point<T>::value, "Uhoh");
|
||||
struct IsFloatingPoint: std::false_type {};
|
||||
template<>
|
||||
struct IsFloatingPoint<half>: std::true_type {};
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct IsFloatingPoint<rccl_bfloat16>: std::true_type {};
|
||||
#endif
|
||||
template<>
|
||||
struct IsFloatingPoint<float>: std::true_type {};
|
||||
template<>
|
||||
struct IsFloatingPoint<double>: std::true_type {};
|
||||
|
||||
template<typename T, bool IsFloating=IsFloatingPoint<T>::value>
|
||||
struct FuncSumPostDiv;
|
||||
|
||||
template<typename T>
|
||||
struct FuncSumPostDiv<T, /*IsFloating=*/false>: FuncSum<T> {
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = false;
|
||||
int n;
|
||||
__device__ FuncSumPostDiv(uint64_t opArg): n(opArg) {}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ T preOp(T x) const { return x; }
|
||||
__device__ T postOp(T x) const { return T(x/n); }
|
||||
};
|
||||
|
||||
template<typename ...Arg>
|
||||
__device__ FuncAvg(int n): n(n) {}
|
||||
template<typename T>
|
||||
struct FuncSumPostDiv<T, /*IsFloating=*/true> {
|
||||
static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types.");
|
||||
};
|
||||
|
||||
__device__ T preOp(T x) const {
|
||||
return x;
|
||||
}
|
||||
__device__ T postOp(T x) const {
|
||||
return T(x/n);
|
||||
}
|
||||
template<typename T>
|
||||
struct FuncPreMulSum: FuncSum<T> { // integral T since all floats are specialized below
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
T scale;
|
||||
__device__ FuncPreMulSum(uint64_t opArg) { scale = *(T*)&opArg; }
|
||||
// inherits FuncSum::operator()
|
||||
__device__ T preOp(T x) const { return x*scale; }
|
||||
__device__ T postOp(T x) const { return x; }
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<double>: FuncSum<double> {
|
||||
struct FuncPreMulSum<double>: FuncSum<double> {
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
double rcp;
|
||||
__device__ FuncAvg(int n) {
|
||||
rcp = __drcp_rn(double(n));
|
||||
double scale;
|
||||
__device__ FuncPreMulSum(uint64_t opArg) {
|
||||
scale = *(double*)&opArg;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ double preOp(double x) const {
|
||||
return IsPreOpIdentity ? x : x*rcp;
|
||||
return IsPreOpIdentity ? x : x*scale;
|
||||
}
|
||||
__device__ double postOp(double x) const {
|
||||
return IsPostOpIdentity ? x : x*rcp;
|
||||
return IsPostOpIdentity ? x : x*scale;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<float>: FuncSum<float> {
|
||||
struct FuncPreMulSum<float>: FuncSum<float> {
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
float rcp;
|
||||
__device__ FuncAvg(int n) {
|
||||
rcp = __frcp_rn(float(n));
|
||||
float scale;
|
||||
__device__ FuncPreMulSum(uint64_t opArg) {
|
||||
scale = *(float*)&opArg;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ float preOp(float x) const {
|
||||
return IsPreOpIdentity ? x : x*rcp;
|
||||
return IsPreOpIdentity ? x : x*scale;
|
||||
}
|
||||
__device__ float postOp(float x) const {
|
||||
return IsPostOpIdentity ? x : x*rcp;
|
||||
return IsPostOpIdentity ? x : x*scale;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct FuncAvg<half>: FuncSum<half> {
|
||||
struct FuncPreMulSum<half>: FuncSum<half> {
|
||||
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
|
||||
// Obviously, the only invalid combination is both true. An improvement would be
|
||||
// make this parameterized as a build time setting and passed here through
|
||||
@@ -455,11 +488,8 @@ struct FuncAvg<half>: FuncSum<half> {
|
||||
|
||||
#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610
|
||||
half2 scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale.x = __float2half(__frsqrt_rn(float(n)));
|
||||
else
|
||||
scale.x = __float2half(__frcp_rn(float(n)));
|
||||
__device__ FuncPreMulSum(uint64_t opArg) {
|
||||
scale.x = *(half*)&opArg;
|
||||
scale.y = scale.x;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
@@ -477,11 +507,8 @@ struct FuncAvg<half>: FuncSum<half> {
|
||||
}
|
||||
#else
|
||||
float scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale = __frsqrt_rn(float(n));
|
||||
else
|
||||
scale = __frcp_rn(float(n));
|
||||
__device__ FuncPreMulSum(uint64_t opArg) {
|
||||
scale = __half2float(*(half*)&opArg);
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ half preOp(half x) const {
|
||||
@@ -515,64 +542,54 @@ struct FuncAvg<half>: FuncSum<half> {
|
||||
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
template<>
|
||||
struct FuncAvg<rccl_bfloat16>: FuncSum<rccl_bfloat16> {
|
||||
struct FuncPreMulSum<rccl_bfloat16>: FuncSum<rccl_bfloat16> {
|
||||
// Change these to switch between all prescale, all postscale, or both by sqrt(N).
|
||||
// Obviously, the only invalid combination is both true. An improvement would be
|
||||
// make this parameterized as a build time setting and passed here through
|
||||
// preprocessor definitions.
|
||||
static constexpr bool IsPreOpIdentity = true;
|
||||
static constexpr bool IsPostOpIdentity = false;
|
||||
static constexpr bool IsPreOpIdentity = false;
|
||||
static constexpr bool IsPostOpIdentity = true;
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale.x = __float2bfloat16(__frsqrt_rn(float(n)));
|
||||
else
|
||||
scale.x = __float2bfloat16(__frcp_rn(float(n)));
|
||||
scale.y = scale.x;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const {
|
||||
return IsPreOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
__device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const {
|
||||
return IsPostOpIdentity ? x : __hmul(x, scale.x);
|
||||
}
|
||||
#else
|
||||
float scale;
|
||||
__device__ FuncAvg(int n) {
|
||||
if (!IsPreOpIdentity && !IsPostOpIdentity)
|
||||
scale = __frsqrt_rn(float(n));
|
||||
else
|
||||
scale = __frcp_rn(float(n));
|
||||
__device__ FuncPreMulSum(uint64_t opArg) {
|
||||
scale = *(rccl_bfloat16*)&opArg;
|
||||
}
|
||||
// inherits FuncSum::operator()
|
||||
__device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const {
|
||||
return IsPreOpIdentity ? x : (rccl_bfloat16)(x*scale);
|
||||
return IsPreOpIdentity ? x : (rccl_bfloat16)((float)x*scale);
|
||||
}
|
||||
__device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const {
|
||||
return IsPostOpIdentity ? x : (rccl_bfloat16)(x*scale);
|
||||
return IsPostOpIdentity ? x : (rccl_bfloat16)((float)x*scale);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
struct FuncTraits<FuncAvg<T>> {
|
||||
static constexpr bool IsPreOpIdentity = FuncAvg<T>::IsPreOpIdentity;
|
||||
static constexpr bool IsPostOpIdentity = FuncAvg<T>::IsPostOpIdentity;
|
||||
struct FuncTraits<FuncPreMulSum<T>> {
|
||||
static constexpr bool IsPreOpIdentity = FuncPreMulSum<T>::IsPreOpIdentity;
|
||||
static constexpr bool IsPostOpIdentity = FuncPreMulSum<T>::IsPostOpIdentity;
|
||||
|
||||
__device__ static FuncAvg<T> make(int rankN) {
|
||||
return FuncAvg<T>(rankN);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U preOp(FuncAvg<T> fn, U x) {
|
||||
__device__ static U preOp(FuncPreMulSum<T> fn, U x) {
|
||||
return fn.preOp(x);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U postOp(FuncAvg<T> fn, U x) {
|
||||
__device__ static U postOp(FuncPreMulSum<T> fn, U x) {
|
||||
return fn.postOp(x);
|
||||
}
|
||||
};
|
||||
template<typename T>
|
||||
struct FuncTraits<FuncSumPostDiv<T>> {
|
||||
static constexpr bool IsPreOpIdentity = FuncSumPostDiv<T>::IsPreOpIdentity;
|
||||
static constexpr bool IsPostOpIdentity = FuncSumPostDiv<T>::IsPostOpIdentity;
|
||||
|
||||
template<typename U>
|
||||
__device__ static U preOp(FuncSumPostDiv<T> fn, U x) {
|
||||
return fn.preOp(x);
|
||||
}
|
||||
template<typename U>
|
||||
__device__ static U postOp(FuncSumPostDiv<T> fn, U x) {
|
||||
return fn.postOp(x);
|
||||
}
|
||||
};
|
||||
#endif // REDUCE_KERNEL_H_
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
namespace {
|
||||
template<typename T, typename RedOp, typename Proto>
|
||||
__device__ void runRing(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void runRing(ncclWorkElem *args) {
|
||||
const int tid = threadIdx.x;
|
||||
const int nthreads = args->nThreads;
|
||||
const int bid = args->coll.bid;
|
||||
@@ -26,7 +26,7 @@ namespace {
|
||||
const ssize_t size = args->coll.count;
|
||||
|
||||
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto>
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex);
|
||||
prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->coll.redOpArg, args->coll.connIndex << 16);
|
||||
|
||||
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
|
||||
ssize_t realChunkSize;
|
||||
@@ -69,7 +69,7 @@ namespace {
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
using Proto = ProtoSimple<REDUCESCATTER_CHUNKSTEPS/REDUCESCATTER_SLICESTEPS, REDUCESCATTER_SLICESTEPS>;
|
||||
runRing<T, RedOp, Proto>(args);
|
||||
}
|
||||
@@ -77,14 +77,14 @@ struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROT
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL>(args);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWorkElement<ncclFuncReduceScatter, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL128> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
|
||||
__device__ __forceinline__ void run(ncclWorkElem *args) {
|
||||
runRing<T, RedOp, ProtoLL128>(args);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
template<typename T, typename RedOp>
|
||||
struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
__device__ __attribute__((noinline)) void run(ncclWork *work) {
|
||||
__device__ __forceinline__ void run(ncclWork *work) {
|
||||
int tid = threadIdx.x;
|
||||
int group = 0;
|
||||
const int rank = ncclShmem->comm.rank;
|
||||
@@ -39,16 +39,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
|
||||
if (delta == 0) {
|
||||
if (sendbuff != recvbuff) {
|
||||
// local copy : ReduceOrCopyMulti takes an int as number of elements,
|
||||
// so we split it in blocks of 1G elements.
|
||||
int blockSize = 1<<30;
|
||||
for (size_t offset=0; offset<sendCount; offset += blockSize) {
|
||||
size_t remaining = sendCount - offset;
|
||||
if (remaining < blockSize) blockSize = remaining;
|
||||
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1>(tid, nThreadsSegment, RedOp(), 0, false, 1, &sendbuff, 1, &recvbuff, blockSize);
|
||||
sendbuff += blockSize;
|
||||
recvbuff += blockSize;
|
||||
}
|
||||
ReduceOrCopyMulti<COLL_UNROLL, RedOp, T, 1, 1, 1, 1, 0>(tid, nThreadsSegment, nullptr, false, 1, &sendbuff, 1, &recvbuff, sendCount);
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -58,7 +49,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
int const nt = nThreadsSplit;
|
||||
int const chunkSize = args->p2p.recvChunkSize/sizeof(T);
|
||||
Primitives<T, RedOp, FanAsymmetric<1, 0>, 0, Proto> prims
|
||||
(tid-t0, nt, &peer, nullptr, nullptr, recvbuff, groupRecv, args->p2p.recvIdx);
|
||||
(tid-t0, nt, &peer, nullptr, nullptr, recvbuff, /*redOpArg(ignored)=*/0, groupRecv | (args->p2p.recvIdx << 16));
|
||||
ssize_t offset = 0;
|
||||
do {
|
||||
int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T)));
|
||||
@@ -74,7 +65,7 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
|
||||
int const nt = nThreadsSegment - nThreadsSplit;
|
||||
int const chunkSize = args->p2p.sendChunkSize/sizeof(T);
|
||||
Primitives<T, RedOp, FanAsymmetric<0, 1>, 0, Proto> prims
|
||||
(tid-t0, nt, nullptr, &peer, sendbuff, nullptr, groupSend, args->p2p.sendIdx);
|
||||
(tid-t0, nt, nullptr, &peer, sendbuff, nullptr, /*redOpArg(ignored)=*/0, groupSend | (args->p2p.sendIdx << 16));
|
||||
ssize_t offset = 0;
|
||||
do {
|
||||
int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T)));
|
||||
|
||||
@@ -139,7 +139,7 @@ void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *file
|
||||
int cudaDev;
|
||||
hipGetDevice(&cudaDev);
|
||||
int pid = getpid();
|
||||
int tid = gettid();
|
||||
int tid = syscall(SYS_gettid);
|
||||
|
||||
char buffer[1024];
|
||||
size_t len = 0;
|
||||
|
||||
+442
-203
@@ -12,81 +12,61 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_ext.h>
|
||||
#include "gdrwrap.h"
|
||||
#include "bootstrap.h"
|
||||
#include <cstring>
|
||||
|
||||
#include <cstring> // std::memcpy
|
||||
|
||||
// Only generate inline kernels for LL
|
||||
#define NCCL_FUNC5(func, algo, redop, dtype) \
|
||||
NCCL_KERN_NAME(func, algo, LL, redop, dtype), \
|
||||
NCCL_KERN_NAME(func, algo, LL, redop, dtype), \
|
||||
NCCL_KERN_NAME(func, algo, LL, redop, dtype)
|
||||
#define NCCL_FUNC5(func, algo, devredop, dtype) \
|
||||
(void*)NCCL_KERN_NAME(func, algo, LL, devredop, dtype), \
|
||||
(void*)NCCL_KERN_NAME(func, algo, LL, devredop, dtype), \
|
||||
(void*)NCCL_KERN_NAME(func, algo, LL, devredop, dtype)
|
||||
|
||||
#define NCCL_FUNC4(func, redop, type) \
|
||||
NCCL_FUNC5(func, TREE, redop, type), \
|
||||
NCCL_FUNC5(func, RING, redop, type), \
|
||||
NCCL_FUNC5(func, COLLNET, redop, type)
|
||||
#define NCCL_FUNC4(func, devredop, type) \
|
||||
(void*)NCCL_FUNC5(func, TREE, devredop, type), \
|
||||
(void*)NCCL_FUNC5(func, RING, devredop, type), \
|
||||
(void*)NCCL_FUNC5(func, COLLNET, devredop, type)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int32_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint32_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int64_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, uint64_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, half), \
|
||||
(void*)NCCL_FUNC4(func, redop, float), \
|
||||
(void*)NCCL_FUNC4(func, redop, double), \
|
||||
(void*)NCCL_FUNC4(func, redop, __nv_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, redop, int8_t)
|
||||
#else
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, uint8_t), \
|
||||
NCCL_FUNC4(func, redop, int32_t), \
|
||||
NCCL_FUNC4(func, redop, uint32_t), \
|
||||
NCCL_FUNC4(func, redop, int64_t), \
|
||||
NCCL_FUNC4(func, redop, uint64_t), \
|
||||
NCCL_FUNC4(func, redop, half), \
|
||||
NCCL_FUNC4(func, redop, float), \
|
||||
NCCL_FUNC4(func, redop, double), \
|
||||
NCCL_FUNC4(func, redop, rccl_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, redop) \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t), \
|
||||
NCCL_FUNC4(func, redop, int8_t)
|
||||
#endif
|
||||
#define NCCL_FUNCS3A(func, devredop) \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, uint8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int32_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, uint32_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int64_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, uint64_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, half), \
|
||||
(void*)NCCL_FUNC4(func, devredop, float), \
|
||||
(void*)NCCL_FUNC4(func, devredop, double), \
|
||||
(void*)NCCL_FUNC4(func, devredop, rccl_bfloat16)
|
||||
#define NCCL_FUNCS3B(func, devredop) \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t), \
|
||||
(void*)NCCL_FUNC4(func, devredop, int8_t)
|
||||
|
||||
// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums.
|
||||
// Must be consistent with ncclDevRedOp_t -- but we only generate kernel for sums.
|
||||
#define NCCL_FUNCS2A(func) \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum)
|
||||
NCCL_FUNCS3A(func, Sum), /*Sum*/ \
|
||||
NCCL_FUNCS3A(func, Sum), /*Prod*/ \
|
||||
NCCL_FUNCS3A(func, Sum), /*Max*/ \
|
||||
NCCL_FUNCS3A(func, Sum), /*Min*/ \
|
||||
NCCL_FUNCS3A(func, Sum), /*PreMulSum*/ \
|
||||
NCCL_FUNCS3A(func, Sum) /*SumPostDiv*/
|
||||
#define NCCL_FUNCS2B(func) \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3A(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum), \
|
||||
NCCL_FUNCS3B(func, Sum)
|
||||
NCCL_FUNCS3B(func, Sum), /*Sum*/ \
|
||||
NCCL_FUNCS3B(func, Sum), /*Prod*/ \
|
||||
NCCL_FUNCS3B(func, Sum), /*Max*/ \
|
||||
NCCL_FUNCS3B(func, Sum), /*Min*/ \
|
||||
NCCL_FUNCS3B(func, Sum), /*PreMulSum*/ \
|
||||
NCCL_FUNCS3B(func, Sum) /*SumPostDiv*/
|
||||
|
||||
typedef void(*ncclKern_t)(struct ncclWorkElem first);
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
@@ -145,7 +125,6 @@ static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** wor
|
||||
// Initialize with work elem if provided
|
||||
if (base) memcpy(e, base, sizeof(struct ncclWorkElem));
|
||||
e->active = 1;
|
||||
e->index = opIndex;
|
||||
channel->workFifoTail++;
|
||||
channel->workCount++;
|
||||
if (work) *work = w;
|
||||
@@ -183,10 +162,11 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph
|
||||
|
||||
if (c == 0) {
|
||||
// As we inline the first coll directly, we can free it immediately.
|
||||
// Except P2P or aggregation cases
|
||||
// Except P2P or aggregation or registration cases
|
||||
struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS);
|
||||
struct ncclWorkElem* elem = work->elems;
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1) elem->active = 0;
|
||||
if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1 && elem->regUsed == 0)
|
||||
elem->active = 0;
|
||||
}
|
||||
|
||||
if (channel->gdrMemDesc) {
|
||||
@@ -371,7 +351,7 @@ RCCL_PARAM(SharpThreshold, "SHARP_THRESHOLD", 16384);
|
||||
|
||||
static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetTypeSupport) {
|
||||
if (info->comm->collNetSupport > 0 && info->nBytes < rcclParamSharpThreshold()) {
|
||||
ncclRedOp_t netOp = info->op == ncclAvg ? ncclSum : info->op;
|
||||
ncclRedOp_t netOp = info->op == ncclAvg || info->op >= ncclNumOps ? ncclSum : info->op;
|
||||
NCCLCHECK(collNetReduceSupport(info->datatype, netOp, collNetTypeSupport));
|
||||
} else {
|
||||
*collNetTypeSupport = 0;
|
||||
@@ -381,30 +361,35 @@ static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNet
|
||||
|
||||
static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) {
|
||||
struct ncclComm* comm = info->comm;
|
||||
float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete.
|
||||
// Find algorithm / protocol.
|
||||
info->algorithm = -1;
|
||||
info->protocol = -1;
|
||||
if (comm->nRanks == 1) return ncclSuccess;
|
||||
int nAlgos = NCCL_NUM_ALGORITHMS;
|
||||
for (int a=0; a<nAlgos; a++) {
|
||||
if (a == NCCL_ALGO_COLLNET && collNetTypeSupport != 1) continue;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
float time;
|
||||
NCCLCHECK(ncclTopoGetAlgoTime(info, a, p, numPipeOps, &time));
|
||||
if (time >= 0 && time < minTime) {
|
||||
info->algorithm = a;
|
||||
info->protocol = p;
|
||||
minTime = time;
|
||||
if (comm->nRanks == 1) {
|
||||
info->algorithm = NCCL_ALGO_RING;
|
||||
info->protocol = NCCL_PROTO_SIMPLE;
|
||||
}
|
||||
else {
|
||||
float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete.
|
||||
// Find algorithm / protocol.
|
||||
info->algorithm = -1;
|
||||
info->protocol = -1;
|
||||
int nAlgos = NCCL_NUM_ALGORITHMS;
|
||||
for (int a=0; a<nAlgos; a++) {
|
||||
if (a == NCCL_ALGO_COLLNET && collNetTypeSupport != 1) continue;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
float time;
|
||||
NCCLCHECK(ncclTopoGetAlgoTime(info, a, p, numPipeOps, &time));
|
||||
if (time >= 0 && time < minTime) {
|
||||
info->algorithm = a;
|
||||
info->protocol = p;
|
||||
minTime = time;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (info->algorithm == -1 || info->protocol == -1) {
|
||||
WARN("Error : no algorithm/protocol available");
|
||||
return ncclInternalError;
|
||||
}
|
||||
//if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime);
|
||||
TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime);
|
||||
}
|
||||
if (info->algorithm == -1 || info->protocol == -1) {
|
||||
WARN("Error : no algorithm/protocol available");
|
||||
return ncclInternalError;
|
||||
}
|
||||
//if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime);
|
||||
TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime);
|
||||
|
||||
int nc = (info->nChannels > 0) ? info->nChannels : comm->nChannels;
|
||||
int nt = comm->maxThreads[info->algorithm][info->protocol];
|
||||
@@ -498,15 +483,23 @@ comp_next:
|
||||
NCCLCHECK(getPatternInfo(info));
|
||||
NCCLCHECK(getLoopInfo(info));
|
||||
|
||||
work->op.opCount = info->comm->collOpCount;
|
||||
work->coll.opCount = info->comm->collOpCount;
|
||||
work->sendbuff = info->sendbuff;
|
||||
work->recvbuff = info->recvbuff;
|
||||
work->coll.root = info->root;
|
||||
work->coll.count = info->count;
|
||||
work->coll.nChannels = info->nChannels;
|
||||
work->nThreads = info->nThreads;
|
||||
work->coll.redOpArg = info->opFull.scalarArg;
|
||||
work->redOpArgIsPtr = info->opFull.scalarArgIsPtr;
|
||||
|
||||
work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol);
|
||||
if (info->comm->nRanks == 1) {
|
||||
// one-rank reduce index
|
||||
work->funcIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
|
||||
|
||||
work->coll.connIndex = 0;
|
||||
proxyArgs->connIndex = 0;
|
||||
@@ -533,7 +526,7 @@ comp_next:
|
||||
info->comm->nChannels,
|
||||
&work->clique.nChannels));
|
||||
work->clique.count = info->count;
|
||||
work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol);
|
||||
work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol);
|
||||
|
||||
// Setup pointers to where all the input/output pointers will be
|
||||
NCCLCHECK(info->comm->cliqueManager->WaitForPointers(work));
|
||||
@@ -563,6 +556,8 @@ comp_next:
|
||||
while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 32768) chunkSize /= 2;
|
||||
// Use lastChunkSize as chunkSize
|
||||
work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
|
||||
// Set direct direction for broadcast-gather (read or write)
|
||||
work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ;
|
||||
} else if (info->protocol == NCCL_PROTO_LL) {
|
||||
const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine);
|
||||
const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize;
|
||||
@@ -592,7 +587,7 @@ comp_next:
|
||||
proxyArgs->protocol = info->protocol;
|
||||
proxyArgs->dtype = info->datatype;
|
||||
proxyArgs->redOp = info->algorithm != NCCL_ALGO_COLLNET ? ncclNumOps : // Only set redOp when using CollNet
|
||||
info->op == ncclAvg ? ncclSum : // Network sees avg as sum
|
||||
info->opFull.op==ncclDevPreMulSum || info->opFull.op==ncclDevSumPostDiv ? ncclSum : // Network sees avg as sum
|
||||
info->op;
|
||||
proxyArgs->pattern = info->pattern;
|
||||
proxyArgs->root = info->root;
|
||||
@@ -618,12 +613,61 @@ static ncclResult_t checkSetStream(struct ncclInfo* info) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
struct ncclBuffRegHandle {
|
||||
hipIpcMemHandle_t sendBuffIpc;
|
||||
hipIpcMemHandle_t recvBuffIpc;
|
||||
ssize_t sendBuffOffset;
|
||||
ssize_t recvBuffOffset;
|
||||
};
|
||||
|
||||
// Register input and output buffers
|
||||
// Exchange with ranks on the same host
|
||||
static ncclResult_t ncclRegBuffAndExchange(struct ncclInfo* info, struct ncclBuffRegInfo* regInfo) {
|
||||
ncclComm_t comm = info->comm;
|
||||
if (comm->localRanks == 1) return ncclSuccess;
|
||||
if (comm->pfnCuMemGetAddressRange == NULL) return ncclSuccess; // CUDA toolkit or driver version too old
|
||||
|
||||
struct ncclBuffRegHandle regHandles[NCCL_MAX_INTRA_RANKS];
|
||||
// Get IPC handles
|
||||
// Note: the handle only corresponds to the base address of the allocation
|
||||
CUDACHECK(hipIpcGetMemHandle(®Handles[comm->intraNodeRank].sendBuffIpc, (void*)info->sendbuff));
|
||||
CUDACHECK(hipIpcGetMemHandle(®Handles[comm->intraNodeRank].recvBuffIpc, (void*)info->recvbuff));
|
||||
// Get offset of user buffer within allocation
|
||||
void* baseAddr;
|
||||
size_t size;
|
||||
CUDACHECK(comm->pfnCuMemGetAddressRange(&baseAddr, &size, (void*)info->sendbuff));
|
||||
regHandles[comm->intraNodeRank].sendBuffOffset = (char*)info->sendbuff - (char*)baseAddr;
|
||||
CUDACHECK(comm->pfnCuMemGetAddressRange(&baseAddr, &size, (void*)info->recvbuff));
|
||||
regHandles[comm->intraNodeRank].recvBuffOffset = (char*)info->recvbuff - (char*)baseAddr;
|
||||
TRACE(NCCL_COLL, "Base %p size %lu offset %ld", baseAddr, size, regHandles[comm->intraNodeRank].recvBuffOffset);
|
||||
|
||||
// Exchange handles within node
|
||||
NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->intraNodeGlobalRanks, comm->intraNodeRank, comm->localRanks, regHandles, sizeof(struct ncclBuffRegHandle)));
|
||||
// Open handles at local process
|
||||
for (int i=0; i<comm->localRanks; i++) {
|
||||
if (i == comm->intraNodeRank) {
|
||||
regInfo->sendbuffsBase[i] = regInfo->recvbuffsBase[i] = NULL;
|
||||
continue;
|
||||
}
|
||||
CUDACHECK(hipIpcOpenMemHandle(regInfo->sendbuffsBase+i, regHandles[i].sendBuffIpc, hipIpcMemLazyEnablePeerAccess));
|
||||
CUDACHECK(hipIpcOpenMemHandle(regInfo->recvbuffsBase+i, regHandles[i].recvBuffIpc, hipIpcMemLazyEnablePeerAccess));
|
||||
// Get real address of buffer
|
||||
regInfo->sendbuffs[i] = (char*)regInfo->sendbuffsBase[i] + regHandles[i].sendBuffOffset;
|
||||
regInfo->recvbuffs[i] = (char*)regInfo->recvbuffsBase[i] + regHandles[i].recvBuffOffset;
|
||||
}
|
||||
regInfo->nBuffs = comm->localRanks;
|
||||
TRACE(NCCL_COLL, "Rank %d exchanged %d buffers", comm->rank, regInfo->nBuffs);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Compute enqueue element, save it in list
|
||||
// Compute CUDA launch parameters
|
||||
// Capture time code in view of CUDA graph
|
||||
static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
|
||||
ncclComm_t comm = info->comm;
|
||||
if (comm->nRanks == 1) {
|
||||
if (comm->nRanks == 1 &&
|
||||
// User-defined reduction ops may need alter the data even for unitary reductions
|
||||
info->op < ncclNumOps) {
|
||||
if (info->sendbuff != info->recvbuff)
|
||||
CUDACHECK(hipMemcpyAsync(info->recvbuff, info->sendbuff, info->nBytes, hipMemcpyDeviceToDevice, info->stream));
|
||||
return ncclSuccess;
|
||||
@@ -651,6 +695,19 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
|
||||
comm->args.active = 2; // I am so far the last element; may be changed later in aggregation mode
|
||||
}
|
||||
|
||||
// Register and exchange input and output buffers
|
||||
if (comm->usingCudaGraph && // only in CUDA graph mode
|
||||
comm->graphRegister == 1 && // when registration is enabled
|
||||
info->algorithm == NCCL_ALGO_COLLNET && // limited to CollNet for now
|
||||
comm->intraHighestTransportType == TRANSPORT_P2P && // only when all ranks can p2p each other
|
||||
comm->intraRanks == 1) { // only in multi-process mode
|
||||
NCCLCHECK(ncclRegBuffAndExchange(info, &eqElem->buffRegInfo));
|
||||
// Disable inline argument because we need kernel to copy the entire ncclWork from workFifo
|
||||
// because the registered addresses are in ncclWork
|
||||
if (eqElem->buffRegInfo.nBuffs > 0) comm->args.active = 0;
|
||||
comm->enqueueInfo->nRegBuffs += eqElem->buffRegInfo.nBuffs;
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -667,41 +724,15 @@ static inline int findShortestChannel(ncclComm_t comm) {
|
||||
return minC;
|
||||
}
|
||||
|
||||
static inline ncclResult_t getNextChannel(ncclComm_t comm, int* nextChannel) {
|
||||
if (comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) {
|
||||
*nextChannel = findShortestChannel(comm);
|
||||
static inline int getNextChannel(ncclComm_t comm, int aggMode) {
|
||||
int nextChannel = 0;
|
||||
if (aggMode && comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) {
|
||||
nextChannel = findShortestChannel(comm);
|
||||
} else {
|
||||
*nextChannel = comm->lastChannel % comm->nChannels;
|
||||
nextChannel = comm->lastChannel % comm->nChannels;
|
||||
comm->lastChannel++;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
// Dynamic enqueue code
|
||||
static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem* eqElem) {
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs;
|
||||
|
||||
int nChannels = work->coll.nChannels;
|
||||
for (int bid=0; bid<nChannels; bid++) {
|
||||
int channelId = comm->lastChannel % comm->nChannels;
|
||||
struct ncclChannel* channel = comm->channels+channelId;
|
||||
|
||||
// Proxy
|
||||
proxyArgs->subs[0].channel = channel;
|
||||
proxyArgs->opCount = comm->collOpCount;
|
||||
proxyArgs->commOpCount = comm->opCount;
|
||||
|
||||
if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks));
|
||||
|
||||
comm->lastChannel++;
|
||||
work->coll.bid = bid % nChannels;
|
||||
NCCLCHECK(getNextOp(channel, NULL, work));
|
||||
//INFO(NCCL_COLL, "Host enqueue: bid %d channel %d index %ld nThreads %d funcIndex %d count %ld nChannels %d",
|
||||
// work->coll.bid, channelId, channel->workFifoTail, work->nThreads, work->funcIndex, work->coll.count, work->coll.nChannels);
|
||||
}
|
||||
comm->collOpCount++;
|
||||
return ncclSuccess;
|
||||
return nextChannel;
|
||||
}
|
||||
|
||||
ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
@@ -733,7 +764,7 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
|
||||
channelUsed += info->nChannels;
|
||||
// We can use fast path if all collectives are the same
|
||||
homogeneous &= info->coll == comm->asyncOps[0].coll &&
|
||||
info->op == comm->asyncOps[0].op &&
|
||||
info->opFull.op == comm->asyncOps[0].opFull.op &&
|
||||
info->datatype == comm->asyncOps[0].datatype;
|
||||
if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport));
|
||||
}
|
||||
@@ -818,13 +849,22 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
enum { COLL_SEGMENT=0, P2P_SEGMENT=1 };
|
||||
enum { RingTree_Segment=0, P2P_Segment=1, CollNet_Segment=2 };
|
||||
static int getSegment(int type, int delta, struct ncclWork* work) {
|
||||
if (type == P2P_SEGMENT) { // P2P
|
||||
// Current ncclWork is full
|
||||
if (work->elems[NCCL_MAX_WORK_ELEMENTS-1].active != 0) return -1;
|
||||
|
||||
if (type == P2P_Segment) { // P2P
|
||||
// Do not mix P2P and collective ops
|
||||
if (work->elems[0].funcIndex != FUNC_INDEX_P2P) return -1;
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS && work->elems[s].p2p.delta != delta; s++) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
} else { // aggregation
|
||||
} else if (type == CollNet_Segment) { // CollNet
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s+=NCCL_REG_ELEM_FACTOR) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
} else { // Ring or Tree
|
||||
for (int s=0; s<NCCL_MAX_WORK_ELEMENTS; s++) {
|
||||
if (work->elems[s].active == 0) return s;
|
||||
}
|
||||
@@ -838,7 +878,7 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct
|
||||
elem->nThreads = NCCL_MAX_NTHREADS;
|
||||
elem->sendbuff = info->sendbuff;
|
||||
elem->recvbuff = info->recvbuff;
|
||||
elem->op.opCount = info->comm->p2pOpCount;
|
||||
elem->p2p.opCount = info->comm->p2pOpCount;
|
||||
elem->p2p.sendCount = info->sendbytes;
|
||||
elem->p2p.recvCount = info->recvbytes;
|
||||
elem->p2p.sendChunkSize = info->sendChunkSize;
|
||||
@@ -847,13 +887,14 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) {
|
||||
static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s,
|
||||
struct ncclBuffRegInfo* regInfo, struct ncclChannel* channel, struct ncclComm* comm) {
|
||||
// Copy element into corresponding segment of ncclWork
|
||||
memcpy(work->elems+s, elem, sizeof(struct ncclWorkElem));
|
||||
work->elems[s].active = 1;
|
||||
|
||||
// Determine nThreads at dynamic time
|
||||
if (type == P2P_SEGMENT) {
|
||||
if (type == P2P_Segment) {
|
||||
const int nsegments = s+1;
|
||||
int nThreads = 512;
|
||||
while (nsegments*nThreads > 256) nThreads /= 2;
|
||||
@@ -861,6 +902,33 @@ static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */
|
||||
for (int i=0; i<nsegments; i++) work->elems[i].p2p.nThreads = nThreads;
|
||||
}
|
||||
|
||||
// Copy registered buffer addresses into ncclWork
|
||||
if (regInfo->nBuffs > 0) {
|
||||
struct ncclWorkRegElem* regElem = (struct ncclWorkRegElem*)(work->elems+s);
|
||||
// For CollNet
|
||||
for (int i=0; i<NCCL_MAX_DIRECT_ARITY; i++) {
|
||||
int peer = channel->collTree.down[i];
|
||||
if (peer == -1) break;
|
||||
int j = comm->rankToIntraNodeRank[peer];
|
||||
if (j < 0) {
|
||||
WARN("Invalid intra-node rank %d for peer %d", j, peer);
|
||||
return ncclInternalError;
|
||||
}
|
||||
regElem->dnInputs[i] = regInfo->sendbuffs[j];
|
||||
regElem->dnOutputs[i] = regInfo->recvbuffs[j];
|
||||
}
|
||||
for (int i=0; i<NCCL_MAX_DIRECT_ARITY; i++) {
|
||||
int peer = channel->collTree.up[i];
|
||||
if (peer == -1) break;
|
||||
int j = comm->rankToIntraNodeRank[peer];
|
||||
if (j < 0) {
|
||||
WARN("Invalid intra-node rank %d for peer %d", j, peer);
|
||||
return ncclInternalError;
|
||||
}
|
||||
regElem->upOutputs[i] = regInfo->recvbuffs[j];
|
||||
}
|
||||
work->elems[s].regUsed = 1;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -873,9 +941,9 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) {
|
||||
if (channel->workCount) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(P2P_SEGMENT, workElem->p2p.delta, w);
|
||||
segment = getSegment(P2P_Segment, workElem->p2p.delta, w);
|
||||
}
|
||||
if (segment == -1) {
|
||||
NCCLCHECK(getNextOp(channel, &w, NULL));
|
||||
@@ -884,7 +952,7 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e
|
||||
|
||||
// store work element into FIFO
|
||||
NCCLCHECK(ncclProxySaveP2p(comm, proxyArgs));
|
||||
NCCLCHECK(enqueueSegOp(P2P_SEGMENT, workElem, w, segment));
|
||||
NCCLCHECK(enqueueSegOp(P2P_Segment, workElem, w, segment, &eqElem->buffRegInfo, channel, comm));
|
||||
comm->p2pOpCount++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -920,15 +988,18 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem) {
|
||||
// Dynamic enqueue function for collective kernels
|
||||
// Supports both aggregated and non-aggregated modes
|
||||
ncclResult_t ncclEnqueueCollKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem, int aggMode) {
|
||||
struct ncclWorkElem* work = &eqElem->work;
|
||||
struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs;
|
||||
|
||||
int nChannels = work->coll.nChannels;
|
||||
size_t channelSize = work->coll.count*ncclTypeSize(proxyArgs->dtype)/work->coll.nChannels;
|
||||
int segmentType = proxyArgs->redOp == ncclNumOps ? RingTree_Segment : CollNet_Segment; // redOp is only set when using CollNet
|
||||
|
||||
for (int bid=0; bid<nChannels; bid++) {
|
||||
int channelId;
|
||||
NCCLCHECK(getNextChannel(comm, &channelId));
|
||||
int channelId = getNextChannel(comm, aggMode);
|
||||
struct ncclChannel* channel = comm->channels+channelId;
|
||||
|
||||
// Proxy
|
||||
@@ -937,18 +1008,19 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem*
|
||||
proxyArgs->commOpCount = comm->opCount;
|
||||
if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks));
|
||||
|
||||
// Try to reuse last work if not full yet
|
||||
work->coll.bid = bid % nChannels;
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
struct ncclWork* w = channel->workFifo+opIndex;
|
||||
struct ncclWork* w = NULL;
|
||||
int segment = -1;
|
||||
if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0 &&
|
||||
// All elems in work must have same (funcIndex,nThreads),
|
||||
// see "src/collectives/device/common.h"
|
||||
w->elems[0].funcIndex == work->funcIndex &&
|
||||
w->elems[0].nThreads == work->nThreads) {
|
||||
if (aggMode && channel->workCount) {
|
||||
// Try to pack more segments into a single operation
|
||||
segment = getSegment(COLL_SEGMENT, 0, w);
|
||||
int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS;
|
||||
w = channel->workFifo+opIndex;
|
||||
// All elems in work must have same (funcIndex,nThreads),
|
||||
// see "src/collectives/device/common.h"
|
||||
if (w->elems[0].funcIndex == work->funcIndex &&
|
||||
w->elems[0].nThreads == work->nThreads) {
|
||||
segment = getSegment(segmentType, 0, w);
|
||||
}
|
||||
}
|
||||
if (segment == -1) {
|
||||
NCCLCHECK(getNextOp(channel, &w, NULL));
|
||||
@@ -956,7 +1028,7 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem*
|
||||
}
|
||||
|
||||
// store work element into FIFO
|
||||
NCCLCHECK(enqueueSegOp(COLL_SEGMENT, work, w, segment));
|
||||
NCCLCHECK(enqueueSegOp(segmentType, work, w, segment, &eqElem->buffRegInfo, channel, comm));
|
||||
channel->totalSize += channelSize;
|
||||
}
|
||||
comm->collOpCount++;
|
||||
@@ -968,17 +1040,15 @@ void HIPRT_CB ncclEnqueueHostSetup(void* arg) {
|
||||
ncclResult_t ret;
|
||||
struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)arg;
|
||||
ncclComm_t comm = eqInfo->comm;
|
||||
int aggMode = eqInfo->elemList->count() > 1 ? 1 : 0;
|
||||
|
||||
// Iterate through the element list
|
||||
struct ncclQueueElem* eqElem = eqInfo->elemList->begin();
|
||||
while (eqElem != NULL) {
|
||||
if (eqElem->work.funcIndex == FUNC_INDEX_P2P) {
|
||||
NCCLCHECKGOTO(ncclEnqueueP2pKernel(comm, eqElem), ret, cb_end);
|
||||
} else if (eqInfo->elemList->count() > 1) {
|
||||
// We have more than one operation, hence aggregating
|
||||
NCCLCHECKGOTO(ncclEnqueueAsyncKernel(comm, eqElem), ret, cb_end);
|
||||
} else {
|
||||
NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem), ret, cb_end);
|
||||
NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem, aggMode), ret, cb_end);
|
||||
}
|
||||
eqElem = eqInfo->elemList->getNext();
|
||||
}
|
||||
@@ -996,51 +1066,95 @@ cb_end:
|
||||
template void HIPRT_CB ncclEnqueueHostSetup<0>(void*);
|
||||
template void HIPRT_CB ncclEnqueueHostSetup<1>(void*);
|
||||
|
||||
void* graphHelperFunc(void *args) {
|
||||
struct ncclGraphHelperResources* res = (struct ncclGraphHelperResources*)args;
|
||||
if (res == NULL) {
|
||||
WARN("CUDA Graph helper resource is null");
|
||||
return NULL;
|
||||
}
|
||||
int dev = res->comm->cudaDev;
|
||||
CUDACHECKIGNORE(hipSetDevice(dev));
|
||||
INFO(NCCL_COLL, "CUDA Graph helper thread created for device %d", dev);
|
||||
|
||||
volatile enum helperThreadState* state = &res->threadState;
|
||||
volatile int* ipcTail = &res->ipcTail;
|
||||
while (1) {
|
||||
int ipcTailMark = *ipcTail;
|
||||
int ipcCount = 0;
|
||||
while (res->ipcHead != ipcTailMark) {
|
||||
if (res->ipcBases[res->ipcHead] != NULL)
|
||||
CUDACHECKIGNORE(hipIpcCloseMemHandle(res->ipcBases[res->ipcHead]));
|
||||
res->ipcBases[res->ipcHead] = NULL;
|
||||
res->ipcHead = (res->ipcHead+1)%NCCL_IPC_POOL_SIZE;
|
||||
ipcCount++;
|
||||
}
|
||||
TRACE(NCCL_COLL, "CUDA Graph helper thread closed %d IPC handles", ipcCount);
|
||||
pthread_mutex_lock(&res->threadLock);
|
||||
while (res->ipcHead == *ipcTail && *state != ThreadStop) {
|
||||
pthread_cond_wait(&res->threadCond, &res->threadLock);
|
||||
}
|
||||
pthread_mutex_unlock(&res->threadLock);
|
||||
if (*state == ThreadStop) {
|
||||
INFO(NCCL_COLL, "CUDA Graph helper thread for device %d returning", dev);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ncclResult_t ncclGetCudaGraph(ncclComm_t comm, cudaGraph_t* graph) {
|
||||
comm->usingCudaGraph = 0;
|
||||
#if CUDART_VERSION >= 11030
|
||||
cudaStreamCaptureStatus captureStatus;
|
||||
unsigned long long cudaGraphId;
|
||||
hipStreamCaptureStatus captureStatus;
|
||||
unsigned long long hipGraphId;
|
||||
if (comm->driverVersion < 11030) {
|
||||
CUDACHECK(cudaStreamIsCapturing(comm->userStream, &captureStatus));
|
||||
if (captureStatus != cudaStreamCaptureStatusNone) {
|
||||
CUDACHECK(hipStreamIsCapturing(comm->userStream, &captureStatus));
|
||||
if (captureStatus != hipStreamCaptureStatusNone) {
|
||||
WARN("The installed CUDA driver is older than the minimum version (R465) required for NCCL's CUDA Graphs support");
|
||||
return ncclInvalidUsage;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
CUDACHECK(cudaStreamGetCaptureInfo_v2(comm->userStream, &captureStatus, &cudaGraphId, graph, NULL, NULL));
|
||||
if (captureStatus == cudaStreamCaptureStatusActive) {
|
||||
if (cudaGraphId != comm->lastCudaGraphId) {
|
||||
INFO(NCCL_COLL, "stream is being captured by a new graph, id %llu", cudaGraphId);
|
||||
CUDACHECK(hipStreamGetCaptureInfo_v2(comm->userStream, &captureStatus, &hipGraphId, graph, NULL, NULL));
|
||||
if (captureStatus == hipStreamCaptureStatusActive) {
|
||||
if (hipGraphId != comm->lastCudaGraphId) {
|
||||
INFO(NCCL_COLL, "stream is being captured by a new graph, id %llu", hipGraphId);
|
||||
// We are in a new graph, hence need to forget the last setup node so that
|
||||
// the first setup node in the new graph will not have a dependency
|
||||
comm->lastCudaGraphId = cudaGraphId;
|
||||
comm->lastCudaGraphId = hipGraphId;
|
||||
comm->lastSetupNode = NULL;
|
||||
}
|
||||
if (comm->launchMode == ncclComm::GROUP) comm->launchMode = ncclComm::GROUP_GRAPH;
|
||||
comm->usingCudaGraph = 1;
|
||||
|
||||
// Create helper thread that closes IPC handles during graph destruction
|
||||
// Only create this thread when buffer registration is enabled
|
||||
if ((!comm->graphHelperThread) && comm->graphRegister == 1 && comm->disableGraphHelper == 0) {
|
||||
pthread_mutex_init(&comm->graphHelperResources->threadLock, NULL);
|
||||
pthread_cond_init(&comm->graphHelperResources->threadCond, NULL);
|
||||
comm->graphHelperResources->threadState = ThreadStart;
|
||||
pthread_create(&comm->graphHelperThread, NULL, graphHelperFunc, comm->graphHelperResources);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph) {
|
||||
ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, hipGraph_t graph) {
|
||||
#if CUDART_VERSION >= 11030
|
||||
struct ncclQueueInfo* eqInfo = comm->enqueueInfo;
|
||||
// Create a CUDA object to wrap around the argument space
|
||||
// which CUDA graph would manage lifetime of
|
||||
cudaUserObject_t object;
|
||||
CUDACHECK(cudaUserObjectCreate(&object, eqInfo, ncclDestroyQueueInfo, 1/*initialRefcount*/, cudaUserObjectNoDestructorSync));
|
||||
CUDACHECK(cudaGraphRetainUserObject(graph, object, 1, cudaGraphUserObjectMove));
|
||||
hipUserObject_t object;
|
||||
CUDACHECK(hipUserObjectCreate(&object, eqInfo, ncclDestroyQueueInfo, 1/*initialRefcount*/, hipUserObjectNoDestructorSync));
|
||||
CUDACHECK(hipGraphRetainUserObject(graph, object, 1, hipGraphUserObjectMove));
|
||||
|
||||
cudaHostFn_t fn = ncclEnqueueHostSetup<1>;
|
||||
hipHostFn_t fn = ncclEnqueueHostSetup<1>;
|
||||
// Add a CPU node to the graph
|
||||
cudaGraphNode_t setupNode;
|
||||
cudaHostNodeParams setupNodeParams = {fn, eqInfo};
|
||||
hipGraphNode_t setupNode;
|
||||
hipHostNodeParams setupNodeParams = {fn, eqInfo};
|
||||
int numDependencies = comm->lastSetupNode == NULL ? 0 : 1;
|
||||
CUDACHECK(cudaGraphAddHostNode(&setupNode, graph, &comm->lastSetupNode, numDependencies, &setupNodeParams));
|
||||
CUDACHECK(cudaStreamUpdateCaptureDependencies(comm->userStream, &setupNode, 1, cudaStreamAddCaptureDependencies));
|
||||
CUDACHECK(hipGraphAddHostNode(&setupNode, graph, &comm->lastSetupNode, numDependencies, &setupNodeParams));
|
||||
CUDACHECK(hipStreamUpdateCaptureDependencies(comm->userStream, &setupNode, 1, hipStreamAddCaptureDependencies));
|
||||
comm->lastSetupNode = setupNode;
|
||||
return ncclSuccess;
|
||||
#else
|
||||
@@ -1049,6 +1163,74 @@ ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static ncclResult_t hostToDevRedOp(
|
||||
ncclDevRedOpFull *opFull, ncclRedOp_t op, ncclDataType_t datatype, ncclComm *comm
|
||||
) {
|
||||
union {
|
||||
int8_t i8;
|
||||
uint8_t u8;
|
||||
int32_t i32;
|
||||
uint32_t u32;
|
||||
int64_t i64;
|
||||
uint64_t u64;
|
||||
half f16;
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
rccl_bfloat16 bf16;
|
||||
#endif
|
||||
float f32;
|
||||
double f64;
|
||||
void *ptr;
|
||||
};
|
||||
u64 = 0;
|
||||
opFull->scalarArgIsPtr = false;
|
||||
switch (int(op)) {
|
||||
case ncclSum: opFull->op = ncclDevSum; break;
|
||||
case ncclProd: opFull->op = ncclDevProd; break;
|
||||
case ncclMax: opFull->op = ncclDevMax; break;
|
||||
case ncclMin: opFull->op = ncclDevMin; break;
|
||||
case ncclAvg:
|
||||
switch ((int)datatype) {
|
||||
case ncclInt8: case ncclInt32: case ncclInt64:
|
||||
case ncclUint8: case ncclUint32: case ncclUint64:
|
||||
opFull->op = ncclDevSumPostDiv;
|
||||
u64 = comm->nRanks;
|
||||
break;
|
||||
case ncclFloat16:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
f16 = __float2half(float(1.0/comm->nRanks)); // __double2half not supported pre CUDA 11.x
|
||||
break;
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
case ncclBfloat16:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks));
|
||||
break;
|
||||
#endif
|
||||
case ncclFloat32:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
f32 = float(1.0/comm->nRanks);
|
||||
break;
|
||||
case ncclFloat64:
|
||||
opFull->op = ncclDevPreMulSum;
|
||||
f64 = 1.0/comm->nRanks;
|
||||
break;
|
||||
}
|
||||
opFull->scalarArgIsPtr = false;
|
||||
opFull->scalarArg = u64;
|
||||
break;
|
||||
default: // user created
|
||||
int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps);
|
||||
ncclUserRedOp *user = &comm->userRedOps[ix];
|
||||
if (datatype != user->datatype) {
|
||||
WARN("Data type supplied to user-created ncclRedOp_t does not match type "
|
||||
"given to reduction operation");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
*opFull = user->opFull;
|
||||
break;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
|
||||
// [RCCL] Check for clique-based kernel support
|
||||
{
|
||||
@@ -1064,40 +1246,39 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
|
||||
}
|
||||
// [/RCCL]
|
||||
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
bool isAsync = ncclAsyncMode();
|
||||
int savedDev = -1;
|
||||
// Check arguments
|
||||
NCCLCHECK(PtrCheck(info->comm, info->opName, "comm"));
|
||||
if (isAsync && info->comm->checkPointers) {
|
||||
CUDACHECKGOTO(hipGetDevice(&savedDev), ret, end);
|
||||
CUDACHECKGOTO(hipSetDevice(info->comm->cudaDev), ret, end);
|
||||
}
|
||||
NCCLCHECKGOTO(ArgsCheck(info), ret, end);
|
||||
|
||||
// Copy reduction op state from op handle into info struct here since the
|
||||
// op handle may be destroyed before ncclGroupEnd().
|
||||
NCCLCHECKGOTO(hostToDevRedOp(&info->opFull, info->op, info->datatype, info->comm), ret, end);
|
||||
|
||||
// Launch asynchronously if needed
|
||||
if (ncclAsyncMode()) {
|
||||
ncclResult_t ret = ncclSuccess;
|
||||
int savedDev = -1;
|
||||
// Check arguments
|
||||
NCCLCHECK(PtrCheck(info->comm, info->opName, "comm"));
|
||||
if (info->comm->checkPointers) {
|
||||
CUDACHECKGOTO(hipGetDevice(&savedDev), ret, end);
|
||||
CUDACHECKGOTO(hipSetDevice(info->comm->cudaDev), ret, end);
|
||||
}
|
||||
NCCLCHECKGOTO(ArgsCheck(info), ret, end);
|
||||
if (isAsync) {
|
||||
// Always register comm even in case of error to make sure ncclGroupEnd
|
||||
// cleans it up.
|
||||
NCCLCHECKGOTO(ncclAsyncColl(info->comm), ret, end);
|
||||
NCCLCHECKGOTO(checkSetStream(info), ret, end);
|
||||
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p devRedOp %d isPtr %d scaler %lx",
|
||||
info->opName, info->coll == ncclFuncSendRecv ? info->comm->p2pOpCount : info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count,
|
||||
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream);
|
||||
info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream, info->opFull.op, info->opFull.scalarArgIsPtr, info->opFull.scalarArg);
|
||||
|
||||
if (info->coll == ncclFuncSendRecv) { //p2p stored separately
|
||||
NCCLCHECKGOTO(ncclSaveP2p(info), ret, end);
|
||||
} else {
|
||||
NCCLCHECKGOTO(ncclSaveAsyncColl(info), ret, end);
|
||||
}
|
||||
|
||||
end:
|
||||
if (savedDev != -1) CUDACHECK(hipSetDevice(savedDev));
|
||||
ncclAsyncErrCheck(ret);
|
||||
return ret;
|
||||
} else {
|
||||
NCCLCHECK(PtrCheck(info->comm, info->opName, "comm"));
|
||||
NCCLCHECK(ArgsCheck(info));
|
||||
NCCLCHECK(checkSetStream(info));
|
||||
NCCLCHECKGOTO(checkSetStream(info), ret, end);
|
||||
|
||||
INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p",
|
||||
info->opName, info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count,
|
||||
@@ -1106,24 +1287,82 @@ end:
|
||||
// Check whether we are in cuda graph mode
|
||||
cudaGraph_t graph;
|
||||
ncclComm_t comm = info->comm;
|
||||
NCCLCHECK(ncclGetCudaGraph(comm, &graph));
|
||||
NCCLCHECKGOTO(ncclGetCudaGraph(comm, &graph), ret, end);
|
||||
|
||||
// Common part between graph mode and non-graph mode
|
||||
NCCLCHECK(ncclSetupCollKernel(info));
|
||||
NCCLCHECKGOTO(ncclSetupCollKernel(info), ret, end);
|
||||
|
||||
// Host setup
|
||||
if (comm->usingCudaGraph) {
|
||||
NCCLCHECK(ncclCudaGraphHostSetup(comm, graph));
|
||||
NCCLCHECKGOTO(ncclCudaGraphHostSetup(comm, graph), ret, end);
|
||||
} else {
|
||||
ncclEnqueueHostSetup<0>(comm->enqueueInfo);
|
||||
NCCLCHECK(comm->enqueueInfo->ret);
|
||||
NCCLCHECKGOTO(comm->enqueueInfo->ret, ret, end);
|
||||
}
|
||||
|
||||
// Common part between graph mode and non-graph mode
|
||||
NCCLCHECK(ncclLaunchBarrier(comm));
|
||||
NCCLCHECK(ncclLaunchKernel(comm));
|
||||
NCCLCHECK(ncclRecordEvents(comm));
|
||||
NCCLCHECK(ncclLaunchReset(comm));
|
||||
return ncclSuccess;
|
||||
NCCLCHECKGOTO(ncclLaunchBarrier(comm), ret, end);
|
||||
NCCLCHECKGOTO(ncclLaunchKernel(comm), ret, end);
|
||||
NCCLCHECKGOTO(ncclRecordEvents(comm), ret, end);
|
||||
NCCLCHECKGOTO(ncclLaunchReset(comm), ret, end);
|
||||
}
|
||||
end:
|
||||
if (isAsync && savedDev != -1) CUDACHECK(hipSetDevice(savedDev));
|
||||
if (isAsync) ncclAsyncErrCheck(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
NCCL_API(ncclResult_t, ncclRedOpCreatePreMulSum, ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm);
|
||||
ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm) {
|
||||
if (comm->userRedOpFreeHead == comm->userRedOpCapacity) {
|
||||
// double capacity and resize
|
||||
int cap = 2*comm->userRedOpCapacity;
|
||||
if (cap < 4) cap = 4;
|
||||
ncclUserRedOp *ops = new ncclUserRedOp[cap];
|
||||
std::memcpy(ops, comm->userRedOps, comm->userRedOpCapacity*sizeof(ncclUserRedOp));
|
||||
for(int ix=comm->userRedOpCapacity; ix < cap; ix++)
|
||||
ops[ix].freeNext = ix + 1;
|
||||
delete[] comm->userRedOps;
|
||||
comm->userRedOps = ops;
|
||||
comm->userRedOpCapacity = cap;
|
||||
}
|
||||
// pop from free list
|
||||
int ix = comm->userRedOpFreeHead;
|
||||
ncclUserRedOp *user = &comm->userRedOps[ix];
|
||||
comm->userRedOpFreeHead = user->freeNext;
|
||||
|
||||
user->freeNext = -1; // allocated
|
||||
user->datatype = datatype;
|
||||
user->opFull.op = ncclDevPreMulSum;
|
||||
if (residence == ncclScalarHostImmediate) {
|
||||
user->opFull.scalarArgIsPtr = false;
|
||||
std::memcpy(&user->opFull.scalarArg, scalar, ncclTypeSize(datatype));
|
||||
} else {
|
||||
user->opFull.scalarArgIsPtr = true;
|
||||
user->opFull.scalarArg = reinterpret_cast<uint64_t>(scalar);
|
||||
}
|
||||
*op = ncclRedOp_t(int(ncclNumOps) + ix);
|
||||
*op = ncclUserRedOpMangle(comm, *op);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
NCCL_API(ncclResult_t, ncclRedOpDestroy, ncclRedOp_t op, ncclComm_t comm);
|
||||
ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm) {
|
||||
if (0 <= int(op) && int(op) < int(ncclNumOps)) {
|
||||
WARN("ncclRedOpDestroy : operator is a NCCL builtin.");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
if (int(op) < 0 || int(ncclMaxRedOp) < int(op)) {
|
||||
WARN("ncclRedOpDestroy : operator is garbage.");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps);
|
||||
if (comm->userRedOpCapacity <= ix || comm->userRedOps[ix].freeNext != -1) {
|
||||
WARN("ncclRedOpDestroy : operator unknown to this communicator.");
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
// push to free list
|
||||
comm->userRedOps[ix].freeNext = comm->userRedOpFreeHead;
|
||||
comm->userRedOpFreeHead = ix;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -378,8 +378,11 @@ ncclResult_t ncclTopoAddGpu(struct ncclXmlNode* xmlGpu, struct ncclTopoSystem* s
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
struct kvDict kvDictPciClass[] = { { "0x060400", PCI }, { "0x068000", NVS }, { "0x068001", CPU }, { "0x030200", GPU }, { "0x030000", GPU }, { "0x038000", GPU }, { "0x020700", NIC }, { "0x020000", NIC }, { NULL, PCI /* Default fallback value */ } };
|
||||
struct kvDict kvDictPciGen[] = { { "2.5 GT/s", 15 }, { "5 GT/s", 30 }, { "8 GT/s", 60 }, { "16 GT/s", 120 }, { "32 GT/s", 240 }, { "8.0 GT/s", 60 }, { "16.0 GT/s", 120 }, { "32.0 GT/s", 240 }, { NULL, 60 /* Default fallback */ } }; // x100 Mbps per lane
|
||||
struct kvDict kvDictPciClass[] = { { "0x060400", PCI }, { "0x068000", NVS }, { "0x068001", CPU }, { "0x03", GPU }, { "0x02", NIC }, { NULL, PCI /* Default fallback value */ } };
|
||||
struct kvDict kvDictPciGen[] = {
|
||||
{ "2.5 GT/s", 15 }, { "5 GT/s", 30 }, { "8 GT/s", 60 }, { "16 GT/s", 120 }, /* Kernel 5.6 and earlier */
|
||||
{ "2.5 GT/s PCIe", 15 }, { "5.0 GT/s PCIe", 30 }, { "8.0 GT/s PCIe", 60 }, { "16.0 GT/s PCIe", 120 }, { "32.0 GT/s PCIe", 240 }, { "64.0 GT/s PCIe", 480 },
|
||||
{ NULL, 60 /* Default fallback */ } }; // x100 Mbps per lane
|
||||
ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* system, struct ncclTopoNode* parent) {
|
||||
const char* str;
|
||||
|
||||
|
||||
@@ -660,7 +660,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm
|
||||
NCCLCHECK(xmlGetAttrInt(gpuNode, "arch", &arch.value));
|
||||
|
||||
struct ncclXmlNode* nvlNode = NULL;
|
||||
NCCLCHECK(xmlGetSub(pciNode, "nvlink", &nvlNode));
|
||||
NCCLCHECK(xmlGetSub(gpuNode, "nvlink", &nvlNode));
|
||||
if (nvlNode == NULL) {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
const char* busId;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -7,6 +8,11 @@
|
||||
#ifndef XML_H_
|
||||
#define XML_H_
|
||||
|
||||
#include "nccl.h"
|
||||
#include "debug.h"
|
||||
#include "checks.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
// A few constraints to make the implementation easy
|
||||
#define MAX_STR_LEN 255
|
||||
#define MAX_ATTR_COUNT 16
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "debug.h"
|
||||
#include "enqueue.h"
|
||||
#include "transport.h"
|
||||
#include <unistd.h>
|
||||
|
||||
#define MAX_ASYNC_OPS 128
|
||||
thread_local pthread_t ncclGroupThreads[MAX_ASYNC_OPS];
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
#include "checks.h"
|
||||
#include "align.h"
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
@@ -28,7 +31,7 @@ static inline ncclResult_t ncclCudaHostFree(void* ptr) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) {
|
||||
static ncclResult_t ncclCalloc(T** ptr, size_t nelem) {
|
||||
void* p = malloc(nelem*sizeof(T));
|
||||
if (p == NULL) {
|
||||
WARN("Failed to malloc %ld bytes", nelem*sizeof(T));
|
||||
@@ -36,10 +39,8 @@ static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc,
|
||||
}
|
||||
memset(p, 0, nelem*sizeof(T));
|
||||
*ptr = (T*)p;
|
||||
INFO(NCCL_ALLOC, "%s:%d Mem Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr);
|
||||
return ncclSuccess;
|
||||
}
|
||||
#define ncclCalloc(...) ncclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
|
||||
|
||||
struct __attribute__ ((aligned(64))) allocationTracker {
|
||||
union {
|
||||
|
||||
@@ -17,7 +17,8 @@ ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commSt
|
||||
ncclResult_t bootstrapAllGather(void* commState, void* allData, int size);
|
||||
ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size);
|
||||
ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size);
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks);
|
||||
ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag);
|
||||
ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, int nranks, void* allData, int size);
|
||||
ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, hipIpcMemHandle_t* ipc, void** ptr);
|
||||
ncclResult_t bootstrapRemFree(int id, int rank, void* commState);
|
||||
ncclResult_t bootstrapClose(void* commState);
|
||||
|
||||
@@ -8,78 +8,104 @@
|
||||
#ifndef NCCL_COLLECTIVES_H_
|
||||
#define NCCL_COLLECTIVES_H_
|
||||
|
||||
#define FUNC_INDEX_P2P (NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumOps)
|
||||
#define FUNC_INDEX(func, redop, ncclType, al, pr) ((((((func)*ncclNumOps + (redop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
|
||||
enum ncclDevRedOp_t {
|
||||
ncclDevSum, ncclDevProd, ncclDevMax, ncclDevMin,
|
||||
ncclDevPreMulSum, ncclDevSumPostDiv,
|
||||
ncclNumDevRedOps
|
||||
};
|
||||
struct ncclDevRedOpFull {
|
||||
ncclDevRedOp_t op;
|
||||
bool scalarArgIsPtr;
|
||||
uint64_t scalarArg;
|
||||
};
|
||||
|
||||
#define NCCL_FUNC_NAME(func, algo, proto, redop, type) \
|
||||
ncclFunction_##func##_##algo##_##proto##_##redop##_##type
|
||||
#define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps)
|
||||
#define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr))
|
||||
|
||||
#define NCCL_KERN_NAME(func, algo, proto, redop, type) \
|
||||
ncclKernel_##func##_##algo##_##proto##_##redop##_##type
|
||||
#define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \
|
||||
ncclFunction_##func##_##algo##_##proto##_##devredop##_##type
|
||||
|
||||
#define NCCL_ONERANK_REDUCE_NAME(devredop, type) \
|
||||
ncclFunction_OneRankReduce_##devredop##_##type
|
||||
|
||||
#define NCCL_KERN_NAME(func, algo, proto, devredop, type) \
|
||||
ncclKernel_##func##_##algo##_##proto##_##devredop##_##type
|
||||
|
||||
#define NCCL_IMPL_NAME(func, algo, proto) \
|
||||
nccl##func##algo##proto
|
||||
|
||||
/* Declare all collective operations */
|
||||
#define DECL5(func, algo, proto, redop, type) \
|
||||
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first); \
|
||||
#define DECL5(func, algo, proto, devredop, type) \
|
||||
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(struct ncclWorkElem* args); \
|
||||
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(ncclWorkElem c); \
|
||||
|
||||
//extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(); \
|
||||
//extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem c); \
|
||||
#define CONCAT(a,b) a##b
|
||||
#define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(t, f)
|
||||
#define MACRO_IF_0(t, f) f
|
||||
#define MACRO_IF_1(t, f) t
|
||||
|
||||
#define DECL4(func, algo, redop, type) \
|
||||
DECL5(func, algo, SIMPLE, redop, type) \
|
||||
DECL5(func, algo, LL, redop, type) \
|
||||
DECL5(func, algo, LL128, redop, type)
|
||||
#define DECL4(func, algo, devredop, type, undef) \
|
||||
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, SIMPLE, devredop, type)) \
|
||||
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL, devredop, type)) \
|
||||
MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type))
|
||||
|
||||
#define DECL3(func, redop, type) \
|
||||
DECL4(func, RING, redop, type) \
|
||||
DECL4(func, TREE, redop, type) \
|
||||
DECL4(func, COLLNET, redop, type)
|
||||
#define DECL3(func, devredop, type, undef) \
|
||||
DECL4(func, RING, devredop, type, undef) \
|
||||
DECL4(func, TREE, devredop, type, undef) \
|
||||
DECL4(func, COLLNET, devredop, type, undef)
|
||||
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
DECL3(func, redop, uint8_t) \
|
||||
DECL3(func, redop, int32_t) \
|
||||
DECL3(func, redop, uint32_t) \
|
||||
DECL3(func, redop, int64_t) \
|
||||
DECL3(func, redop, uint64_t) \
|
||||
DECL3(func, redop, half) \
|
||||
DECL3(func, redop, float) \
|
||||
DECL3(func, redop, double) \
|
||||
DECL3(func, redop, __nv_bfloat16)
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
#define DECL2(func, devredop, undefForFloat) \
|
||||
DECL3(func, devredop, int8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, half, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, float, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, double, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, rccl_bfloat16, /*undef=*/undefForFloat)
|
||||
#else
|
||||
#define DECL2(func, redop) \
|
||||
DECL3(func, redop, int8_t) \
|
||||
DECL3(func, redop, uint8_t) \
|
||||
DECL3(func, redop, int32_t) \
|
||||
DECL3(func, redop, uint32_t) \
|
||||
DECL3(func, redop, int64_t) \
|
||||
DECL3(func, redop, uint64_t) \
|
||||
DECL3(func, redop, half) \
|
||||
DECL3(func, redop, float) \
|
||||
DECL3(func, redop, double) \
|
||||
DECL3(func, redop, rccl_bfloat16)
|
||||
#define DECL2(func, devredop, undefForFloat) \
|
||||
DECL3(func, devredop, int8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint8_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint32_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, int64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, uint64_t, /*undef=*/0) \
|
||||
DECL3(func, devredop, half, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, float, /*undef=*/undefForFloat) \
|
||||
DECL3(func, devredop, double, /*undef=*/undefForFloat)
|
||||
#endif
|
||||
|
||||
#define DECL(func) \
|
||||
DECL2(func, Sum) \
|
||||
DECL2(func, Prod) \
|
||||
DECL2(func, Min) \
|
||||
DECL2(func, Max) \
|
||||
DECL2(func, Avg) \
|
||||
DECL2(func, Sum, /*undefForFloat=*/0) \
|
||||
DECL2(func, Prod, /*undefForFloat=*/0) \
|
||||
DECL2(func, Min, /*undefForFloat=*/0) \
|
||||
DECL2(func, Max, /*undefForFloat=*/0) \
|
||||
DECL2(func, PreMulSum, /*undefForFloat=*/0) \
|
||||
DECL2(func, SumPostDiv, /*undefForFloat=*/1)
|
||||
|
||||
#define DECL_ALL \
|
||||
DECL2(Broadcast, Sum) \
|
||||
DECL(Reduce) \
|
||||
DECL2(AllGather, Sum) \
|
||||
DECL(ReduceScatter) \
|
||||
DECL(AllReduce) \
|
||||
DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) \
|
||||
DECL2(Broadcast, Sum, /*undefForFloat=*/0)
|
||||
DECL(Reduce)
|
||||
DECL2(AllGather, Sum, /*undefForFloat=*/0)
|
||||
DECL(ReduceScatter)
|
||||
DECL(AllReduce)
|
||||
DECL5(SendRecv, RING, SIMPLE, Sum, int8_t)
|
||||
|
||||
DECL_ALL
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, half)(struct ncclWorkElem* args);
|
||||
#if defined(RCCL_BFLOAT16)
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16)(struct ncclWorkElem* args);
|
||||
#endif
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)(struct ncclWorkElem* args);
|
||||
extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(struct ncclWorkElem* args);
|
||||
|
||||
// CHUNKSIZE must be a multiple of SLICESIZE
|
||||
//#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4)
|
||||
@@ -99,5 +125,6 @@ DECL_ALL
|
||||
#define REDUCE_SLICESTEPS 1
|
||||
#define REDUCE_CHUNKSTEPS 1
|
||||
#define SENDRECV_SLICEFACTOR 1
|
||||
#define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above
|
||||
|
||||
#endif
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
typedef void *cudaGraphNode_t;
|
||||
#define HIPRT_CB
|
||||
#else
|
||||
#include "collectives.h"
|
||||
|
||||
#if CUDART_VERSION < 9000
|
||||
struct cudaLaunchParams {
|
||||
void *func;
|
||||
@@ -40,13 +42,16 @@ struct cudaLaunchParams {
|
||||
#define NCCL_LL128_THREAD_THRESHOLD 8
|
||||
#define NCCL_SIMPLE_THREAD_THRESHOLD 64
|
||||
|
||||
#define NCCL_MAX_INTRA_RANKS 32
|
||||
|
||||
struct ncclSendMem {
|
||||
union {
|
||||
struct {
|
||||
uint64_t head;
|
||||
char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)];
|
||||
void* ptrExchange;
|
||||
char pad2[CACHE_LINE_SIZE-sizeof(void*)];
|
||||
uint64_t redOpArgExchange[2];
|
||||
char pad2[CACHE_LINE_SIZE-sizeof(void*)-2*sizeof(uint64_t)];
|
||||
};
|
||||
char pad3[MEM_ALIGN];
|
||||
};
|
||||
@@ -66,6 +71,28 @@ struct ncclRecvMem {
|
||||
char buff[1]; // Actually larger than that
|
||||
};
|
||||
|
||||
typedef hipError_t(*pfn_cuMemGetAddressRange_t)(void**, size_t*, void*);
|
||||
|
||||
enum helperThreadState {ThreadStart, ThreadStop};
|
||||
|
||||
#define NCCL_IPC_POOL_SIZE (2*NCCL_MAX_INTRA_RANKS*NCCL_MAX_OPS)
|
||||
|
||||
struct ncclGraphHelperResources {
|
||||
ncclComm* comm;
|
||||
pthread_mutex_t threadLock;
|
||||
pthread_cond_t threadCond;
|
||||
enum helperThreadState threadState;
|
||||
void* ipcBases[NCCL_IPC_POOL_SIZE];
|
||||
int ipcTail;
|
||||
int ipcHead;
|
||||
};
|
||||
|
||||
struct ncclUserRedOp {
|
||||
int freeNext; // -1=allocated, otherwise index of next free entry in array
|
||||
ncclDataType_t datatype;
|
||||
ncclDevRedOpFull opFull;
|
||||
};
|
||||
|
||||
struct ncclComm {
|
||||
struct ncclChannel channels[MAXCHANNELS];
|
||||
|
||||
@@ -86,7 +113,12 @@ struct ncclComm {
|
||||
|
||||
int node;
|
||||
int nNodes;
|
||||
|
||||
// Intra-node rank info
|
||||
int intraNodeGlobalRanks[NCCL_MAX_INTRA_RANKS];
|
||||
int localRanks;
|
||||
int intraNodeRank;
|
||||
int8_t* rankToIntraNodeRank;
|
||||
|
||||
enum { GROUP, PARALLEL, GROUP_GRAPH } launchMode;
|
||||
hipStream_t userStream;
|
||||
@@ -158,6 +190,7 @@ struct ncclComm {
|
||||
|
||||
// Whether this communicator uses collNet
|
||||
int collNetSupport;
|
||||
int intraHighestTransportType;
|
||||
|
||||
// Store info of async operations
|
||||
struct ncclInfo* asyncOps;
|
||||
@@ -181,9 +214,38 @@ struct ncclComm {
|
||||
// Store info for cudaGraph
|
||||
int usingCudaGraph; // Only use it during capture time, not launch time
|
||||
struct ncclQueueInfo* enqueueInfo;
|
||||
int nQueueInfoCreated;
|
||||
int nQueueInfoDestroyed;
|
||||
cudaGraphNode_t lastSetupNode;
|
||||
unsigned long long lastCudaGraphId;
|
||||
int driverVersion;
|
||||
pfn_cuMemGetAddressRange_t pfnCuMemGetAddressRange;
|
||||
pthread_t graphHelperThread;
|
||||
struct ncclGraphHelperResources* graphHelperResources;
|
||||
int disableGraphHelper;
|
||||
int graphRegister;
|
||||
|
||||
// user-created reduction ops
|
||||
int userRedOpCapacity, userRedOpFreeHead;
|
||||
ncclUserRedOp *userRedOps;
|
||||
};
|
||||
|
||||
// Scrambles the bits of non-builtin values of ncclRedOp_t according to the
|
||||
// communicator memory address. Used to catch bugs so that integer handles
|
||||
// associated with this communicator won't collide with handles of other
|
||||
// communicatrs. This function is its own inverse.
|
||||
static inline ncclRedOp_t ncclUserRedOpMangle(ncclComm *comm, ncclRedOp_t op) {
|
||||
// Preserve the built-in values.
|
||||
if(int(op) < int(ncclNumOps))
|
||||
return op;
|
||||
uint64_t h = reinterpret_cast<uint64_t>(comm);
|
||||
h ^= h >> 32;
|
||||
h *= 0x9e3779b97f4a7c13u; // Knuth's 64-bit magical hash constant
|
||||
h >>= 32; // h is now an excellent 32-bit hash of the comm pointer
|
||||
h &= int(ncclMaxRedOp); // ncclMaxRedOp is a power of 2 minus 1
|
||||
int op1 = int(h) ^ int(op);
|
||||
// Since builtin values are preserved, we also have to preserve their preimage.
|
||||
return op1 < int(ncclNumOps) ? op : ncclRedOp_t(op1);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
/*************************************************************************
|
||||
<<<<<<< HEAD
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
=======
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
>>>>>>> nccl/master
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
@@ -7,17 +7,14 @@
|
||||
#ifndef NCCL_DEBUG_H_
|
||||
#define NCCL_DEBUG_H_
|
||||
|
||||
#include "core.h"
|
||||
|
||||
#include "nccl_net.h"
|
||||
#include <stdio.h>
|
||||
#include <chrono>
|
||||
|
||||
#include <sys/syscall.h>
|
||||
#include <limits.h>
|
||||
#include <string.h>
|
||||
#include "nccl_net.h"
|
||||
|
||||
#define gettid() (pid_t) syscall(SYS_gettid)
|
||||
#include <pthread.h>
|
||||
|
||||
extern int ncclDebugLevel;
|
||||
extern uint64_t ncclDebugMask;
|
||||
|
||||
@@ -96,8 +96,11 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK
|
||||
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 2
|
||||
#define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS)
|
||||
|
||||
#define NCCL_DIRECT_GPU 0x01
|
||||
#define NCCL_DIRECT_NIC 0x10
|
||||
#define NCCL_DIRECT_WRITE 0x01
|
||||
#define NCCL_DIRECT_READ 0x02
|
||||
#define NCCL_DIRECT_NIC 0x04
|
||||
#define NCCL_IPC_WRITE 0x08
|
||||
#define NCCL_IPC_READ 0x10
|
||||
|
||||
struct ncclConnInfo {
|
||||
// Regular comm mechanism
|
||||
@@ -108,6 +111,7 @@ struct ncclConnInfo {
|
||||
int direct; // Direct communication
|
||||
int shared; // Buffers are shared
|
||||
void **ptrExchange; // Pointer exchange for direct communication
|
||||
uint64_t* redOpArgExchange; // PreOp scaler exchange for direct pull case
|
||||
|
||||
int *sizesFifo; // Sizes fifo from GPU to proxy
|
||||
void* *ptrsFifo; // Buffer fifo from proxy to GPU
|
||||
@@ -175,7 +179,7 @@ struct ncclPeer {
|
||||
struct ncclDevComm;
|
||||
|
||||
#pragma pack(push) /* push current alignment to stack */
|
||||
#pragma pack(8) /* set alignment to 4 bytes boundary */
|
||||
#pragma pack(8) /* set alignment to 8 bytes boundary */
|
||||
#define NCCL_MAX_WORK_ELEMENTS 1
|
||||
#define NCCL_MAX_GROUPS (NCCL_MAX_NTHREADS/WARP_SIZE)
|
||||
|
||||
@@ -187,8 +191,9 @@ struct ncclWorkElem {
|
||||
struct ncclDevComm* comm;
|
||||
uint16_t nThreads;
|
||||
uint16_t funcIndex;
|
||||
uint16_t index;
|
||||
uint16_t active;
|
||||
uint8_t regUsed;
|
||||
uint8_t direct;
|
||||
uint8_t active, redOpArgIsPtr;
|
||||
|
||||
const void * sendbuff;
|
||||
void * recvbuff;
|
||||
@@ -198,10 +203,12 @@ struct ncclWorkElem {
|
||||
struct {
|
||||
size_t count;
|
||||
size_t lastChunkSize;
|
||||
uint32_t root;
|
||||
uint64_t redOpArg;
|
||||
uint16_t root;
|
||||
uint8_t bid;
|
||||
uint8_t nChannels;
|
||||
uint8_t connIndex;
|
||||
uint16_t connIndex;
|
||||
uint16_t opCount;
|
||||
} coll;
|
||||
struct {
|
||||
size_t sendCount;
|
||||
@@ -217,18 +224,16 @@ struct ncclWorkElem {
|
||||
};
|
||||
uint16_t padding;
|
||||
};
|
||||
} p2p;
|
||||
struct {
|
||||
uint16_t padding[15];
|
||||
uint16_t opCount;
|
||||
} op;
|
||||
} p2p;
|
||||
// [RCCL] Clique-based arguments
|
||||
// NOTE: Follows same field structure as coll
|
||||
// because nChannels is accessed from "coll" struct.
|
||||
struct {
|
||||
size_t count;
|
||||
cliqueDevicePtrs_t* ptrs;
|
||||
uint32_t unused;
|
||||
uint64_t unused_1;
|
||||
uint16_t unused_2;
|
||||
uint8_t bid;
|
||||
uint8_t nChannels;
|
||||
} clique;
|
||||
@@ -236,11 +241,24 @@ struct ncclWorkElem {
|
||||
uint64_t align[4];
|
||||
};
|
||||
};
|
||||
struct ncclWork {
|
||||
struct ncclWorkElem elems[NCCL_MAX_WORK_ELEMENTS];
|
||||
};
|
||||
static_assert(sizeof(struct ncclWorkElem) == (0x10*sizeof(int)), "ncclWorkElem must have a pow2 size");
|
||||
|
||||
struct ncclWorkRegElem {
|
||||
struct ncclWorkElem elem;
|
||||
void* dnInputs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
void* dnOutputs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
void* upOutputs[NCCL_MAX_DIRECT_ARITY+1];
|
||||
};
|
||||
#define NCCL_REG_ELEM_FACTOR 4
|
||||
static_assert(sizeof(struct ncclWorkRegElem) == (NCCL_REG_ELEM_FACTOR*sizeof(struct ncclWorkElem)), "ncclWorkRegElem size must be pow2 times ncclWorkElem size");
|
||||
|
||||
struct ncclWork {
|
||||
union {
|
||||
struct ncclWorkElem elems[NCCL_MAX_WORK_ELEMENTS];
|
||||
struct ncclWorkRegElem regElems[NCCL_MAX_WORK_ELEMENTS/NCCL_REG_ELEM_FACTOR];
|
||||
};
|
||||
};
|
||||
|
||||
struct ncclChannel {
|
||||
union {
|
||||
struct {
|
||||
@@ -330,6 +348,7 @@ struct ncclProf {
|
||||
typedef enum {
|
||||
ncclCollTraceNotReady,
|
||||
ncclCollTraceKernelLaunchType,
|
||||
ncclCollTraceKernelEndType,
|
||||
ncclCollTraceCollEndType,
|
||||
ncclCollTraceAbortType,
|
||||
ncclCollTraceDataType
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
@@ -30,10 +31,19 @@ void HIPRT_CB ncclEnqueueHostSetup(void* arg);
|
||||
ncclResult_t ncclGetCudaGraph(ncclComm_t comm, cudaGraph_t* graph);
|
||||
ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph);
|
||||
|
||||
struct ncclBuffRegInfo {
|
||||
void* sendbuffsBase[NCCL_MAX_INTRA_RANKS];
|
||||
void* recvbuffsBase[NCCL_MAX_INTRA_RANKS];
|
||||
void* sendbuffs[NCCL_MAX_INTRA_RANKS];
|
||||
void* recvbuffs[NCCL_MAX_INTRA_RANKS];
|
||||
int nBuffs;
|
||||
};
|
||||
|
||||
// Enqueue information (for kernel and proxy) for each operation
|
||||
struct ncclQueueElem {
|
||||
struct ncclWorkElem work;
|
||||
struct ncclProxyArgs proxyArgs;
|
||||
struct ncclBuffRegInfo buffRegInfo;
|
||||
};
|
||||
|
||||
typedef ncclRecyclableList<struct ncclQueueElem> ncclQueueElemList;
|
||||
@@ -43,6 +53,7 @@ struct ncclQueueInfo {
|
||||
ncclComm_t comm;
|
||||
int maxChannels; // Dynamic version of gridDim
|
||||
ncclResult_t ret; // Return value of host setup call
|
||||
int nRegBuffs;
|
||||
ncclQueueElemList* elemList;
|
||||
};
|
||||
|
||||
@@ -50,6 +61,7 @@ static ncclResult_t ncclCreateQueueInfo(struct ncclQueueInfo** eqInfo, ncclComm_
|
||||
NCCLCHECK(ncclCalloc(eqInfo, 1));
|
||||
(*eqInfo)->comm = comm;
|
||||
(*eqInfo)->elemList = new ncclQueueElemList();
|
||||
(*eqInfo)->comm->nQueueInfoCreated++;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -58,6 +70,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) {
|
||||
if (eqInfo == NULL) return ncclInternalError;
|
||||
eqInfo->maxChannels = 0;
|
||||
eqInfo->ret = ncclSuccess;
|
||||
eqInfo->nRegBuffs = 0;
|
||||
eqInfo->elemList->recycle();
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -67,7 +80,54 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) {
|
||||
static void ncclDestroyQueueInfo(void* ptr) {
|
||||
if (ptr == NULL) return;
|
||||
struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)ptr;
|
||||
struct ncclComm* comm = eqInfo->comm;
|
||||
// Close IPC mem handles for registered buffers
|
||||
struct ncclQueueElem* eqElem = eqInfo->elemList->begin();
|
||||
#if 0
|
||||
// Ideally, the deregistration should happen here
|
||||
// but currently the destroy function of CUDA objects does not allow CUDA API calls
|
||||
while (eqElem != NULL) {
|
||||
for (int i=0; i<eqElem->buffRegInfo.nBuffs; i++) {
|
||||
if (i == eqInfo->comm->intraNodeRank) continue;
|
||||
CUDACHECKIGNORE(cudaIpcCloseMemHandle(eqElem->buffRegInfo.sendbuffsBase[i]));
|
||||
CUDACHECKIGNORE(cudaIpcCloseMemHandle(eqElem->buffRegInfo.recvbuffsBase[i]));
|
||||
}
|
||||
eqElem = eqInfo->elemList->getNext();
|
||||
}
|
||||
#else
|
||||
// Instead, we push these pointers to a pool owned by ncclComm
|
||||
// and asks a helper thread to close mem handles
|
||||
struct ncclGraphHelperResources* res = comm->graphHelperResources;
|
||||
int ipcTailOld = 0;
|
||||
if (res == NULL || (!comm->graphHelperThread) || eqInfo->nRegBuffs == 0) goto skip;
|
||||
|
||||
pthread_mutex_lock(&res->threadLock);
|
||||
ipcTailOld = res->ipcTail;
|
||||
while (eqElem != NULL) {
|
||||
for (int i=0; i<eqElem->buffRegInfo.nBuffs; i++) {
|
||||
if (eqElem->buffRegInfo.sendbuffsBase[i] != NULL) {
|
||||
res->ipcBases[res->ipcTail] = eqElem->buffRegInfo.sendbuffsBase[i];
|
||||
res->ipcTail = (res->ipcTail+1)%NCCL_IPC_POOL_SIZE;
|
||||
}
|
||||
if (eqElem->buffRegInfo.recvbuffsBase[i] != NULL) {
|
||||
res->ipcBases[res->ipcTail] = eqElem->buffRegInfo.recvbuffsBase[i];
|
||||
res->ipcTail = (res->ipcTail+1)%NCCL_IPC_POOL_SIZE;
|
||||
}
|
||||
}
|
||||
eqElem = eqInfo->elemList->getNext();
|
||||
}
|
||||
if (res->ipcTail != ipcTailOld) {
|
||||
res->threadState = ThreadStart;
|
||||
TRACE(NCCL_COLL, "CUDA Graph destroy function signaling helper thread with %d IPC handles", res->ipcTail-ipcTailOld);
|
||||
pthread_cond_signal(&res->threadCond);
|
||||
}
|
||||
pthread_mutex_unlock(&res->threadLock);
|
||||
#endif
|
||||
|
||||
skip:
|
||||
delete eqInfo->elemList;
|
||||
free(eqInfo);
|
||||
comm->nQueueInfoDestroyed++;
|
||||
return;
|
||||
}
|
||||
#endif // End include guard
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "nccl.h"
|
||||
#include <stdint.h> // for standard [u]intX_t types
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// These can be used if the GDR library isn't thread safe
|
||||
#include <pthread.h>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "nccl.h"
|
||||
#include "devcomm.h"
|
||||
#include "collectives.h"
|
||||
|
||||
typedef enum {
|
||||
ncclPatternRing,
|
||||
@@ -39,6 +40,7 @@ struct ncclInfo {
|
||||
int chunkSteps;
|
||||
int sliceSteps;
|
||||
// Computed later
|
||||
ncclDevRedOpFull opFull;
|
||||
int algorithm;
|
||||
int protocol;
|
||||
ncclPattern_t pattern;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#ifndef NCCL_PARAM_H_
|
||||
#define NCCL_PARAM_H_
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
@@ -55,7 +55,7 @@ struct ncclTransport {
|
||||
};
|
||||
|
||||
ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex);
|
||||
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex);
|
||||
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL);
|
||||
|
||||
enum { collNetRecv=0, collNetSend=1 };
|
||||
int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type);
|
||||
|
||||
@@ -50,7 +50,7 @@ std::chrono::high_resolution_clock::time_point ncclEpoch;
|
||||
const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv" };
|
||||
const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNet" };
|
||||
const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" };
|
||||
const char* ncclRedOpStr[ncclNumOps] = { "Sum", "Prod", "Max", "Min" };
|
||||
const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" };
|
||||
const char *ncclTypeStr[ncclNumTypes] = {"_i8", "_u8", "_i32", "_u32", "_i64", "_u64", "_f16", "_f32", "_f64", "_b16"};
|
||||
|
||||
NCCL_PARAM(GroupCudaStream, "GROUP_CUDA_STREAM", NCCL_GROUP_CUDA_STREAM);
|
||||
@@ -80,13 +80,21 @@ ncclResult_t initCollNet(ncclCollNet_t* collnet) {
|
||||
}
|
||||
|
||||
ncclResult_t initNetPlugin(ncclNet_t** net, ncclCollNet_t** collnet) {
|
||||
void* netPluginLib = dlopen("librccl-net.so", RTLD_NOW | RTLD_LOCAL);
|
||||
char ncclNetPluginName[128];
|
||||
const char* envPluginName = getenv("NCCL_NET_PLUGIN");
|
||||
if (envPluginName && strlen(envPluginName)) {
|
||||
snprintf(ncclNetPluginName, 128, "librccl-net-%s.so", envPluginName);
|
||||
INFO(NCCL_INIT, "Plugin name set by env to %s\n", ncclNetPluginName);
|
||||
} else {
|
||||
sprintf(ncclNetPluginName, "librccl-net.so");
|
||||
}
|
||||
void* netPluginLib = dlopen(ncclNetPluginName, RTLD_NOW | RTLD_LOCAL);
|
||||
if (netPluginLib == NULL) {
|
||||
// dlopen does not guarantee to set errno, but dlerror only gives us a
|
||||
// string, so checking errno doesn't hurt to try to provide a better
|
||||
// error message
|
||||
if (errno == ENOENT) {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (librccl-net.so), using internal implementation");
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (%s), using internal implementation", ncclNetPluginName);
|
||||
} else {
|
||||
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load returned %d : %s.", errno, dlerror());
|
||||
}
|
||||
@@ -199,23 +207,27 @@ RCCL_PARAM(KernelCollTraceEnable, "KERNEL_COLL_TRACE_ENABLE", 0);
|
||||
void *ncclCommThreadMain(void *arg) {
|
||||
ncclComm_t comm = (ncclComm_t)arg;
|
||||
int head = comm->hostDevComm.collTraceHead;
|
||||
#define MAX_NAME_LENGTH 32
|
||||
#define MAX_NAME_LENGTH 64
|
||||
char* func_names = (char *)malloc(MAX_NAME_LENGTH*(FUNC_INDEX_P2P+1));
|
||||
for (int func = 0; func < NCCL_NUM_FUNCTIONS; func++) {
|
||||
for (int al = 0; al < NCCL_NUM_ALGORITHMS; al++) {
|
||||
for (int type = 0; type < ncclNumTypes; type++) {
|
||||
for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) {
|
||||
for (int redop = 0; redop < ncclNumOps; redop++) {
|
||||
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, redop, type, al, pr);
|
||||
for (int devredop = 0; devredop < ncclNumDevRedOps; devredop++) {
|
||||
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, devredop, type, al, pr);
|
||||
sprintf(line, "%s%s%s%s%s", ncclFuncStr[func], ncclAlgoStr[al], ncclProtoStr[pr],
|
||||
ncclRedOpStr[redop], ncclTypeStr[type]);
|
||||
ncclDevRedOpStr[devredop], ncclTypeStr[type]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int type = 0; type < ncclNumTypes; type++) {
|
||||
char* line = func_names+MAX_NAME_LENGTH*(FUNC_INDEX_P2P-ncclNumTypes+type);
|
||||
sprintf(line, "OneRankReducePreMulSum%s", ncclTypeStr[type]);
|
||||
}
|
||||
char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX_P2P;
|
||||
sprintf(line, "%s", ncclFuncStr[NCCL_NUM_FUNCTIONS]);
|
||||
sprintf(line, "SendRecvRingSimpleSum_i8");
|
||||
do {
|
||||
int tail = LOAD(comm->hostDevComm.collTraceTail)%COLLTRACE_NUM_ITEMS;
|
||||
int count;
|
||||
@@ -246,7 +258,7 @@ void *ncclCommThreadMain(void *arg) {
|
||||
fIdx, td->data_0, td->opCount, td->data_1);
|
||||
} else {
|
||||
sprintf(line, "## [%12.6f] [%02d:%02d] %06lx",
|
||||
(double)(td->timeStamp)/VEGA_GPU_RTC_FREQUENCY, comm->rank, td->bid, td->opCount);
|
||||
(double)(td->timeStamp)/VEGA_GPU_RTC_FREQUENCY, comm->rank, td->bid, fIdx == FUNC_INDEX_P2P ? (td->opCount + 0x100000): td->opCount);
|
||||
offset = strlen(line);
|
||||
switch (type) {
|
||||
case ncclCollTraceKernelLaunchType:
|
||||
@@ -261,18 +273,17 @@ void *ncclCommThreadMain(void *arg) {
|
||||
sprintf(line+offset, "nt %d bi %d nc %d busId %lx nRanks %d", td->coll.nThreads, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks);
|
||||
break;
|
||||
case ncclCollTraceCollEndType:
|
||||
if (fIdx != 0xffff) {
|
||||
sprintf(line+offset, " CE %s ", func_names+MAX_NAME_LENGTH*fIdx);
|
||||
offset = strlen(line);
|
||||
if (fIdx > FUNC_INDEX_P2P)
|
||||
sprintf(line+offset, "ERROR bad function index %d", fIdx);
|
||||
else if (fIdx == FUNC_INDEX_P2P)
|
||||
sprintf(line+offset, "nt %d dt %d busId %lx nRanks %d", td->p2p.nThreads, td->p2p.delta, comm->busId, comm->nRanks);
|
||||
else
|
||||
sprintf(line+offset, "nt %d bi %d nc %d busId %lx nRanks %d", td->coll.nThreads, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks);
|
||||
}
|
||||
sprintf(line+offset, " CE %s ", func_names+MAX_NAME_LENGTH*fIdx);
|
||||
offset = strlen(line);
|
||||
if (fIdx > FUNC_INDEX_P2P)
|
||||
sprintf(line+offset, "ERROR bad function index %d", fIdx);
|
||||
else if (fIdx == FUNC_INDEX_P2P)
|
||||
sprintf(line+offset, "nt %d dt %d busId %lx nRanks %d", td->p2p.nThreads, td->p2p.delta, comm->busId, comm->nRanks);
|
||||
else
|
||||
sprintf(line+offset, " KE busId %lx nRanks %d", comm->busId, comm->nRanks);
|
||||
sprintf(line+offset, "nt %d bi %d nc %d busId %lx nRanks %d", td->coll.nThreads, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks);
|
||||
break;
|
||||
case ncclCollTraceKernelEndType:
|
||||
sprintf(line+offset, " KE busId %lx nRanks %d", comm->busId, comm->nRanks);
|
||||
break;
|
||||
case ncclCollTraceAbortType:
|
||||
sprintf(line+offset, " Abort");
|
||||
@@ -299,6 +310,9 @@ void *ncclCommThreadMain(void *arg) {
|
||||
static ncclResult_t commFree(ncclComm_t comm) {
|
||||
if (comm == NULL)
|
||||
return ncclSuccess;
|
||||
|
||||
delete[] comm->userRedOps;
|
||||
|
||||
free(comm->connectSend);
|
||||
free(comm->connectRecv);
|
||||
for (int peer=0; peer<comm->nRanks; peer++) {
|
||||
@@ -442,8 +456,6 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
CUDACHECK(hipStreamDestroy(comm->groupStream));
|
||||
}
|
||||
|
||||
ncclDestroyQueueInfo(comm->enqueueInfo);
|
||||
|
||||
// Last rank frees shared resources between threads
|
||||
int isLast;
|
||||
NCCLCHECK(ncclCpuBarrierIn(comm, &isLast));
|
||||
@@ -466,6 +478,8 @@ static ncclResult_t commFree(ncclComm_t comm) {
|
||||
RCCL_PARAM(CliqueIgnoreTopo, "CLIQUE_IGNORE_TOPO", 0);
|
||||
RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0);
|
||||
NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2);
|
||||
NCCL_PARAM(DisableGraphHelper, "GRAPH_HELPER_DISABLE", 0);
|
||||
NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 0);
|
||||
|
||||
static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
if (ndev < 1) {
|
||||
@@ -509,7 +523,7 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
*comm->abortFlag = 0;
|
||||
|
||||
comm->collOpCount = 0;
|
||||
comm->p2pOpCount = 0x8000;
|
||||
comm->p2pOpCount = 0;
|
||||
|
||||
comm->argsptr = &comm->args;
|
||||
#ifdef ENABLE_PROFILING
|
||||
@@ -539,11 +553,20 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
comm->asyncAllocMode = ncclComm::SHORTEST_QUEUE;
|
||||
}
|
||||
|
||||
CUDACHECK(hipDriverGetVersion(&comm->driverVersion));
|
||||
|
||||
NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm));
|
||||
comm->lastSetupNode = NULL;
|
||||
comm->lastCudaGraphId = -1;
|
||||
|
||||
CUDACHECK(hipDriverGetVersion(&comm->driverVersion));
|
||||
comm->disableGraphHelper = ncclParamDisableGraphHelper();
|
||||
comm->graphRegister = ncclParamGraphRegister();
|
||||
#if CUDART_VERSION >= 11030
|
||||
NCCLCHECK(ncclCalloc(&comm->graphHelperResources, 1));
|
||||
comm->graphHelperResources->comm = comm;
|
||||
if (comm->driverVersion >= 11030)
|
||||
// hipGetDriverEntryPoint requires R465 or above (enhanced compat need)
|
||||
CUDACHECK(hipGetDriverEntryPoint("cuMemGetAddressRange", (void**)&comm->pfnCuMemGetAddressRange, hipEnableDefault));
|
||||
#endif
|
||||
|
||||
static_assert(MAXCHANNELS <= sizeof(*comm->connectSend)*8, "comm->connectSend must have enough bits for all channels");
|
||||
static_assert(MAXCHANNELS <= sizeof(*comm->connectRecv)*8, "comm->connectRecv must have enough bits for all channels");
|
||||
@@ -554,6 +577,10 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) {
|
||||
NCCLCHECK(ncclCalloc(&comm->p2pSends, comm->nRanks));
|
||||
NCCLCHECK(ncclCalloc(&comm->p2pRecvs, comm->nRanks));
|
||||
|
||||
// Create a map between global rank and intra-node rank
|
||||
NCCLCHECK(ncclCalloc(&comm->rankToIntraNodeRank, comm->nRanks));
|
||||
memset(comm->rankToIntraNodeRank, -1, comm->nRanks*sizeof(comm->rankToIntraNodeRank[0]));
|
||||
|
||||
// Mark channels as non initialized.
|
||||
for (int c=0; c<MAXCHANNELS; c++) comm->channels[c].id = -1;
|
||||
|
||||
@@ -786,13 +813,14 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
int intraNodeRank0 = -1, intraNodeRank = -1, intraNodeRanks = 0;
|
||||
int myCompCap = allGather1Data[rank].cudaCompCap;
|
||||
int minCompCap = myCompCap, maxCompCap = myCompCap;
|
||||
int intraNodeGlobalRanks[256];
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) {
|
||||
// Rank is on same node
|
||||
if (intraNodeRanks == 0) intraNodeRank0 = i;
|
||||
if (i == rank) intraNodeRank = intraNodeRanks;
|
||||
intraNodeGlobalRanks[intraNodeRanks++] = i;
|
||||
comm->intraNodeGlobalRanks[intraNodeRanks] = i;
|
||||
comm->rankToIntraNodeRank[i] = intraNodeRanks;
|
||||
intraNodeRanks++;
|
||||
if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) {
|
||||
// Rank is in same process
|
||||
if (intraProcRanks == 0) intraProcRank0 = i;
|
||||
@@ -821,6 +849,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
}
|
||||
struct ncclComm* intraProcRank0Comm = allGather1Data[intraProcRank0].comm;
|
||||
uint64_t intraNodeRank0pidHash = allGather1Data[intraNodeRank0].peerInfo.pidHash;
|
||||
comm->intraNodeRank = intraNodeRank;
|
||||
|
||||
// AllGather1 - end
|
||||
|
||||
@@ -1138,6 +1167,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
// Check if we can setup CollNet
|
||||
if (comm->collNetSupport > 0) {
|
||||
int collNetSetupFail = 0;
|
||||
int highestTypes[NCCL_MAX_INTRA_RANKS] = {TRANSPORT_P2P};
|
||||
// Find all head ranks
|
||||
int nHeads = collNetGraph.nChannels;
|
||||
int *heads;
|
||||
@@ -1163,16 +1193,26 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
|
||||
TRACE(NCCL_INIT, "rank %d Connected inter-node CollNet", rank);
|
||||
|
||||
// Connect intra-node CollNet
|
||||
int highestTransportType0, highestTransportType1;
|
||||
for (int c=0; c<comm->nChannels; c++) {
|
||||
struct ncclChannel* channelRecv = comm->channels+c;
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channelRecv, NCCL_MAX_DIRECT_ARITY, channelRecv->collTree.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collTree.down, 0), ret, collnet_cleanup);
|
||||
}
|
||||
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0), ret, collnet_cleanup);
|
||||
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0, &highestTransportType0), ret, collnet_cleanup);
|
||||
for (int c=0; c<comm->nChannels; c++) {
|
||||
struct ncclChannel* channelSend = comm->channels+c;
|
||||
NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channelSend, NCCL_MAX_DIRECT_ARITY, channelSend->collTree.down, NCCL_MAX_DIRECT_ARITY, channelSend->collTree.up, 1), ret, collnet_cleanup);
|
||||
}
|
||||
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1), ret, collnet_cleanup);
|
||||
NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1, &highestTransportType1), ret, collnet_cleanup);
|
||||
|
||||
// Exchange highest intra-node transport type among ranks
|
||||
// because we need to know whether all ranks can p2p each other to determine whether we can directly read/write registered user buffer
|
||||
comm->intraHighestTransportType = highestTypes[comm->intraNodeRank] = highestTransportType0 > highestTransportType1 ? highestTransportType0 : highestTransportType1;
|
||||
NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->intraNodeGlobalRanks, comm->intraNodeRank, comm->localRanks, highestTypes, sizeof(int)));
|
||||
for (int i=0; i<comm->localRanks; i++) {
|
||||
if (highestTypes[i] > comm->intraHighestTransportType)
|
||||
comm->intraHighestTransportType = highestTypes[i];
|
||||
}
|
||||
INFO(NCCL_INIT, "rank %d Connected CollNet comm %p nRanks %02d", rank, comm, comm->nRanks);
|
||||
|
||||
collnet_cleanup:
|
||||
@@ -1220,7 +1260,7 @@ collnet_cleanup:
|
||||
NCCLCHECK(ncclCommSetIntraProc(comm, intraProcRank, intraProcRanks, intraProcRank0Comm));
|
||||
|
||||
/* Local intra-node barrier */
|
||||
NCCLCHECK(bootstrapBarrier(comm->bootstrap, intraNodeGlobalRanks, (int)intraNodeRank0pidHash, intraNodeRank, intraNodeRanks));
|
||||
NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->intraNodeGlobalRanks, intraNodeRank, intraNodeRanks, (int)intraNodeRank0pidHash));
|
||||
|
||||
if (comm->nNodes) NCCLCHECK(ncclProxyCreate(comm));
|
||||
|
||||
@@ -1321,6 +1361,22 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t ncclGraphHelperDestroy(ncclComm* comm) {
|
||||
auto res = comm->graphHelperResources;
|
||||
if (comm->graphHelperThread && res) {
|
||||
pthread_mutex_lock(&res->threadLock);
|
||||
res->threadState = ThreadStop;
|
||||
pthread_cond_signal(&res->threadCond);
|
||||
pthread_mutex_unlock(&res->threadLock);
|
||||
pthread_join(comm->graphHelperThread, NULL);
|
||||
}
|
||||
if (res) {
|
||||
free(res);
|
||||
res = NULL;
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t commDestroy(ncclComm_t comm) {
|
||||
int savedDevice;
|
||||
#ifdef ENABLE_TRACE
|
||||
@@ -1337,6 +1393,11 @@ static ncclResult_t commDestroy(ncclComm_t comm) {
|
||||
|
||||
CUDACHECK(hipStreamSynchronize(comm->groupStream));
|
||||
NCCLCHECK(ncclProxyDestroy(comm));
|
||||
ncclDestroyQueueInfo(comm->enqueueInfo);
|
||||
#if CUDART_VERSION >= 11030
|
||||
NCCLCHECK(ncclGraphHelperDestroy(comm));
|
||||
#endif
|
||||
INFO(NCCL_COLL, "Created %d queue info, destroyed %d", comm->nQueueInfoCreated, comm->nQueueInfoDestroyed);
|
||||
NCCLCHECK(commFree(comm));
|
||||
|
||||
if (savedDevice != commDevice)
|
||||
|
||||
@@ -52,10 +52,16 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) {
|
||||
}
|
||||
if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank
|
||||
|
||||
if (info->op < 0 || info->op >= ncclNumOps) {
|
||||
if (info->op < 0 || ncclMaxRedOp < info->op) {
|
||||
WARN("%s : invalid reduction operation %d", info->opName, info->op);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
int opIx = int(ncclUserRedOpMangle(info->comm, info->op)) - int(ncclNumOps);
|
||||
if (ncclNumOps <= info->op &&
|
||||
(info->comm->userRedOpCapacity <= opIx || info->comm->userRedOps[opIx].freeNext != -1)) {
|
||||
WARN("%s : reduction operation %d unknown to this communicator", info->opName, info->op);
|
||||
return ncclInvalidArgument;
|
||||
}
|
||||
|
||||
if (info->comm->checkPointers) {
|
||||
if (info->coll == ncclFuncSendRecv) {
|
||||
|
||||
@@ -60,7 +60,7 @@ ncclResult_t wrap_ibv_symbols(void) {
|
||||
if (!ibvhandle) {
|
||||
ibvhandle=dlopen("libibverbs.so.1", RTLD_NOW);
|
||||
if (!ibvhandle) {
|
||||
WARN("Failed to open libibverbs.so[.1]");
|
||||
INFO(NCCL_INIT, "Failed to open libibverbs.so[.1]");
|
||||
goto teardown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#define NCCL_SUFFIX "${NCCL_SUFFIX}"
|
||||
|
||||
#define NCCL_VERSION_CODE ${NCCL_VERSION}
|
||||
#define NCCL_VERSION(X,Y,Z) (((X) >= 2 && (Y) >= 9) ? (X) * 10000 + (Y) * 100 + (Z) : (X) * 1000 + (Y) * 100 + (Z))
|
||||
#define NCCL_VERSION(X,Y,Z) (((X) <= 2 && (Y) <= 8) ? (X) * 1000 + (Y) * 100 + (Z) : (X) * 10000 + (Y) * 100 + (Z))
|
||||
|
||||
#define RCCL_BFLOAT16 1
|
||||
#define RCCL_GATHER_SCATTER 1
|
||||
@@ -141,12 +141,24 @@ ncclResult_t pncclCommUserRank(const ncclComm_t comm, int* rank);
|
||||
/// @endcond
|
||||
|
||||
/*! @brief Reduction operation selector */
|
||||
/* Reduction operation selector */
|
||||
typedef enum { ncclNumOps_dummy = 5 } ncclRedOp_dummy_t;
|
||||
typedef enum { ncclSum = 0,
|
||||
ncclProd = 1,
|
||||
ncclMax = 2,
|
||||
ncclMin = 3,
|
||||
ncclAvg = 4,
|
||||
ncclNumOps = 5 } ncclRedOp_t;
|
||||
/* ncclNumOps: The number of built-in ncclRedOp_t values. Also
|
||||
* serves as the least possible value for dynamic ncclRedOp_t's
|
||||
* as constructed by ncclRedOpCreate*** functions. */
|
||||
ncclNumOps = 5,
|
||||
/* ncclMaxRedOp: The largest valid value for ncclRedOp_t.
|
||||
* It is defined to be the largest signed value (since compilers
|
||||
* are permitted to use signed enums) that won't grow
|
||||
* sizeof(ncclRedOp_t) when compared to previous NCCL versions to
|
||||
* maintain ABI compatibility. */
|
||||
ncclMaxRedOp = 0x7fffffff>>(32-8*sizeof(ncclRedOp_dummy_t))
|
||||
} ncclRedOp_t;
|
||||
|
||||
/*! @brief Data types */
|
||||
typedef enum { ncclInt8 = 0, ncclChar = 0,
|
||||
@@ -161,6 +173,40 @@ typedef enum { ncclInt8 = 0, ncclChar = 0,
|
||||
ncclBfloat16 = 9,
|
||||
ncclNumTypes = 10 } ncclDataType_t;
|
||||
|
||||
/* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */
|
||||
typedef enum {
|
||||
/* ncclScalarDevice: The scalar is in device-visible memory and will be
|
||||
* dereferenced while the collective is running. */
|
||||
ncclScalarDevice = 0,
|
||||
|
||||
/* ncclScalarHostImmediate: The scalar is in host-visible memory and will be
|
||||
* dereferenced before the ncclRedOpCreate***() function returns. */
|
||||
ncclScalarHostImmediate = 1
|
||||
} ncclScalarResidence_t;
|
||||
|
||||
/*
|
||||
* ncclRedOpCreatePreMulSum
|
||||
*
|
||||
* Creates a new reduction operator which pre-multiplies input values by a given
|
||||
* scalar locally before reducing them with peer values via summation. For use
|
||||
* only with collectives launched against *comm* and *datatype*. The
|
||||
* *residence* argument indicates how/when the memory pointed to by *scalar*
|
||||
* will be dereferenced. Upon return, the newly created operator's handle
|
||||
* is stored in *op*.
|
||||
*/
|
||||
ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm);
|
||||
ncclResult_t pncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm);
|
||||
|
||||
/*
|
||||
* ncclRedOpDestroy
|
||||
*
|
||||
* Destroys the reduction operator *op*. The operator must have been created by
|
||||
* ncclRedOpCreatePreMul with the matching communicator *comm*. An operator may be
|
||||
* destroyed as soon as the last NCCL function which is given that operator returns.
|
||||
*/
|
||||
ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm);
|
||||
ncclResult_t pncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm);
|
||||
|
||||
/*
|
||||
* Collective communication operations
|
||||
*
|
||||
|
||||
@@ -20,7 +20,7 @@ struct ncclTransport ncclTransports[NTRANSPORTS] = {
|
||||
};
|
||||
|
||||
template <int type>
|
||||
static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex) {
|
||||
static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex, int* transportType) {
|
||||
struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank;
|
||||
struct ncclPeerInfo* peerInfo = comm->peerInfo+peer;
|
||||
struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer].send + connIndex :
|
||||
@@ -45,6 +45,7 @@ static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph*
|
||||
if (ret) {
|
||||
connector->transportComm = transportComm;
|
||||
NCCLCHECK(transportComm->setup(comm, graph, myInfo, peerInfo, connect, connector, channelId, connIndex));
|
||||
if (transportType) *transportType = t;
|
||||
return ncclSuccess;
|
||||
}
|
||||
}
|
||||
@@ -77,12 +78,14 @@ void dumpData(struct ncclConnect* data, int ndata) {
|
||||
}
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex) {
|
||||
ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType/*=NULL*/) {
|
||||
#if CUDART_VERSION >= 11030
|
||||
// Stream used during transport setup; need for P2P pre-connect + CUDA Graph
|
||||
hipStream_t transportSetupStream;
|
||||
CUDACHECK(hipStreamCreateWithFlags(&transportSetupStream, hipStreamNonBlocking));
|
||||
#endif
|
||||
int highestType = TRANSPORT_P2P; // track highest transport type
|
||||
|
||||
struct ncclConnect data[2*MAXCHANNELS];
|
||||
for (int i=1; i<comm->nRanks; i++) {
|
||||
int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0);
|
||||
@@ -93,15 +96,18 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
|
||||
|
||||
struct ncclConnect* recvData = data;
|
||||
int sendChannels = 0, recvChannels = 0;
|
||||
int type;
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (recvMask & (1<<c)) {
|
||||
NCCLCHECK(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex));
|
||||
NCCLCHECK(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type));
|
||||
if (type > highestType) highestType = type;
|
||||
}
|
||||
}
|
||||
struct ncclConnect* sendData = recvData+recvChannels;
|
||||
for (int c=0; c<MAXCHANNELS; c++) {
|
||||
if (sendMask & (1<<c)) {
|
||||
NCCLCHECK(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex));
|
||||
NCCLCHECK(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type));
|
||||
if (type > highestType) highestType = type;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,6 +155,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph*
|
||||
CUDACHECK(hipStreamSynchronize(transportSetupStream));
|
||||
CUDACHECK(hipStreamDestroy(transportSetupStream));
|
||||
#endif
|
||||
if (highestTransportType != NULL) *highestTransportType = highestType;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -242,22 +249,18 @@ cleanup:
|
||||
}
|
||||
|
||||
ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail) {
|
||||
int rank = comm->rank;
|
||||
int nranks = comm->nRanks;
|
||||
// AllGather collNet setup results
|
||||
int* allGatherFailures;
|
||||
NCCLCHECK(ncclCalloc(&allGatherFailures, nranks));
|
||||
allGatherFailures[rank] = collNetSetupFail;
|
||||
NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGatherFailures, sizeof(int)));
|
||||
for (int i=0; i<nranks; i++) {
|
||||
int allGatherFailures[NCCL_MAX_INTRA_RANKS] = {0};
|
||||
allGatherFailures[comm->intraNodeRank] = collNetSetupFail;
|
||||
NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->intraNodeGlobalRanks, comm->intraNodeRank, comm->localRanks, allGatherFailures, sizeof(int)));
|
||||
for (int i=0; i<comm->localRanks; i++) {
|
||||
if (allGatherFailures[i] != 0) {
|
||||
collNetSetupFail = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
free(allGatherFailures);
|
||||
if (collNetSetupFail) {
|
||||
if (rank == 0) WARN("Cannot initialize CollNet, using point-to-point network instead");
|
||||
if (comm->intraNodeRank == 0) WARN("Cannot initialize CollNet, using point-to-point network instead");
|
||||
return ncclSystemError;
|
||||
}
|
||||
return ncclSuccess;
|
||||
|
||||
@@ -473,7 +473,7 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &ptr));
|
||||
reqFifo[group][buffSlot].recvBuff = ptr;
|
||||
TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
||||
TRACE(NCCL_NET, "recvProxy [%lu/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff);
|
||||
sub->posted += args->sliceSteps;
|
||||
args->idle = 0;
|
||||
continue;
|
||||
@@ -518,7 +518,7 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) {
|
||||
int sharedBuffSlot = sub->transmitted%NCCL_STEPS;
|
||||
int startChannel = group*COLLNET_GROUP_NSUBS;
|
||||
char* groupRecvAddress;
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress));
|
||||
NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &groupRecvAddress));
|
||||
char* ptr = groupRecvAddress + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot];
|
||||
if (p == NCCL_PROTO_SIMPLE) {
|
||||
volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo;
|
||||
|
||||
@@ -89,6 +89,8 @@ static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort)
|
||||
} else {
|
||||
// Merge multi-port NICs into the same PCI device
|
||||
p[strlen(p)-1] = '0';
|
||||
// Also merge virtual functions (VF) into the same device
|
||||
p[strlen(p)-3] = '0';
|
||||
// And keep the real port aside (the ibv port is always 1 on recent cards)
|
||||
*realPort = 0;
|
||||
for (int d=0; d<ncclNIbDevs; d++) {
|
||||
|
||||
@@ -23,6 +23,7 @@ struct p2pSendResources {
|
||||
uint32_t* next_hdp_reg; // Next GPU in ring (for p2p transport use only)
|
||||
int remoteId;
|
||||
int memRank;
|
||||
void* remIpcPtr;
|
||||
void* bootstrap;
|
||||
};
|
||||
|
||||
@@ -31,6 +32,7 @@ struct p2pRecvResources {
|
||||
void* ipcPtr;
|
||||
int remoteId;
|
||||
int memRank;
|
||||
void* remIpcPtr;
|
||||
void* bootstrap;
|
||||
};
|
||||
|
||||
@@ -107,6 +109,24 @@ ncclResult_t p2pCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTop
|
||||
*ret = 0;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
#else
|
||||
// Check that legacy IPC support is available
|
||||
if (p2p != 0) {
|
||||
char *dummy;
|
||||
cudaIpcMemHandle_t ipc;
|
||||
NCCLCHECK(ncclCudaCalloc(&dummy, CUDA_IPC_MIN));
|
||||
if (cudaIpcGetMemHandle(&ipc, dummy) != cudaSuccess) {
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Legacy IPC not supported on dev %d(=%lx)",
|
||||
cudaDev1, info1->busId);
|
||||
*ret = 0;
|
||||
}
|
||||
CUDACHECK(cudaFree(dummy));
|
||||
return ncclSuccess;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (p2p == 0) {
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Could not enable P2P between dev %d(=%lx) and dev %d(=%lx)",
|
||||
cudaDev1, info1->busId, cudaDev2, info2->busId);
|
||||
@@ -195,13 +215,14 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
|
||||
NCCLCHECK(ncclCudaCalloc((char**)&info.directPtr, sendSize, true));
|
||||
info.rank = myInfo->rank;
|
||||
if (myInfo->pidHash == peerInfo->pidHash) {
|
||||
if (info.read == 0) send->conn.direct |= NCCL_DIRECT_GPU;
|
||||
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s comm %p nRanks %02d",
|
||||
channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, comm, comm->nRanks);
|
||||
send->conn.direct |= info.read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE;
|
||||
INFO(NCCL_INIT|NCCL_P2P, "Channel %02d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s",
|
||||
channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
|
||||
} else {
|
||||
send->conn.direct |= info.read ? NCCL_IPC_READ : NCCL_IPC_WRITE;
|
||||
CUDACHECK(hipIpcGetMemHandle(&info.devIpc, info.directPtr));
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/IPC%s comm %p nRanks %02d",
|
||||
channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, comm, comm->nRanks);
|
||||
INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/IPC%s",
|
||||
channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr);
|
||||
}
|
||||
} else {
|
||||
NCCLCHECK(bootstrapRemAlloc(sendSize, intermediateRank, resources->bootstrap, &resources->remoteId, &info.devIpc, &info.directPtr));
|
||||
@@ -243,8 +264,9 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st
|
||||
NCCLCHECK(ncclCudaCalloc((char**)&info.directPtr, recvSize, true));
|
||||
info.rank = myInfo->rank;
|
||||
if (myInfo->pidHash == peerInfo->pidHash) {
|
||||
if (info.read == 0) recv->conn.direct |= NCCL_DIRECT_GPU;
|
||||
recv->conn.direct |= info.read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE;
|
||||
} else {
|
||||
recv->conn.direct |= info.read ? NCCL_IPC_READ : NCCL_IPC_WRITE;
|
||||
CUDACHECK(hipIpcGetMemHandle(&info.devIpc, info.directPtr));
|
||||
}
|
||||
} else {
|
||||
@@ -266,7 +288,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co
|
||||
struct ncclRecvMem* remDevMem;
|
||||
struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
|
||||
|
||||
NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr));
|
||||
NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->remIpcPtr));
|
||||
|
||||
int offset = 0;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
@@ -282,6 +304,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co
|
||||
send->conn.head = &resources->devMem->head;
|
||||
send->conn.ptrExchange = &resources->devMem->ptrExchange;
|
||||
send->conn.next_hdp_reg = resources->next_hdp_reg;
|
||||
send->conn.redOpArgExchange = resources->devMem->redOpArgExchange;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -291,7 +314,7 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn
|
||||
struct ncclSendMem* remDevMem;
|
||||
struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
|
||||
|
||||
NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr));
|
||||
NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->remIpcPtr));
|
||||
|
||||
int offset = 0;
|
||||
for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
|
||||
@@ -306,6 +329,7 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn
|
||||
recv->conn.tail = &resources->devMem->tail;
|
||||
recv->conn.head = &remDevMem->head;
|
||||
recv->conn.ptrExchange = &remDevMem->ptrExchange;
|
||||
recv->conn.redOpArgExchange = remDevMem->redOpArgExchange;
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -313,6 +337,8 @@ ncclResult_t p2pSendFree(void* resources) {
|
||||
struct p2pSendResources* sendRes = (struct p2pSendResources*)resources;
|
||||
if (sendRes->ipcPtr)
|
||||
CUDACHECK(hipIpcCloseMemHandle(sendRes->ipcPtr));
|
||||
if (sendRes->remIpcPtr)
|
||||
CUDACHECK(hipIpcCloseMemHandle(sendRes->remIpcPtr));
|
||||
if (sendRes->remoteId != -1) {
|
||||
NCCLCHECK(bootstrapRemFree(sendRes->remoteId, sendRes->memRank, sendRes->bootstrap));
|
||||
sendRes->devMem = NULL;
|
||||
@@ -326,6 +352,8 @@ ncclResult_t p2pRecvFree(void* resources) {
|
||||
struct p2pRecvResources* recvRes = (struct p2pRecvResources*)resources;
|
||||
if (recvRes->ipcPtr)
|
||||
CUDACHECK(hipIpcCloseMemHandle(recvRes->ipcPtr));
|
||||
if (recvRes->remIpcPtr)
|
||||
CUDACHECK(hipIpcCloseMemHandle(recvRes->remIpcPtr));
|
||||
if (recvRes->remoteId != -1) {
|
||||
NCCLCHECK(bootstrapRemFree(recvRes->remoteId, recvRes->memRank, recvRes->bootstrap));
|
||||
recvRes->devMem = NULL;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
@@ -47,7 +47,7 @@ typedef struct
|
||||
uint16_t data;
|
||||
} rccl_bfloat16;
|
||||
|
||||
#include "common_kernel.h"
|
||||
#include "../../src/collectives/device/common_kernel.h"
|
||||
#include "EnvVars.hpp"
|
||||
|
||||
// Helper macro for catching HIP errors
|
||||
|
||||
@@ -1,450 +0,0 @@
|
||||
/*************************************************************************
|
||||
* Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
* Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* See LICENSE.txt for license information
|
||||
************************************************************************/
|
||||
|
||||
#ifndef NCCL_COMMON_KERNEL_H_
|
||||
#define NCCL_COMMON_KERNEL_H_
|
||||
|
||||
#include "devcomm.h"
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
// Define min for ssize_t
|
||||
static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; }
|
||||
|
||||
typedef uint64_t PackType;
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI {
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const
|
||||
{
|
||||
return FUNC()(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
// unpack x and y to elements of type T and apply FUNC to each element
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI {
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const;
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, int8_t> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(uint32_t),
|
||||
"PackType must be twice the size of uint32_t.");
|
||||
union converter {
|
||||
PackType storage;
|
||||
struct {
|
||||
uint32_t a, b;
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
// for char, we do these as vector ops
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, uint8_t> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(uint32_t),
|
||||
"PackType must be twice the size of uint32_t.");
|
||||
union converter {
|
||||
PackType storage;
|
||||
struct {
|
||||
uint32_t a, b;
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
// for char, we do these as vector ops
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, int32_t> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(int32_t),
|
||||
"PackType must be twice the size of int.");
|
||||
union converter {
|
||||
PackType storage;
|
||||
struct {
|
||||
int32_t a, b;
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, uint32_t> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(uint32_t),
|
||||
"PackType must be twice the size of int.");
|
||||
union converter {
|
||||
PackType storage;
|
||||
struct {
|
||||
uint32_t a, b;
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, half> {
|
||||
static_assert(sizeof(PackType) == 4 * sizeof(half),
|
||||
"PackType must be four times the size of half.");
|
||||
|
||||
struct PackHalf2 {
|
||||
half2 a, b;
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
struct PackHalf2 cx, cy, cr;
|
||||
cx = *(reinterpret_cast<const struct PackHalf2*>(&x));
|
||||
cy = *(reinterpret_cast<const struct PackHalf2*>(&y));
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return *(reinterpret_cast<PackType*>(&cr));
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, float> {
|
||||
static_assert(sizeof(PackType) == 2 * sizeof(float),
|
||||
"PackType must be twice the size of float.");
|
||||
union converter {
|
||||
PackType storage;
|
||||
struct {
|
||||
float a, b;
|
||||
};
|
||||
};
|
||||
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
converter cx, cy, cr;
|
||||
cx.storage = x;
|
||||
cy.storage = y;
|
||||
|
||||
cr.a = FUNC()(cx.a, cy.a);
|
||||
cr.b = FUNC()(cx.b, cy.b);
|
||||
|
||||
return cr.storage;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, double> {
|
||||
static_assert(sizeof(PackType) == sizeof(double),
|
||||
"PackType must be the same size as double.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y));
|
||||
return __double_as_longlong(rv);
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, uint64_t> {
|
||||
static_assert(sizeof(PackType) == sizeof(uint64_t),
|
||||
"PackType must be the same size as uint64_t.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
uint64_t rv = FUNC()(x, y);
|
||||
return rv;
|
||||
}
|
||||
};
|
||||
|
||||
template<class FUNC>
|
||||
struct MULTI<FUNC, int64_t> {
|
||||
static_assert(sizeof(PackType) == sizeof(int64_t),
|
||||
"PackType must be the same size as int64_t.");
|
||||
__device__ PackType operator()(const PackType x, const PackType y) const {
|
||||
int64_t rv = FUNC()((int64_t)x, (int64_t)y);
|
||||
return rv;
|
||||
}
|
||||
};
|
||||
|
||||
#endif //defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
|
||||
template<typename T> inline __device__
|
||||
T vFetch(const volatile T* ptr) {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
template<typename T> inline __device__
|
||||
void vStore(volatile T* ptr, const T val) {
|
||||
*ptr = val;
|
||||
}
|
||||
|
||||
#if CUDART_VERSION < 9000 && !(defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__))
|
||||
template<> inline __device__
|
||||
half vFetch<half>(const volatile half* ptr) {
|
||||
half r;
|
||||
r.x = ptr->x;
|
||||
return r;
|
||||
}
|
||||
|
||||
template<> inline __device__
|
||||
void vStore<half>(volatile half* ptr, const half val) {
|
||||
ptr->x = val.x;
|
||||
}
|
||||
#else
|
||||
template<> inline __device__
|
||||
half vFetch<half>(const volatile half* ptr) {
|
||||
half r;
|
||||
r = ((half*)ptr)[0];
|
||||
return r;
|
||||
}
|
||||
|
||||
template<> inline __device__
|
||||
void vStore<half>(volatile half* ptr, const half val) {
|
||||
((half*)ptr)[0] = val;
|
||||
}
|
||||
|
||||
template<> inline __device__
|
||||
rccl_bfloat16 vFetch<rccl_bfloat16>(const volatile rccl_bfloat16* ptr) {
|
||||
rccl_bfloat16 r;
|
||||
r.data = ptr->data;
|
||||
return r;
|
||||
}
|
||||
|
||||
template<> inline __device__
|
||||
void vStore<rccl_bfloat16>(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) {
|
||||
ptr->data = val.data;
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef ulong2 Pack128;
|
||||
|
||||
template<class FUNC, typename T>
|
||||
struct MULTI128 {
|
||||
__device__ void operator()(Pack128& x, Pack128& y) {
|
||||
x.x = MULTI<FUNC, T>()(x.x, y.x);
|
||||
x.y = MULTI<FUNC, T>()(x.y, y.y);
|
||||
}
|
||||
};
|
||||
|
||||
inline __device__ void Fetch128(Pack128& v, const Pack128* p) {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
v.x = p->x;
|
||||
v.y = p->y;
|
||||
#else
|
||||
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
|
||||
#endif
|
||||
}
|
||||
inline __device__ void Store128(Pack128* p, Pack128& v) {
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
p->x = v.x;
|
||||
p->y = v.y;
|
||||
#else
|
||||
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t,
|
||||
int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
const T* srcs[MAXSRCS];
|
||||
for (int i=0; i<MAXSRCS; i++) srcs[i] = s[i]+elemOffset+offset;
|
||||
T* dsts[MAXDSTS];
|
||||
for (int i=0; i<MAXDSTS; i++) dsts[i] = d[i]+elemOffset+offset;
|
||||
|
||||
while (offset < Nelem) {
|
||||
T vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = vFetch(srcs[0]+u*WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
T vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i=MINSRCS; i<MAXSRCS; i++) {
|
||||
if (i<nsrcs) {
|
||||
T vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) vals2[u] = vFetch(srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) vals[u] = FUNC()(vals[u], vals2[u]);
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
#pragma unroll
|
||||
for (int i = 0; i < MINDSTS; i++) {
|
||||
for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i=MINDSTS; i<MAXDSTS; i++) {
|
||||
if (i<ndsts) {
|
||||
for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
|
||||
for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
|
||||
offset += inc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ void ReduceCopy128bMulti(const int w, const int nw, const int t,
|
||||
int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) {
|
||||
const int inc = nw * UNROLL * WARP_SIZE;
|
||||
int offset = w * UNROLL * WARP_SIZE + t;
|
||||
|
||||
const Pack128* srcs[MAXSRCS];
|
||||
for (int i=0; i<MAXSRCS; i++) srcs[i] = ((const Pack128*)(s[i]+elemOffset))+offset;
|
||||
Pack128* dsts[MAXDSTS];
|
||||
for (int i=0; i<MAXDSTS; i++) dsts[i] = ((Pack128*)(d[i]+elemOffset))+offset;
|
||||
|
||||
while (offset < Npack) {
|
||||
Pack128 vals[UNROLL];
|
||||
// Load and reduce
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
|
||||
|
||||
for (int i=1; i<MINSRCS; i++) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
|
||||
}
|
||||
#pragma unroll 1
|
||||
for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) {
|
||||
Pack128 vals2[UNROLL];
|
||||
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
|
||||
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
|
||||
}
|
||||
|
||||
// Store
|
||||
for (int i = 0; i < MINDSTS; i++) {
|
||||
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
|
||||
}
|
||||
#pragma unroll 1
|
||||
for (int i=MINDSTS; i<MAXDSTS; i++) {
|
||||
if (i<ndsts) {
|
||||
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
|
||||
}
|
||||
}
|
||||
for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
|
||||
for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
|
||||
offset += inc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(int32_t); }
|
||||
|
||||
#define PACKELEMS (sizeof(Pack128) / sizeof(T))
|
||||
|
||||
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
|
||||
// Multiply UNROLL by 2 if single source/single destination
|
||||
#define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1))
|
||||
#else
|
||||
// Try to limit consecutive load/stores to 8.
|
||||
// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise
|
||||
#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS)))
|
||||
#endif
|
||||
|
||||
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
|
||||
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
|
||||
int nsrcs, const T** srcs, int ndsts, T** dsts,
|
||||
int N) {
|
||||
int Nrem = N;
|
||||
if (Nrem <= 0) return;
|
||||
|
||||
int w = tid / WARP_SIZE; // Warp number
|
||||
int nw = nthreads / WARP_SIZE; // Number of warps
|
||||
int t = tid % WARP_SIZE; // Thread (inside the warp)
|
||||
|
||||
// Check that all is 16B aligned. If not don't use 16B load/stores.
|
||||
int align = 0;
|
||||
#pragma unroll
|
||||
for (int i=0; i<MINSRCS; i++) align |= ptrAlign128(srcs[i]);
|
||||
for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) align |= ptrAlign128(srcs[i]);
|
||||
#pragma unroll
|
||||
for (int i=0; i<MINDSTS; i++) align |= ptrAlign128(dsts[i]);
|
||||
for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) align |= ptrAlign128(dsts[i]);
|
||||
|
||||
int offset = 0;
|
||||
if (align == 0) {
|
||||
// fast path: use 128b loads/stores to do the bulk of the work,
|
||||
// assuming the pointers we have are all 128-bit aligned.
|
||||
|
||||
// main loop
|
||||
int Npack = (Nrem / (PACKELEMS*AUTOUNROLL*WARP_SIZE)) * (AUTOUNROLL*WARP_SIZE); // round down
|
||||
int Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
offset += Nelem;
|
||||
|
||||
// slightly less optimized for section when we don't have full unrolling
|
||||
Npack = Nrem / PACKELEMS;
|
||||
Nelem = Npack * PACKELEMS;
|
||||
|
||||
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
offset += Nelem;
|
||||
}
|
||||
|
||||
// unrolled, by-type (mostly for unaligned buffers)
|
||||
int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down
|
||||
|
||||
ReduceCopyMulti<FUNC, T, UNROLL*PACKELEMS/2, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem);
|
||||
|
||||
Nrem -= Nelem;
|
||||
if (Nrem == 0) return;
|
||||
offset += Nelem;
|
||||
|
||||
// no unroll, by type. Should finish what's remaining.
|
||||
ReduceCopyMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem);
|
||||
}
|
||||
|
||||
#endif // COMMON_KERNEL_H_
|
||||
Ссылка в новой задаче
Block a user