diff --git a/projects/rccl/CMakeLists.txt b/projects/rccl/CMakeLists.txt index e5e62e56b0..87433aa6b4 100644 --- a/projects/rccl/CMakeLists.txt +++ b/projects/rccl/CMakeLists.txt @@ -133,6 +133,7 @@ else() src/collectives/device/broadcast.cu src/collectives/device/reduce_scatter.cu src/collectives/device/sendrecv.cu + src/collectives/device/onerank_reduce.cu src/collectives/device/functions.cu) endif() diff --git a/projects/rccl/makefiles/version.mk b/projects/rccl/makefiles/version.mk index 833ab99d11..22bddcee2e 100644 --- a/projects/rccl/makefiles/version.mk +++ b/projects/rccl/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 10 -NCCL_PATCH := 3 +NCCL_MINOR := 11 +NCCL_PATCH := 4 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/projects/rccl/src/bootstrap.cc b/projects/rccl/src/bootstrap.cc index 36df8795b1..b38f8be0bb 100644 --- a/projects/rccl/src/bootstrap.cc +++ b/projects/rccl/src/bootstrap.cc @@ -232,7 +232,7 @@ struct unexConn { struct remAllocState { int cudaDev; int listenFd; - int stop; + volatile int stop; }; struct extState { @@ -287,7 +287,7 @@ void* ncclRemoteMemAllocationService(void* args) { for (int s=0; slistenFd; pollfds[MAX_SEGMENTS].events = POLLIN; @@ -315,7 +315,7 @@ void* ncclRemoteMemAllocationService(void* args) { } } for (int s=0; s().make(args->comm->nRanks)); - if (blockN > 0) { // Prepare input / output subarrays @@ -65,8 +63,8 @@ __device__ void AllReduceCliqueSplitKernel(struct ncclWorkElem* args) // Perform the reduction #define ALL_REDUCE_CLIQUE_UNROLL 1 - ReduceOrCopyMulti( - threadIdx.x, blockDim.x, redOp, NUM_RANKS, true, NUM_RANKS, srcs, NUM_RANKS, dsts, blockN); + ReduceOrCopyMulti( + threadIdx.x, blockDim.x, nullptr, false, NUM_RANKS, srcs, NUM_RANKS, dsts, blockN); } // Even if there was nothing for this GPU to do, it must participate in a barrier diff --git a/projects/rccl/src/clique/CliqueManager.cc b/projects/rccl/src/clique/CliqueManager.cc index 7870ff9bef..49a193d77a 100644 --- a/projects/rccl/src/clique/CliqueManager.cc +++ b/projects/rccl/src/clique/CliqueManager.cc @@ -274,9 +274,9 @@ bool CliqueManager::IsSupported(ncclFunc_t const coll, { if (m_cliqueMode == CLIQUE_DISABLED) return false; - // Filter based on total input size for each collective type + // Filter based on total input size for each collective type and ops sum/prod/min/max size_t totalBytes = count * ncclTypeSize(datatype); - if (coll == ncclFuncAllReduce && (totalBytes <= m_allReduceByteLimit)) return true; + if (coll == ncclFuncAllReduce && (totalBytes <= m_allReduceByteLimit) && op < ncclAvg) return true; return false; } diff --git a/projects/rccl/src/collectives/device/Makefile b/projects/rccl/src/collectives/device/Makefile index ead98ec6c9..04bce8ecde 100644 --- a/projects/rccl/src/collectives/device/Makefile +++ b/projects/rccl/src/collectives/device/Makefile @@ -10,7 +10,7 @@ include ../../../makefiles/version.mk BUILDDIR ?= $(abspath ../../../build) OBJDIR := $(BUILDDIR)/obj/collectives/device -LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu sendrecv.cu +LIBSRCFILES := all_reduce.cu broadcast.cu reduce.cu all_gather.cu reduce_scatter.cu sendrecv.cu onerank_reduce.cu LIBSRCFILES += functions.cu @@ -36,7 +36,7 @@ $(RULESFILE) : -include $(RULESFILE) -LIBOBJ := $(GENOBJS) $(OBJDIR)/functions.o +LIBOBJ := $(GENOBJS) $(OBJDIR)/functions.o $(OBJDIR)/onerank_reduce.o -include $(DEPFILES) @@ -63,6 +63,11 @@ $(OBJDIR)/functions.o : functions.cu $(OBJDIR)/functions.dep mkdir -p `dirname $@` $(NVCC) $(NVCUFLAGS) -dc $< -o $@ +$(OBJDIR)/onerank_reduce.o : onerank_reduce.cu $(OBJDIR)/onerank_reduce.dep + @printf "Compiling %-35s > %s\n" $< $@ + mkdir -p `dirname $@` + $(NVCC) $(NVCUFLAGS) -dc $< -o $@ + # ... and create the device-side linked object with all those. $(DEVOBJ) : $(LIBOBJ) $(NVCC) $(NVCUFLAGS) -dlink $^ -o $@ diff --git a/projects/rccl/src/collectives/device/all_gather.h b/projects/rccl/src/collectives/device/all_gather.h index f030c8fbfd..59ee7b0eed 100644 --- a/projects/rccl/src/collectives/device/all_gather.h +++ b/projects/rccl/src/collectives/device/all_gather.h @@ -11,7 +11,7 @@ namespace { template - __device__ void runRing(ncclWorkElem *args) { + __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -28,7 +28,7 @@ namespace { T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 0, Proto> - prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, 0, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->coll.redOpArg, args->coll.connIndex << 16); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -79,7 +79,7 @@ namespace { template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { using Proto = ProtoSimple; runRing(args); } @@ -87,14 +87,14 @@ struct RunWorkElement struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; diff --git a/projects/rccl/src/collectives/device/all_reduce.h b/projects/rccl/src/collectives/device/all_reduce.h index 97f347a887..4ec8d3620a 100644 --- a/projects/rccl/src/collectives/device/all_reduce.h +++ b/projects/rccl/src/collectives/device/all_reduce.h @@ -12,7 +12,7 @@ namespace { template - __device__ void runRing(ncclWorkElem *args) { + __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -38,7 +38,7 @@ namespace { } Primitives, 0, Proto> prims - (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex); + (tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->coll.redOpArg, args->coll.connIndex << 16); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -110,12 +110,12 @@ namespace { ACCUMULATE_COUNTER(directRecv); } #ifdef ENABLE_PROFILING - if (tid == 0 && args->op.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk); + if (tid == 0 && args->coll.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk); #endif } template - __device__ void runTreeUpDown(ncclWorkElem *args) { + __device__ __forceinline__ void runTreeUpDown(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -135,7 +135,7 @@ namespace { { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) Primitives, /*Direct=*/0, Proto> prims - (tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff); + (tid, nthreads, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->coll.redOpArg); if (tree->up == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -161,7 +161,7 @@ namespace { { // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) Primitives, /*Direct=*/0, Proto> prims - (tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff); + (tid, nthreads, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg); if (tree->up == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -187,7 +187,7 @@ namespace { } template - __device__ void runTreeSplit(ncclWorkElem *args) { + __device__ __forceinline__ void runTreeSplit(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -219,7 +219,7 @@ namespace { if (tree->up == -1) { // Reduce and broadcast. Max number of recv is 3, max number of send is 3 Primitives, /*Direct=*/0, Proto> - prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff); + prims(tid, nthreads, tree->down, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); int nelem = min(chunkSize, size-offset); @@ -236,7 +236,7 @@ namespace { * but the ctor above for tree roots would be DirectRecv=0 DirectSend=1. */ Primitives, /*Direct=*/0, Proto> - prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth); + prims(tid, nthreadsSplit, tree->down, &tree->up, args->sendbuff, args->recvbuff, args->coll.redOpArg, 0*Proto::MaxGroupWidth); if (tree->down[0] == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -255,7 +255,7 @@ namespace { else { // Broadcast down. Max number of recv is 1, max number of send is 3 (binary tree + local) Primitives, /*Direct=*/0, Proto> - prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth); + prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg, 1*Proto::MaxGroupWidth); if (tree->down[0] == -1) { for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*int(chunkSize); @@ -276,7 +276,7 @@ namespace { template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { using Proto = ProtoSimple; runRing(args); } @@ -284,14 +284,14 @@ struct RunWorkElement struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runTreeUpDown>(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { static constexpr int COLLNET_COPY_THREADS = 64; const int tid = threadIdx.x; const int bid = args->coll.bid; @@ -315,27 +315,37 @@ struct RunWorkElement= tidStartScatter && tid < tidStartReduce && hasUp) { // Scatter + int group = (2*Proto::MaxGroupWidth) | (1<<16); Primitives, /*Direct=*/0, Proto> - prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, 2*Proto::MaxGroupWidth); + prims(tid-tidStartScatter, nThreadsScatter, NULL, tree->up, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize; int nelem = min(tree->nHeads*chunkSize, size-offset); - prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift); + if (args->regUsed) { + prims.directScatter(offset, nelem, chunkSize, tree->headRank, tree->shift); + } else { + prims.scatter(offset, nelem, chunkSize, tree->headRank, tree->shift); + } } } else if (tid >= tidStartReduce && tree->out != -1) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); if (hasDn) { // Reduce, send to network Primitives, /*Direct=*/0, Proto> - prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth); + prims(tid-tidStartReduce, nThreadsReduce, tree->down, &tree->out, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); - prims.recvReduceSend(offset, nelem); + if (args->regUsed) { + prims.directRecvReduceSend(offset, offset, nelem); + } else { + prims.recvReduceSend(offset, nelem); + } } } else { // Directly send to network Primitives, /*Direct=*/0, Proto> - prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, 3*Proto::MaxGroupWidth); + prims(tid-tidStartReduce, nThreadsReduce, nullptr, &tree->out, args->sendbuff, args->recvbuff, args->coll.redOpArg, group); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); @@ -344,27 +354,29 @@ struct RunWorkElement, /*Direct=*/0, Proto> - prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, 0*Proto::MaxGroupWidth); + prims(tid, nThreadsGather, tree->up, NULL, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + bid*tree->nHeads*chunkSize; int nelem = min(tree->nHeads*chunkSize, size-offset); - prims.gather(offset, nelem, chunkSize, tree->headRank, tree->shift); + prims.directGather(offset, nelem, chunkSize, tree->headRank, tree->shift); } } else if (tid >= tidStartBcast && tid < tidStartScatter && tree->out != -1) { + int group = (1*Proto::MaxGroupWidth) | (0<<16); if (hasDn) { // Recv from network, broadcast Primitives, /*Direct=*/0, Proto> - prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth); + prims(tid-tidStartBcast, nThreadsBcast, &tree->out, tree->down, args->sendbuff, args->recvbuff, args->coll.redOpArg, group, args); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); - prims.recvCopySend(offset, nelem, /*postOp=*/true); + prims.recvCopyDirectSend(offset, offset, nelem, /*postOp=*/true); } } else { // Recv from network (no post thread needed) Primitives, /*Direct=*/0, Proto> - prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, 1*Proto::MaxGroupWidth); + prims(tid-tidStartBcast, nThreadsBcast, &tree->out, nullptr, args->sendbuff, args->recvbuff, args->coll.redOpArg, group); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t offset = gridOffset + (bid*tree->nHeads+tree->headRank)*chunkSize; int nelem = min(chunkSize, size-offset); @@ -377,28 +389,28 @@ struct RunWorkElement struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runTreeUpDown(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, RedOp, T, args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { LAUNCH_CLIQUE_KERNEL(AllReduceCliqueSplitKernel, RedOp, T, args); } }; diff --git a/projects/rccl/src/collectives/device/broadcast.h b/projects/rccl/src/collectives/device/broadcast.h index c318afd93c..31a9d2a503 100644 --- a/projects/rccl/src/collectives/device/broadcast.h +++ b/projects/rccl/src/collectives/device/broadcast.h @@ -10,7 +10,7 @@ namespace { template - __device__ void runRing(ncclWorkElem *args) { + __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -32,7 +32,7 @@ namespace { T *inputBuf = (T*)args->sendbuff; T *outputBuf = (T*)args->recvbuff; Primitives, 0, Proto> - prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, 0, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, inputBuf, outputBuf, args->coll.redOpArg, args->coll.connIndex << 16); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -70,14 +70,14 @@ namespace { } } #ifdef ENABLE_PROFILING - if (tid == 0 && args->op.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk); + if (tid == 0 && args->coll.opCount) devProf->elems[blockIdx.x].total_cycle += (__builtin_amdgcn_s_memrealtime() - clk); #endif } } template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { using Proto = ProtoSimple; runRing(args); } @@ -85,14 +85,14 @@ struct RunWorkElement struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index c11f35e882..543a3ca039 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -12,97 +12,92 @@ #include "devcomm.h" #define COLL_UNROLL 2 -#define NCCL_MAX_DEV_ARITY NCCL_MAX_TREE_ARITY +#define NCCL_MAX_DEV_ARITY (NCCL_MAX_TREE_ARITY-1) // Using balanced tree instead of split tree #define __syncwarp() -#define NCCL_FUNC5(func, algo, redop, type) \ - NCCL_FUNC_NAME(func, algo, LL, redop, type), \ - NCCL_FUNC_NAME(func, algo, LL, redop, type), \ - NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type) +#define NCCL_FUNC5(func, algo, devredop, type, nullify) \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) -#define NCCL_FUNC4(func, redop, type) \ - NCCL_FUNC5(func, TREE, redop, type), \ - NCCL_FUNC5(func, RING, redop, type), \ - NCCL_FUNC5(func, COLLNET, redop, type) +#define NCCL_FUNC4(func, devredop, type, nullify) \ + NCCL_FUNC5(func, TREE, devredop, type, nullify), \ + NCCL_FUNC5(func, RING, devredop, type, nullify), \ + NCCL_FUNC5(func, COLLNET, devredop, type, nullify) // Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, redop) \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, uint8_t), \ - NCCL_FUNC4(func, redop, int32_t), \ - NCCL_FUNC4(func, redop, uint32_t), \ - NCCL_FUNC4(func, redop, int64_t), \ - NCCL_FUNC4(func, redop, uint64_t), \ - NCCL_FUNC4(func, redop, half), \ - NCCL_FUNC4(func, redop, float), \ - NCCL_FUNC4(func, redop, double), \ - NCCL_FUNC4(func, redop, rccl_bfloat16) -#define NCCL_FUNCS3B(func, redop) \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t) +#define NCCL_FUNCS3A(func, devredop, nullForFloat) \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, uint8_t, 0), \ + NCCL_FUNC4(func, devredop, int32_t, 0), \ + NCCL_FUNC4(func, devredop, uint32_t, 0), \ + NCCL_FUNC4(func, devredop, int64_t, 0), \ + NCCL_FUNC4(func, devredop, uint64_t, 0), \ + NCCL_FUNC4(func, devredop, half, nullForFloat), \ + NCCL_FUNC4(func, devredop, float, nullForFloat), \ + NCCL_FUNC4(func, devredop, double, nullForFloat), \ + NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat) +#define NCCL_FUNCS3B(func, devredop) \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0) // Must be consistent with ncclRedOp_t #define NCCL_FUNCS2A(func) \ - NCCL_FUNCS3A(func, Sum ), \ - NCCL_FUNCS3A(func, Prod), \ - NCCL_FUNCS3A(func, Max ), \ - NCCL_FUNCS3A(func, Min ), \ - NCCL_FUNCS3A(func, Avg) + NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1) + #define NCCL_FUNCS2B(func) \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum) // [RCCL] Adding clique-based kernels for AllReduce, in-place of unused RingLL28 kernels -#define NCCL_FUNC5B(func, algo, redop, type) \ - NCCL_FUNC_NAME(func, algo, LL, redop, type), \ - NCCL_FUNC_NAME(func, algo, LL128, redop, type), \ - NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type) +#define NCCL_FUNC5B(func, algo, devredop, type, nullify) \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) -#define NCCL_FUNC4B(func, redop, type) \ - NCCL_FUNC5(func, TREE, redop, type), \ - NCCL_FUNC5B(func, RING, redop, type), \ - NCCL_FUNC5(func, COLLNET, redop, type) +#define NCCL_FUNC4B(func, devredop, type, nullify) \ + NCCL_FUNC5B(func, TREE, devredop, type, nullify), \ + NCCL_FUNC5B(func, RING, devredop, type, nullify), \ + NCCL_FUNC5B(func, COLLNET, devredop, type, nullify) -#define NCCL_FUNCS3C(func, redop) \ - NCCL_FUNC4B(func, redop, int8_t), \ - NCCL_FUNC4B(func, redop, uint8_t), \ - NCCL_FUNC4B(func, redop, int32_t), \ - NCCL_FUNC4B(func, redop, uint32_t), \ - NCCL_FUNC4B(func, redop, int64_t), \ - NCCL_FUNC4B(func, redop, uint64_t), \ - NCCL_FUNC4B(func, redop, half), \ - NCCL_FUNC4B(func, redop, float), \ - NCCL_FUNC4B(func, redop, double), \ - NCCL_FUNC4B(func, redop, rccl_bfloat16) +#define NCCL_FUNCS3C(func, devredop, nullForFloat) \ + NCCL_FUNC4B(func, devredop, int8_t, 0), \ + NCCL_FUNC4B(func, devredop, uint8_t, 0), \ + NCCL_FUNC4B(func, devredop, int32_t, 0), \ + NCCL_FUNC4B(func, devredop, uint32_t, 0), \ + NCCL_FUNC4B(func, devredop, int64_t, 0), \ + NCCL_FUNC4B(func, devredop, uint64_t, 0), \ + NCCL_FUNC4B(func, devredop, half, nullForFloat), \ + NCCL_FUNC4B(func, devredop, float, nullForFloat), \ + NCCL_FUNC4B(func, devredop, double, nullForFloat), \ + NCCL_FUNC4B(func, devredop, rccl_bfloat16, nullForFloat) #define NCCL_FUNCS2C(func) \ - NCCL_FUNCS3C(func, Sum ), \ - NCCL_FUNCS3C(func, Prod), \ - NCCL_FUNCS3C(func, Max ), \ - NCCL_FUNCS3C(func, Min ), \ - NCCL_FUNCS3C(func, Avg) + NCCL_FUNCS3C(func, Sum, /*nullForFloat=*/0), \ + NCCL_FUNCS3C(func, Prod, /*nullForFloat=*/0), \ + NCCL_FUNCS3C(func, Max, /*nullForFloat=*/0), \ + NCCL_FUNCS3C(func, Min, /*nullForFloat=*/0), \ + NCCL_FUNCS3C(func, PreMulSum, /*nullForFloat=*/0), \ + NCCL_FUNCS3C(func, SumPostDiv, /*nullForFloat=*/1) -// Must be consistent with ncclFunc_t -#define NCCL_FUNCS() { \ - NCCL_FUNCS2B(Broadcast), \ - NCCL_FUNCS2A(Reduce), \ - NCCL_FUNCS2B(AllGather), \ - NCCL_FUNCS2A(ReduceScatter), \ - NCCL_FUNCS2C(AllReduce), \ - NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t) } -// [/RCCL] // Must be consistent with the ncclFuncSet enum using ncclKernelFunc_t = void (*)(struct ncclWorkElem* args); @@ -113,13 +108,25 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{ // confuses clang. This will be fixed in the next clang release. #if defined(__HIP_DEVICE_COMPILE__) #if defined(BUILD_ALLREDUCE_ONLY) - NCCL_FUNC4B(AllReduce, Sum, float), + NCCL_FUNC4B(AllReduce, Sum, float, 0), #else NCCL_FUNCS2B(Broadcast), NCCL_FUNCS2A(Reduce), NCCL_FUNCS2B(AllGather), NCCL_FUNCS2A(ReduceScatter), NCCL_FUNCS2C(AllReduce), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), +#if defined(RCCL_BFLOAT16) + NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16), +#endif NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), #endif #endif @@ -142,7 +149,7 @@ struct Caller{ void call(struct ncclWorkElem* const c) noexcept { ncclFuncs[f](c); } }; -static_assert(FUNC_INDEX_P2P == 2250, "Wrong P2P function index"); +static_assert(FUNC_INDEX_P2P == 2710, "Wrong P2P function index"); inline __device__ @@ -165,7 +172,7 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept { else assert("Unsupported function index"); #else - if (c->funcIndex < 450) { + if (c->funcIndex < 540) { if (c->funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c); else if (c->funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c); else if (c->funcIndex % 9 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(c); @@ -176,8 +183,8 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept { else if (c->funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t(c); else ncclFunction_Broadcast_COLLNET_SIMPLE_Sum_int8_t(c); } - else if (c->funcIndex < 900) Caller<450, 900>::call(c); - else if (c->funcIndex < 1350) { + else if (c->funcIndex < 1080) Caller<540, 1080>::call(c); + else if (c->funcIndex < 1620) { if (c->funcIndex % 9 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(c); else if (c->funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(c); else if (c->funcIndex % 9 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(c); @@ -188,8 +195,46 @@ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept { else if (c->funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t(c); else ncclFunction_AllGather_COLLNET_SIMPLE_Sum_int8_t(c); } - else if (c->funcIndex < 2250) Caller<1350, 2250>::call(c); - else ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c); + else if (c->funcIndex < 2700) Caller<1620, 2700>::call(c); + else { + switch (c->funcIndex - 2700) { + case 0: + ncclFunction_OneRankReduce_PreMulSum_int8_t(c); + break; + case 1: + ncclFunction_OneRankReduce_PreMulSum_uint8_t(c); + break; + case 2: + ncclFunction_OneRankReduce_PreMulSum_int32_t(c); + break; + case 3: + ncclFunction_OneRankReduce_PreMulSum_uint32_t(c); + break; + case 4: + ncclFunction_OneRankReduce_PreMulSum_int64_t(c); + break; + case 5: + ncclFunction_OneRankReduce_PreMulSum_uint64_t(c); + break; + case 6: + ncclFunction_OneRankReduce_PreMulSum_half(c); + break; + case 7: + ncclFunction_OneRankReduce_PreMulSum_float(c); + break; + case 8: + ncclFunction_OneRankReduce_PreMulSum_double(c); + break; + case 9: + ncclFunction_OneRankReduce_PreMulSum_rccl_bfloat16(c); + break; + case 10: + ncclFunction_SendRecv_RING_SIMPLE_Sum_int8_t(c); + break; + default: + break; + } + } #endif } @@ -203,29 +248,42 @@ class ncclFunction { #define traceColl(fIdx) \ uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ - shmem.comm.collTrace[pos].opCount = elems[0].op.opCount; \ shmem.comm.collTrace[pos].bid = bid; \ shmem.comm.collTrace[pos].funcIndex = fIdx; \ if (fIdx == FUNC_INDEX_P2P) { \ + shmem.comm.collTrace[pos].opCount = elems[0].p2p.opCount; \ shmem.comm.collTrace[pos].p2p.nThreads = elems[0].p2p.nThreads; \ shmem.comm.collTrace[pos].p2p.delta = (uint16_t)(elems[0].p2p.delta); \ } else { \ + shmem.comm.collTrace[pos].opCount = elems[0].coll.opCount; \ shmem.comm.collTrace[pos].coll.nThreads = elems[0].nThreads; \ shmem.comm.collTrace[pos].coll.bid = elems[0].coll.bid; \ shmem.comm.collTrace[pos].coll.nChannels = elems[0].coll.nChannels; \ } #define traceKernelLaunch(fIdx) { \ - traceColl(fIdx); \ - asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \ - shmem.comm.collTrace[pos].type = ncclCollTraceKernelLaunchType; \ + if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \ + traceColl(fIdx); \ + asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \ + shmem.comm.collTrace[pos].type = ncclCollTraceKernelLaunchType; \ + } \ } #define traceCollEnd(fIdx) { \ - traceColl(fIdx); \ - shmem.comm.collTrace[pos].type = ncclCollTraceCollEndType; \ + if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \ + traceColl(fIdx); \ + shmem.comm.collTrace[pos].type = ncclCollTraceCollEndType; \ + } \ + } +#define traceKernelEnd(fIdx) { \ + if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \ + traceColl(fIdx); \ + shmem.comm.collTrace[pos].type = ncclCollTraceKernelEndType; \ + } \ } #define traceAbort(fIdx) { \ - traceColl(fIdx); \ - shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \ + if (!(fIdx == FUNC_INDEX_P2P && elems[0].p2p.nThreads == 0)) { \ + traceColl(fIdx); \ + shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \ + } \ } // traceData(int16_t data2, uint32_t data4, uint64_t data8_0, uint64_t data8_1) #define traceData(data2, data4, data8_0, data8_1) { \ @@ -285,14 +343,24 @@ __device__ int copyToShmem(T *dst, T const *src, int turn=0) { template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem*) { + __device__ void run(ncclWorkElem*) { // Put NOT IMPLEMENTED behavior here. } }; +#if CUDART_VERSION >= 11030 +__device__ constexpr int ncclWorkElemFactors[NCCL_NUM_ALGORITHMS] = +#else +static __device__ __constant__ int ncclWorkElemFactors[NCCL_NUM_ALGORITHMS] = +#endif +{/*Tree*/1, /*Ring and P2P*/1, /*CollNet*/NCCL_REG_ELEM_FACTOR}; + template struct RunWork { - __device__ __attribute__((noinline)) void run(ncclWork *w) { + // This __forceinline__ is necessary. The compiler was inserting a function call + // here from the LL ncclKernel. + __device__ __forceinline__ void run(ncclWork *w) { + int tid = threadIdx.x; /* Some invariants that must hold: * 1. All elems[] have same funcIndex. * 2. All elems[] have same nThreads. @@ -300,20 +368,23 @@ struct RunWork { * for all elems[]. * * If (1) isn't true then we might be in the wrong function since dispatch - * on ncclFuncs[w->elems[0].funcIndex] is how we got here. + * on ncclFuncs[w->funcIndex] is how we got here. * * If (2) or (3) aren't true, then threads from different work elements * could race for barrier resources (barrier numbers 0...15) which is fatal. * - * Important, to ensure (3), implementations of - * `RunWorkElement::run()` may only use values which - * are the same for all elems[] when deciding how to map threads to groups, - * such as the following: + * IMPORTANT!!! To ensure (3), implementations of + * `RunWorkElement::run()` may only use the following + * when deciding how to map threads to groups: * Fn, T, RedOp, Algo, Proto, nThreads * - * This last one is difficult to enforce and diagnosing it is a headeache. - * Device-side developers, consider yourselves warned. + * This last one is difficult to enforce so I hope everyone reads this. */ + if (tid < w->elems[0].nThreads) { + #pragma unroll 1 + for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e+=ncclWorkElemFactors[Algo]) + RunWorkElement().run(&w->elems[e]); + } } }; @@ -323,6 +394,7 @@ struct ncclShmemGroup { ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY]; void* srcs[NCCL_MAX_DIRECT_ARITY+1]; void* dsts[NCCL_MAX_DIRECT_ARITY+1]; + int totalSendSize[NCCL_MAX_SLICE_PER_CHUNK]; uint64_t barrier; uint64_t barrier_next[MAXWARPS]; }; @@ -333,6 +405,7 @@ struct ncclShmemData { struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; }; uint32_t sync[MAXWARPS]; + uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1]; ncclDevComm comm; ncclChannel channel; ncclWork work; @@ -362,9 +435,13 @@ __device__ void ncclKernel(ncclWorkElem first) { turn = copyToShmem(&shmem.channel, channel, turn); // To optimize for latency, (only) the first operation is passed as argument. - if (bid == 0 && first.active != 0) + if (bid == 0 && first.active != 0) { turn = copyToShmem(&shmem.work.elems[0], &first, turn); - + if (1 <= tid && tid < NCCL_MAX_WORK_ELEMENTS && tid % ncclWorkElemFactors[Algo] == 0) { + shmem.work.elems[tid].active = 0; + shmem.work.elems[tid].redOpArgIsPtr = 0; + } + } struct ncclWorkElem* elems = shmem.work.elems; __syncthreads(); // publish shmem @@ -401,13 +478,36 @@ __device__ void ncclKernel(ncclWorkElem first) { if (tid == 0) channel->index = workFifoIx; // write back to real channel, not shmem shadow + if (tid < NCCL_MAX_WORK_ELEMENTS && tid % ncclWorkElemFactors[Algo] == 0) { + ncclWorkElem *we = &shmem.work.elems[tid]; + if (we->redOpArgIsPtr && we->active != 0) { + /* redOpArg is a pointer to the scalar value, so we'll dereference it + * here so that redOpArg holds the bits of the scalar going forward. + * The tricky thing is we don't know its type T since that's encoded in + * the funcIndex. Because it would be difficult to get sizeof(T) from + * funcIndex, we'll cheat and just dereference the largest possible size + * given the alignment of the pointer. We might be reading in more bytes + * than we need but that's harmless. + */ + if (we->coll.redOpArg%2 != 0) + we->coll.redOpArg = *reinterpret_cast(we->coll.redOpArg); + else if (we->coll.redOpArg%4 != 0) + we->coll.redOpArg = *reinterpret_cast(we->coll.redOpArg); + else if (we->coll.redOpArg%8 != 0) + we->coll.redOpArg = *reinterpret_cast(we->coll.redOpArg); + else + we->coll.redOpArg = *reinterpret_cast(we->coll.redOpArg); + } + } + __syncthreads(); + if (shmem.work.elems[0].funcIndex == FnIndex) RunWork().run(&shmem.work); else NCCL_CALL_FUNCTIONS(&elems[0]); if (shmem.work.elems[0].active == 2) { - if (COLLTRACE && tid == 0) traceCollEnd(0xffff) + if (COLLTRACE && tid == 0) traceKernelEnd(elems->funcIndex) break; } __syncthreads(); @@ -415,43 +515,51 @@ __device__ void ncclKernel(ncclWorkElem first) { } } -#define IMPL_COLL_KERN(func, algo, proto, redop, type, fIndex) \ +#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first) { \ +__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(ncclWorkElem first) { \ if (first.comm->collTraceThread) \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(first); \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(first); \ else \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(first); \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(first); \ } // Examples : AllReduce, RING, LL, Sum, uint8 /* Functions for aggregation case */ -#define IMPL_COLL_FUNC(func, algo, proto, redop, type) \ -__device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args) { \ - RunWorkElement, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(args); \ +#define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \ +__device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(struct ncclWorkElem* args) { \ + RunWork, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem->work); \ } // Only generate inline kernels for LL -#define IMPL_COLL4(func, algo, redop, type, ncclType) \ - IMPL_COLL_FUNC(func, algo, LL, redop, type) \ - IMPL_COLL_FUNC(func, algo, SIMPLE, redop, type) +#define IMPL_COLL4(func, algo, devredop, type, ncclType) \ + IMPL_COLL_FUNC(func, algo, LL, devredop, type) \ + IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) \ -#define IMPL_COLL3(func, redop, type, ncclType) \ - IMPL_COLL4(func, TREE, redop, type, ncclType) \ - IMPL_COLL4(func, RING, redop, type, ncclType) \ - IMPL_COLL4(func, COLLNET, redop, type, ncclType) +#define IMPL_COLL3(func, devredop, type, ncclType) \ + IMPL_COLL4(func, TREE, devredop, type, ncclType) \ + IMPL_COLL4(func, RING, devredop, type, ncclType) \ + IMPL_COLL4(func, COLLNET, devredop, type, ncclType) -#define IMPL_COLL2(func, redop) \ - IMPL_COLL3(func, redop, int8_t, ncclInt8) \ - IMPL_COLL3(func, redop, uint8_t, ncclUint8) \ - IMPL_COLL3(func, redop, int32_t, ncclInt32) \ - IMPL_COLL3(func, redop, uint32_t, ncclUint32) \ - IMPL_COLL3(func, redop, int64_t, ncclInt64) \ - IMPL_COLL3(func, redop, uint64_t, ncclUint64) \ - IMPL_COLL3(func, redop, half, ncclFloat16) \ - IMPL_COLL3(func, redop, float, ncclFloat32) \ - IMPL_COLL3(func, redop, double, ncclFloat64) \ - IMPL_COLL3(func, redop, rccl_bfloat16, ncclBfloat16) +#define IMPL_COLL2(func, devredop) \ + IMPL_COLL3(func, devredop, int8_t, ncclInt8) \ + IMPL_COLL3(func, devredop, uint8_t, ncclUint8) \ + IMPL_COLL3(func, devredop, int32_t, ncclInt32) \ + IMPL_COLL3(func, devredop, uint32_t, ncclUint32) \ + IMPL_COLL3(func, devredop, int64_t, ncclInt64) \ + IMPL_COLL3(func, devredop, uint64_t, ncclUint64) \ + IMPL_COLL3(func, devredop, half, ncclFloat16) \ + IMPL_COLL3(func, devredop, float, ncclFloat32) \ + IMPL_COLL3(func, devredop, double, ncclFloat64) \ + IMPL_COLL3(func, devredop, rccl_bfloat16, ncclBfloat16) + +#define IMPL_COLL2A(func, devredop) \ + IMPL_COLL3(func, devredop, int8_t, ncclInt8) \ + IMPL_COLL3(func, devredop, uint8_t, ncclUint8) \ + IMPL_COLL3(func, devredop, int32_t, ncclInt32) \ + IMPL_COLL3(func, devredop, uint32_t, ncclUint32) \ + IMPL_COLL3(func, devredop, int64_t, ncclInt64) \ + IMPL_COLL3(func, devredop, uint64_t, ncclUint64) // Reduction define all functions #define IMPL_COLL_R(func) \ @@ -459,40 +567,49 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, red IMPL_COLL2(func, Prod) \ IMPL_COLL2(func, Min) \ IMPL_COLL2(func, Max) \ - IMPL_COLL2(func, Avg) + IMPL_COLL2(func, PreMulSum) \ + IMPL_COLL2A(func, SumPostDiv) // [RCCL] Define clique-based implementations (repurposed LL128) -#define IMPL_COLL4_CLIQUE(func, algo, redop, type, ncclType) \ - IMPL_COLL_FUNC(func, algo, LL, redop, type) \ - IMPL_COLL_FUNC(func, algo, LL128, redop, type) \ - IMPL_COLL_FUNC(func, algo, SIMPLE, redop, type) +#define IMPL_COLL4_CLIQUE(func, algo, devredop, type, ncclType) \ + IMPL_COLL_FUNC(func, algo, LL, devredop, type) \ + IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \ + IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) \ -#define IMPL_COLL3_CLIQUE(func, redop, type, ncclType) \ - IMPL_COLL4(func, TREE, redop, type, ncclType) \ - IMPL_COLL4_CLIQUE(func, RING, redop, type, ncclType) \ - IMPL_COLL4(func, COLLNET, redop, type, ncclType) +#define IMPL_COLL3_CLIQUE(func, devredop, type, ncclType) \ + IMPL_COLL4_CLIQUE(func, TREE, devredop, type, ncclType) \ + IMPL_COLL4_CLIQUE(func, RING, devredop, type, ncclType) \ + IMPL_COLL4_CLIQUE(func, COLLNET, devredop, type, ncclType) -#define IMPL_COLL2_CLIQUE(func, redop) \ - IMPL_COLL3_CLIQUE(func, redop, int8_t, ncclInt8) \ - IMPL_COLL3_CLIQUE(func, redop, uint8_t, ncclUint8) \ - IMPL_COLL3_CLIQUE(func, redop, int32_t, ncclInt32) \ - IMPL_COLL3_CLIQUE(func, redop, uint32_t, ncclUint32) \ - IMPL_COLL3_CLIQUE(func, redop, int64_t, ncclInt64) \ - IMPL_COLL3_CLIQUE(func, redop, uint64_t, ncclUint64) \ - IMPL_COLL3_CLIQUE(func, redop, half, ncclFloat16) \ - IMPL_COLL3_CLIQUE(func, redop, float, ncclFloat32) \ - IMPL_COLL3_CLIQUE(func, redop, double, ncclFloat64) \ - IMPL_COLL3_CLIQUE(func, redop, rccl_bfloat16, ncclBfloat16) +#define IMPL_COLL2_CLIQUE(func, devredop) \ + IMPL_COLL3_CLIQUE(func, devredop, int8_t, ncclInt8) \ + IMPL_COLL3_CLIQUE(func, devredop, uint8_t, ncclUint8) \ + IMPL_COLL3_CLIQUE(func, devredop, int32_t, ncclInt32) \ + IMPL_COLL3_CLIQUE(func, devredop, uint32_t, ncclUint32) \ + IMPL_COLL3_CLIQUE(func, devredop, int64_t, ncclInt64) \ + IMPL_COLL3_CLIQUE(func, devredop, uint64_t, ncclUint64) \ + IMPL_COLL3_CLIQUE(func, devredop, half, ncclFloat16) \ + IMPL_COLL3_CLIQUE(func, devredop, float, ncclFloat32) \ + IMPL_COLL3_CLIQUE(func, devredop, double, ncclFloat64) \ + IMPL_COLL3_CLIQUE(func, devredop, rccl_bfloat16, ncclBfloat16) + +#define IMPL_COLL2A_CLIQUE(func, devredop) \ + IMPL_COLL3_CLIQUE(func, devredop, int8_t, ncclInt8) \ + IMPL_COLL3_CLIQUE(func, devredop, uint8_t, ncclUint8) \ + IMPL_COLL3_CLIQUE(func, devredop, int32_t, ncclInt32) \ + IMPL_COLL3_CLIQUE(func, devredop, uint32_t, ncclUint32) \ + IMPL_COLL3_CLIQUE(func, devredop, int64_t, ncclInt64) \ + IMPL_COLL3_CLIQUE(func, devredop, uint64_t, ncclUint64) #define IMPL_COLL_CLIQUE(func) \ IMPL_COLL2_CLIQUE(func, Sum) \ IMPL_COLL2_CLIQUE(func, Prod) \ IMPL_COLL2_CLIQUE(func, Min) \ IMPL_COLL2_CLIQUE(func, Max) \ - IMPL_COLL2_CLIQUE(func, Avg) + IMPL_COLL2_CLIQUE(func, PreMulSum) \ + IMPL_COLL2A_CLIQUE(func, SumPostDiv) // [/RCCL] - // Copy primitives only define one function for copy #define IMPL_COLL_C(func) IMPL_COLL3(func, Sum, int8_t, ncclInt8); diff --git a/projects/rccl/src/collectives/device/common_kernel.h b/projects/rccl/src/collectives/device/common_kernel.h index 525e3a4cda..0349249e1a 100644 --- a/projects/rccl/src/collectives/device/common_kernel.h +++ b/projects/rccl/src/collectives/device/common_kernel.h @@ -26,7 +26,6 @@ typedef uint64_t PackType; template struct FuncTraits /*{ - __device__ static Fn make(); __device__ static T preOp(Fn, T); __device__ static T postOp(Fn, T); }*/; @@ -501,12 +500,12 @@ inline __device__ void Store128(Pack128* p, Pack128& v) { #endif } -template +template __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, - FUNC fn, int const numPreOpSrcs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem + uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Nelem ) { - const int inc = nw * UNROLL * WARP_SIZE; - int offset = w * UNROLL * WARP_SIZE + t; + const Int inc = nw * UNROLL * WARP_SIZE; + Int offset = w * UNROLL * WARP_SIZE + t; const T* srcs[MAXSRCS]; for (int i=0; i().preOp(fn, vals[u]); } #pragma unroll for (int i=1; i().preOp(fn, vals2[u]); } for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]); @@ -534,12 +535,17 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const for (int i=MINSRCS; i().preOp(fn, vals2[u]); + } for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]); } } if (postOp) { + FUNC fn(redOpArgs[0]); #pragma unroll for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits().postOp(fn, vals[u]); } @@ -561,12 +567,12 @@ __device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const } } -template +template __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t, - FUNC fn, int numPreOpSrcs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack + uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Npack ) { - const int inc = nw * UNROLL * WARP_SIZE; - int offset = w * UNROLL * WARP_SIZE + t; + const Int inc = nw * UNROLL * WARP_SIZE; + Int offset = w * UNROLL * WARP_SIZE + t; const Pack128* srcs[MAXSRCS]; for (int i=0; i().preOp(fn, vals[u]); } #pragma unroll for (int i=1; i().preOp(fn, vals2[u]); } for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); @@ -594,12 +602,17 @@ __device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, c for (int i=MINSRCS; i().preOp(fn, vals2[u]); + } for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); } } if (postOp) { + FUNC fn(redOpArgs[0]); #pragma unroll for (int u = 0; u < UNROLL; ++u) MULTI128().postOp(fn, vals[u]); } @@ -627,11 +640,11 @@ __device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(uint32_t); } #define PACKELEMS (sizeof(Pack128) / sizeof(T)) #define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1)) -template +template __device__ __forceinline__ void ReduceOrCopyMulti( - const int tid, const int nthreads, FUNC fn, int numPreOpSrcs, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, int N + const int tid, const int nthreads, uint64_t* redOpArgs, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, Int N ) { - int Nrem = N; + Int Nrem = N; if (Nrem <= 0) return; int w = tid / WARP_SIZE; // Warp number @@ -647,17 +660,17 @@ __device__ __forceinline__ void ReduceOrCopyMulti( for (int i=0; i - (w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); + ReduceCopy128bMulti + (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); Nrem -= Nelem; if (Nrem == 0) return; @@ -667,8 +680,8 @@ __device__ __forceinline__ void ReduceOrCopyMulti( Npack = Nrem / PACKELEMS; Nelem = Npack * PACKELEMS; - ReduceCopy128bMulti - (w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); + ReduceCopy128bMulti + (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); Nrem -= Nelem; if (Nrem == 0) return; @@ -676,18 +689,18 @@ __device__ __forceinline__ void ReduceOrCopyMulti( } // unrolled, by-type (mostly for unaligned buffers) - int Nelem = (Nrem / (AUTOUNROLL*PACKELEMS/2*WARP_SIZE)) * (AUTOUNROLL*PACKELEMS/2*WARP_SIZE); // round down + Int Nelem = (Nrem / (AUTOUNROLL*PACKELEMS/2*WARP_SIZE)) * (AUTOUNROLL*PACKELEMS/2*WARP_SIZE); // round down - ReduceCopyMulti - (w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem); + ReduceCopyMulti + (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem); Nrem -= Nelem; if (Nrem == 0) return; offset += Nelem; // no unroll, by type. Should finish what's remaining. - ReduceCopyMulti - (w, nw, t, fn, numPreOpSrcs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem); + ReduceCopyMulti + (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem); } #endif // COMMON_KERNEL_H_ diff --git a/projects/rccl/src/collectives/device/functions.cu b/projects/rccl/src/collectives/device/functions.cu index 1689b3df09..7acb80be60 100644 --- a/projects/rccl/src/collectives/device/functions.cu +++ b/projects/rccl/src/collectives/device/functions.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information @@ -13,66 +13,100 @@ __device__ struct ncclShmemData* ncclShmem; #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) #else -#define NCCL_FUNC5(func, algo, redop, type) \ - NCCL_FUNC_NAME(func, algo, LL, redop, type), \ - NCCL_FUNC_NAME(func, algo, LL128, redop, type), \ - NCCL_FUNC_NAME(func, algo, SIMPLE, redop, type) +#define NCCL_FUNC5(func, algo, devredop, type, nullify) \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) -#define NCCL_FUNC4(func, redop, type) \ - NCCL_FUNC5(func, TREE, redop, type), \ - NCCL_FUNC5(func, RING, redop, type), \ - NCCL_FUNC5(func, COLLNET, redop, type) +#define NCCL_FUNC4(func, devredop, type, nullify) \ + NCCL_FUNC5(func, TREE, devredop, type, nullify), \ + NCCL_FUNC5(func, RING, devredop, type, nullify), \ + NCCL_FUNC5(func, COLLNET, devredop, type, nullify) +#if defined(RCCL_BFLOAT16) // Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, redop) \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, uint8_t), \ - NCCL_FUNC4(func, redop, int32_t), \ - NCCL_FUNC4(func, redop, uint32_t), \ - NCCL_FUNC4(func, redop, int64_t), \ - NCCL_FUNC4(func, redop, uint64_t), \ - NCCL_FUNC4(func, redop, half), \ - NCCL_FUNC4(func, redop, float), \ - NCCL_FUNC4(func, redop, double) -#define NCCL_FUNCS3B(func, redop) \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t) +#define NCCL_FUNCS3A(func, devredop, nullForFloat) \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, uint8_t, 0), \ + NCCL_FUNC4(func, devredop, int32_t, 0), \ + NCCL_FUNC4(func, devredop, uint32_t, 0), \ + NCCL_FUNC4(func, devredop, int64_t, 0), \ + NCCL_FUNC4(func, devredop, uint64_t, 0), \ + NCCL_FUNC4(func, devredop, half, nullForFloat), \ + NCCL_FUNC4(func, devredop, float, nullForFloat), \ + NCCL_FUNC4(func, devredop, double, nullForFloat), \ + NCCL_FUNC4(func, devredop, rccl_bfloat16, nullForFloat) +#define NCCL_FUNCS3B(func, devredop) \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0) +#else +// Must be consistent with ncclDataType_t +#define NCCL_FUNCS3A(func, devredop, nullForFloat) \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, uint8_t, 0), \ + NCCL_FUNC4(func, devredop, int32_t, 0), \ + NCCL_FUNC4(func, devredop, uint32_t, 0), \ + NCCL_FUNC4(func, devredop, int64_t, 0), \ + NCCL_FUNC4(func, devredop, uint64_t, 0), \ + NCCL_FUNC4(func, devredop, half, nullForFloat), \ + NCCL_FUNC4(func, devredop, float, nullForFloat), \ + NCCL_FUNC4(func, devredop, double, nullForFloat) +#define NCCL_FUNCS3B(func, devredop) \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0), \ + NCCL_FUNC4(func, devredop, int8_t, 0) +#endif // Must be consistent with ncclRedOp_t #define NCCL_FUNCS2A(func) \ - NCCL_FUNCS3A(func, Sum ), \ - NCCL_FUNCS3A(func, Prod), \ - NCCL_FUNCS3A(func, Max ), \ - NCCL_FUNCS3A(func, Min ) + NCCL_FUNCS3A(func, Sum, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, Prod, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, Max, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, Min, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, PreMulSum, /*nullForFloat=*/0), \ + NCCL_FUNCS3A(func, SumPostDiv, /*nullForFloat=*/1) + #define NCCL_FUNCS2B(func) \ + NCCL_FUNCS3B(func, Sum), \ + NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum), \ NCCL_FUNCS3B(func, Sum) -// Must be consistent with ncclFunc_t -#define NCCL_FUNCS() { \ - NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),\ - NCCL_FUNCS2B(Broadcast), \ - NCCL_FUNCS2A(Reduce), \ - NCCL_FUNCS2B(AllGather), \ - NCCL_FUNCS2A(ReduceScatter), \ - NCCL_FUNCS2A(AllReduce) } - // Must be consistent with the ncclFuncSet enum -__device__ ncclKern_t ncclFuncs[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { +__device__ ncclKern_t ncclFuncs[1+ncclNumTypes+NCCL_NUM_FUNCTIONS*ncclNumDevRedOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = { // Don't try to initialize the host shadow copy of this device-side global // variable. There is no host pointer to a device-side function, which // confuses clang. This will be fixed in the next clang release. #if __CUDA_ARCH__ NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), + #if defined(RCCL_BFLOAT16) + NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16), + #endif NCCL_FUNCS2B(Broadcast), NCCL_FUNCS2A(Reduce), NCCL_FUNCS2B(AllGather), diff --git a/projects/rccl/src/collectives/device/gen_rules.sh b/projects/rccl/src/collectives/device/gen_rules.sh index e99dc61465..aaf368523d 100755 --- a/projects/rccl/src/collectives/device/gen_rules.sh +++ b/projects/rccl/src/collectives/device/gen_rules.sh @@ -17,7 +17,7 @@ targets="GENOBJS := \\\\\n" for base in sendrecv all_reduce all_gather broadcast reduce reduce_scatter; do opn=0 - for op in sum prod min max avg; do + for op in sum prod min max premulsum sumpostdiv; do dtn=0 # Order must match that of the ncclDataType_t enum for dt in ${datatypes}; do diff --git a/projects/rccl/src/collectives/device/onerank_reduce.cu b/projects/rccl/src/collectives/device/onerank_reduce.cu new file mode 100644 index 0000000000..89ef528484 --- /dev/null +++ b/projects/rccl/src/collectives/device/onerank_reduce.cu @@ -0,0 +1,61 @@ +/************************************************************************* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#include "devcomm.h" +#include "collectives.h" +#include "reduce_kernel.h" +#include "common.h" + +namespace { + template + __device__ __forceinline__ void oneRankReduce() { + ncclWork *w = &ncclShmem->work; + int tid = threadIdx.x; + int tn = blockDim.x; + #pragma unroll 1 + for(int e=0; e < NCCL_MAX_WORK_ELEMENTS && w->elems[e].active != 0; e++) { + ncclWorkElem *we = &w->elems[e]; + intptr_t eltN = we->coll.count; + int bid = we->coll.bid; + int bn = we->coll.nChannels; + T const *src = (T const*)we->sendbuff; + T *dst = (T*)we->recvbuff; + + // each block/channel gets a roughly equal segment of 16 byte packs + constexpr int EltPerPack = 16/sizeof(T); + intptr_t packN = (eltN + EltPerPack-1) - (eltN + EltPerPack-1)%EltPerPack; + intptr_t i0 = (bid+0)*(packN/bn) + (bid+0 < packN%bn ? bid+0 : packN%bn); + intptr_t i1 = (bid+1)*(packN/bn) + (bid+1 < packN%bn ? bid+1 : packN%bn); + i0 *= EltPerPack; + i0 = i0 < eltN ? i0 : eltN; + i1 *= EltPerPack; + i1 = i1 < eltN ? i1 : eltN; + src += i0; + dst += i0; + ReduceOrCopyMulti + (tid, tn, &(we->coll.redOpArg), true, 1, &src, 1, &dst, i1-i0); + } + } +} + +#define INSTANTIATE(devredop, type) \ + __device__ void NCCL_ONERANK_REDUCE_NAME(devredop, type)(struct ncclWorkElem* args) { \ + oneRankReduce>(); \ + } + +INSTANTIATE(PreMulSum, int8_t) +INSTANTIATE(PreMulSum, uint8_t) +INSTANTIATE(PreMulSum, int32_t) +INSTANTIATE(PreMulSum, uint32_t) +INSTANTIATE(PreMulSum, int64_t) +INSTANTIATE(PreMulSum, uint64_t) +INSTANTIATE(PreMulSum, half) +#if defined(RCCL_BFLOAT16) +INSTANTIATE(PreMulSum, rccl_bfloat16) +#endif +INSTANTIATE(PreMulSum, float) +INSTANTIATE(PreMulSum, double) diff --git a/projects/rccl/src/collectives/device/primitives.h b/projects/rccl/src/collectives/device/primitives.h index 9fa1a427c0..627407ddd0 100644 --- a/projects/rccl/src/collectives/device/primitives.h +++ b/projects/rccl/src/collectives/device/primitives.h @@ -161,13 +161,13 @@ struct PrimitivesWithoutDirect { #define INIT_COUNTER \ if (tid == 0) { t0 = __builtin_amdgcn_s_memrealtime(); } #define ACCUMULATE_COUNTER(prim) \ - if (tid == 0 && args->op.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0); \ + if (tid == 0 && args->coll.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0); \ devProf->elems[blockIdx.x].prim##_byte += nelem * sizeof(T); } #else #define INIT_COUNTER \ if (tid == 0) { t0 = __builtin_amdgcn_s_memrealtime(); ws = devProf->elems[blockIdx.x].wait_cycle; } #define ACCUMULATE_COUNTER(prim) \ - if (tid == 0 && args->op.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0 \ + if (tid == 0 && args->coll.opCount) { devProf->elems[blockIdx.x].prim##_cycle += (__builtin_amdgcn_s_memrealtime() - t0 \ + ws - devProf->elems[blockIdx.x].wait_cycle); \ devProf->elems[blockIdx.x].prim##_byte += nelem * sizeof(T); } #endif diff --git a/projects/rccl/src/collectives/device/prims_ll.h b/projects/rccl/src/collectives/device/prims_ll.h index d58b044725..8b8c5b86aa 100644 --- a/projects/rccl/src/collectives/device/prims_ll.h +++ b/projects/rccl/src/collectives/device/prims_ll.h @@ -284,7 +284,7 @@ class Primitives: } template - __device__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { + __device__ __forceinline__ void LLGenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { constexpr int SRC = SrcBuf != -1 ? 1 : 0; constexpr int DST = DstBuf != -1 ? 1 : 0; T *srcElts = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; @@ -389,9 +389,9 @@ class Primitives: public: __device__ Primitives( const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, int group=0, int connIndex=0 + void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0, int connIndex=0 ): - redOp(FuncTraits().make(ncclShmem->comm.nRanks)), + redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group), stepLines(ncclShmem->comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)), barriers(&ncclShmem->groups[group].barrier), barrier_next(ncclShmem->groups[group].barrier_next) { diff --git a/projects/rccl/src/collectives/device/prims_ll128.h b/projects/rccl/src/collectives/device/prims_ll128.h index 81753a09f5..e1247022f4 100644 --- a/projects/rccl/src/collectives/device/prims_ll128.h +++ b/projects/rccl/src/collectives/device/prims_ll128.h @@ -280,7 +280,7 @@ class Primitives: static constexpr int DataEltPerSlice = (WireWordPerSlice - WireWordPerSlice/NCCL_LL128_LINEELEMS)*(sizeof(uint64_t)/sizeof(T)); template - __device__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { + __device__ __forceinline__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { constexpr int SRC = SrcBuf != -1 ? 1 : 0; constexpr int DST = DstBuf != -1 ? 1 : 0; static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh"); @@ -357,9 +357,9 @@ class Primitives: public: __device__ Primitives( const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, int group=0, int connIndex=0 + void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0, int connIndex=0 ): - redOp(FuncTraits().make(ncclShmem->comm.nRanks)), + redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), flagThread((tid%8)==7), group(group), stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { diff --git a/projects/rccl/src/collectives/device/prims_simple.h b/projects/rccl/src/collectives/device/prims_simple.h index 92f2d29c93..4b4fd227dc 100644 --- a/projects/rccl/src/collectives/device/prims_simple.h +++ b/projects/rccl/src/collectives/device/prims_simple.h @@ -1,5 +1,6 @@ /************************************************************************* * Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -20,14 +21,14 @@ class Primitives< Aborted = 0x40, PtrsFifoEnabled = 0x80, SizesFifoEnabled = 0x100, - DirectEnabled = 0x200, - ThreadsSynced = 0x400; + DirectWrite = 0x200, + DirectRead = 0x400, + ThreadsSynced = 0x800; const int tid; int nthreads; int nworkers; const int stepSize; Fan fan; - RedOp const redOp; int index; // Peer index I'm responsible for int flags; int group; @@ -45,7 +46,6 @@ class Primitives< uint64_t connStepCache; // Cache last seen value of (*connStepPtr) uint64_t* barriers; uint64_t* barrier_next; - const int connIndex; const uint64_t opCount; // Don't use barrier 0 as it's used by the final sync @@ -84,12 +84,15 @@ class Primitives< } template - inline __device__ void waitPeer(intptr_t dstIx, intptr_t remoteOutIx, int offset, int nelts) { - if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { - bool const isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; + __device__ __forceinline__ void waitPeer(intptr_t dstIx, intptr_t remoteIx, int offset, int nelts) { + const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; + const bool noRecvWait = DirectRecv && Src && (flags & DirectRead); // no wait when directly reading from remote input + const bool noSendWait = DirectSend && (flags & (DirectRead|DirectWrite)); // no wait in empty send (e.g. directScatter) or direct remote write #if defined(ENABLE_PROFILING) && !defined(ENABLE_TIMING_PROFILE) - uint64_t t0 = __builtin_amdgcn_s_memrealtime(); + uint64_t t0 = __builtin_amdgcn_s_memrealtime(); #endif + if (((flags & (Recv*RoleWaitRecv)) && !noRecvWait) || + ((flags & (Send*RoleWaitSend)) && !noSendWait)) { int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { __builtin_amdgcn_s_sleep(8); @@ -98,7 +101,9 @@ class Primitives< //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem->comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); } __asm__ __volatile__("s_wakeup"); + } + if (flags & (Recv*RoleWaitRecv | Send*RoleWaitSend)) { if (isSendNotRecv && (flags & SizesFifoEnabled)) STORE(connSizesFifoPtr+step%NCCL_STEPS, nelts*sizeof(T)); @@ -106,10 +111,26 @@ class Primitives< : (ncclShmem->groups[group].srcs + Src); if (flags & PtrsFifoEnabled) loadPtr(connPtrsFifoPtr + step%NCCL_STEPS, ptrs[index]); - else if ((isSendNotRecv ? DirectSend : DirectRecv) && (flags & DirectEnabled)) - ptrs[index] = directBuff + (isSendNotRecv ? remoteOutIx : dstIx) + offset; - else + else if (isSendNotRecv && DirectSend) { + if (flags & DirectWrite) { + ptrs[index] = directBuff + remoteIx + offset; + } else if (flags & DirectRead) { // empty send + ptrs[index] = nullptr; + } else { + ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; + } + } else if (!isSendNotRecv && DirectRecv) { + if (flags & DirectRead) { + ptrs[index] = directBuff + remoteIx + offset; + } else if (flags & DirectWrite) { + ptrs[index] = directBuff + dstIx + offset; // send to next from my output buffer + } else { + ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; + } + } + else { ptrs[index] = connEltsFifo + (step%NCCL_STEPS)*stepSize; + } step += StepPerSlice; #if defined(ENABLE_PROFILING) && !defined(ENABLE_TIMING_PROFILE) if (opCount) { @@ -131,8 +152,8 @@ class Primitives< } template - inline __device__ void genericOp( - intptr_t srcIx, intptr_t dstIx, intptr_t remoteOutIx, int nelem, bool postOp + __device__ __forceinline__ void genericOp( + intptr_t srcIx, intptr_t dstIx, intptr_t remoteIx, int nelem, bool postOp ) { constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectSend = 1 && Direct && DirectSend1; @@ -180,7 +201,7 @@ class Primitives< uint64_t t0; if (tid == 0) t0 = __builtin_amdgcn_s_memrealtime(); #endif - waitPeer(dstIx, remoteOutIx, offset, sliceSize); + waitPeer(dstIx, remoteIx, offset, sliceSize); subBarrier(); #ifdef ENABLE_PROFILING if (tid == 0 && opCount) ncclShmem->comm.devProf->elems[blockIdx.x].wait_cycle += (__builtin_amdgcn_s_memrealtime() - t0); @@ -189,15 +210,24 @@ class Primitives< // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (Send) { // (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0). - ReduceOrCopyMulti - (tid, nworkers, redOp, 0, false, + ReduceOrCopyMulti + (tid, nworkers, nullptr, false, 1, (T const**)ncclShmem->groups[group].srcs, fan.nsend(), (T**)ncclShmem->groups[group].dsts+1, sliceSize); } + } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem->groups[group].dsts[Dst] == nullptr) { + // For broadcast in CollNet to do empty send + ReduceOrCopyMulti + (tid, nworkers, ncclShmem->redOpArgs, postOp, + Recv, (T const**)ncclShmem->groups[group].srcs, + Dst, (T**)ncclShmem->groups[group].dsts, + sliceSize); } else { - ReduceOrCopyMulti - (tid, nworkers, redOp, SrcBuf==Input ? 1 : 0, postOp, + constexpr int PreOpN = SrcBuf != Input ? 0 : + DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; + ReduceOrCopyMulti + (tid, nworkers, ncclShmem->redOpArgs, postOp, Recv*fan.nrecv()+Src, (T const**)ncclShmem->groups[group].srcs, Send*fan.nsend()+Dst, (T**)ncclShmem->groups[group].dsts, sliceSize); @@ -231,10 +261,12 @@ class Primitives< } } - // Scatter and gather do not support Direct - template - inline __device__ void + // Scatter/Gather generic op + template + __device__ __forceinline__ void ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) { + constexpr int DirectRecv = 1 && Direct && DirectRecv1; + constexpr int DirectSend = 1 && Direct && DirectSend1; int offset = 0; // slice offset int sliceSize = stepSize*StepPerSlice; int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size @@ -243,12 +275,14 @@ class Primitives< for (int slice=0; slicegroups[group].srcs[0] = userBuff + inpIx + offset; - if (Recv && (flags & RoleOutput)) ncclShmem->groups[group].dsts[0] = userBuff + outIx + offset; - // realSize is not accurate here; but intra-node does not rely on sizes FIFO - waitPeer<0, 0, Recv, Send, 0, 0>(0, 0, 0, realSize); - subBarrier(); if (Send) { + // Scatter pre-scales data of input buffer only in non-Direct case + constexpr int PreOpN = DirectSend ? 0 : 1; + if (flags & RoleInput) ncclShmem->groups[group].srcs[0] = userBuff + inpIx + offset; + if (tid == 0) ncclShmem->groups[group].totalSendSize[slice] = 0; // Skip the threadfence + // realSize is not accurate here; but intra-node does not rely on sizes FIFO + waitPeer<0, DirectSend, 0, 1, 1, 0>(0, inpIx, offset, realSize); + subBarrier(); #pragma unroll 1 for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; const T* src0 = (T*)ncclShmem->groups[group].srcs[0] + peerOffset; int realPeerSize = min(realSize, totalElem-peerOffset); - if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, redOp, 1, false, 1, &src0, 1, (T**)ncclShmem->groups[group].dsts+i, realPeerSize); + if (realPeerSize > 0 && ncclShmem->groups[group].dsts[i] != nullptr) { + ReduceOrCopyMulti(tid, nworkers, ncclShmem->redOpArgs, false, 1, &src0, 1, (T**)ncclShmem->groups[group].dsts+i, realPeerSize); + if (tid == 0) ncclShmem->groups[group].totalSendSize[slice] += realPeerSize; + } } } else if (Recv) { - #pragma unroll 1 - for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; - T* dst0 = (T*)ncclShmem->groups[group].dsts[0] + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); - if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, redOp, 0, postOp, 1, (T const**)ncclShmem->groups[group].srcs+i, 1, &dst0, realPeerSize); + if (flags & RoleOutput) ncclShmem->groups[group].dsts[0] = userBuff + outIx + offset; + int peerOffset = index*peerElem; + if (skip >= 0 && index >= skip) peerOffset += peerElem; + // Adjust remote index with peer offset in case we are directly pulling from peer's output buffer + waitPeer(outIx, outIx+peerOffset, offset, realSize); + subBarrier(); + if (DirectRecv && ncclShmem->groups[group].srcs[0] == ncclShmem->groups[group].dsts[0]) { + // Since waitPeer sets srcs[0] to output buffer + offset, we are doing a direct-write based recv + // Do nothing + } else { + #pragma unroll 1 + for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; + T* dst0 = (T*)ncclShmem->groups[group].dsts[0] + peerOffset; + int realPeerSize = min(realSize, totalElem-peerOffset); + if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, ncclShmem->redOpArgs, postOp, 1, (const T**)ncclShmem->groups[group].srcs+i, 1, &dst0, realPeerSize); + } } } } barrier(); - //if (Send && (flags & RolePostSend) && realSize > 0 && index == 0) __threadfence_system(); + if (Send && (flags & RolePostSend) && ncclShmem->groups[group].totalSendSize[slice] > 0 && index == 0) + __threadfence_system(); __syncwarp(); postPeer(); offset += realSize; } } - __device__ __forceinline__ void loadRecvConn(ncclPeer *peer) { + __device__ __forceinline__ void loadRecvConn(ncclPeer *peer, int connIndex, struct ncclWorkElem* e) { if (flags & (RoleWaitRecv|RolePostRecv)) { - // For other colls: group <= 2, hence always use conn 0 - // For P2P: Direct is set to 1, hence always use conn 0 - // Ideally we should be accepting connIndex from the constructor! auto *conn = &peer->recv[connIndex].conn; step = conn->step; step = roundUp(step, SlicePerChunk*StepPerSlice); @@ -295,7 +341,25 @@ class Primitives< connStepPtr = conn->tail; connStepCache = LOAD(connStepPtr); flags |= (conn->ptrsFifo != nullptr) ? PtrsFifoEnabled : 0; - flags |= (Direct && (conn->direct & NCCL_DIRECT_GPU)) ? DirectEnabled : 0; + if (Direct) { + // User buffers have been registered + if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { + if (connIndex == 1) { + flags |= DirectRead; // scatter-reduce use direct pull + } else { + flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : + (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; + } + } else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { + if (connIndex == 1) { + flags |= DirectRead; // scatter-reduce use direct pull + } else { + // direct read not allowed in non-register case + // otherwise, in one-to-multi send, we could mix empty send and intermediate send + flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0; + } + } + } if (flags & PtrsFifoEnabled) connPtrsFifoPtr = conn->ptrsFifo; else @@ -304,11 +368,8 @@ class Primitives< } } - __device__ __forceinline__ void loadSendConn(ncclPeer *peer) { + __device__ __forceinline__ void loadSendConn(ncclPeer *peer, int connIndex, struct ncclWorkElem* e) { if (flags & (RoleWaitSend|RolePostSend)) { - // For other colls: group <= 2, hence always use conn 0 - // For P2P: Direct is set to 1, hence always use conn 0 - // Ideally we should be accepting connIndex from the constructor! auto *conn = &peer->send[connIndex].conn; step = conn->step; step = roundUp(step, SlicePerChunk*StepPerSlice); @@ -328,9 +389,25 @@ class Primitives< if (conn->sizesFifo != nullptr) { flags |= SizesFifoEnabled; connSizesFifoPtr = conn->sizesFifo; + } else if (Direct) { + // User buffers have been registered + if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { + if (connIndex == 1) { + flags |= DirectRead; // scatter-reduce use direct pull + } else { + flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : + (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; + } + } else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { + if (connIndex == 1) { + flags |= DirectRead; // scatter-reduce use direct pull + } else { + // direct read not allowed in non-register case + // otherwise, in one-to-multi send, we could mix empty send and intermediate send + flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0; + } + } } - else if (Direct && (conn->direct & NCCL_DIRECT_GPU)) - flags |= DirectEnabled; } } } @@ -338,19 +415,19 @@ class Primitives< public: __device__ Primitives( int tid, int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, int group=0, int connIndex=0 + void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr ): tid(tid), stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)), - redOp(FuncTraits::make(ncclShmem->comm.nRanks)), - connIndex((NCCL_MAX_DIRECT_ARITY==Fan::MaxSend || NCCL_MAX_DIRECT_ARITY==Fan::MaxRecv)?(group/2):connIndex), - barriers(&ncclShmem->groups[group].barrier), barrier_next(ncclShmem->groups[group].barrier_next), - opCount(ncclShmem->work.elems[0].op.opCount) { + opCount(ncclShmem->work.elems[0].coll.opCount) { // For send operations, we need an extra warp to overlap the threadfence and the copy this->nthreads = nthreads; this->nworkers = nthreads; - this->group = group; + this->group = group & (uint16_t)0xFFFF; + int connIndex = group >> 16; + barriers = &ncclShmem->groups[this->group].barrier; + barrier_next = ncclShmem->groups[this->group].barrier_next; int nrecv=0, nsend=0; while (nrecv < MaxRecv && recvPeers[nrecv] != -1) nrecv++; @@ -380,10 +457,10 @@ class Primitives< if (flags & (RoleWaitRecv|RolePostRecv)) peer = recvPeers[index]; if (flags & (RoleWaitSend|RolePostSend)) peer = sendPeers[index]; - loadRecvConn(&ncclShmem->channel.devPeers[peer]); - loadSendConn(&ncclShmem->channel.devPeers[peer]); + loadRecvConn(&ncclShmem->channel.devPeers[peer], connIndex, e); + loadSendConn(&ncclShmem->channel.devPeers[peer], connIndex, e); - setDataPtrs(inputBuf, outputBuf); + setDataPtrs(inputBuf, outputBuf, redOpArg, (struct ncclWorkRegElem*)e); } __device__ ~Primitives() { @@ -400,10 +477,19 @@ class Primitives< barrier(); } - __device__ void setDataPtrs(void const *inputBuf, void *outputBuf) { - if (flags & RoleInput) userBuff = (T*)inputBuf; + __device__ void setDataPtrs(void const *inputBuf, void *outputBuf, uint64_t redOpArg, struct ncclWorkRegElem* e) { + if (flags & RoleInput) { + userBuff = (T*)inputBuf; + ncclShmem->redOpArgs[0] = redOpArg; // scaler for local input + } if (flags & RoleOutput) userBuff = (T*)outputBuf; - if (Direct && flags == (flags|RoleWaitRecv|DirectEnabled)) { + bool recvProvider = flags == (flags|RoleWaitRecv|DirectWrite); + bool sendAcceptor = flags == (flags|RoleWaitSend|DirectWrite); + bool sendProvider = flags == (flags|RoleWaitSend|DirectRead); // sender provides direct buffer (to be fetched) + bool recvAcceptor = flags == (flags|RoleWaitRecv|DirectRead); // receiver accepts direct buffer + int regUsed = e != nullptr ? e->elem.regUsed : 0; + + if (Direct && recvProvider) { int spins = 0; void *volatile *slot = ncclShmem->groups[group].recvConns[index]->ptrExchange; // Wait for consumer to consume previous value before trampling it. @@ -412,9 +498,9 @@ class Primitives< // Encode pointer by XOR'ing against some address they definitely wouldn't send // since we want to allow them sending us nullptr while not colliding with // the empty slot value. - *slot = reinterpret_cast(reinterpret_cast(outputBuf) ^ reinterpret_cast(slot)); + *slot = reinterpret_cast(reinterpret_cast(directBuff) ^ reinterpret_cast(slot)); } - if (Direct && flags == (flags|RoleWaitSend|DirectEnabled)) { + if (Direct && sendAcceptor) { int spins = 0; void *volatile *slot = ncclShmem->groups[group].sendConns[index]->ptrExchange; void *ptr; @@ -422,7 +508,51 @@ class Primitives< ptr = LOAD(slot); if (ptr != nullptr || checkAbort(spins)) break; } - directBuff = reinterpret_cast(reinterpret_cast(ptr) ^ reinterpret_cast(slot)); + directBuff = regUsed ? (T*)(e->dnOutputs[index]) : + reinterpret_cast(reinterpret_cast(ptr) ^ reinterpret_cast(slot)); + *slot = nullptr; + } + if (Direct && sendProvider) { + int spins = 0; + void *volatile *slot = ncclShmem->groups[group].sendConns[index]->ptrExchange; + volatile uint64_t* argSlot0 = ncclShmem->groups[group].sendConns[index]->redOpArgExchange; + volatile uint64_t* argSlot1 = ncclShmem->groups[group].sendConns[index]->redOpArgExchange+1; + // Wait for consumer to consume previous value before trampling it. + while ((*slot != nullptr || *argSlot0 != 0 || *argSlot1 !=0) && !checkAbort(spins)); + // If there is no recv, then we are directly pulling from input buffer (e.g. directScatter) + // Otherwise, we are pulling from output buffer (e.g. recvCopyDirectSend) + directBuff = MaxRecv == 0 ? (T*)inputBuf : (T*)outputBuf; + // Exchange pre-scalers for use in direct pull + *argSlot0 = (uint64_t(1)<<32) | (uint32_t)redOpArg; + *argSlot1 = (uint64_t(1)<<32) | (uint32_t)(redOpArg>>32); + // Encode pointer by XOR'ing against some address they definitely wouldn't send + // since we want to allow them sending us nullptr while not colliding with + // the empty slot value. + *slot = reinterpret_cast(reinterpret_cast(directBuff) ^ reinterpret_cast(slot)); + } + if (Direct && recvAcceptor) { + int spins = 0; + void *volatile *slot = ncclShmem->groups[group].recvConns[index]->ptrExchange; + volatile uint64_t* argSlot0 = ncclShmem->groups[group].recvConns[index]->redOpArgExchange; + volatile uint64_t* argSlot1 = ncclShmem->groups[group].recvConns[index]->redOpArgExchange+1; + void *ptr; + while (true) { + ptr = *slot; + if (ptr != nullptr || checkAbort(spins)) break; + } + directBuff = regUsed ? (T*)(MaxSend == 0 ? e->upOutputs[index] : e->dnInputs[index]) : + reinterpret_cast(reinterpret_cast(ptr) ^ reinterpret_cast(slot)); + if (MaxSend != 0) { // reduce group rather than gather group + // Store scalers for remote inputs + uint64_t arg0, arg1; + while (true) { + arg0 = *argSlot0; + arg1 = *argSlot1; + if ((arg0 != 0 && arg1 != 0) || checkAbort(spins)) break; + } + ncclShmem->redOpArgs[1+index] = ((arg1 & 0xffffffff)<<32) | (arg0 & 0xffffffff); + } + *argSlot0 = 0; *argSlot1 = 0; *slot = nullptr; } } @@ -465,6 +595,9 @@ class Primitives< __device__ __forceinline__ void directRecvCopySend(intptr_t outIx, intptr_t remoteOutIx, int eltN) { genericOp<1, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, false); } + __device__ __forceinline__ void recvCopyDirectSend(intptr_t outIx, intptr_t remoteOutIx, int eltN, bool postOp=false) { + genericOp<0, 1, 1, 1, -1, Output>(-1, outIx, remoteOutIx, eltN, postOp); + } __device__ __forceinline__ void recvReduceCopy(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 0, Input, Output>(inpIx, outIx, -1, eltN, postOp); @@ -473,6 +606,9 @@ class Primitives< __device__ __forceinline__ void recvReduceSend(intptr_t inpIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, Input, -1>(inpIx, -1, -1, eltN, postOp); } + __device__ __forceinline__ void directRecvReduceSend(intptr_t inpIx, intptr_t remoteInpIx, int eltN, bool postOp=false) { + genericOp<1, 0, 1, 1, Input, -1>(inpIx, -1, remoteInpIx, eltN, postOp); + } __device__ __forceinline__ void recvReduceCopySend(intptr_t inpIx, intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, Input, Output>(inpIx, outIx, -1, eltN, postOp); @@ -484,11 +620,19 @@ class Primitives< __device__ __forceinline__ void scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + } + __device__ __forceinline__ void + directScatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { + ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) { - ScatterGatherOp<1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp); + ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp); + } + __device__ __forceinline__ void + directGather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift) { + ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, /*postOp=*/false); } }; diff --git a/projects/rccl/src/collectives/device/reduce.h b/projects/rccl/src/collectives/device/reduce.h index 89b94f6637..db8c235f2b 100644 --- a/projects/rccl/src/collectives/device/reduce.h +++ b/projects/rccl/src/collectives/device/reduce.h @@ -11,7 +11,7 @@ namespace { template - __device__ void runRing(ncclWorkElem *args) { + __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -27,7 +27,7 @@ namespace { const int root = args->coll.root; Primitives, 0, Proto> - prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->coll.redOpArg, args->coll.connIndex << 16); auto calcChunkSize = [&]__device__(ssize_t gridOffset)->int { int realChunkSize; @@ -71,7 +71,7 @@ namespace { template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { using Proto = ProtoSimple; runRing(args); } @@ -79,14 +79,14 @@ struct RunWorkElement struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; diff --git a/projects/rccl/src/collectives/device/reduce_kernel.h b/projects/rccl/src/collectives/device/reduce_kernel.h index 00dd5e4d18..88ba7d788f 100644 --- a/projects/rccl/src/collectives/device/reduce_kernel.h +++ b/projects/rccl/src/collectives/device/reduce_kernel.h @@ -15,6 +15,7 @@ template struct FuncNull { + __device__ FuncNull(uint64_t opArg=0) {} __device__ T operator()(const T x, const T y) const { return 0; } @@ -22,6 +23,7 @@ struct FuncNull { template struct FuncSum { + __device__ FuncSum(uint64_t opArg=0) {} __device__ T operator()(const T x, const T y) const { return x + y; } @@ -29,6 +31,7 @@ struct FuncSum { template struct FuncProd { + __device__ FuncProd(uint64_t opArg=0) {} __device__ T operator()(const T x, const T y) const { return x * y; } @@ -36,6 +39,7 @@ struct FuncProd { template struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} __device__ T operator()(const T x, const T y) const { return (x < y) ? y : x; } @@ -43,6 +47,7 @@ struct FuncMax { template struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} __device__ T operator()(const T x, const T y) const { return (x < y) ? x : y; } @@ -53,7 +58,6 @@ struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max static constexpr bool IsPreOpIdentity = true; static constexpr bool IsPostOpIdentity = true; - __device__ static Fn make(int rankN) { return Fn(); } template __device__ static T preOp(Fn, T x) { return x; } template @@ -75,6 +79,7 @@ static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) { template<> struct FuncSum { + __device__ FuncSum(uint64_t opArg=0) {} __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { #if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) int32_t rv, z=0; @@ -90,6 +95,7 @@ struct FuncSum { }; template<> struct FuncSum { + __device__ FuncSum(uint64_t opArg=0) {} __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { #if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) int32_t rv, z=0; @@ -119,6 +125,7 @@ static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) { template<> struct FuncProd { + __device__ FuncProd(uint64_t opArg=0) {} __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { return mulChar4(x, y); } @@ -128,6 +135,7 @@ struct FuncProd { }; template<> struct FuncProd { + __device__ FuncProd(uint64_t opArg=0) {} __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { return mulChar4(x, y); } @@ -138,6 +146,7 @@ struct FuncProd { template<> struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} union converter { uint32_t storage; char4 a; }; __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { #if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) @@ -161,6 +170,7 @@ struct FuncMax { }; template<> struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} union converter { uint32_t storage; uchar4 a; }; __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { #if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) @@ -185,6 +195,7 @@ struct FuncMax { template<> struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} union converter { uint32_t storage; char4 a; }; __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { #if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) @@ -208,6 +219,7 @@ struct FuncMin { }; template<> struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} union converter { uint32_t storage; uchar4 a; }; __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { #if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) @@ -232,6 +244,7 @@ struct FuncMin { template<> struct FuncSum { + __device__ FuncSum(uint64_t opArg=0) {} __device__ half2 operator()(const half2 x, const half2 y) const { #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 return __hadd2(x, y); @@ -256,18 +269,16 @@ struct FuncSum { #if defined(RCCL_BFLOAT16) template<> struct FuncSum { + __device__ FuncSum(uint64_t opArg=0) {} __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hadd(x, y); -#else - return x + y; -#endif + return (rccl_bfloat16)((float)x + (float)y); } }; #endif template<> struct FuncProd { + __device__ FuncProd(uint64_t opArg=0) {} __device__ half2 operator()(const half2 x, const half2 y) const { #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 return __hmul2(x, y); @@ -292,18 +303,16 @@ struct FuncProd { #if defined(RCCL_BFLOAT16) template<> struct FuncProd { + __device__ FuncProd(uint64_t opArg=0) {} __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmul(x, y); -#else - return x * y; -#endif + return (rccl_bfloat16)((float)x * (float)y); } }; #endif template<> struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} __device__ half2 operator()(const half2 x, const half2 y) const { float2 fx, fy, fr; fx = __half22float2(x); @@ -324,18 +333,16 @@ struct FuncMax { #if defined(RCCL_BFLOAT16) template<> struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmax(x, y); -#else - return x < y ? y : x; -#endif + return (float)x < (float)y ? y : x; } }; #endif template<> struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} __device__ half2 operator()(const half2 x, const half2 y) const { float2 fx, fy, fr; fx = __half22float2(x); @@ -356,24 +363,23 @@ struct FuncMin { #if defined(RCCL_BFLOAT16) template<> struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { -#if __CUDA_ARCH__ >= 800 - return __hmin(x, y); -#else - return x < y ? x : y; -#endif + return (float)x < (float)y ? x : y; } }; #endif template<> struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} __device__ float operator()(float x, float y) const { return fmaxf(x, y); } }; template<> struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} __device__ float operator()(float x, float y) const { return fminf(x, y); } @@ -381,71 +387,98 @@ struct FuncMin { template<> struct FuncMax { + __device__ FuncMax(uint64_t opArg=0) {} __device__ double operator()(double x, double y) const { return fmax(x, y); } }; template<> struct FuncMin { + __device__ FuncMin(uint64_t opArg=0) {} __device__ double operator()(double x, double y) const { return fmin(x, y); } }; template -struct FuncAvg: FuncSum { - static_assert(!std::is_floating_point::value, "Uhoh"); +struct IsFloatingPoint: std::false_type {}; +template<> +struct IsFloatingPoint: std::true_type {}; +#if defined(RCCL_BFLOAT16) +template<> +struct IsFloatingPoint: std::true_type {}; +#endif +template<> +struct IsFloatingPoint: std::true_type {}; +template<> +struct IsFloatingPoint: std::true_type {}; + +template::value> +struct FuncSumPostDiv; + +template +struct FuncSumPostDiv: FuncSum { static constexpr bool IsPreOpIdentity = true; static constexpr bool IsPostOpIdentity = false; int n; + __device__ FuncSumPostDiv(uint64_t opArg): n(opArg) {} + // inherits FuncSum::operator() + __device__ T preOp(T x) const { return x; } + __device__ T postOp(T x) const { return T(x/n); } +}; - template - __device__ FuncAvg(int n): n(n) {} +template +struct FuncSumPostDiv { + static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types."); +}; - __device__ T preOp(T x) const { - return x; - } - __device__ T postOp(T x) const { - return T(x/n); - } +template +struct FuncPreMulSum: FuncSum { // integral T since all floats are specialized below + static constexpr bool IsPreOpIdentity = false; + static constexpr bool IsPostOpIdentity = true; + T scale; + __device__ FuncPreMulSum(uint64_t opArg) { scale = *(T*)&opArg; } + // inherits FuncSum::operator() + __device__ T preOp(T x) const { return x*scale; } + __device__ T postOp(T x) const { return x; } }; template<> -struct FuncAvg: FuncSum { +struct FuncPreMulSum: FuncSum { static constexpr bool IsPreOpIdentity = false; static constexpr bool IsPostOpIdentity = true; - double rcp; - __device__ FuncAvg(int n) { - rcp = __drcp_rn(double(n)); + double scale; + __device__ FuncPreMulSum(uint64_t opArg) { + scale = *(double*)&opArg; } // inherits FuncSum::operator() __device__ double preOp(double x) const { - return IsPreOpIdentity ? x : x*rcp; + return IsPreOpIdentity ? x : x*scale; } __device__ double postOp(double x) const { - return IsPostOpIdentity ? x : x*rcp; + return IsPostOpIdentity ? x : x*scale; } }; template<> -struct FuncAvg: FuncSum { +struct FuncPreMulSum: FuncSum { static constexpr bool IsPreOpIdentity = false; static constexpr bool IsPostOpIdentity = true; - float rcp; - __device__ FuncAvg(int n) { - rcp = __frcp_rn(float(n)); + float scale; + __device__ FuncPreMulSum(uint64_t opArg) { + scale = *(float*)&opArg; } // inherits FuncSum::operator() __device__ float preOp(float x) const { - return IsPreOpIdentity ? x : x*rcp; + return IsPreOpIdentity ? x : x*scale; } __device__ float postOp(float x) const { - return IsPostOpIdentity ? x : x*rcp; + return IsPostOpIdentity ? x : x*scale; } }; template<> -struct FuncAvg: FuncSum { +struct FuncPreMulSum: FuncSum { // Change these to switch between all prescale, all postscale, or both by sqrt(N). // Obviously, the only invalid combination is both true. An improvement would be // make this parameterized as a build time setting and passed here through @@ -455,11 +488,8 @@ struct FuncAvg: FuncSum { #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 half2 scale; - __device__ FuncAvg(int n) { - if (!IsPreOpIdentity && !IsPostOpIdentity) - scale.x = __float2half(__frsqrt_rn(float(n))); - else - scale.x = __float2half(__frcp_rn(float(n))); + __device__ FuncPreMulSum(uint64_t opArg) { + scale.x = *(half*)&opArg; scale.y = scale.x; } // inherits FuncSum::operator() @@ -477,11 +507,8 @@ struct FuncAvg: FuncSum { } #else float scale; - __device__ FuncAvg(int n) { - if (!IsPreOpIdentity && !IsPostOpIdentity) - scale = __frsqrt_rn(float(n)); - else - scale = __frcp_rn(float(n)); + __device__ FuncPreMulSum(uint64_t opArg) { + scale = __half2float(*(half*)&opArg); } // inherits FuncSum::operator() __device__ half preOp(half x) const { @@ -515,64 +542,54 @@ struct FuncAvg: FuncSum { #if defined(RCCL_BFLOAT16) template<> -struct FuncAvg: FuncSum { +struct FuncPreMulSum: FuncSum { // Change these to switch between all prescale, all postscale, or both by sqrt(N). // Obviously, the only invalid combination is both true. An improvement would be // make this parameterized as a build time setting and passed here through // preprocessor definitions. - static constexpr bool IsPreOpIdentity = true; - static constexpr bool IsPostOpIdentity = false; + static constexpr bool IsPreOpIdentity = false; + static constexpr bool IsPostOpIdentity = true; -#if __CUDA_ARCH__ >= 800 - __device__ FuncAvg(int n) { - if (!IsPreOpIdentity && !IsPostOpIdentity) - scale.x = __float2bfloat16(__frsqrt_rn(float(n))); - else - scale.x = __float2bfloat16(__frcp_rn(float(n))); - scale.y = scale.x; - } - // inherits FuncSum::operator() - __device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const { - return IsPreOpIdentity ? x : __hmul(x, scale.x); - } - __device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const { - return IsPostOpIdentity ? x : __hmul(x, scale.x); - } -#else float scale; - __device__ FuncAvg(int n) { - if (!IsPreOpIdentity && !IsPostOpIdentity) - scale = __frsqrt_rn(float(n)); - else - scale = __frcp_rn(float(n)); + __device__ FuncPreMulSum(uint64_t opArg) { + scale = *(rccl_bfloat16*)&opArg; } // inherits FuncSum::operator() __device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const { - return IsPreOpIdentity ? x : (rccl_bfloat16)(x*scale); + return IsPreOpIdentity ? x : (rccl_bfloat16)((float)x*scale); } __device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const { - return IsPostOpIdentity ? x : (rccl_bfloat16)(x*scale); + return IsPostOpIdentity ? x : (rccl_bfloat16)((float)x*scale); } -#endif }; #endif template -struct FuncTraits> { - static constexpr bool IsPreOpIdentity = FuncAvg::IsPreOpIdentity; - static constexpr bool IsPostOpIdentity = FuncAvg::IsPostOpIdentity; +struct FuncTraits> { + static constexpr bool IsPreOpIdentity = FuncPreMulSum::IsPreOpIdentity; + static constexpr bool IsPostOpIdentity = FuncPreMulSum::IsPostOpIdentity; - __device__ static FuncAvg make(int rankN) { - return FuncAvg(rankN); - } template - __device__ static U preOp(FuncAvg fn, U x) { + __device__ static U preOp(FuncPreMulSum fn, U x) { return fn.preOp(x); } template - __device__ static U postOp(FuncAvg fn, U x) { + __device__ static U postOp(FuncPreMulSum fn, U x) { return fn.postOp(x); } }; +template +struct FuncTraits> { + static constexpr bool IsPreOpIdentity = FuncSumPostDiv::IsPreOpIdentity; + static constexpr bool IsPostOpIdentity = FuncSumPostDiv::IsPostOpIdentity; + template + __device__ static U preOp(FuncSumPostDiv fn, U x) { + return fn.preOp(x); + } + template + __device__ static U postOp(FuncSumPostDiv fn, U x) { + return fn.postOp(x); + } +}; #endif // REDUCE_KERNEL_H_ diff --git a/projects/rccl/src/collectives/device/reduce_scatter.h b/projects/rccl/src/collectives/device/reduce_scatter.h index c7172f7a90..a4d773c612 100644 --- a/projects/rccl/src/collectives/device/reduce_scatter.h +++ b/projects/rccl/src/collectives/device/reduce_scatter.h @@ -11,7 +11,7 @@ namespace { template - __device__ void runRing(ncclWorkElem *args) { + __device__ __forceinline__ void runRing(ncclWorkElem *args) { const int tid = threadIdx.x; const int nthreads = args->nThreads; const int bid = args->coll.bid; @@ -26,7 +26,7 @@ namespace { const ssize_t size = args->coll.count; Primitives, 0, Proto> - prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, 0, args->coll.connIndex); + prims(tid, nthreads, &ring->prev, &ring->next, args->sendbuff, args->recvbuff, args->coll.redOpArg, args->coll.connIndex << 16); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { ssize_t realChunkSize; @@ -69,7 +69,7 @@ namespace { template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { using Proto = ProtoSimple; runRing(args); } @@ -77,14 +77,14 @@ struct RunWorkElement struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; template struct RunWorkElement { - __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { + __device__ __forceinline__ void run(ncclWorkElem *args) { runRing(args); } }; diff --git a/projects/rccl/src/collectives/device/sendrecv.h b/projects/rccl/src/collectives/device/sendrecv.h index f717297df4..caad7950c1 100644 --- a/projects/rccl/src/collectives/device/sendrecv.h +++ b/projects/rccl/src/collectives/device/sendrecv.h @@ -11,7 +11,7 @@ template struct RunWork { - __device__ __attribute__((noinline)) void run(ncclWork *work) { + __device__ __forceinline__ void run(ncclWork *work) { int tid = threadIdx.x; int group = 0; const int rank = ncclShmem->comm.rank; @@ -39,16 +39,7 @@ struct RunWork { if (delta == 0) { if (sendbuff != recvbuff) { - // local copy : ReduceOrCopyMulti takes an int as number of elements, - // so we split it in blocks of 1G elements. - int blockSize = 1<<30; - for (size_t offset=0; offset(tid, nThreadsSegment, RedOp(), 0, false, 1, &sendbuff, 1, &recvbuff, blockSize); - sendbuff += blockSize; - recvbuff += blockSize; - } + ReduceOrCopyMulti(tid, nThreadsSegment, nullptr, false, 1, &sendbuff, 1, &recvbuff, sendCount); } } else { @@ -58,7 +49,7 @@ struct RunWork { int const nt = nThreadsSplit; int const chunkSize = args->p2p.recvChunkSize/sizeof(T); Primitives, 0, Proto> prims - (tid-t0, nt, &peer, nullptr, nullptr, recvbuff, groupRecv, args->p2p.recvIdx); + (tid-t0, nt, &peer, nullptr, nullptr, recvbuff, /*redOpArg(ignored)=*/0, groupRecv | (args->p2p.recvIdx << 16)); ssize_t offset = 0; do { int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T))); @@ -74,7 +65,7 @@ struct RunWork { int const nt = nThreadsSegment - nThreadsSplit; int const chunkSize = args->p2p.sendChunkSize/sizeof(T); Primitives, 0, Proto> prims - (tid-t0, nt, nullptr, &peer, sendbuff, nullptr, groupSend, args->p2p.sendIdx); + (tid-t0, nt, nullptr, &peer, sendbuff, nullptr, /*redOpArg(ignored)=*/0, groupSend | (args->p2p.sendIdx << 16)); ssize_t offset = 0; do { int nelem = roundUp(chunkSize, nt*(sizeof(uint64_t)/sizeof(T))); diff --git a/projects/rccl/src/debug.cc b/projects/rccl/src/debug.cc index bf5e6ddf60..321d8081e3 100644 --- a/projects/rccl/src/debug.cc +++ b/projects/rccl/src/debug.cc @@ -139,7 +139,7 @@ void ncclDebugLog(ncclDebugLogLevel level, unsigned long flags, const char *file int cudaDev; hipGetDevice(&cudaDev); int pid = getpid(); - int tid = gettid(); + int tid = syscall(SYS_gettid); char buffer[1024]; size_t len = 0; diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index f8c5130af5..15915fe19e 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -12,81 +12,61 @@ #include #include #include "gdrwrap.h" +#include "bootstrap.h" +#include + +#include // std::memcpy // Only generate inline kernels for LL -#define NCCL_FUNC5(func, algo, redop, dtype) \ - NCCL_KERN_NAME(func, algo, LL, redop, dtype), \ - NCCL_KERN_NAME(func, algo, LL, redop, dtype), \ - NCCL_KERN_NAME(func, algo, LL, redop, dtype) +#define NCCL_FUNC5(func, algo, devredop, dtype) \ + (void*)NCCL_KERN_NAME(func, algo, LL, devredop, dtype), \ + (void*)NCCL_KERN_NAME(func, algo, LL, devredop, dtype), \ + (void*)NCCL_KERN_NAME(func, algo, LL, devredop, dtype) -#define NCCL_FUNC4(func, redop, type) \ - NCCL_FUNC5(func, TREE, redop, type), \ - NCCL_FUNC5(func, RING, redop, type), \ - NCCL_FUNC5(func, COLLNET, redop, type) +#define NCCL_FUNC4(func, devredop, type) \ + (void*)NCCL_FUNC5(func, TREE, devredop, type), \ + (void*)NCCL_FUNC5(func, RING, devredop, type), \ + (void*)NCCL_FUNC5(func, COLLNET, devredop, type) -#if defined(__CUDA_BF16_TYPES_EXIST__) // Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, redop) \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, uint8_t), \ - (void*)NCCL_FUNC4(func, redop, int32_t), \ - (void*)NCCL_FUNC4(func, redop, uint32_t), \ - (void*)NCCL_FUNC4(func, redop, int64_t), \ - (void*)NCCL_FUNC4(func, redop, uint64_t), \ - (void*)NCCL_FUNC4(func, redop, half), \ - (void*)NCCL_FUNC4(func, redop, float), \ - (void*)NCCL_FUNC4(func, redop, double), \ - (void*)NCCL_FUNC4(func, redop, __nv_bfloat16) -#define NCCL_FUNCS3B(func, redop) \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t), \ - (void*)NCCL_FUNC4(func, redop, int8_t) -#else -// Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A(func, redop) \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, uint8_t), \ - NCCL_FUNC4(func, redop, int32_t), \ - NCCL_FUNC4(func, redop, uint32_t), \ - NCCL_FUNC4(func, redop, int64_t), \ - NCCL_FUNC4(func, redop, uint64_t), \ - NCCL_FUNC4(func, redop, half), \ - NCCL_FUNC4(func, redop, float), \ - NCCL_FUNC4(func, redop, double), \ - NCCL_FUNC4(func, redop, rccl_bfloat16) -#define NCCL_FUNCS3B(func, redop) \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t), \ - NCCL_FUNC4(func, redop, int8_t) -#endif +#define NCCL_FUNCS3A(func, devredop) \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, uint8_t), \ + (void*)NCCL_FUNC4(func, devredop, int32_t), \ + (void*)NCCL_FUNC4(func, devredop, uint32_t), \ + (void*)NCCL_FUNC4(func, devredop, int64_t), \ + (void*)NCCL_FUNC4(func, devredop, uint64_t), \ + (void*)NCCL_FUNC4(func, devredop, half), \ + (void*)NCCL_FUNC4(func, devredop, float), \ + (void*)NCCL_FUNC4(func, devredop, double), \ + (void*)NCCL_FUNC4(func, devredop, rccl_bfloat16) +#define NCCL_FUNCS3B(func, devredop) \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t), \ + (void*)NCCL_FUNC4(func, devredop, int8_t) -// Must be consistent with ncclRedOp_t -- but we only generate kernel for sums. +// Must be consistent with ncclDevRedOp_t -- but we only generate kernel for sums. #define NCCL_FUNCS2A(func) \ - NCCL_FUNCS3A(func, Sum), \ - NCCL_FUNCS3A(func, Sum), \ - NCCL_FUNCS3A(func, Sum), \ - NCCL_FUNCS3A(func, Sum), \ - NCCL_FUNCS3A(func, Sum) + NCCL_FUNCS3A(func, Sum), /*Sum*/ \ + NCCL_FUNCS3A(func, Sum), /*Prod*/ \ + NCCL_FUNCS3A(func, Sum), /*Max*/ \ + NCCL_FUNCS3A(func, Sum), /*Min*/ \ + NCCL_FUNCS3A(func, Sum), /*PreMulSum*/ \ + NCCL_FUNCS3A(func, Sum) /*SumPostDiv*/ #define NCCL_FUNCS2B(func) \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3A(func, Sum), \ - NCCL_FUNCS3B(func, Sum), \ - NCCL_FUNCS3B(func, Sum) + NCCL_FUNCS3B(func, Sum), /*Sum*/ \ + NCCL_FUNCS3B(func, Sum), /*Prod*/ \ + NCCL_FUNCS3B(func, Sum), /*Max*/ \ + NCCL_FUNCS3B(func, Sum), /*Min*/ \ + NCCL_FUNCS3B(func, Sum), /*PreMulSum*/ \ + NCCL_FUNCS3B(func, Sum) /*SumPostDiv*/ typedef void(*ncclKern_t)(struct ncclWorkElem first); // Must be consistent with the ncclFuncSet enum @@ -145,7 +125,6 @@ static ncclResult_t getNextOp(struct ncclChannel* channel, struct ncclWork** wor // Initialize with work elem if provided if (base) memcpy(e, base, sizeof(struct ncclWorkElem)); e->active = 1; - e->index = opIndex; channel->workFifoTail++; channel->workCount++; if (work) *work = w; @@ -183,10 +162,11 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph if (c == 0) { // As we inline the first coll directly, we can free it immediately. - // Except P2P or aggregation cases + // Except P2P or aggregation or registration cases struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS); struct ncclWorkElem* elem = work->elems; - if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1) elem->active = 0; + if (elem->funcIndex != FUNC_INDEX_P2P && eqInfo->elemList->count() == 1 && elem->regUsed == 0) + elem->active = 0; } if (channel->gdrMemDesc) { @@ -371,7 +351,7 @@ RCCL_PARAM(SharpThreshold, "SHARP_THRESHOLD", 16384); static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNetTypeSupport) { if (info->comm->collNetSupport > 0 && info->nBytes < rcclParamSharpThreshold()) { - ncclRedOp_t netOp = info->op == ncclAvg ? ncclSum : info->op; + ncclRedOp_t netOp = info->op == ncclAvg || info->op >= ncclNumOps ? ncclSum : info->op; NCCLCHECK(collNetReduceSupport(info->datatype, netOp, collNetTypeSupport)); } else { *collNetTypeSupport = 0; @@ -381,30 +361,35 @@ static inline ncclResult_t getCollNetSupport(struct ncclInfo* info, int* collNet static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, int numPipeOps) { struct ncclComm* comm = info->comm; - float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete. - // Find algorithm / protocol. - info->algorithm = -1; - info->protocol = -1; - if (comm->nRanks == 1) return ncclSuccess; - int nAlgos = NCCL_NUM_ALGORITHMS; - for (int a=0; a= 0 && time < minTime) { - info->algorithm = a; - info->protocol = p; - minTime = time; + if (comm->nRanks == 1) { + info->algorithm = NCCL_ALGO_RING; + info->protocol = NCCL_PROTO_SIMPLE; + } + else { + float minTime = 3600000000.0; // Hopefully no operation will take an hour to complete. + // Find algorithm / protocol. + info->algorithm = -1; + info->protocol = -1; + int nAlgos = NCCL_NUM_ALGORITHMS; + for (int a=0; a= 0 && time < minTime) { + info->algorithm = a; + info->protocol = p; + minTime = time; + } } } + if (info->algorithm == -1 || info->protocol == -1) { + WARN("Error : no algorithm/protocol available"); + return ncclInternalError; + } + //if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); + TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); } - if (info->algorithm == -1 || info->protocol == -1) { - WARN("Error : no algorithm/protocol available"); - return ncclInternalError; - } - //if (comm->rank == 0) INFO(NCCL_TUNING, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); - TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", info->nBytes, info->algorithm, info->protocol, minTime); int nc = (info->nChannels > 0) ? info->nChannels : comm->nChannels; int nt = comm->maxThreads[info->algorithm][info->protocol]; @@ -498,15 +483,23 @@ comp_next: NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); - work->op.opCount = info->comm->collOpCount; + work->coll.opCount = info->comm->collOpCount; work->sendbuff = info->sendbuff; work->recvbuff = info->recvbuff; work->coll.root = info->root; work->coll.count = info->count; work->coll.nChannels = info->nChannels; work->nThreads = info->nThreads; + work->coll.redOpArg = info->opFull.scalarArg; + work->redOpArgIsPtr = info->opFull.scalarArgIsPtr; - work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol); + if (info->comm->nRanks == 1) { + // one-rank reduce index + work->funcIndex = FUNC_INDEX_P2P - ncclNumTypes + int(info->datatype); + return ncclSuccess; + } + + work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); work->coll.connIndex = 0; proxyArgs->connIndex = 0; @@ -533,7 +526,7 @@ comp_next: info->comm->nChannels, &work->clique.nChannels)); work->clique.count = info->count; - work->funcIndex = FUNC_INDEX(info->coll, info->op, info->datatype, info->algorithm, info->protocol); + work->funcIndex = FUNC_INDEX(info->coll, info->opFull.op, info->datatype, info->algorithm, info->protocol); // Setup pointers to where all the input/output pointers will be NCCLCHECK(info->comm->cliqueManager->WaitForPointers(work)); @@ -563,6 +556,8 @@ comp_next: while (info->nBytes / (info->nChannels*info->comm->channels[0].collTree.nHeads*chunkSize) < info->comm->channels[0].collTree.depth*8 && chunkSize > 32768) chunkSize /= 2; // Use lastChunkSize as chunkSize work->coll.lastChunkSize = chunkSize / ncclTypeSize(info->datatype); + // Set direct direction for broadcast-gather (read or write) + work->direct = (info->nBytes / info->nChannels <= 1024*1024) ? NCCL_DIRECT_WRITE : NCCL_DIRECT_READ; } else if (info->protocol == NCCL_PROTO_LL) { const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; @@ -592,7 +587,7 @@ comp_next: proxyArgs->protocol = info->protocol; proxyArgs->dtype = info->datatype; proxyArgs->redOp = info->algorithm != NCCL_ALGO_COLLNET ? ncclNumOps : // Only set redOp when using CollNet - info->op == ncclAvg ? ncclSum : // Network sees avg as sum + info->opFull.op==ncclDevPreMulSum || info->opFull.op==ncclDevSumPostDiv ? ncclSum : // Network sees avg as sum info->op; proxyArgs->pattern = info->pattern; proxyArgs->root = info->root; @@ -618,12 +613,61 @@ static ncclResult_t checkSetStream(struct ncclInfo* info) { return ncclSuccess; } +struct ncclBuffRegHandle { + hipIpcMemHandle_t sendBuffIpc; + hipIpcMemHandle_t recvBuffIpc; + ssize_t sendBuffOffset; + ssize_t recvBuffOffset; +}; + +// Register input and output buffers +// Exchange with ranks on the same host +static ncclResult_t ncclRegBuffAndExchange(struct ncclInfo* info, struct ncclBuffRegInfo* regInfo) { + ncclComm_t comm = info->comm; + if (comm->localRanks == 1) return ncclSuccess; + if (comm->pfnCuMemGetAddressRange == NULL) return ncclSuccess; // CUDA toolkit or driver version too old + + struct ncclBuffRegHandle regHandles[NCCL_MAX_INTRA_RANKS]; + // Get IPC handles + // Note: the handle only corresponds to the base address of the allocation + CUDACHECK(hipIpcGetMemHandle(®Handles[comm->intraNodeRank].sendBuffIpc, (void*)info->sendbuff)); + CUDACHECK(hipIpcGetMemHandle(®Handles[comm->intraNodeRank].recvBuffIpc, (void*)info->recvbuff)); + // Get offset of user buffer within allocation + void* baseAddr; + size_t size; + CUDACHECK(comm->pfnCuMemGetAddressRange(&baseAddr, &size, (void*)info->sendbuff)); + regHandles[comm->intraNodeRank].sendBuffOffset = (char*)info->sendbuff - (char*)baseAddr; + CUDACHECK(comm->pfnCuMemGetAddressRange(&baseAddr, &size, (void*)info->recvbuff)); + regHandles[comm->intraNodeRank].recvBuffOffset = (char*)info->recvbuff - (char*)baseAddr; + TRACE(NCCL_COLL, "Base %p size %lu offset %ld", baseAddr, size, regHandles[comm->intraNodeRank].recvBuffOffset); + + // Exchange handles within node + NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->intraNodeGlobalRanks, comm->intraNodeRank, comm->localRanks, regHandles, sizeof(struct ncclBuffRegHandle))); + // Open handles at local process + for (int i=0; ilocalRanks; i++) { + if (i == comm->intraNodeRank) { + regInfo->sendbuffsBase[i] = regInfo->recvbuffsBase[i] = NULL; + continue; + } + CUDACHECK(hipIpcOpenMemHandle(regInfo->sendbuffsBase+i, regHandles[i].sendBuffIpc, hipIpcMemLazyEnablePeerAccess)); + CUDACHECK(hipIpcOpenMemHandle(regInfo->recvbuffsBase+i, regHandles[i].recvBuffIpc, hipIpcMemLazyEnablePeerAccess)); + // Get real address of buffer + regInfo->sendbuffs[i] = (char*)regInfo->sendbuffsBase[i] + regHandles[i].sendBuffOffset; + regInfo->recvbuffs[i] = (char*)regInfo->recvbuffsBase[i] + regHandles[i].recvBuffOffset; + } + regInfo->nBuffs = comm->localRanks; + TRACE(NCCL_COLL, "Rank %d exchanged %d buffers", comm->rank, regInfo->nBuffs); + return ncclSuccess; +} + // Compute enqueue element, save it in list // Compute CUDA launch parameters // Capture time code in view of CUDA graph static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) { ncclComm_t comm = info->comm; - if (comm->nRanks == 1) { + if (comm->nRanks == 1 && + // User-defined reduction ops may need alter the data even for unitary reductions + info->op < ncclNumOps) { if (info->sendbuff != info->recvbuff) CUDACHECK(hipMemcpyAsync(info->recvbuff, info->sendbuff, info->nBytes, hipMemcpyDeviceToDevice, info->stream)); return ncclSuccess; @@ -651,6 +695,19 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) { comm->args.active = 2; // I am so far the last element; may be changed later in aggregation mode } + // Register and exchange input and output buffers + if (comm->usingCudaGraph && // only in CUDA graph mode + comm->graphRegister == 1 && // when registration is enabled + info->algorithm == NCCL_ALGO_COLLNET && // limited to CollNet for now + comm->intraHighestTransportType == TRANSPORT_P2P && // only when all ranks can p2p each other + comm->intraRanks == 1) { // only in multi-process mode + NCCLCHECK(ncclRegBuffAndExchange(info, &eqElem->buffRegInfo)); + // Disable inline argument because we need kernel to copy the entire ncclWork from workFifo + // because the registered addresses are in ncclWork + if (eqElem->buffRegInfo.nBuffs > 0) comm->args.active = 0; + comm->enqueueInfo->nRegBuffs += eqElem->buffRegInfo.nBuffs; + } + return ncclSuccess; } @@ -667,41 +724,15 @@ static inline int findShortestChannel(ncclComm_t comm) { return minC; } -static inline ncclResult_t getNextChannel(ncclComm_t comm, int* nextChannel) { - if (comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) { - *nextChannel = findShortestChannel(comm); +static inline int getNextChannel(ncclComm_t comm, int aggMode) { + int nextChannel = 0; + if (aggMode && comm->asyncAllocMode == ncclComm::SHORTEST_QUEUE) { + nextChannel = findShortestChannel(comm); } else { - *nextChannel = comm->lastChannel % comm->nChannels; + nextChannel = comm->lastChannel % comm->nChannels; comm->lastChannel++; } - return ncclSuccess; -} - -// Dynamic enqueue code -static ncclResult_t ncclEnqueueCollKernel(ncclComm_t comm, struct ncclQueueElem* eqElem) { - struct ncclWorkElem* work = &eqElem->work; - struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs; - - int nChannels = work->coll.nChannels; - for (int bid=0; bidlastChannel % comm->nChannels; - struct ncclChannel* channel = comm->channels+channelId; - - // Proxy - proxyArgs->subs[0].channel = channel; - proxyArgs->opCount = comm->collOpCount; - proxyArgs->commOpCount = comm->opCount; - - if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks)); - - comm->lastChannel++; - work->coll.bid = bid % nChannels; - NCCLCHECK(getNextOp(channel, NULL, work)); - //INFO(NCCL_COLL, "Host enqueue: bid %d channel %d index %ld nThreads %d funcIndex %d count %ld nChannels %d", - // work->coll.bid, channelId, channel->workFifoTail, work->nThreads, work->funcIndex, work->coll.count, work->coll.nChannels); - } - comm->collOpCount++; - return ncclSuccess; + return nextChannel; } ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { @@ -733,7 +764,7 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { channelUsed += info->nChannels; // We can use fast path if all collectives are the same homogeneous &= info->coll == comm->asyncOps[0].coll && - info->op == comm->asyncOps[0].op && + info->opFull.op == comm->asyncOps[0].opFull.op && info->datatype == comm->asyncOps[0].datatype; if (allCollNetSupport > 0) NCCLCHECK(getCollNetSupport(info, &allCollNetSupport)); } @@ -818,13 +849,22 @@ static ncclResult_t ncclSaveP2p(struct ncclInfo* info) { return ncclSuccess; } -enum { COLL_SEGMENT=0, P2P_SEGMENT=1 }; +enum { RingTree_Segment=0, P2P_Segment=1, CollNet_Segment=2 }; static int getSegment(int type, int delta, struct ncclWork* work) { - if (type == P2P_SEGMENT) { // P2P + // Current ncclWork is full + if (work->elems[NCCL_MAX_WORK_ELEMENTS-1].active != 0) return -1; + + if (type == P2P_Segment) { // P2P + // Do not mix P2P and collective ops + if (work->elems[0].funcIndex != FUNC_INDEX_P2P) return -1; for (int s=0; selems[s].p2p.delta != delta; s++) { if (work->elems[s].active == 0) return s; } - } else { // aggregation + } else if (type == CollNet_Segment) { // CollNet + for (int s=0; selems[s].active == 0) return s; + } + } else { // Ring or Tree for (int s=0; selems[s].active == 0) return s; } @@ -838,7 +878,7 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct elem->nThreads = NCCL_MAX_NTHREADS; elem->sendbuff = info->sendbuff; elem->recvbuff = info->recvbuff; - elem->op.opCount = info->comm->p2pOpCount; + elem->p2p.opCount = info->comm->p2pOpCount; elem->p2p.sendCount = info->sendbytes; elem->p2p.recvCount = info->recvbytes; elem->p2p.sendChunkSize = info->sendChunkSize; @@ -847,13 +887,14 @@ static ncclResult_t computeP2pWorkElem(struct ncclInfo* info /* input */, struct return ncclSuccess; } -static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s) { +static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */, struct ncclWork* work, int s, + struct ncclBuffRegInfo* regInfo, struct ncclChannel* channel, struct ncclComm* comm) { // Copy element into corresponding segment of ncclWork memcpy(work->elems+s, elem, sizeof(struct ncclWorkElem)); work->elems[s].active = 1; // Determine nThreads at dynamic time - if (type == P2P_SEGMENT) { + if (type == P2P_Segment) { const int nsegments = s+1; int nThreads = 512; while (nsegments*nThreads > 256) nThreads /= 2; @@ -861,6 +902,33 @@ static ncclResult_t enqueueSegOp(int type, struct ncclWorkElem* elem /* input */ for (int i=0; ielems[i].p2p.nThreads = nThreads; } + // Copy registered buffer addresses into ncclWork + if (regInfo->nBuffs > 0) { + struct ncclWorkRegElem* regElem = (struct ncclWorkRegElem*)(work->elems+s); + // For CollNet + for (int i=0; icollTree.down[i]; + if (peer == -1) break; + int j = comm->rankToIntraNodeRank[peer]; + if (j < 0) { + WARN("Invalid intra-node rank %d for peer %d", j, peer); + return ncclInternalError; + } + regElem->dnInputs[i] = regInfo->sendbuffs[j]; + regElem->dnOutputs[i] = regInfo->recvbuffs[j]; + } + for (int i=0; icollTree.up[i]; + if (peer == -1) break; + int j = comm->rankToIntraNodeRank[peer]; + if (j < 0) { + WARN("Invalid intra-node rank %d for peer %d", j, peer); + return ncclInternalError; + } + regElem->upOutputs[i] = regInfo->recvbuffs[j]; + } + work->elems[s].regUsed = 1; + } return ncclSuccess; } @@ -873,9 +941,9 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; struct ncclWork* w = channel->workFifo+opIndex; int segment = -1; - if (channel->workCount && w->elems[0].funcIndex == FUNC_INDEX_P2P && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0) { + if (channel->workCount) { // Try to pack more segments into a single operation - segment = getSegment(P2P_SEGMENT, workElem->p2p.delta, w); + segment = getSegment(P2P_Segment, workElem->p2p.delta, w); } if (segment == -1) { NCCLCHECK(getNextOp(channel, &w, NULL)); @@ -884,7 +952,7 @@ ncclResult_t ncclEnqueueP2pKernel(struct ncclComm* comm, struct ncclQueueElem* e // store work element into FIFO NCCLCHECK(ncclProxySaveP2p(comm, proxyArgs)); - NCCLCHECK(enqueueSegOp(P2P_SEGMENT, workElem, w, segment)); + NCCLCHECK(enqueueSegOp(P2P_Segment, workElem, w, segment, &eqElem->buffRegInfo, channel, comm)); comm->p2pOpCount++; return ncclSuccess; } @@ -920,15 +988,18 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) { return ncclSuccess; } -ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem) { +// Dynamic enqueue function for collective kernels +// Supports both aggregated and non-aggregated modes +ncclResult_t ncclEnqueueCollKernel(struct ncclComm* comm, struct ncclQueueElem* eqElem, int aggMode) { struct ncclWorkElem* work = &eqElem->work; struct ncclProxyArgs* proxyArgs = &eqElem->proxyArgs; int nChannels = work->coll.nChannels; size_t channelSize = work->coll.count*ncclTypeSize(proxyArgs->dtype)/work->coll.nChannels; + int segmentType = proxyArgs->redOp == ncclNumOps ? RingTree_Segment : CollNet_Segment; // redOp is only set when using CollNet + for (int bid=0; bidchannels+channelId; // Proxy @@ -937,18 +1008,19 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* proxyArgs->commOpCount = comm->opCount; if (proxyArgs->subs[0].nsteps) NCCLCHECK(ncclProxySaveColl(proxyArgs, comm->nRanks)); - // Try to reuse last work if not full yet work->coll.bid = bid % nChannels; - int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; - struct ncclWork* w = channel->workFifo+opIndex; + struct ncclWork* w = NULL; int segment = -1; - if (channel->workCount && w->elems[NCCL_MAX_WORK_ELEMENTS-1].active == 0 && - // All elems in work must have same (funcIndex,nThreads), - // see "src/collectives/device/common.h" - w->elems[0].funcIndex == work->funcIndex && - w->elems[0].nThreads == work->nThreads) { + if (aggMode && channel->workCount) { // Try to pack more segments into a single operation - segment = getSegment(COLL_SEGMENT, 0, w); + int opIndex = (channel->workFifoTail-1+NCCL_MAX_OPS)%NCCL_MAX_OPS; + w = channel->workFifo+opIndex; + // All elems in work must have same (funcIndex,nThreads), + // see "src/collectives/device/common.h" + if (w->elems[0].funcIndex == work->funcIndex && + w->elems[0].nThreads == work->nThreads) { + segment = getSegment(segmentType, 0, w); + } } if (segment == -1) { NCCLCHECK(getNextOp(channel, &w, NULL)); @@ -956,7 +1028,7 @@ ncclResult_t ncclEnqueueAsyncKernel(struct ncclComm* comm, struct ncclQueueElem* } // store work element into FIFO - NCCLCHECK(enqueueSegOp(COLL_SEGMENT, work, w, segment)); + NCCLCHECK(enqueueSegOp(segmentType, work, w, segment, &eqElem->buffRegInfo, channel, comm)); channel->totalSize += channelSize; } comm->collOpCount++; @@ -968,17 +1040,15 @@ void HIPRT_CB ncclEnqueueHostSetup(void* arg) { ncclResult_t ret; struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)arg; ncclComm_t comm = eqInfo->comm; + int aggMode = eqInfo->elemList->count() > 1 ? 1 : 0; // Iterate through the element list struct ncclQueueElem* eqElem = eqInfo->elemList->begin(); while (eqElem != NULL) { if (eqElem->work.funcIndex == FUNC_INDEX_P2P) { NCCLCHECKGOTO(ncclEnqueueP2pKernel(comm, eqElem), ret, cb_end); - } else if (eqInfo->elemList->count() > 1) { - // We have more than one operation, hence aggregating - NCCLCHECKGOTO(ncclEnqueueAsyncKernel(comm, eqElem), ret, cb_end); } else { - NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem), ret, cb_end); + NCCLCHECKGOTO(ncclEnqueueCollKernel(comm, eqElem, aggMode), ret, cb_end); } eqElem = eqInfo->elemList->getNext(); } @@ -996,51 +1066,95 @@ cb_end: template void HIPRT_CB ncclEnqueueHostSetup<0>(void*); template void HIPRT_CB ncclEnqueueHostSetup<1>(void*); +void* graphHelperFunc(void *args) { + struct ncclGraphHelperResources* res = (struct ncclGraphHelperResources*)args; + if (res == NULL) { + WARN("CUDA Graph helper resource is null"); + return NULL; + } + int dev = res->comm->cudaDev; + CUDACHECKIGNORE(hipSetDevice(dev)); + INFO(NCCL_COLL, "CUDA Graph helper thread created for device %d", dev); + + volatile enum helperThreadState* state = &res->threadState; + volatile int* ipcTail = &res->ipcTail; + while (1) { + int ipcTailMark = *ipcTail; + int ipcCount = 0; + while (res->ipcHead != ipcTailMark) { + if (res->ipcBases[res->ipcHead] != NULL) + CUDACHECKIGNORE(hipIpcCloseMemHandle(res->ipcBases[res->ipcHead])); + res->ipcBases[res->ipcHead] = NULL; + res->ipcHead = (res->ipcHead+1)%NCCL_IPC_POOL_SIZE; + ipcCount++; + } + TRACE(NCCL_COLL, "CUDA Graph helper thread closed %d IPC handles", ipcCount); + pthread_mutex_lock(&res->threadLock); + while (res->ipcHead == *ipcTail && *state != ThreadStop) { + pthread_cond_wait(&res->threadCond, &res->threadLock); + } + pthread_mutex_unlock(&res->threadLock); + if (*state == ThreadStop) { + INFO(NCCL_COLL, "CUDA Graph helper thread for device %d returning", dev); + return NULL; + } + } +} + ncclResult_t ncclGetCudaGraph(ncclComm_t comm, cudaGraph_t* graph) { comm->usingCudaGraph = 0; #if CUDART_VERSION >= 11030 - cudaStreamCaptureStatus captureStatus; - unsigned long long cudaGraphId; + hipStreamCaptureStatus captureStatus; + unsigned long long hipGraphId; if (comm->driverVersion < 11030) { - CUDACHECK(cudaStreamIsCapturing(comm->userStream, &captureStatus)); - if (captureStatus != cudaStreamCaptureStatusNone) { + CUDACHECK(hipStreamIsCapturing(comm->userStream, &captureStatus)); + if (captureStatus != hipStreamCaptureStatusNone) { WARN("The installed CUDA driver is older than the minimum version (R465) required for NCCL's CUDA Graphs support"); return ncclInvalidUsage; } return ncclSuccess; } - CUDACHECK(cudaStreamGetCaptureInfo_v2(comm->userStream, &captureStatus, &cudaGraphId, graph, NULL, NULL)); - if (captureStatus == cudaStreamCaptureStatusActive) { - if (cudaGraphId != comm->lastCudaGraphId) { - INFO(NCCL_COLL, "stream is being captured by a new graph, id %llu", cudaGraphId); + CUDACHECK(hipStreamGetCaptureInfo_v2(comm->userStream, &captureStatus, &hipGraphId, graph, NULL, NULL)); + if (captureStatus == hipStreamCaptureStatusActive) { + if (hipGraphId != comm->lastCudaGraphId) { + INFO(NCCL_COLL, "stream is being captured by a new graph, id %llu", hipGraphId); // We are in a new graph, hence need to forget the last setup node so that // the first setup node in the new graph will not have a dependency - comm->lastCudaGraphId = cudaGraphId; + comm->lastCudaGraphId = hipGraphId; comm->lastSetupNode = NULL; } if (comm->launchMode == ncclComm::GROUP) comm->launchMode = ncclComm::GROUP_GRAPH; comm->usingCudaGraph = 1; + + // Create helper thread that closes IPC handles during graph destruction + // Only create this thread when buffer registration is enabled + if ((!comm->graphHelperThread) && comm->graphRegister == 1 && comm->disableGraphHelper == 0) { + pthread_mutex_init(&comm->graphHelperResources->threadLock, NULL); + pthread_cond_init(&comm->graphHelperResources->threadCond, NULL); + comm->graphHelperResources->threadState = ThreadStart; + pthread_create(&comm->graphHelperThread, NULL, graphHelperFunc, comm->graphHelperResources); + } } #endif return ncclSuccess; } -ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph) { +ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, hipGraph_t graph) { #if CUDART_VERSION >= 11030 struct ncclQueueInfo* eqInfo = comm->enqueueInfo; // Create a CUDA object to wrap around the argument space // which CUDA graph would manage lifetime of - cudaUserObject_t object; - CUDACHECK(cudaUserObjectCreate(&object, eqInfo, ncclDestroyQueueInfo, 1/*initialRefcount*/, cudaUserObjectNoDestructorSync)); - CUDACHECK(cudaGraphRetainUserObject(graph, object, 1, cudaGraphUserObjectMove)); + hipUserObject_t object; + CUDACHECK(hipUserObjectCreate(&object, eqInfo, ncclDestroyQueueInfo, 1/*initialRefcount*/, hipUserObjectNoDestructorSync)); + CUDACHECK(hipGraphRetainUserObject(graph, object, 1, hipGraphUserObjectMove)); - cudaHostFn_t fn = ncclEnqueueHostSetup<1>; + hipHostFn_t fn = ncclEnqueueHostSetup<1>; // Add a CPU node to the graph - cudaGraphNode_t setupNode; - cudaHostNodeParams setupNodeParams = {fn, eqInfo}; + hipGraphNode_t setupNode; + hipHostNodeParams setupNodeParams = {fn, eqInfo}; int numDependencies = comm->lastSetupNode == NULL ? 0 : 1; - CUDACHECK(cudaGraphAddHostNode(&setupNode, graph, &comm->lastSetupNode, numDependencies, &setupNodeParams)); - CUDACHECK(cudaStreamUpdateCaptureDependencies(comm->userStream, &setupNode, 1, cudaStreamAddCaptureDependencies)); + CUDACHECK(hipGraphAddHostNode(&setupNode, graph, &comm->lastSetupNode, numDependencies, &setupNodeParams)); + CUDACHECK(hipStreamUpdateCaptureDependencies(comm->userStream, &setupNode, 1, hipStreamAddCaptureDependencies)); comm->lastSetupNode = setupNode; return ncclSuccess; #else @@ -1049,6 +1163,74 @@ ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph) { #endif } +static ncclResult_t hostToDevRedOp( + ncclDevRedOpFull *opFull, ncclRedOp_t op, ncclDataType_t datatype, ncclComm *comm + ) { + union { + int8_t i8; + uint8_t u8; + int32_t i32; + uint32_t u32; + int64_t i64; + uint64_t u64; + half f16; + #if defined(RCCL_BFLOAT16) + rccl_bfloat16 bf16; + #endif + float f32; + double f64; + void *ptr; + }; + u64 = 0; + opFull->scalarArgIsPtr = false; + switch (int(op)) { + case ncclSum: opFull->op = ncclDevSum; break; + case ncclProd: opFull->op = ncclDevProd; break; + case ncclMax: opFull->op = ncclDevMax; break; + case ncclMin: opFull->op = ncclDevMin; break; + case ncclAvg: + switch ((int)datatype) { + case ncclInt8: case ncclInt32: case ncclInt64: + case ncclUint8: case ncclUint32: case ncclUint64: + opFull->op = ncclDevSumPostDiv; + u64 = comm->nRanks; + break; + case ncclFloat16: + opFull->op = ncclDevPreMulSum; + f16 = __float2half(float(1.0/comm->nRanks)); // __double2half not supported pre CUDA 11.x + break; + #if defined(RCCL_BFLOAT16) + case ncclBfloat16: + opFull->op = ncclDevPreMulSum; + bf16 = (rccl_bfloat16)(float(1.0/comm->nRanks)); + break; + #endif + case ncclFloat32: + opFull->op = ncclDevPreMulSum; + f32 = float(1.0/comm->nRanks); + break; + case ncclFloat64: + opFull->op = ncclDevPreMulSum; + f64 = 1.0/comm->nRanks; + break; + } + opFull->scalarArgIsPtr = false; + opFull->scalarArg = u64; + break; + default: // user created + int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps); + ncclUserRedOp *user = &comm->userRedOps[ix]; + if (datatype != user->datatype) { + WARN("Data type supplied to user-created ncclRedOp_t does not match type " + "given to reduction operation"); + return ncclInvalidArgument; + } + *opFull = user->opFull; + break; + } + return ncclSuccess; +} + ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) { // [RCCL] Check for clique-based kernel support { @@ -1064,40 +1246,39 @@ ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) { } // [/RCCL] + ncclResult_t ret = ncclSuccess; + bool isAsync = ncclAsyncMode(); + int savedDev = -1; + // Check arguments + NCCLCHECK(PtrCheck(info->comm, info->opName, "comm")); + if (isAsync && info->comm->checkPointers) { + CUDACHECKGOTO(hipGetDevice(&savedDev), ret, end); + CUDACHECKGOTO(hipSetDevice(info->comm->cudaDev), ret, end); + } + NCCLCHECKGOTO(ArgsCheck(info), ret, end); + + // Copy reduction op state from op handle into info struct here since the + // op handle may be destroyed before ncclGroupEnd(). + NCCLCHECKGOTO(hostToDevRedOp(&info->opFull, info->op, info->datatype, info->comm), ret, end); + // Launch asynchronously if needed - if (ncclAsyncMode()) { - ncclResult_t ret = ncclSuccess; - int savedDev = -1; - // Check arguments - NCCLCHECK(PtrCheck(info->comm, info->opName, "comm")); - if (info->comm->checkPointers) { - CUDACHECKGOTO(hipGetDevice(&savedDev), ret, end); - CUDACHECKGOTO(hipSetDevice(info->comm->cudaDev), ret, end); - } - NCCLCHECKGOTO(ArgsCheck(info), ret, end); + if (isAsync) { // Always register comm even in case of error to make sure ncclGroupEnd // cleans it up. NCCLCHECKGOTO(ncclAsyncColl(info->comm), ret, end); NCCLCHECKGOTO(checkSetStream(info), ret, end); - INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", + INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p devRedOp %d isPtr %d scaler %lx", info->opName, info->coll == ncclFuncSendRecv ? info->comm->p2pOpCount : info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count, - info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream); + info->datatype, info->op, info->root, info->comm, info->comm->nRanks, info->stream, info->opFull.op, info->opFull.scalarArgIsPtr, info->opFull.scalarArg); if (info->coll == ncclFuncSendRecv) { //p2p stored separately NCCLCHECKGOTO(ncclSaveP2p(info), ret, end); } else { NCCLCHECKGOTO(ncclSaveAsyncColl(info), ret, end); } - -end: - if (savedDev != -1) CUDACHECK(hipSetDevice(savedDev)); - ncclAsyncErrCheck(ret); - return ret; } else { - NCCLCHECK(PtrCheck(info->comm, info->opName, "comm")); - NCCLCHECK(ArgsCheck(info)); - NCCLCHECK(checkSetStream(info)); + NCCLCHECKGOTO(checkSetStream(info), ret, end); INFO(NCCL_COLL,"%s: opCount %lx sendbuff %p recvbuff %p count %zi datatype %d op %d root %d comm %p [nranks=%d] stream %p", info->opName, info->comm->collOpCount, info->sendbuff, info->recvbuff, info->count, @@ -1106,24 +1287,82 @@ end: // Check whether we are in cuda graph mode cudaGraph_t graph; ncclComm_t comm = info->comm; - NCCLCHECK(ncclGetCudaGraph(comm, &graph)); + NCCLCHECKGOTO(ncclGetCudaGraph(comm, &graph), ret, end); // Common part between graph mode and non-graph mode - NCCLCHECK(ncclSetupCollKernel(info)); + NCCLCHECKGOTO(ncclSetupCollKernel(info), ret, end); // Host setup if (comm->usingCudaGraph) { - NCCLCHECK(ncclCudaGraphHostSetup(comm, graph)); + NCCLCHECKGOTO(ncclCudaGraphHostSetup(comm, graph), ret, end); } else { ncclEnqueueHostSetup<0>(comm->enqueueInfo); - NCCLCHECK(comm->enqueueInfo->ret); + NCCLCHECKGOTO(comm->enqueueInfo->ret, ret, end); } // Common part between graph mode and non-graph mode - NCCLCHECK(ncclLaunchBarrier(comm)); - NCCLCHECK(ncclLaunchKernel(comm)); - NCCLCHECK(ncclRecordEvents(comm)); - NCCLCHECK(ncclLaunchReset(comm)); - return ncclSuccess; + NCCLCHECKGOTO(ncclLaunchBarrier(comm), ret, end); + NCCLCHECKGOTO(ncclLaunchKernel(comm), ret, end); + NCCLCHECKGOTO(ncclRecordEvents(comm), ret, end); + NCCLCHECKGOTO(ncclLaunchReset(comm), ret, end); } +end: + if (isAsync && savedDev != -1) CUDACHECK(hipSetDevice(savedDev)); + if (isAsync) ncclAsyncErrCheck(ret); + return ret; +} + +NCCL_API(ncclResult_t, ncclRedOpCreatePreMulSum, ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); +ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm) { + if (comm->userRedOpFreeHead == comm->userRedOpCapacity) { + // double capacity and resize + int cap = 2*comm->userRedOpCapacity; + if (cap < 4) cap = 4; + ncclUserRedOp *ops = new ncclUserRedOp[cap]; + std::memcpy(ops, comm->userRedOps, comm->userRedOpCapacity*sizeof(ncclUserRedOp)); + for(int ix=comm->userRedOpCapacity; ix < cap; ix++) + ops[ix].freeNext = ix + 1; + delete[] comm->userRedOps; + comm->userRedOps = ops; + comm->userRedOpCapacity = cap; + } + // pop from free list + int ix = comm->userRedOpFreeHead; + ncclUserRedOp *user = &comm->userRedOps[ix]; + comm->userRedOpFreeHead = user->freeNext; + + user->freeNext = -1; // allocated + user->datatype = datatype; + user->opFull.op = ncclDevPreMulSum; + if (residence == ncclScalarHostImmediate) { + user->opFull.scalarArgIsPtr = false; + std::memcpy(&user->opFull.scalarArg, scalar, ncclTypeSize(datatype)); + } else { + user->opFull.scalarArgIsPtr = true; + user->opFull.scalarArg = reinterpret_cast(scalar); + } + *op = ncclRedOp_t(int(ncclNumOps) + ix); + *op = ncclUserRedOpMangle(comm, *op); + return ncclSuccess; +} + +NCCL_API(ncclResult_t, ncclRedOpDestroy, ncclRedOp_t op, ncclComm_t comm); +ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm) { + if (0 <= int(op) && int(op) < int(ncclNumOps)) { + WARN("ncclRedOpDestroy : operator is a NCCL builtin."); + return ncclInvalidArgument; + } + if (int(op) < 0 || int(ncclMaxRedOp) < int(op)) { + WARN("ncclRedOpDestroy : operator is garbage."); + return ncclInvalidArgument; + } + int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps); + if (comm->userRedOpCapacity <= ix || comm->userRedOps[ix].freeNext != -1) { + WARN("ncclRedOpDestroy : operator unknown to this communicator."); + return ncclInvalidArgument; + } + // push to free list + comm->userRedOps[ix].freeNext = comm->userRedOpFreeHead; + comm->userRedOpFreeHead = ix; + return ncclSuccess; } diff --git a/projects/rccl/src/graph/topo.cc b/projects/rccl/src/graph/topo.cc index 0e72141769..226217c1b7 100644 --- a/projects/rccl/src/graph/topo.cc +++ b/projects/rccl/src/graph/topo.cc @@ -378,8 +378,11 @@ ncclResult_t ncclTopoAddGpu(struct ncclXmlNode* xmlGpu, struct ncclTopoSystem* s return ncclSuccess; } -struct kvDict kvDictPciClass[] = { { "0x060400", PCI }, { "0x068000", NVS }, { "0x068001", CPU }, { "0x030200", GPU }, { "0x030000", GPU }, { "0x038000", GPU }, { "0x020700", NIC }, { "0x020000", NIC }, { NULL, PCI /* Default fallback value */ } }; -struct kvDict kvDictPciGen[] = { { "2.5 GT/s", 15 }, { "5 GT/s", 30 }, { "8 GT/s", 60 }, { "16 GT/s", 120 }, { "32 GT/s", 240 }, { "8.0 GT/s", 60 }, { "16.0 GT/s", 120 }, { "32.0 GT/s", 240 }, { NULL, 60 /* Default fallback */ } }; // x100 Mbps per lane +struct kvDict kvDictPciClass[] = { { "0x060400", PCI }, { "0x068000", NVS }, { "0x068001", CPU }, { "0x03", GPU }, { "0x02", NIC }, { NULL, PCI /* Default fallback value */ } }; +struct kvDict kvDictPciGen[] = { + { "2.5 GT/s", 15 }, { "5 GT/s", 30 }, { "8 GT/s", 60 }, { "16 GT/s", 120 }, /* Kernel 5.6 and earlier */ + { "2.5 GT/s PCIe", 15 }, { "5.0 GT/s PCIe", 30 }, { "8.0 GT/s PCIe", 60 }, { "16.0 GT/s PCIe", 120 }, { "32.0 GT/s PCIe", 240 }, { "64.0 GT/s PCIe", 480 }, + { NULL, 60 /* Default fallback */ } }; // x100 Mbps per lane ncclResult_t ncclTopoAddPci(struct ncclXmlNode* xmlPci, struct ncclTopoSystem* system, struct ncclTopoNode* parent) { const char* str; diff --git a/projects/rccl/src/graph/xml.cc b/projects/rccl/src/graph/xml.cc index b9cba9953d..6b2c2e2e65 100644 --- a/projects/rccl/src/graph/xml.cc +++ b/projects/rccl/src/graph/xml.cc @@ -660,7 +660,7 @@ ncclResult_t ncclTopoGetXmlFromGpu(struct ncclXmlNode* pciNode, nvmlDevice_t nvm NCCLCHECK(xmlGetAttrInt(gpuNode, "arch", &arch.value)); struct ncclXmlNode* nvlNode = NULL; - NCCLCHECK(xmlGetSub(pciNode, "nvlink", &nvlNode)); + NCCLCHECK(xmlGetSub(gpuNode, "nvlink", &nvlNode)); if (nvlNode == NULL) { #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) const char* busId; diff --git a/projects/rccl/src/graph/xml.h b/projects/rccl/src/graph/xml.h index 8c350935b6..07d4c8d6ee 100644 --- a/projects/rccl/src/graph/xml.h +++ b/projects/rccl/src/graph/xml.h @@ -1,5 +1,6 @@ /************************************************************************* * Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -7,6 +8,11 @@ #ifndef XML_H_ #define XML_H_ +#include "nccl.h" +#include "debug.h" +#include "checks.h" +#include + // A few constraints to make the implementation easy #define MAX_STR_LEN 255 #define MAX_ATTR_COUNT 16 diff --git a/projects/rccl/src/group.cc b/projects/rccl/src/group.cc index 3be1e9368a..8a8d212a61 100644 --- a/projects/rccl/src/group.cc +++ b/projects/rccl/src/group.cc @@ -9,6 +9,7 @@ #include "debug.h" #include "enqueue.h" #include "transport.h" +#include #define MAX_ASYNC_OPS 128 thread_local pthread_t ncclGroupThreads[MAX_ASYNC_OPS]; diff --git a/projects/rccl/src/include/alloc.h b/projects/rccl/src/include/alloc.h index 0dad1eb1ad..9307e048f7 100644 --- a/projects/rccl/src/include/alloc.h +++ b/projects/rccl/src/include/alloc.h @@ -12,6 +12,9 @@ #include "checks.h" #include "align.h" #include +#include +#include +#include template static ncclResult_t ncclCudaHostCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) { @@ -28,7 +31,7 @@ static inline ncclResult_t ncclCudaHostFree(void* ptr) { } template -static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, int line) { +static ncclResult_t ncclCalloc(T** ptr, size_t nelem) { void* p = malloc(nelem*sizeof(T)); if (p == NULL) { WARN("Failed to malloc %ld bytes", nelem*sizeof(T)); @@ -36,10 +39,8 @@ static ncclResult_t ncclCallocDebug(T** ptr, size_t nelem, const char *filefunc, } memset(p, 0, nelem*sizeof(T)); *ptr = (T*)p; - INFO(NCCL_ALLOC, "%s:%d Mem Alloc Size %ld pointer %p", filefunc, line, nelem*sizeof(T), *ptr); return ncclSuccess; } -#define ncclCalloc(...) ncclCallocDebug(__VA_ARGS__, __FILE__, __LINE__) struct __attribute__ ((aligned(64))) allocationTracker { union { diff --git a/projects/rccl/src/include/bootstrap.h b/projects/rccl/src/include/bootstrap.h index 2bbd97c7d2..6f3f02cdb4 100644 --- a/projects/rccl/src/include/bootstrap.h +++ b/projects/rccl/src/include/bootstrap.h @@ -17,7 +17,8 @@ ncclResult_t bootstrapInit(ncclUniqueId* id, int rank, int nranks, void** commSt ncclResult_t bootstrapAllGather(void* commState, void* allData, int size); ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); -ncclResult_t bootstrapBarrier(void* commState, int *ranks, int tag, int rank, int nranks); +ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag); +ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, int nranks, void* allData, int size); ncclResult_t bootstrapRemAlloc(size_t size, int rank, void* commState, int* id, hipIpcMemHandle_t* ipc, void** ptr); ncclResult_t bootstrapRemFree(int id, int rank, void* commState); ncclResult_t bootstrapClose(void* commState); diff --git a/projects/rccl/src/include/collectives.h b/projects/rccl/src/include/collectives.h index e4283b9db1..1d0c56cf54 100644 --- a/projects/rccl/src/include/collectives.h +++ b/projects/rccl/src/include/collectives.h @@ -8,78 +8,104 @@ #ifndef NCCL_COLLECTIVES_H_ #define NCCL_COLLECTIVES_H_ -#define FUNC_INDEX_P2P (NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumOps) -#define FUNC_INDEX(func, redop, ncclType, al, pr) ((((((func)*ncclNumOps + (redop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr)) +enum ncclDevRedOp_t { + ncclDevSum, ncclDevProd, ncclDevMax, ncclDevMin, + ncclDevPreMulSum, ncclDevSumPostDiv, + ncclNumDevRedOps +}; +struct ncclDevRedOpFull { + ncclDevRedOp_t op; + bool scalarArgIsPtr; + uint64_t scalarArg; +}; -#define NCCL_FUNC_NAME(func, algo, proto, redop, type) \ - ncclFunction_##func##_##algo##_##proto##_##redop##_##type +#define FUNC_INDEX_P2P (ncclNumTypes+NCCL_NUM_FUNCTIONS*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS*ncclNumTypes*ncclNumDevRedOps) +#define FUNC_INDEX(func, devredop, ncclType, al, pr) ((((((func)*ncclNumDevRedOps + (devredop))*ncclNumTypes) + (ncclType))*NCCL_NUM_ALGORITHMS+(al))*NCCL_NUM_PROTOCOLS+(pr)) -#define NCCL_KERN_NAME(func, algo, proto, redop, type) \ - ncclKernel_##func##_##algo##_##proto##_##redop##_##type +#define NCCL_FUNC_NAME(func, algo, proto, devredop, type) \ + ncclFunction_##func##_##algo##_##proto##_##devredop##_##type + +#define NCCL_ONERANK_REDUCE_NAME(devredop, type) \ + ncclFunction_OneRankReduce_##devredop##_##type + +#define NCCL_KERN_NAME(func, algo, proto, devredop, type) \ + ncclKernel_##func##_##algo##_##proto##_##devredop##_##type #define NCCL_IMPL_NAME(func, algo, proto) \ nccl##func##algo##proto /* Declare all collective operations */ -#define DECL5(func, algo, proto, redop, type) \ - extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, redop, type)(struct ncclWorkElem* args); \ - extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(struct ncclWorkElem first); \ +#define DECL5(func, algo, proto, devredop, type) \ + extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(struct ncclWorkElem* args); \ + extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(ncclWorkElem c); \ - //extern __device__ void NCCL_FUNC_NAME(func, algo, proto, redop, type)(); \ - //extern __global__ void NCCL_KERN_NAME(func, algo, proto, redop, type)(ncclWorkElem c); \ +#define CONCAT(a,b) a##b +#define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(t, f) +#define MACRO_IF_0(t, f) f +#define MACRO_IF_1(t, f) t -#define DECL4(func, algo, redop, type) \ - DECL5(func, algo, SIMPLE, redop, type) \ - DECL5(func, algo, LL, redop, type) \ - DECL5(func, algo, LL128, redop, type) +#define DECL4(func, algo, devredop, type, undef) \ + MACRO_IF(undef, /*undefined*/, DECL5(func, algo, SIMPLE, devredop, type)) \ + MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL, devredop, type)) \ + MACRO_IF(undef, /*undefined*/, DECL5(func, algo, LL128, devredop, type)) -#define DECL3(func, redop, type) \ - DECL4(func, RING, redop, type) \ - DECL4(func, TREE, redop, type) \ - DECL4(func, COLLNET, redop, type) +#define DECL3(func, devredop, type, undef) \ + DECL4(func, RING, devredop, type, undef) \ + DECL4(func, TREE, devredop, type, undef) \ + DECL4(func, COLLNET, devredop, type, undef) -#if defined(__CUDA_BF16_TYPES_EXIST__) -#define DECL2(func, redop) \ - DECL3(func, redop, int8_t) \ - DECL3(func, redop, uint8_t) \ - DECL3(func, redop, int32_t) \ - DECL3(func, redop, uint32_t) \ - DECL3(func, redop, int64_t) \ - DECL3(func, redop, uint64_t) \ - DECL3(func, redop, half) \ - DECL3(func, redop, float) \ - DECL3(func, redop, double) \ - DECL3(func, redop, __nv_bfloat16) +#if defined(RCCL_BFLOAT16) +#define DECL2(func, devredop, undefForFloat) \ + DECL3(func, devredop, int8_t, /*undef=*/0) \ + DECL3(func, devredop, uint8_t, /*undef=*/0) \ + DECL3(func, devredop, int32_t, /*undef=*/0) \ + DECL3(func, devredop, uint32_t, /*undef=*/0) \ + DECL3(func, devredop, int64_t, /*undef=*/0) \ + DECL3(func, devredop, uint64_t, /*undef=*/0) \ + DECL3(func, devredop, half, /*undef=*/undefForFloat) \ + DECL3(func, devredop, float, /*undef=*/undefForFloat) \ + DECL3(func, devredop, double, /*undef=*/undefForFloat) \ + DECL3(func, devredop, rccl_bfloat16, /*undef=*/undefForFloat) #else -#define DECL2(func, redop) \ - DECL3(func, redop, int8_t) \ - DECL3(func, redop, uint8_t) \ - DECL3(func, redop, int32_t) \ - DECL3(func, redop, uint32_t) \ - DECL3(func, redop, int64_t) \ - DECL3(func, redop, uint64_t) \ - DECL3(func, redop, half) \ - DECL3(func, redop, float) \ - DECL3(func, redop, double) \ - DECL3(func, redop, rccl_bfloat16) +#define DECL2(func, devredop, undefForFloat) \ + DECL3(func, devredop, int8_t, /*undef=*/0) \ + DECL3(func, devredop, uint8_t, /*undef=*/0) \ + DECL3(func, devredop, int32_t, /*undef=*/0) \ + DECL3(func, devredop, uint32_t, /*undef=*/0) \ + DECL3(func, devredop, int64_t, /*undef=*/0) \ + DECL3(func, devredop, uint64_t, /*undef=*/0) \ + DECL3(func, devredop, half, /*undef=*/undefForFloat) \ + DECL3(func, devredop, float, /*undef=*/undefForFloat) \ + DECL3(func, devredop, double, /*undef=*/undefForFloat) #endif #define DECL(func) \ - DECL2(func, Sum) \ - DECL2(func, Prod) \ - DECL2(func, Min) \ - DECL2(func, Max) \ - DECL2(func, Avg) \ + DECL2(func, Sum, /*undefForFloat=*/0) \ + DECL2(func, Prod, /*undefForFloat=*/0) \ + DECL2(func, Min, /*undefForFloat=*/0) \ + DECL2(func, Max, /*undefForFloat=*/0) \ + DECL2(func, PreMulSum, /*undefForFloat=*/0) \ + DECL2(func, SumPostDiv, /*undefForFloat=*/1) -#define DECL_ALL \ - DECL2(Broadcast, Sum) \ - DECL(Reduce) \ - DECL2(AllGather, Sum) \ - DECL(ReduceScatter) \ - DECL(AllReduce) \ - DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) \ +DECL2(Broadcast, Sum, /*undefForFloat=*/0) +DECL(Reduce) +DECL2(AllGather, Sum, /*undefForFloat=*/0) +DECL(ReduceScatter) +DECL(AllReduce) +DECL5(SendRecv, RING, SIMPLE, Sum, int8_t) -DECL_ALL +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, half)(struct ncclWorkElem* args); +#if defined(RCCL_BFLOAT16) +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16)(struct ncclWorkElem* args); +#endif +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, float)(struct ncclWorkElem* args); +extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(struct ncclWorkElem* args); // CHUNKSIZE must be a multiple of SLICESIZE //#define ALLREDUCE_SLICESTEPS (NCCL_STEPS/4) @@ -99,5 +125,6 @@ DECL_ALL #define REDUCE_SLICESTEPS 1 #define REDUCE_CHUNKSTEPS 1 #define SENDRECV_SLICEFACTOR 1 +#define NCCL_MAX_SLICE_PER_CHUNK 2 // max value for CHUNKSTEPS/SLICESTEPS, must accord with above #endif diff --git a/projects/rccl/src/include/comm.h b/projects/rccl/src/include/comm.h index a2bb62daba..5ccd0eba97 100644 --- a/projects/rccl/src/include/comm.h +++ b/projects/rccl/src/include/comm.h @@ -19,6 +19,8 @@ typedef void *cudaGraphNode_t; #define HIPRT_CB #else +#include "collectives.h" + #if CUDART_VERSION < 9000 struct cudaLaunchParams { void *func; @@ -40,13 +42,16 @@ struct cudaLaunchParams { #define NCCL_LL128_THREAD_THRESHOLD 8 #define NCCL_SIMPLE_THREAD_THRESHOLD 64 +#define NCCL_MAX_INTRA_RANKS 32 + struct ncclSendMem { union { struct { uint64_t head; char pad1[CACHE_LINE_SIZE-sizeof(uint64_t)]; void* ptrExchange; - char pad2[CACHE_LINE_SIZE-sizeof(void*)]; + uint64_t redOpArgExchange[2]; + char pad2[CACHE_LINE_SIZE-sizeof(void*)-2*sizeof(uint64_t)]; }; char pad3[MEM_ALIGN]; }; @@ -66,6 +71,28 @@ struct ncclRecvMem { char buff[1]; // Actually larger than that }; +typedef hipError_t(*pfn_cuMemGetAddressRange_t)(void**, size_t*, void*); + +enum helperThreadState {ThreadStart, ThreadStop}; + +#define NCCL_IPC_POOL_SIZE (2*NCCL_MAX_INTRA_RANKS*NCCL_MAX_OPS) + +struct ncclGraphHelperResources { + ncclComm* comm; + pthread_mutex_t threadLock; + pthread_cond_t threadCond; + enum helperThreadState threadState; + void* ipcBases[NCCL_IPC_POOL_SIZE]; + int ipcTail; + int ipcHead; +}; + +struct ncclUserRedOp { + int freeNext; // -1=allocated, otherwise index of next free entry in array + ncclDataType_t datatype; + ncclDevRedOpFull opFull; +}; + struct ncclComm { struct ncclChannel channels[MAXCHANNELS]; @@ -86,7 +113,12 @@ struct ncclComm { int node; int nNodes; + + // Intra-node rank info + int intraNodeGlobalRanks[NCCL_MAX_INTRA_RANKS]; int localRanks; + int intraNodeRank; + int8_t* rankToIntraNodeRank; enum { GROUP, PARALLEL, GROUP_GRAPH } launchMode; hipStream_t userStream; @@ -158,6 +190,7 @@ struct ncclComm { // Whether this communicator uses collNet int collNetSupport; + int intraHighestTransportType; // Store info of async operations struct ncclInfo* asyncOps; @@ -181,9 +214,38 @@ struct ncclComm { // Store info for cudaGraph int usingCudaGraph; // Only use it during capture time, not launch time struct ncclQueueInfo* enqueueInfo; + int nQueueInfoCreated; + int nQueueInfoDestroyed; cudaGraphNode_t lastSetupNode; unsigned long long lastCudaGraphId; int driverVersion; + pfn_cuMemGetAddressRange_t pfnCuMemGetAddressRange; + pthread_t graphHelperThread; + struct ncclGraphHelperResources* graphHelperResources; + int disableGraphHelper; + int graphRegister; + + // user-created reduction ops + int userRedOpCapacity, userRedOpFreeHead; + ncclUserRedOp *userRedOps; }; +// Scrambles the bits of non-builtin values of ncclRedOp_t according to the +// communicator memory address. Used to catch bugs so that integer handles +// associated with this communicator won't collide with handles of other +// communicatrs. This function is its own inverse. +static inline ncclRedOp_t ncclUserRedOpMangle(ncclComm *comm, ncclRedOp_t op) { + // Preserve the built-in values. + if(int(op) < int(ncclNumOps)) + return op; + uint64_t h = reinterpret_cast(comm); + h ^= h >> 32; + h *= 0x9e3779b97f4a7c13u; // Knuth's 64-bit magical hash constant + h >>= 32; // h is now an excellent 32-bit hash of the comm pointer + h &= int(ncclMaxRedOp); // ncclMaxRedOp is a power of 2 minus 1 + int op1 = int(h) ^ int(op); + // Since builtin values are preserved, we also have to preserve their preimage. + return op1 < int(ncclNumOps) ? op : ncclRedOp_t(op1); +} + #endif diff --git a/projects/rccl/src/include/core.h b/projects/rccl/src/include/core.h index 6d0a2762e9..c3aae4bf3c 100644 --- a/projects/rccl/src/include/core.h +++ b/projects/rccl/src/include/core.h @@ -1,10 +1,6 @@ /************************************************************************* -<<<<<<< HEAD - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. -======= * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. ->>>>>>> nccl/master + * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ diff --git a/projects/rccl/src/include/debug.h b/projects/rccl/src/include/debug.h index e7a152cc97..6ce90ee375 100644 --- a/projects/rccl/src/include/debug.h +++ b/projects/rccl/src/include/debug.h @@ -7,17 +7,14 @@ #ifndef NCCL_DEBUG_H_ #define NCCL_DEBUG_H_ -#include "core.h" - +#include "nccl_net.h" #include #include #include #include #include -#include "nccl_net.h" - -#define gettid() (pid_t) syscall(SYS_gettid) +#include extern int ncclDebugLevel; extern uint64_t ncclDebugMask; diff --git a/projects/rccl/src/include/devcomm.h b/projects/rccl/src/include/devcomm.h index 08398a7e3a..e69a71c330 100644 --- a/projects/rccl/src/include/devcomm.h +++ b/projects/rccl/src/include/devcomm.h @@ -96,8 +96,11 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 2 #define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS) -#define NCCL_DIRECT_GPU 0x01 -#define NCCL_DIRECT_NIC 0x10 +#define NCCL_DIRECT_WRITE 0x01 +#define NCCL_DIRECT_READ 0x02 +#define NCCL_DIRECT_NIC 0x04 +#define NCCL_IPC_WRITE 0x08 +#define NCCL_IPC_READ 0x10 struct ncclConnInfo { // Regular comm mechanism @@ -108,6 +111,7 @@ struct ncclConnInfo { int direct; // Direct communication int shared; // Buffers are shared void **ptrExchange; // Pointer exchange for direct communication + uint64_t* redOpArgExchange; // PreOp scaler exchange for direct pull case int *sizesFifo; // Sizes fifo from GPU to proxy void* *ptrsFifo; // Buffer fifo from proxy to GPU @@ -175,7 +179,7 @@ struct ncclPeer { struct ncclDevComm; #pragma pack(push) /* push current alignment to stack */ -#pragma pack(8) /* set alignment to 4 bytes boundary */ +#pragma pack(8) /* set alignment to 8 bytes boundary */ #define NCCL_MAX_WORK_ELEMENTS 1 #define NCCL_MAX_GROUPS (NCCL_MAX_NTHREADS/WARP_SIZE) @@ -187,8 +191,9 @@ struct ncclWorkElem { struct ncclDevComm* comm; uint16_t nThreads; uint16_t funcIndex; - uint16_t index; - uint16_t active; + uint8_t regUsed; + uint8_t direct; + uint8_t active, redOpArgIsPtr; const void * sendbuff; void * recvbuff; @@ -198,10 +203,12 @@ struct ncclWorkElem { struct { size_t count; size_t lastChunkSize; - uint32_t root; + uint64_t redOpArg; + uint16_t root; uint8_t bid; uint8_t nChannels; - uint8_t connIndex; + uint16_t connIndex; + uint16_t opCount; } coll; struct { size_t sendCount; @@ -217,18 +224,16 @@ struct ncclWorkElem { }; uint16_t padding; }; - } p2p; - struct { - uint16_t padding[15]; uint16_t opCount; - } op; + } p2p; // [RCCL] Clique-based arguments // NOTE: Follows same field structure as coll // because nChannels is accessed from "coll" struct. struct { size_t count; cliqueDevicePtrs_t* ptrs; - uint32_t unused; + uint64_t unused_1; + uint16_t unused_2; uint8_t bid; uint8_t nChannels; } clique; @@ -236,11 +241,24 @@ struct ncclWorkElem { uint64_t align[4]; }; }; -struct ncclWork { - struct ncclWorkElem elems[NCCL_MAX_WORK_ELEMENTS]; -}; static_assert(sizeof(struct ncclWorkElem) == (0x10*sizeof(int)), "ncclWorkElem must have a pow2 size"); +struct ncclWorkRegElem { + struct ncclWorkElem elem; + void* dnInputs[NCCL_MAX_DIRECT_ARITY+1]; + void* dnOutputs[NCCL_MAX_DIRECT_ARITY+1]; + void* upOutputs[NCCL_MAX_DIRECT_ARITY+1]; +}; +#define NCCL_REG_ELEM_FACTOR 4 +static_assert(sizeof(struct ncclWorkRegElem) == (NCCL_REG_ELEM_FACTOR*sizeof(struct ncclWorkElem)), "ncclWorkRegElem size must be pow2 times ncclWorkElem size"); + +struct ncclWork { + union { + struct ncclWorkElem elems[NCCL_MAX_WORK_ELEMENTS]; + struct ncclWorkRegElem regElems[NCCL_MAX_WORK_ELEMENTS/NCCL_REG_ELEM_FACTOR]; + }; +}; + struct ncclChannel { union { struct { @@ -330,6 +348,7 @@ struct ncclProf { typedef enum { ncclCollTraceNotReady, ncclCollTraceKernelLaunchType, + ncclCollTraceKernelEndType, ncclCollTraceCollEndType, ncclCollTraceAbortType, ncclCollTraceDataType diff --git a/projects/rccl/src/include/enqueue.h b/projects/rccl/src/include/enqueue.h index 6cf0415a29..b059879929 100644 --- a/projects/rccl/src/include/enqueue.h +++ b/projects/rccl/src/include/enqueue.h @@ -1,5 +1,6 @@ /************************************************************************* * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ @@ -30,10 +31,19 @@ void HIPRT_CB ncclEnqueueHostSetup(void* arg); ncclResult_t ncclGetCudaGraph(ncclComm_t comm, cudaGraph_t* graph); ncclResult_t ncclCudaGraphHostSetup(ncclComm_t comm, cudaGraph_t graph); +struct ncclBuffRegInfo { + void* sendbuffsBase[NCCL_MAX_INTRA_RANKS]; + void* recvbuffsBase[NCCL_MAX_INTRA_RANKS]; + void* sendbuffs[NCCL_MAX_INTRA_RANKS]; + void* recvbuffs[NCCL_MAX_INTRA_RANKS]; + int nBuffs; +}; + // Enqueue information (for kernel and proxy) for each operation struct ncclQueueElem { struct ncclWorkElem work; struct ncclProxyArgs proxyArgs; + struct ncclBuffRegInfo buffRegInfo; }; typedef ncclRecyclableList ncclQueueElemList; @@ -43,6 +53,7 @@ struct ncclQueueInfo { ncclComm_t comm; int maxChannels; // Dynamic version of gridDim ncclResult_t ret; // Return value of host setup call + int nRegBuffs; ncclQueueElemList* elemList; }; @@ -50,6 +61,7 @@ static ncclResult_t ncclCreateQueueInfo(struct ncclQueueInfo** eqInfo, ncclComm_ NCCLCHECK(ncclCalloc(eqInfo, 1)); (*eqInfo)->comm = comm; (*eqInfo)->elemList = new ncclQueueElemList(); + (*eqInfo)->comm->nQueueInfoCreated++; return ncclSuccess; } @@ -58,6 +70,7 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) { if (eqInfo == NULL) return ncclInternalError; eqInfo->maxChannels = 0; eqInfo->ret = ncclSuccess; + eqInfo->nRegBuffs = 0; eqInfo->elemList->recycle(); return ncclSuccess; } @@ -67,7 +80,54 @@ static ncclResult_t ncclResetQueueInfo(struct ncclQueueInfo* eqInfo) { static void ncclDestroyQueueInfo(void* ptr) { if (ptr == NULL) return; struct ncclQueueInfo* eqInfo = (struct ncclQueueInfo*)ptr; + struct ncclComm* comm = eqInfo->comm; + // Close IPC mem handles for registered buffers + struct ncclQueueElem* eqElem = eqInfo->elemList->begin(); +#if 0 + // Ideally, the deregistration should happen here + // but currently the destroy function of CUDA objects does not allow CUDA API calls + while (eqElem != NULL) { + for (int i=0; ibuffRegInfo.nBuffs; i++) { + if (i == eqInfo->comm->intraNodeRank) continue; + CUDACHECKIGNORE(cudaIpcCloseMemHandle(eqElem->buffRegInfo.sendbuffsBase[i])); + CUDACHECKIGNORE(cudaIpcCloseMemHandle(eqElem->buffRegInfo.recvbuffsBase[i])); + } + eqElem = eqInfo->elemList->getNext(); + } +#else + // Instead, we push these pointers to a pool owned by ncclComm + // and asks a helper thread to close mem handles + struct ncclGraphHelperResources* res = comm->graphHelperResources; + int ipcTailOld = 0; + if (res == NULL || (!comm->graphHelperThread) || eqInfo->nRegBuffs == 0) goto skip; + + pthread_mutex_lock(&res->threadLock); + ipcTailOld = res->ipcTail; + while (eqElem != NULL) { + for (int i=0; ibuffRegInfo.nBuffs; i++) { + if (eqElem->buffRegInfo.sendbuffsBase[i] != NULL) { + res->ipcBases[res->ipcTail] = eqElem->buffRegInfo.sendbuffsBase[i]; + res->ipcTail = (res->ipcTail+1)%NCCL_IPC_POOL_SIZE; + } + if (eqElem->buffRegInfo.recvbuffsBase[i] != NULL) { + res->ipcBases[res->ipcTail] = eqElem->buffRegInfo.recvbuffsBase[i]; + res->ipcTail = (res->ipcTail+1)%NCCL_IPC_POOL_SIZE; + } + } + eqElem = eqInfo->elemList->getNext(); + } + if (res->ipcTail != ipcTailOld) { + res->threadState = ThreadStart; + TRACE(NCCL_COLL, "CUDA Graph destroy function signaling helper thread with %d IPC handles", res->ipcTail-ipcTailOld); + pthread_cond_signal(&res->threadCond); + } + pthread_mutex_unlock(&res->threadLock); +#endif + +skip: delete eqInfo->elemList; free(eqInfo); + comm->nQueueInfoDestroyed++; + return; } #endif // End include guard diff --git a/projects/rccl/src/include/gdrwrap.h b/projects/rccl/src/include/gdrwrap.h index e3d923661f..761b9cf5e7 100644 --- a/projects/rccl/src/include/gdrwrap.h +++ b/projects/rccl/src/include/gdrwrap.h @@ -11,6 +11,7 @@ #include "nccl.h" #include // for standard [u]intX_t types #include +#include // These can be used if the GDR library isn't thread safe #include diff --git a/projects/rccl/src/include/info.h b/projects/rccl/src/include/info.h index d35eb99b8e..08a80f69e7 100644 --- a/projects/rccl/src/include/info.h +++ b/projects/rccl/src/include/info.h @@ -10,6 +10,7 @@ #include "nccl.h" #include "devcomm.h" +#include "collectives.h" typedef enum { ncclPatternRing, @@ -39,6 +40,7 @@ struct ncclInfo { int chunkSteps; int sliceSteps; // Computed later + ncclDevRedOpFull opFull; int algorithm; int protocol; ncclPattern_t pattern; diff --git a/projects/rccl/src/include/param.h b/projects/rccl/src/include/param.h index 6303ef8959..ca992d71c8 100644 --- a/projects/rccl/src/include/param.h +++ b/projects/rccl/src/include/param.h @@ -8,6 +8,7 @@ #ifndef NCCL_PARAM_H_ #define NCCL_PARAM_H_ +#include #include #include #include diff --git a/projects/rccl/src/include/transport.h b/projects/rccl/src/include/transport.h index 115bdc50f1..e64dfbf748 100644 --- a/projects/rccl/src/include/transport.h +++ b/projects/rccl/src/include/transport.h @@ -55,7 +55,7 @@ struct ncclTransport { }; ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, struct ncclChannel* channel, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex); -ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex); +ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL); enum { collNetRecv=0, collNetSend=1 }; int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type); diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index d90d8ea53a..7dfa0dbc19 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -50,7 +50,7 @@ std::chrono::high_resolution_clock::time_point ncclEpoch; const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+1] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv" }; const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNet" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; -const char* ncclRedOpStr[ncclNumOps] = { "Sum", "Prod", "Max", "Min" }; +const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" }; const char *ncclTypeStr[ncclNumTypes] = {"_i8", "_u8", "_i32", "_u32", "_i64", "_u64", "_f16", "_f32", "_f64", "_b16"}; NCCL_PARAM(GroupCudaStream, "GROUP_CUDA_STREAM", NCCL_GROUP_CUDA_STREAM); @@ -80,13 +80,21 @@ ncclResult_t initCollNet(ncclCollNet_t* collnet) { } ncclResult_t initNetPlugin(ncclNet_t** net, ncclCollNet_t** collnet) { - void* netPluginLib = dlopen("librccl-net.so", RTLD_NOW | RTLD_LOCAL); + char ncclNetPluginName[128]; + const char* envPluginName = getenv("NCCL_NET_PLUGIN"); + if (envPluginName && strlen(envPluginName)) { + snprintf(ncclNetPluginName, 128, "librccl-net-%s.so", envPluginName); + INFO(NCCL_INIT, "Plugin name set by env to %s\n", ncclNetPluginName); + } else { + sprintf(ncclNetPluginName, "librccl-net.so"); + } + void* netPluginLib = dlopen(ncclNetPluginName, RTLD_NOW | RTLD_LOCAL); if (netPluginLib == NULL) { // dlopen does not guarantee to set errno, but dlerror only gives us a // string, so checking errno doesn't hurt to try to provide a better // error message if (errno == ENOENT) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (librccl-net.so), using internal implementation"); + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (%s), using internal implementation", ncclNetPluginName); } else { INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load returned %d : %s.", errno, dlerror()); } @@ -199,23 +207,27 @@ RCCL_PARAM(KernelCollTraceEnable, "KERNEL_COLL_TRACE_ENABLE", 0); void *ncclCommThreadMain(void *arg) { ncclComm_t comm = (ncclComm_t)arg; int head = comm->hostDevComm.collTraceHead; - #define MAX_NAME_LENGTH 32 + #define MAX_NAME_LENGTH 64 char* func_names = (char *)malloc(MAX_NAME_LENGTH*(FUNC_INDEX_P2P+1)); for (int func = 0; func < NCCL_NUM_FUNCTIONS; func++) { for (int al = 0; al < NCCL_NUM_ALGORITHMS; al++) { for (int type = 0; type < ncclNumTypes; type++) { for (int pr = 0; pr < NCCL_NUM_PROTOCOLS; pr++) { - for (int redop = 0; redop < ncclNumOps; redop++) { - char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, redop, type, al, pr); + for (int devredop = 0; devredop < ncclNumDevRedOps; devredop++) { + char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX(func, devredop, type, al, pr); sprintf(line, "%s%s%s%s%s", ncclFuncStr[func], ncclAlgoStr[al], ncclProtoStr[pr], - ncclRedOpStr[redop], ncclTypeStr[type]); + ncclDevRedOpStr[devredop], ncclTypeStr[type]); } } } } } + for (int type = 0; type < ncclNumTypes; type++) { + char* line = func_names+MAX_NAME_LENGTH*(FUNC_INDEX_P2P-ncclNumTypes+type); + sprintf(line, "OneRankReducePreMulSum%s", ncclTypeStr[type]); + } char* line = func_names+MAX_NAME_LENGTH*FUNC_INDEX_P2P; - sprintf(line, "%s", ncclFuncStr[NCCL_NUM_FUNCTIONS]); + sprintf(line, "SendRecvRingSimpleSum_i8"); do { int tail = LOAD(comm->hostDevComm.collTraceTail)%COLLTRACE_NUM_ITEMS; int count; @@ -246,7 +258,7 @@ void *ncclCommThreadMain(void *arg) { fIdx, td->data_0, td->opCount, td->data_1); } else { sprintf(line, "## [%12.6f] [%02d:%02d] %06lx", - (double)(td->timeStamp)/VEGA_GPU_RTC_FREQUENCY, comm->rank, td->bid, td->opCount); + (double)(td->timeStamp)/VEGA_GPU_RTC_FREQUENCY, comm->rank, td->bid, fIdx == FUNC_INDEX_P2P ? (td->opCount + 0x100000): td->opCount); offset = strlen(line); switch (type) { case ncclCollTraceKernelLaunchType: @@ -261,18 +273,17 @@ void *ncclCommThreadMain(void *arg) { sprintf(line+offset, "nt %d bi %d nc %d busId %lx nRanks %d", td->coll.nThreads, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks); break; case ncclCollTraceCollEndType: - if (fIdx != 0xffff) { - sprintf(line+offset, " CE %s ", func_names+MAX_NAME_LENGTH*fIdx); - offset = strlen(line); - if (fIdx > FUNC_INDEX_P2P) - sprintf(line+offset, "ERROR bad function index %d", fIdx); - else if (fIdx == FUNC_INDEX_P2P) - sprintf(line+offset, "nt %d dt %d busId %lx nRanks %d", td->p2p.nThreads, td->p2p.delta, comm->busId, comm->nRanks); - else - sprintf(line+offset, "nt %d bi %d nc %d busId %lx nRanks %d", td->coll.nThreads, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks); - } + sprintf(line+offset, " CE %s ", func_names+MAX_NAME_LENGTH*fIdx); + offset = strlen(line); + if (fIdx > FUNC_INDEX_P2P) + sprintf(line+offset, "ERROR bad function index %d", fIdx); + else if (fIdx == FUNC_INDEX_P2P) + sprintf(line+offset, "nt %d dt %d busId %lx nRanks %d", td->p2p.nThreads, td->p2p.delta, comm->busId, comm->nRanks); else - sprintf(line+offset, " KE busId %lx nRanks %d", comm->busId, comm->nRanks); + sprintf(line+offset, "nt %d bi %d nc %d busId %lx nRanks %d", td->coll.nThreads, td->coll.bid, td->coll.nChannels, comm->busId, comm->nRanks); + break; + case ncclCollTraceKernelEndType: + sprintf(line+offset, " KE busId %lx nRanks %d", comm->busId, comm->nRanks); break; case ncclCollTraceAbortType: sprintf(line+offset, " Abort"); @@ -299,6 +310,9 @@ void *ncclCommThreadMain(void *arg) { static ncclResult_t commFree(ncclComm_t comm) { if (comm == NULL) return ncclSuccess; + + delete[] comm->userRedOps; + free(comm->connectSend); free(comm->connectRecv); for (int peer=0; peernRanks; peer++) { @@ -442,8 +456,6 @@ static ncclResult_t commFree(ncclComm_t comm) { CUDACHECK(hipStreamDestroy(comm->groupStream)); } - ncclDestroyQueueInfo(comm->enqueueInfo); - // Last rank frees shared resources between threads int isLast; NCCLCHECK(ncclCpuBarrierIn(comm, &isLast)); @@ -466,6 +478,8 @@ static ncclResult_t commFree(ncclComm_t comm) { RCCL_PARAM(CliqueIgnoreTopo, "CLIQUE_IGNORE_TOPO", 0); RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0); NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2); +NCCL_PARAM(DisableGraphHelper, "GRAPH_HELPER_DISABLE", 0); +NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 0); static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { if (ndev < 1) { @@ -509,7 +523,7 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { *comm->abortFlag = 0; comm->collOpCount = 0; - comm->p2pOpCount = 0x8000; + comm->p2pOpCount = 0; comm->argsptr = &comm->args; #ifdef ENABLE_PROFILING @@ -539,11 +553,20 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { comm->asyncAllocMode = ncclComm::SHORTEST_QUEUE; } + CUDACHECK(hipDriverGetVersion(&comm->driverVersion)); + NCCLCHECK(ncclCreateQueueInfo(&comm->enqueueInfo, comm)); comm->lastSetupNode = NULL; comm->lastCudaGraphId = -1; - - CUDACHECK(hipDriverGetVersion(&comm->driverVersion)); + comm->disableGraphHelper = ncclParamDisableGraphHelper(); + comm->graphRegister = ncclParamGraphRegister(); +#if CUDART_VERSION >= 11030 + NCCLCHECK(ncclCalloc(&comm->graphHelperResources, 1)); + comm->graphHelperResources->comm = comm; + if (comm->driverVersion >= 11030) + // hipGetDriverEntryPoint requires R465 or above (enhanced compat need) + CUDACHECK(hipGetDriverEntryPoint("cuMemGetAddressRange", (void**)&comm->pfnCuMemGetAddressRange, hipEnableDefault)); +#endif static_assert(MAXCHANNELS <= sizeof(*comm->connectSend)*8, "comm->connectSend must have enough bits for all channels"); static_assert(MAXCHANNELS <= sizeof(*comm->connectRecv)*8, "comm->connectRecv must have enough bits for all channels"); @@ -554,6 +577,10 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank) { NCCLCHECK(ncclCalloc(&comm->p2pSends, comm->nRanks)); NCCLCHECK(ncclCalloc(&comm->p2pRecvs, comm->nRanks)); + // Create a map between global rank and intra-node rank + NCCLCHECK(ncclCalloc(&comm->rankToIntraNodeRank, comm->nRanks)); + memset(comm->rankToIntraNodeRank, -1, comm->nRanks*sizeof(comm->rankToIntraNodeRank[0])); + // Mark channels as non initialized. for (int c=0; cchannels[c].id = -1; @@ -786,13 +813,14 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm int intraNodeRank0 = -1, intraNodeRank = -1, intraNodeRanks = 0; int myCompCap = allGather1Data[rank].cudaCompCap; int minCompCap = myCompCap, maxCompCap = myCompCap; - int intraNodeGlobalRanks[256]; for (int i = 0; i < nranks; i++) { if (allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) { // Rank is on same node if (intraNodeRanks == 0) intraNodeRank0 = i; if (i == rank) intraNodeRank = intraNodeRanks; - intraNodeGlobalRanks[intraNodeRanks++] = i; + comm->intraNodeGlobalRanks[intraNodeRanks] = i; + comm->rankToIntraNodeRank[i] = intraNodeRanks; + intraNodeRanks++; if (allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash) { // Rank is in same process if (intraProcRanks == 0) intraProcRank0 = i; @@ -821,6 +849,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } struct ncclComm* intraProcRank0Comm = allGather1Data[intraProcRank0].comm; uint64_t intraNodeRank0pidHash = allGather1Data[intraNodeRank0].peerInfo.pidHash; + comm->intraNodeRank = intraNodeRank; // AllGather1 - end @@ -1138,6 +1167,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Check if we can setup CollNet if (comm->collNetSupport > 0) { int collNetSetupFail = 0; + int highestTypes[NCCL_MAX_INTRA_RANKS] = {TRANSPORT_P2P}; // Find all head ranks int nHeads = collNetGraph.nChannels; int *heads; @@ -1163,16 +1193,26 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm TRACE(NCCL_INIT, "rank %d Connected inter-node CollNet", rank); // Connect intra-node CollNet + int highestTransportType0, highestTransportType1; for (int c=0; cnChannels; c++) { struct ncclChannel* channelRecv = comm->channels+c; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channelRecv, NCCL_MAX_DIRECT_ARITY, channelRecv->collTree.up, NCCL_MAX_DIRECT_ARITY, channelRecv->collTree.down, 0), ret, collnet_cleanup); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0), ret, collnet_cleanup); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 0, &highestTransportType0), ret, collnet_cleanup); for (int c=0; cnChannels; c++) { struct ncclChannel* channelSend = comm->channels+c; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, channelSend, NCCL_MAX_DIRECT_ARITY, channelSend->collTree.down, NCCL_MAX_DIRECT_ARITY, channelSend->collTree.up, 1), ret, collnet_cleanup); } - NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1), ret, collnet_cleanup); + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &collNetGraph, 1, &highestTransportType1), ret, collnet_cleanup); + + // Exchange highest intra-node transport type among ranks + // because we need to know whether all ranks can p2p each other to determine whether we can directly read/write registered user buffer + comm->intraHighestTransportType = highestTypes[comm->intraNodeRank] = highestTransportType0 > highestTransportType1 ? highestTransportType0 : highestTransportType1; + NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->intraNodeGlobalRanks, comm->intraNodeRank, comm->localRanks, highestTypes, sizeof(int))); + for (int i=0; ilocalRanks; i++) { + if (highestTypes[i] > comm->intraHighestTransportType) + comm->intraHighestTransportType = highestTypes[i]; + } INFO(NCCL_INIT, "rank %d Connected CollNet comm %p nRanks %02d", rank, comm, comm->nRanks); collnet_cleanup: @@ -1220,7 +1260,7 @@ collnet_cleanup: NCCLCHECK(ncclCommSetIntraProc(comm, intraProcRank, intraProcRanks, intraProcRank0Comm)); /* Local intra-node barrier */ - NCCLCHECK(bootstrapBarrier(comm->bootstrap, intraNodeGlobalRanks, (int)intraNodeRank0pidHash, intraNodeRank, intraNodeRanks)); + NCCLCHECK(bootstrapBarrier(comm->bootstrap, comm->intraNodeGlobalRanks, intraNodeRank, intraNodeRanks, (int)intraNodeRank0pidHash)); if (comm->nNodes) NCCLCHECK(ncclProxyCreate(comm)); @@ -1321,6 +1361,22 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { return ncclSuccess; } +static ncclResult_t ncclGraphHelperDestroy(ncclComm* comm) { + auto res = comm->graphHelperResources; + if (comm->graphHelperThread && res) { + pthread_mutex_lock(&res->threadLock); + res->threadState = ThreadStop; + pthread_cond_signal(&res->threadCond); + pthread_mutex_unlock(&res->threadLock); + pthread_join(comm->graphHelperThread, NULL); + } + if (res) { + free(res); + res = NULL; + } + return ncclSuccess; +} + static ncclResult_t commDestroy(ncclComm_t comm) { int savedDevice; #ifdef ENABLE_TRACE @@ -1337,6 +1393,11 @@ static ncclResult_t commDestroy(ncclComm_t comm) { CUDACHECK(hipStreamSynchronize(comm->groupStream)); NCCLCHECK(ncclProxyDestroy(comm)); + ncclDestroyQueueInfo(comm->enqueueInfo); +#if CUDART_VERSION >= 11030 + NCCLCHECK(ncclGraphHelperDestroy(comm)); +#endif + INFO(NCCL_COLL, "Created %d queue info, destroyed %d", comm->nQueueInfoCreated, comm->nQueueInfoDestroyed); NCCLCHECK(commFree(comm)); if (savedDevice != commDevice) diff --git a/projects/rccl/src/misc/argcheck.cc b/projects/rccl/src/misc/argcheck.cc index 285720f87b..d1aabec471 100644 --- a/projects/rccl/src/misc/argcheck.cc +++ b/projects/rccl/src/misc/argcheck.cc @@ -52,10 +52,16 @@ ncclResult_t ArgsCheck(struct ncclInfo* info) { } if (info->coll == ncclFuncAllGather || info->coll == ncclFuncReduceScatter) info->nBytes *= info->comm->nRanks; // count is per rank - if (info->op < 0 || info->op >= ncclNumOps) { + if (info->op < 0 || ncclMaxRedOp < info->op) { WARN("%s : invalid reduction operation %d", info->opName, info->op); return ncclInvalidArgument; } + int opIx = int(ncclUserRedOpMangle(info->comm, info->op)) - int(ncclNumOps); + if (ncclNumOps <= info->op && + (info->comm->userRedOpCapacity <= opIx || info->comm->userRedOps[opIx].freeNext != -1)) { + WARN("%s : reduction operation %d unknown to this communicator", info->opName, info->op); + return ncclInvalidArgument; + } if (info->comm->checkPointers) { if (info->coll == ncclFuncSendRecv) { diff --git a/projects/rccl/src/misc/ibvwrap.cc b/projects/rccl/src/misc/ibvwrap.cc index f47c141bc1..439712e88f 100644 --- a/projects/rccl/src/misc/ibvwrap.cc +++ b/projects/rccl/src/misc/ibvwrap.cc @@ -60,7 +60,7 @@ ncclResult_t wrap_ibv_symbols(void) { if (!ibvhandle) { ibvhandle=dlopen("libibverbs.so.1", RTLD_NOW); if (!ibvhandle) { - WARN("Failed to open libibverbs.so[.1]"); + INFO(NCCL_INIT, "Failed to open libibverbs.so[.1]"); goto teardown; } } diff --git a/projects/rccl/src/nccl.h.in b/projects/rccl/src/nccl.h.in index 3703f6e3b3..c02038177e 100644 --- a/projects/rccl/src/nccl.h.in +++ b/projects/rccl/src/nccl.h.in @@ -17,7 +17,7 @@ #define NCCL_SUFFIX "${NCCL_SUFFIX}" #define NCCL_VERSION_CODE ${NCCL_VERSION} -#define NCCL_VERSION(X,Y,Z) (((X) >= 2 && (Y) >= 9) ? (X) * 10000 + (Y) * 100 + (Z) : (X) * 1000 + (Y) * 100 + (Z)) +#define NCCL_VERSION(X,Y,Z) (((X) <= 2 && (Y) <= 8) ? (X) * 1000 + (Y) * 100 + (Z) : (X) * 10000 + (Y) * 100 + (Z)) #define RCCL_BFLOAT16 1 #define RCCL_GATHER_SCATTER 1 @@ -141,12 +141,24 @@ ncclResult_t pncclCommUserRank(const ncclComm_t comm, int* rank); /// @endcond /*! @brief Reduction operation selector */ +/* Reduction operation selector */ +typedef enum { ncclNumOps_dummy = 5 } ncclRedOp_dummy_t; typedef enum { ncclSum = 0, ncclProd = 1, ncclMax = 2, ncclMin = 3, ncclAvg = 4, - ncclNumOps = 5 } ncclRedOp_t; + /* ncclNumOps: The number of built-in ncclRedOp_t values. Also + * serves as the least possible value for dynamic ncclRedOp_t's + * as constructed by ncclRedOpCreate*** functions. */ + ncclNumOps = 5, + /* ncclMaxRedOp: The largest valid value for ncclRedOp_t. + * It is defined to be the largest signed value (since compilers + * are permitted to use signed enums) that won't grow + * sizeof(ncclRedOp_t) when compared to previous NCCL versions to + * maintain ABI compatibility. */ + ncclMaxRedOp = 0x7fffffff>>(32-8*sizeof(ncclRedOp_dummy_t)) + } ncclRedOp_t; /*! @brief Data types */ typedef enum { ncclInt8 = 0, ncclChar = 0, @@ -161,6 +173,40 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclBfloat16 = 9, ncclNumTypes = 10 } ncclDataType_t; +/* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */ +typedef enum { + /* ncclScalarDevice: The scalar is in device-visible memory and will be + * dereferenced while the collective is running. */ + ncclScalarDevice = 0, + + /* ncclScalarHostImmediate: The scalar is in host-visible memory and will be + * dereferenced before the ncclRedOpCreate***() function returns. */ + ncclScalarHostImmediate = 1 +} ncclScalarResidence_t; + +/* + * ncclRedOpCreatePreMulSum + * + * Creates a new reduction operator which pre-multiplies input values by a given + * scalar locally before reducing them with peer values via summation. For use + * only with collectives launched against *comm* and *datatype*. The + * *residence* argument indicates how/when the memory pointed to by *scalar* + * will be dereferenced. Upon return, the newly created operator's handle + * is stored in *op*. + */ +ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); +ncclResult_t pncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); + +/* + * ncclRedOpDestroy + * + * Destroys the reduction operator *op*. The operator must have been created by + * ncclRedOpCreatePreMul with the matching communicator *comm*. An operator may be + * destroyed as soon as the last NCCL function which is given that operator returns. + */ +ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm); +ncclResult_t pncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm); + /* * Collective communication operations * diff --git a/projects/rccl/src/transport.cc b/projects/rccl/src/transport.cc index ae47f4f329..e66a6f25ef 100644 --- a/projects/rccl/src/transport.cc +++ b/projects/rccl/src/transport.cc @@ -20,7 +20,7 @@ struct ncclTransport ncclTransports[NTRANSPORTS] = { }; template -static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex) { +static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* graph, struct ncclConnect* connect, int channelId, int peer, int connIndex, int* transportType) { struct ncclPeerInfo* myInfo = comm->peerInfo+comm->rank; struct ncclPeerInfo* peerInfo = comm->peerInfo+peer; struct ncclConnector* connector = (type == 1) ? comm->channels[channelId].peers[peer].send + connIndex : @@ -45,6 +45,7 @@ static ncclResult_t selectTransport(struct ncclComm* comm, struct ncclTopoGraph* if (ret) { connector->transportComm = transportComm; NCCLCHECK(transportComm->setup(comm, graph, myInfo, peerInfo, connect, connector, channelId, connIndex)); + if (transportType) *transportType = t; return ncclSuccess; } } @@ -77,12 +78,14 @@ void dumpData(struct ncclConnect* data, int ndata) { } } -ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex) { +ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType/*=NULL*/) { #if CUDART_VERSION >= 11030 // Stream used during transport setup; need for P2P pre-connect + CUDA Graph hipStream_t transportSetupStream; CUDACHECK(hipStreamCreateWithFlags(&transportSetupStream, hipStreamNonBlocking)); #endif + int highestType = TRANSPORT_P2P; // track highest transport type + struct ncclConnect data[2*MAXCHANNELS]; for (int i=1; inRanks; i++) { int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); @@ -93,15 +96,18 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* struct ncclConnect* recvData = data; int sendChannels = 0, recvChannels = 0; + int type; for (int c=0; c(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex)); + NCCLCHECK(selectTransport<0>(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type)); + if (type > highestType) highestType = type; } } struct ncclConnect* sendData = recvData+recvChannels; for (int c=0; c(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex)); + NCCLCHECK(selectTransport<1>(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type)); + if (type > highestType) highestType = type; } } @@ -149,6 +155,7 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* CUDACHECK(hipStreamSynchronize(transportSetupStream)); CUDACHECK(hipStreamDestroy(transportSetupStream)); #endif + if (highestTransportType != NULL) *highestTransportType = highestType; return ncclSuccess; } @@ -242,22 +249,18 @@ cleanup: } ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail) { - int rank = comm->rank; - int nranks = comm->nRanks; // AllGather collNet setup results - int* allGatherFailures; - NCCLCHECK(ncclCalloc(&allGatherFailures, nranks)); - allGatherFailures[rank] = collNetSetupFail; - NCCLCHECK(bootstrapAllGather(comm->bootstrap, allGatherFailures, sizeof(int))); - for (int i=0; iintraNodeRank] = collNetSetupFail; + NCCLCHECK(bootstrapIntraNodeAllGather(comm->bootstrap, comm->intraNodeGlobalRanks, comm->intraNodeRank, comm->localRanks, allGatherFailures, sizeof(int))); + for (int i=0; ilocalRanks; i++) { if (allGatherFailures[i] != 0) { collNetSetupFail = 1; break; } } - free(allGatherFailures); if (collNetSetupFail) { - if (rank == 0) WARN("Cannot initialize CollNet, using point-to-point network instead"); + if (comm->intraNodeRank == 0) WARN("Cannot initialize CollNet, using point-to-point network instead"); return ncclSystemError; } return ncclSuccess; diff --git a/projects/rccl/src/transport/coll_net.cc b/projects/rccl/src/transport/coll_net.cc index 46b3d6c213..9f9d9b5dd1 100644 --- a/projects/rccl/src/transport/coll_net.cc +++ b/projects/rccl/src/transport/coll_net.cc @@ -473,7 +473,7 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) { int startChannel = group*COLLNET_GROUP_NSUBS; NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &ptr)); reqFifo[group][buffSlot].recvBuff = ptr; - TRACE(NCCL_NET, "recvProxy [%d/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff); + TRACE(NCCL_NET, "recvProxy [%lu/%d/%d] posted buffer %p", sub->posted, group, buffSlot, reqFifo[group][buffSlot].recvBuff); sub->posted += args->sliceSteps; args->idle = 0; continue; @@ -518,7 +518,7 @@ ncclResult_t collNetRecvProxy(struct ncclProxyArgs* args) { int sharedBuffSlot = sub->transmitted%NCCL_STEPS; int startChannel = group*COLLNET_GROUP_NSUBS; char* groupRecvAddress; - NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, 1, 1, sharedBuffSlot, startChannel, &groupRecvAddress)); + NCCLCHECK(ncclProxySharedBuffersGetCollNet(sub->connector->comm, p == NCCL_PROTO_SIMPLE ? resources->useGdr : 0, 1, sharedBuffSlot, startChannel, &groupRecvAddress)); char* ptr = groupRecvAddress + (s%COLLNET_GROUP_NSUBS)*args->sharedSize[sharedBuffSlot]; if (p == NCCL_PROTO_SIMPLE) { volatile void** ptrsFifo = (volatile void**)resources->recvMem->ptrsFifo; diff --git a/projects/rccl/src/transport/net_ib.cc b/projects/rccl/src/transport/net_ib.cc index 763d44ec39..2825a967e5 100644 --- a/projects/rccl/src/transport/net_ib.cc +++ b/projects/rccl/src/transport/net_ib.cc @@ -89,6 +89,8 @@ static ncclResult_t ncclIbGetPciPath(char* devName, char** path, int* realPort) } else { // Merge multi-port NICs into the same PCI device p[strlen(p)-1] = '0'; + // Also merge virtual functions (VF) into the same device + p[strlen(p)-3] = '0'; // And keep the real port aside (the ibv port is always 1 on recent cards) *realPort = 0; for (int d=0; dbusId); + *ret = 0; + } + CUDACHECK(cudaFree(dummy)); + return ncclSuccess; + } +#endif + if (p2p == 0) { INFO(NCCL_INIT|NCCL_P2P,"Could not enable P2P between dev %d(=%lx) and dev %d(=%lx)", cudaDev1, info1->busId, cudaDev2, info2->busId); @@ -195,13 +215,14 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st NCCLCHECK(ncclCudaCalloc((char**)&info.directPtr, sendSize, true)); info.rank = myInfo->rank; if (myInfo->pidHash == peerInfo->pidHash) { - if (info.read == 0) send->conn.direct |= NCCL_DIRECT_GPU; - INFO(NCCL_INIT|NCCL_P2P, "Channel %02d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s comm %p nRanks %02d", - channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, comm, comm->nRanks); + send->conn.direct |= info.read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; + INFO(NCCL_INIT|NCCL_P2P, "Channel %02d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s", + channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr); } else { + send->conn.direct |= info.read ? NCCL_IPC_READ : NCCL_IPC_WRITE; CUDACHECK(hipIpcGetMemHandle(&info.devIpc, info.directPtr)); - INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/IPC%s comm %p nRanks %02d", - channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, comm, comm->nRanks); + INFO(NCCL_INIT|NCCL_P2P,"Channel %02d : %d[%lx] -> %d[%lx] via P2P/IPC%s", + channelId, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr); } } else { NCCLCHECK(bootstrapRemAlloc(sendSize, intermediateRank, resources->bootstrap, &resources->remoteId, &info.devIpc, &info.directPtr)); @@ -243,8 +264,9 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st NCCLCHECK(ncclCudaCalloc((char**)&info.directPtr, recvSize, true)); info.rank = myInfo->rank; if (myInfo->pidHash == peerInfo->pidHash) { - if (info.read == 0) recv->conn.direct |= NCCL_DIRECT_GPU; + recv->conn.direct |= info.read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; } else { + recv->conn.direct |= info.read ? NCCL_IPC_READ : NCCL_IPC_WRITE; CUDACHECK(hipIpcGetMemHandle(&info.devIpc, info.directPtr)); } } else { @@ -266,7 +288,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co struct ncclRecvMem* remDevMem; struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo; - NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr)); + NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->remIpcPtr)); int offset = 0; for (int p=0; pconn.head = &resources->devMem->head; send->conn.ptrExchange = &resources->devMem->ptrExchange; send->conn.next_hdp_reg = resources->next_hdp_reg; + send->conn.redOpArgExchange = resources->devMem->redOpArgExchange; return ncclSuccess; } @@ -291,7 +314,7 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn struct ncclSendMem* remDevMem; struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo; - NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr)); + NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->remIpcPtr)); int offset = 0; for (int p=0; pconn.tail = &resources->devMem->tail; recv->conn.head = &remDevMem->head; recv->conn.ptrExchange = &remDevMem->ptrExchange; + recv->conn.redOpArgExchange = remDevMem->redOpArgExchange; return ncclSuccess; } @@ -313,6 +337,8 @@ ncclResult_t p2pSendFree(void* resources) { struct p2pSendResources* sendRes = (struct p2pSendResources*)resources; if (sendRes->ipcPtr) CUDACHECK(hipIpcCloseMemHandle(sendRes->ipcPtr)); + if (sendRes->remIpcPtr) + CUDACHECK(hipIpcCloseMemHandle(sendRes->remIpcPtr)); if (sendRes->remoteId != -1) { NCCLCHECK(bootstrapRemFree(sendRes->remoteId, sendRes->memRank, sendRes->bootstrap)); sendRes->devMem = NULL; @@ -326,6 +352,8 @@ ncclResult_t p2pRecvFree(void* resources) { struct p2pRecvResources* recvRes = (struct p2pRecvResources*)resources; if (recvRes->ipcPtr) CUDACHECK(hipIpcCloseMemHandle(recvRes->ipcPtr)); + if (recvRes->remIpcPtr) + CUDACHECK(hipIpcCloseMemHandle(recvRes->remIpcPtr)); if (recvRes->remoteId != -1) { NCCLCHECK(bootstrapRemFree(recvRes->remoteId, recvRes->memRank, recvRes->bootstrap)); recvRes->devMem = NULL; diff --git a/projects/rccl/src/transport/shm.cc b/projects/rccl/src/transport/shm.cc index 1faa6c5341..af20188981 100644 --- a/projects/rccl/src/transport/shm.cc +++ b/projects/rccl/src/transport/shm.cc @@ -1,5 +1,6 @@ /************************************************************************* * Copyright (c) 2016-2021, NVIDIA CORPORATION. All rights reserved. + * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. * * See LICENSE.txt for license information ************************************************************************/ diff --git a/projects/rccl/tools/TransferBench/TransferBench.hpp b/projects/rccl/tools/TransferBench/TransferBench.hpp index cff54d8bf1..88f91c8186 100644 --- a/projects/rccl/tools/TransferBench/TransferBench.hpp +++ b/projects/rccl/tools/TransferBench/TransferBench.hpp @@ -47,7 +47,7 @@ typedef struct uint16_t data; } rccl_bfloat16; -#include "common_kernel.h" +#include "../../src/collectives/device/common_kernel.h" #include "EnvVars.hpp" // Helper macro for catching HIP errors diff --git a/projects/rccl/tools/TransferBench/common_kernel.h b/projects/rccl/tools/TransferBench/common_kernel.h deleted file mode 100644 index 8117dad8bf..0000000000 --- a/projects/rccl/tools/TransferBench/common_kernel.h +++ /dev/null @@ -1,450 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2019-2021 Advanced Micro Devices, Inc. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef NCCL_COMMON_KERNEL_H_ -#define NCCL_COMMON_KERNEL_H_ - -#include "devcomm.h" -#include -#include - -#include - -// Define min for ssize_t -static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } - -typedef uint64_t PackType; - -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - -template -struct MULTI { - __device__ PackType operator()(const PackType x, const PackType y) const - { - return FUNC()(x, y); - } -}; - -#else - -// unpack x and y to elements of type T and apply FUNC to each element -template -struct MULTI { - __device__ PackType operator()(const PackType x, const PackType y) const; -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of uint32_t."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - // for char, we do these as vector ops - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return cr.storage; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of uint32_t."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - // for char, we do these as vector ops - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return cr.storage; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(int32_t), - "PackType must be twice the size of int."); - union converter { - PackType storage; - struct { - int32_t a, b; - }; - }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return cr.storage; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of int."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return cr.storage; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 4 * sizeof(half), - "PackType must be four times the size of half."); - - struct PackHalf2 { - half2 a, b; - }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - struct PackHalf2 cx, cy, cr; - cx = *(reinterpret_cast(&x)); - cy = *(reinterpret_cast(&y)); - - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return *(reinterpret_cast(&cr)); - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(float), - "PackType must be twice the size of float."); - union converter { - PackType storage; - struct { - float a, b; - }; - }; - - __device__ PackType operator()(const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = FUNC()(cx.a, cy.a); - cr.b = FUNC()(cx.b, cy.b); - - return cr.storage; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(double), - "PackType must be the same size as double."); - __device__ PackType operator()(const PackType x, const PackType y) const { - double rv = FUNC()(__longlong_as_double(x), __longlong_as_double(y)); - return __double_as_longlong(rv); - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(uint64_t), - "PackType must be the same size as uint64_t."); - __device__ PackType operator()(const PackType x, const PackType y) const { - uint64_t rv = FUNC()(x, y); - return rv; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(int64_t), - "PackType must be the same size as int64_t."); - __device__ PackType operator()(const PackType x, const PackType y) const { - int64_t rv = FUNC()((int64_t)x, (int64_t)y); - return rv; - } -}; - -#endif //defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - -template inline __device__ -T vFetch(const volatile T* ptr) { - return *ptr; -} - -template inline __device__ -void vStore(volatile T* ptr, const T val) { - *ptr = val; -} - -#if CUDART_VERSION < 9000 && !(defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)) -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half r; - r.x = ptr->x; - return r; -} - -template<> inline __device__ -void vStore(volatile half* ptr, const half val) { - ptr->x = val.x; -} -#else -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half r; - r = ((half*)ptr)[0]; - return r; -} - -template<> inline __device__ -void vStore(volatile half* ptr, const half val) { - ((half*)ptr)[0] = val; -} - -template<> inline __device__ -rccl_bfloat16 vFetch(const volatile rccl_bfloat16* ptr) { - rccl_bfloat16 r; - r.data = ptr->data; - return r; -} - -template<> inline __device__ -void vStore(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) { - ptr->data = val.data; -} -#endif - -typedef ulong2 Pack128; - -template -struct MULTI128 { - __device__ void operator()(Pack128& x, Pack128& y) { - x.x = MULTI()(x.x, y.x); - x.y = MULTI()(x.y, y.y); - } -}; - -inline __device__ void Fetch128(Pack128& v, const Pack128* p) { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - v.x = p->x; - v.y = p->y; -#else - asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); -#endif -} -inline __device__ void Store128(Pack128* p, Pack128& v) { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - p->x = v.x; - p->y = v.y; -#else - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory"); -#endif -} - -template -__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, - int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Nelem) { - const int inc = nw * UNROLL * WARP_SIZE; - int offset = w * UNROLL * WARP_SIZE + t; - - const T* srcs[MAXSRCS]; - for (int i=0; i -__device__ void ReduceCopy128bMulti(const int w, const int nw, const int t, - int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const int Npack) { - const int inc = nw * UNROLL * WARP_SIZE; - int offset = w * UNROLL * WARP_SIZE + t; - - const Pack128* srcs[MAXSRCS]; - for (int i=0; i()(vals[u], vals2[u]); - } - #pragma unroll 1 - for (int i=MINSRCS; i()(vals[u], vals2[u]); - } - - // Store - for (int i = 0; i < MINDSTS; i++) { - for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); - } - #pragma unroll 1 - for (int i=MINDSTS; i -__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(int32_t); } - -#define PACKELEMS (sizeof(Pack128) / sizeof(T)) - -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) -// Multiply UNROLL by 2 if single source/single destination -#define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1)) -#else -// Try to limit consecutive load/stores to 8. -// Use UNROLL 8 when we have a single source and a single destination, 4 otherwise -#define AUTOUNROLL (UNROLL*(4/(MINDSTS+MINSRCS))) -#endif - -template -__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads, - int nsrcs, const T** srcs, int ndsts, T** dsts, - int N) { - int Nrem = N; - if (Nrem <= 0) return; - - int w = tid / WARP_SIZE; // Warp number - int nw = nthreads / WARP_SIZE; // Number of warps - int t = tid % WARP_SIZE; // Thread (inside the warp) - - // Check that all is 16B aligned. If not don't use 16B load/stores. - int align = 0; - #pragma unroll - for (int i=0; i(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - - // slightly less optimized for section when we don't have full unrolling - Npack = Nrem / PACKELEMS; - Nelem = Npack * PACKELEMS; - - ReduceCopy128bMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - } - - // unrolled, by-type (mostly for unaligned buffers) - int Nelem = (Nrem / (UNROLL*PACKELEMS/2*WARP_SIZE)) * (UNROLL*PACKELEMS/2*WARP_SIZE); // round down - - ReduceCopyMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nelem); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - - // no unroll, by type. Should finish what's remaining. - ReduceCopyMulti(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Nrem); -} - -#endif // COMMON_KERNEL_H_